Shortcuts

Source code for agentlego.tools.object_detection.object_detection

from agentlego.types import Annotated, ImageIO, Info
from agentlego.utils import load_or_build_object, require
from ..base import BaseTool


[docs]class ObjectDetection(BaseTool): """A tool to detection all objects defined in COCO 80 classes. Args: model (str): The model name used to detect texts. Which can be found in the ``MMDetection`` repository. Defaults to ``rtmdet_l_8xb32-300e_coco``. device (str): The device to load the model. Defaults to 'cuda'. toolmeta (None | dict | ToolMeta): The additional info of the tool. Defaults to None. """ default_desc = 'The tool can detect all common objects in the picture.' @require('mmdet>=3.1.0') def __init__(self, model: str = 'rtmdet_l_8xb32-300e_coco', device: str = 'cuda', toolmeta=None): super().__init__(toolmeta=toolmeta) self.model = model self.device = device def setup(self): from mmdet.apis import DetInferencer self._inferencer = load_or_build_object( DetInferencer, model=self.model, device=self.device) self.classes = self._inferencer.model.dataset_meta['classes'] def apply( self, image: ImageIO, ) -> Annotated[str, Info('All detected objects, include object name, ' 'bbox in (x1, y1, x2, y2) format, ' 'and detection score.')]: from mmdet.structures import DetDataSample results = self._inferencer( image.to_array()[:, :, ::-1], return_datasamples=True, ) data_sample = results['predictions'][0] preds: DetDataSample = data_sample.pred_instances preds = preds[preds.scores > 0.5] pred_descs = [] pred_tmpl = '{} ({:.0f}, {:.0f}, {:.0f}, {:.0f}), score {:.0f}' for label, bbox, score in zip(preds.labels, preds.bboxes, preds.scores): label = self.classes[label] pred_descs.append(pred_tmpl.format(label, *bbox, score * 100)) if len(pred_descs) == 0: return 'No object found.' else: return '\n'.join(pred_descs)