Shortcuts

Source code for agentlego.tools.segmentation.segment_anything

import random
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Tuple

import numpy as np
from PIL import Image

from agentlego.types import Annotated, ImageIO, Info
from agentlego.utils import (download_checkpoint, download_url_to_file,
                             is_package_available, load_or_build_object, require)
from ..base import BaseTool

if is_package_available('torch'):
    import torch
    from torch import Tensor
else:
    assert not TYPE_CHECKING, 'torch is not installed'
    Tensor = None

GLOBAL_SEED = 1912


def load_sam_and_predictor(model, device=None, ckpt_path=None):

    def _load_sam(model, ckpt_path, device):
        try:
            from segment_anything import sam_model_registry
        except ImportError as e:
            raise ImportError(
                f'Failed to run the tool for {e}, please check if you have '
                'install `segment_anything` correctly')

        url = f'https://dl.fbaipublicfiles.com/segment_anything/{model}'
        if ckpt_path is not None:
            Path(ckpt_path).parent.mkdir(exist_ok=True, parents=True)
            download_url_to_file(url, ckpt_path)
        else:
            ckpt_path = download_checkpoint(url)

        sam = sam_model_registry['vit_h'](checkpoint=ckpt_path)
        sam.to(device=device)
        return sam

    def _load_sam_predictor(sam):
        return SamPredictor(sam)

    sam = load_or_build_object(_load_sam, model, ckpt_path, device)
    sam_predictor = load_or_build_object(_load_sam_predictor, sam)
    return sam, sam_predictor


