Source code for agentlego.tools.segmentation.semantic_segmentation
from agentlego.types import ImageIO
from agentlego.utils import load_or_build_object, require
from ..base import BaseTool
[docs]class SemanticSegmentation(BaseTool):
"""A tool to conduct semantic segmentation on an image.
Args:
seg_model (str): The model name used to inference. Which can be found
in the ``MMSegmentation`` repository.
Defaults to ``mask2former_r50_8xb2-90k_cityscapes-512x1024``.
device (str): The device to load the model. Defaults to 'cuda'.
toolmeta (None | dict | ToolMeta): The additional info of the tool.
Defaults to None.
"""
default_desc = ('This tool can segment all items in the input image and '
'return a segmentation result image. '
'It focus on urban scene images.')
@require('mmsegmentation')
def __init__(self,
seg_model: str = 'mask2former_r50_8xb2-90k_cityscapes-512x1024',
device: str = 'cuda',
toolmeta=None):
super().__init__(toolmeta=toolmeta)
self.seg_model = seg_model
self.device = device
def setup(self):
from mmseg.apis import MMSegInferencer
self._inferencer = load_or_build_object(
MMSegInferencer, model=self.seg_model, device=self.device)
def apply(self, image: ImageIO) -> ImageIO:
image = image.to_path()
results = self._inferencer(image, return_vis=True)
output_image = results['visualization']
return ImageIO(output_image)