Source code for agentlego.tools.image_editing.replace
import numpy as np
from PIL import Image
from agentlego.types import Annotated, ImageIO, Info
from agentlego.utils import is_package_available, load_or_build_object, require
from ..base import BaseTool
if is_package_available('torch'):
import torch
GLOBAL_SEED = 1912
class Inpainting:
"""Inpainting model.
Refers to 'TaskMatrix/visual_chatgpt.py:
<https://github.com/microsoft/TaskMatrix/blob/main/visual_chatgpt.py>'_.
Args:
device (str): The device to use.
"""
def __init__(self, device):
from diffusers import StableDiffusionInpaintPipeline
self.device = device
self.revision = 'fp16' if 'cuda' in self.device else None
self.torch_dtype = torch.float16 \
if 'cuda' in self.device else torch.float32
self.inpaint = StableDiffusionInpaintPipeline.from_pretrained(
'runwayml/stable-diffusion-inpainting',
revision=self.revision,
torch_dtype=self.torch_dtype).to(device)
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, '\
' missing fingers, extra digit, fewer digits, '\
'cropped, worst quality, low quality'\
'bad lighting, bad background, bad color, '\
'bad aliasing, bad distortion, bad motion blur '\
'bad consistency with the background '
def __call__(self,
prompt,
image,
mask_image,
height=512,
width=512,
num_inference_steps=20):
update_image = self.inpaint(
prompt=prompt,
negative_prompt=self.n_prompt,
image=image.resize((width, height)),
mask_image=mask_image.resize((width, height)),
height=height,
width=width,
num_inference_steps=num_inference_steps).images[0]
return update_image
[docs]class ObjectReplace(BaseTool):
"""A tool to replace the certain objects in the image.
Args:
sam_model (str): The model name used to inference. Which can be found
in the ``segment_anything`` repository.
Defaults to ``sam_vit_h_4b8939.pth``.
grounding_model (str): The model name used to inference. Which can be
found in the ``MMdetection`` repository.
Defaults to ``glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365``.
device (str): The device to load the model. Defaults to 'cuda'.
toolmeta (None | dict | ToolMeta): The additional info of the tool.
Defaults to None.
"""
default_desc = ('This tool can replace the specified object in the input '
'image with another object, like replacing a cat in an '
'image with a dog.')
@require('mmdet')
@require('segment_anything')
@require('diffusers')
def __init__(self,
sam_model: str = 'sam_vit_h_4b8939.pth',
grounding_model: str = 'glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365',
device: str = 'cuda',
toolmeta=None):
super().__init__(toolmeta)
self.sam_model = sam_model
self.grounding_model = grounding_model
self.device = device
def setup(self):
from mmdet.apis import DetInferencer
from ..segmentation.segment_anything import load_sam_and_predictor
self.grounding = load_or_build_object(
DetInferencer, model=self.grounding_model, device=self.device)
self.sam, self.sam_predictor = load_sam_and_predictor(
self.sam_model, device=self.device)
self.inpainting = load_or_build_object(Inpainting, device=self.device)
def apply(
self,
image: Annotated[ImageIO, Info('The image to edit.')],
text1: Annotated[str, Info('The object to be replaced.')],
text2: Annotated[str, Info('The object to replace with.')],
) -> ImageIO:
image_path = image.to_path()
image_pil = image.to_pil()
results = self.grounding(
inputs=image_path, texts=[text1], no_save_vis=True, return_datasamples=True)
results = results['predictions'][0].pred_instances
boxes_filt = results.bboxes
self.sam_predictor.set_image(image.to_array())
masks = self.get_mask_with_boxes(image_pil, image, boxes_filt)
mask = torch.sum(masks, dim=0).unsqueeze(0)
mask = torch.where(mask > 0, True, False)
mask = mask.squeeze(0).squeeze(0).cpu()
mask = self.pad_edge(mask, padding=20)
mask_img = Image.fromarray(mask)
output_image = self.inpainting(
prompt=text2, image=image_pil, mask_image=mask_img)
output_image = output_image.resize(image_pil.size)
return ImageIO(output_image)
def pad_edge(self, mask, padding):
mask = mask.numpy()
true_indices = np.argwhere(mask)
mask_array = np.zeros_like(mask, dtype=bool)
for idx in true_indices:
padded_slice = tuple(
slice(max(0, i - padding), i + padding + 1) for i in idx)
mask_array[padded_slice] = True
new_mask = (mask_array * 255).astype(np.uint8)
return new_mask
def get_mask_with_boxes(self, image_pil, image, boxes_filt):
boxes_filt = boxes_filt.cpu()
transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(
boxes_filt, image.shape[:2]).to(self.device)
features = self.sam_predictor.get_image_embedding(image)
masks, _, _ = self.sam_predictor.predict_torch(
features=features,
point_coords=None,
point_labels=None,
boxes=transformed_boxes.to(self.device),
multimask_output=False,
)
return masks