Source code for agentlego.tools.object_detection.text_to_bbox
from agentlego.types import Annotated, ImageIO, Info
from agentlego.utils import load_or_build_object, require
from ..base import BaseTool
[docs]class TextToBbox(BaseTool):
"""A tool to detection the given object.
Args:
model (str): The model name used to detect texts.
Which can be found in the ``MMDetection`` repository.
Defaults to ``glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365``.
device (str): The device to load the model. Defaults to 'cpu'.
toolmeta (None | dict | ToolMeta): The additional info of the tool.
Defaults to None.
"""
default_desc = ('The tool can detect the object location according to '
'description.')
@require('mmdet>=3.1.0')
def __init__(self,
model: str = 'glip_atss_swin-t_b_fpn_dyhead_pretrain_obj365',
device: str = 'cuda',
toolmeta=None):
super().__init__(toolmeta=toolmeta)
self.model = model
self.device = device
def setup(self):
from mmdet.apis import DetInferencer
self._inferencer = load_or_build_object(
DetInferencer, model=self.model, device=self.device)
self._visualizer = self._inferencer.visualizer
def apply(
self,
image: ImageIO,
text: Annotated[str, Info('The object description in English.')],
top1: Annotated[bool,
Info('If true, return the object with highest score. '
'If false, return all detected objects.')] = True,
) -> Annotated[str,
Info('Detected objects, include bbox in '
'(x1, y1, x2, y2) format, and detection score.')]:
from mmdet.structures import DetDataSample
results = self._inferencer(
image.to_array()[:, :, ::-1],
texts=text,
return_datasamples=True,
)
data_sample = results['predictions'][0]
preds: DetDataSample = data_sample.pred_instances
if len(preds) == 0:
return 'No object found.'
pred_tmpl = '({:.0f}, {:.0f}, {:.0f}, {:.0f}), score {:.0f}'
if top1:
preds = preds[preds.scores.topk(1).indices]
else:
preds = preds[preds.scores > 0.5]
pred_descs = []
for bbox, score in zip(preds.bboxes, preds.scores):
pred_descs.append(pred_tmpl.format(*bbox, score * 100))
pred_str = '\n'.join(pred_descs)
return pred_str