class SamPredictor:

    @require(('torch', 'segment_anything'))
    def __init__(
        self,
        sam_model,
    ) -> None:
        """Uses SAM to calculate the image embedding for an image, and then
        allow repeated, efficient mask prediction given prompts.

        Arguments:
          sam_model (Sam): The model to use for mask prediction.
        """
        super().__init__()
        self.model = sam_model

        from segment_anything.utils.transforms import ResizeLongestSide

        self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)

    def set_image(
        self,
        image: np.ndarray,
        image_format: str = 'RGB',
    ) -> None:
        """Calculates the image embeddings for the provided image, allowing
        masks to be predicted with the 'predict' method.

        Arguments:
          image (np.ndarray): The image for calculating masks. Expects an
            image in HWC uint8 format, with pixel values in [0, 255].
          image_format (str): The color format of the image, in ['RGB', 'BGR'].
        """
        assert image_format in [
            'RGB',
            'BGR',
        ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
        if image_format != self.model.image_format:
            image = image[..., ::-1]

        # Transform the image to the form expected by the model
        input_image = self.transform.apply_image(image)
        input_image_torch = torch.as_tensor(input_image, device=self.device)
        input_image_torch = input_image_torch.permute(2, 0,
                                                      1).contiguous()[None, :, :, :]

        return self.set_torch_image(input_image_torch, image.shape[:2])

    def set_torch_image(
        self,
        transformed_image: Tensor,
        original_image_size: Tuple[int, ...],
    ) -> None:
        """Calculates the image embeddings for the provided image, allowing
        masks to be predicted with the 'predict' method. Expects the input
        image to be already transformed to the format expected by the model.

        Arguments:
          transformed_image (Tensor): The input image, with shape
            1x3xHxW, which has been transformed with ResizeLongestSide.
          original_image_size (tuple(int, int)): The size of the image
            before transformation, in (H, W) format.
        """
        assert (len(transformed_image.shape) == 4 and transformed_image.shape[1] == 3
                and max(*transformed_image.shape[2:])
                == self.model.image_encoder.img_size), (
                    'set_torch_image input must be BCHW with long side'
                    f' {self.model.image_encoder.img_size}.')

        original_size = original_image_size
        input_size = tuple(transformed_image.shape[-2:])
        input_image = self.model.preprocess(transformed_image)
        features = self.model.image_encoder(input_image)

        res = {
            'features': features,
            'original_size': original_size,
            'input_size': input_size
        }
        return res

    def predict(
        self,
        features,
        point_coords: Optional[np.ndarray] = None,
        point_labels: Optional[np.ndarray] = None,
        box: Optional[np.ndarray] = None,
        mask_input: Optional[np.ndarray] = None,
        multimask_output: bool = True,
        return_logits: bool = False,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """Predict masks for the given input prompts, using the currently set
        image.

        Arguments:
          point_coords (np.ndarray or None): A Nx2 array of point prompts to
            the model. Each point is in (X,Y) in pixels.
          point_labels (np.ndarray or None): A length N array of labels for the
            point prompts. 1 indicates a foreground point and 0 indicates a
            background point.
          box (np.ndarray or None): A length 4 array given a box prompt to the
            model, in XYXY format.
          mask_input (np.ndarray): A low resolution mask input to the model,
            typically coming from a previous prediction iteration. Has form
            1xHxW, where for SAM, H=W=256.
          multimask_output (bool): If true, the model will return three masks.
            For ambiguous input prompts (such as a single click), this will
            often produce better masks than a single prediction. If only a
            single mask is needed, the model's predicted quality score can be
            used to select the best mask. For non-ambiguous prompts, such as
            multiple input prompts, multimask_output=False can give better
            results.
          return_logits (bool): If true, returns un-thresholded masks logits
            instead of a binary mask.

        Returns:
          (np.ndarray): The output masks in CxHxW format, where C is the
            number of masks, and (H, W) is the original image size.
          (np.ndarray): An array of length C containing the model's
            predictions for the quality of each mask.
          (np.ndarray): An array of shape CxHxW, where C is the number
            of masks and H=W=256. These low resolution logits can be passed to
            a subsequent iteration as mask input.
        """
        if features.get('features', None) is None:
            raise RuntimeError('An image must be set with .set_image(...)'
                               ' before mask prediction.')

        # Transform input prompts
        coords_torch, labels_torch = None, None
        box_torch, mask_input_torch = None, None
        if point_coords is not None:
            assert (point_labels is not None
                    ), 'point_labels must be supplied if point_coords is supplied.'
            point_coords = self.transform.apply_coords(point_coords,
                                                       features['original_size'])
            coords_torch = torch.as_tensor(
                point_coords, dtype=torch.float, device=self.device)
            labels_torch = torch.as_tensor(
                point_labels, dtype=torch.int, device=self.device)
            coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
        if box is not None:
            box = self.transform.apply_boxes(box, features['original_size'])
            box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
            box_torch = box_torch[None, :]
        if mask_input is not None:
            mask_input_torch = torch.as_tensor(
                mask_input, dtype=torch.float, device=self.device)
            mask_input_torch = mask_input_torch[None, :, :, :]

        masks, iou_predictions, low_res_masks = self.predict_torch(
            features,
            coords_torch,
            labels_torch,
            box_torch,
            mask_input_torch,
            multimask_output,
            return_logits=return_logits,
        )

        masks_np = masks[0].detach().cpu().numpy()
        iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
        low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
        return masks_np, iou_predictions_np, low_res_masks_np

    def predict_torch(
        self,
        features,
        point_coords: Optional[Tensor],
        point_labels: Optional[Tensor],
        boxes: Optional[Tensor] = None,
        mask_input: Optional[Tensor] = None,
        multimask_output: bool = True,
        return_logits: bool = False,
    ) -> Tuple[Tensor, Tensor, Tensor]:
        """Predict masks for the given input prompts, using the currently set
        image. Input prompts are batched torch tensors and are expected to
        already be transformed to the input frame using ResizeLongestSide.

        Arguments:
          point_coords (Tensor or None): A BxNx2 array of point prompts
            to the model. Each point is in (X,Y) in pixels.
          point_labels (Tensor or None): A BxN array of labels for the
            point prompts. 1 indicates a foreground point and 0 indicates a
            background point.
          boxes (np.ndarray or None): A Bx4 array given a box prompt to the
            model, in XYXY format.
          mask_input (np.ndarray): A low resolution mask input to the model,
            typically coming from a previous prediction iteration. Has form
            Bx1xHxW, where for SAM, H=W=256. Masks returned by a previous
            iteration of the predict method do not need further transformation.
          multimask_output (bool): If true, the model will return three masks.
            For ambiguous input prompts (such as a single click), this will
            often produce better masks than a single prediction. If only a
            single mask is needed, the model's predicted quality score can be
            used to select the best mask. For non-ambiguous prompts, such as
            multiple input prompts, multimask_output=False can give better
            results.
          return_logits (bool): If true, returns un-thresholded masks logits
            instead of a binary mask.

        Returns:
          (Tensor): The output masks in BxCxHxW format, where C is the
            number of masks, and (H, W) is the original image size.
          (Tensor): An array of shape BxC containing the model's
            predictions for the quality of each mask.
          (Tensor): An array of shape BxCxHxW, where C is the number
            of masks and H=W=256. These low res logits can be passed to
            a subsequent iteration as mask input.
        """
        if features.get('features', None) is None:
            raise RuntimeError('An image must be set with .set_image(...)'
                               ' before mask prediction.')

        if point_coords is not None:
            points = (point_coords, point_labels)
        else:
            points = None

        # Embed prompts
        sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
            points=points,
            boxes=boxes,
            masks=mask_input,
        )

        # Predict masks
        low_res_masks, iou_predictions = self.model.mask_decoder(
            image_embeddings=features['features'],
            image_pe=self.model.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=multimask_output,
        )

        # Upscale the masks to the original image resolution
        masks = self.model.postprocess_masks(low_res_masks, features['input_size'],
                                             features['original_size'])

        if not return_logits:
            masks = masks > self.model.mask_threshold

        return masks, iou_predictions, low_res_masks

    def get_image_embedding(self, image) -> Tensor:
        return self.set_image(image)

    @property
    def device(self):
        return self.model.device


[docs]class SegmentAnything(BaseTool): """A tool to segment all objects on an image. Args: sam_model (str): The model name used to inference. Which can be found in the ``segment_anything`` repository. Defaults to ``sam_vit_h_4b8939.pth``. device (str): The device to load the model. Defaults to 'cuda'. toolmeta (None | dict | ToolMeta): The additional info of the tool. Defaults to None. """ default_desc = ('This tool can segment all items in the image and ' 'return a segmentation result image.') @require('segment_anything') def __init__(self, sam_model: str = 'sam_vit_h_4b8939.pth', device: str = 'cuda', toolmeta=None): super().__init__(toolmeta=toolmeta) self.sam_model = sam_model self.device = device def setup(self): self.sam, self.sam_predictor = load_sam_and_predictor( self.sam_model, device=self.device) def apply(self, image: ImageIO ) -> Annotated[ImageIO, Info('The segmentation result image.')]: annos = self.segment_anything(image.to_array()) full_img, _ = self.show_annos(annos) return ImageIO(full_img) def segment_anything(self, img): if not self._is_setup: self.setup() self._is_setup = True from segment_anything import SamAutomaticMaskGenerator mask_generator = SamAutomaticMaskGenerator(self.sam) annos = mask_generator.generate(img) return annos def segment_by_mask(self, mask, features): random.seed(GLOBAL_SEED) idxs = np.nonzero(mask) num_points = min(max(1, int(len(idxs[0]) * 0.01)), 16) sampled_idx = random.sample(range(0, len(idxs[0])), num_points) new_mask = [] for i in range(len(idxs)): new_mask.append(idxs[i][sampled_idx]) points = np.array(new_mask).reshape(2, -1).transpose(1, 0)[:, ::-1] labels = np.array([1] * num_points) res_masks, scores, _ = self.sam_predictor.predict( features=features, point_coords=points, point_labels=labels, multimask_output=True, ) return res_masks[np.argmax(scores), :, :] def get_detection_map(self, img): annos = self.segment_anything(img) _, detection_map = self.show_anns(annos) return detection_map def show_annos(self, anns): if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) full_img = None # for ann in sorted_anns: for i in range(len(sorted_anns)): ann = anns[i] m = ann['segmentation'] if full_img is None: full_img = np.zeros((m.shape[0], m.shape[1], 3)) map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16) map[m != 0] = i + 1 color_mask = np.random.random((1, 3)).tolist()[0] full_img[m != 0] = color_mask full_img = full_img * 255 res = np.zeros((map.shape[0], map.shape[1], 3)) res[:, :, 0] = map % 256 res[:, :, 1] = map // 256 res.astype(np.float32) full_img = Image.fromarray(np.uint8(full_img)) return full_img, res
def get_image_embedding(self, img): if not self._is_setup: self.setup() self._is_setup = True embedding = self.sam_predictor.set_image(img) return embedding
[docs]class SegmentObject(BaseTool): """A tool to segment all objects on an image. Args: sam_model (str): The model name used to inference. Which can be found in the ``segment_anything`` repository. Defaults to ``sam_vit_h_4b8939.pth``. grounding_model (str): The model name used to grounding. Which can be found in the ``MMDetection`` repository. Defaults to ``glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365``. device (str): The device to load the model. Defaults to 'cpu'. toolmeta (None | dict | ToolMeta): The additional info of the tool. Defaults to None. """ default_desc = ('This tool can segment the specified kind of objects in ' 'the input image, and return the segmentation ' 'result image.') @require('segment_anything') @require('mmdet>=3.1.0') def __init__( self, sam_model: str = 'sam_vit_h_4b8939.pth', grounding_model: str = ('glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365'), device: str = 'cuda', toolmeta=None): super().__init__(toolmeta=toolmeta) self.sam_model = sam_model self.grounding_model = grounding_model self.device = device def setup(self): from mmdet.apis import DetInferencer self.grounding = load_or_build_object( DetInferencer, model=self.grounding_model, device=self.device) self.sam, self.sam_predictor = load_sam_and_predictor( self.sam_model, device=self.device) def apply( self, image: ImageIO, text: Annotated[str, Info('The object to segment.')], ) -> Annotated[ImageIO, Info('The segmentation result image.')]: results = self.grounding( inputs=image.to_array()[:, :, ::-1], # Input BGR texts=text, no_save_vis=True, return_datasamples=True) results = results['predictions'][0].pred_instances boxes_filt = results.bboxes pred_phrases = results.label_names output_image = self.segment_image_with_boxes(image.to_array(), boxes_filt, pred_phrases) return ImageIO(output_image) def get_mask_with_boxes(self, image, boxes_filt): if not self._is_setup: self.setup() self._is_setup = True boxes_filt = boxes_filt.cpu() transformed_boxes = self.sam_predictor.transform.apply_boxes_torch( boxes_filt, image.shape[:2]).to(self.device) features = self.sam_predictor.get_image_embedding(image) masks, _, _ = self.sam_predictor.predict_torch( features=features, point_coords=None, point_labels=None, boxes=transformed_boxes.to(self.device), multimask_output=False, ) return masks def segment_image_with_boxes(self, image, boxes_filt, pred_phrases): if not self._is_setup: self.setup() self._is_setup = True masks = self.get_mask_with_boxes(image, boxes_filt) # draw output image for mask in masks: image = self.show_mask( mask[0].cpu().numpy(), image, random_color=True, transparency=0.3) return image def show_mask(self, mask: np.ndarray, image: np.ndarray, random_color: bool = False, transparency=1) -> np.ndarray: """Visualize a mask on top of an image. Args: mask (np.ndarray): A 2D array of shape (H, W). image (np.ndarray): A 3D array of shape (H, W, 3). random_color (bool): Whether to use a random color for the mask. Outputs: np.ndarray: A 3D array of shape (H, W, 3) with the mask visualized on top of the image. transparenccy: the transparency of the segmentation mask """ import cv2 if random_color: color = np.concatenate([np.random.random(3)], axis=0) else: color = np.array([30 / 255, 144 / 255, 255 / 255]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) * 255 image = cv2.addWeighted(image, 0.7, mask_image.astype('uint8'), transparency, 0) return image def show_box(self, box, ax, label): import matplotlib.pyplot as plt x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch( plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2)) ax.text(x0, y0, label)