Source code for agentlego.tools.image_editing.remove
import numpy as np
from PIL import Image
from agentlego.types import Annotated, ImageIO, Info
from agentlego.utils import load_or_build_object, require
from ..base import BaseTool
GLOBAL_SEED = 1912
[docs]class ObjectRemove(BaseTool):
"""A tool to remove the certain objects in the image.
Args:
sam_model (str): The model name used to inference. Which can be found
in the ``segment_anything`` repository.
Defaults to ``sam_vit_h_4b8939.pth``.
grounding_model (str): The model name used to inference. Which can be
found in the ``MMdetection`` repository.
Defaults to ``glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365``.
device (str): The device to load the model. Defaults to 'cuda'.
toolmeta (None | dict | ToolMeta): The additional info of the tool.
Defaults to None.
"""
default_desc = 'This tool can remove the specified object in the image.'
@require('mmdet')
@require('segment_anything')
@require('diffusers')
def __init__(self,
sam_model: str = 'sam_vit_h_4b8939.pth',
grounding_model: str = 'glip_atss_swin-t_a'
'_fpn_dyhead_pretrain_obj365',
device: str = 'cuda',
toolmeta=None):
super().__init__(toolmeta)
self.grounding_model = grounding_model
self.sam_model = sam_model
self.device = device
def setup(self):
from mmdet.apis import DetInferencer
from ..segmentation.segment_anything import load_sam_and_predictor
from .replace import Inpainting
self.grounding = load_or_build_object(
DetInferencer, model=self.grounding_model, device=self.device)
self.sam, self.sam_predictor = load_sam_and_predictor(
self.sam_model, device=self.device)
self.inpainting = load_or_build_object(Inpainting, device=self.device)
def apply(
self,
image: ImageIO,
text: Annotated[str, Info('The object to remove.')],
) -> ImageIO:
import torch
image_path = image.to_path()
image_pil = image.to_pil()
text1 = text
text2 = 'background'
results = self.grounding(
inputs=image_path, texts=[text1], no_save_vis=True, return_datasamples=True)
results = results['predictions'][0].pred_instances
boxes_filt = results.bboxes
self.sam_predictor.set_image(image.to_array())
masks = self.get_mask_with_boxes(image_pil, image, boxes_filt)
mask = torch.sum(masks, dim=0).unsqueeze(0)
mask = torch.where(mask > 0, True, False)
mask = mask.squeeze(0).squeeze(0).cpu()
mask = self.pad_edge(mask, padding=20)
mask_image = Image.fromarray(mask)
output_image = self.inpainting(
prompt=text2, image=image_pil, mask_image=mask_image)
output_image = output_image.resize(image_pil.size)
return ImageIO(output_image)
def pad_edge(self, mask, padding):
mask = mask.numpy()
true_indices = np.argwhere(mask)
mask_array = np.zeros_like(mask, dtype=bool)
for idx in true_indices:
padded_slice = tuple(
slice(max(0, i - padding), i + padding + 1) for i in idx)
mask_array[padded_slice] = True
new_mask = (mask_array * 255).astype(np.uint8)
return new_mask
def get_mask_with_boxes(self, image_pil, image, boxes_filt):
boxes_filt = boxes_filt.cpu()
transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(
boxes_filt, image.shape[:2]).to(self.device)
features = self.sam_predictor.get_image_embedding(image)
masks, _, _ = self.sam_predictor.predict_torch(
features=features,
point_coords=None,
point_labels=None,
boxes=transformed_boxes.to(self.device),
multimask_output=False,
)
return masks