Shortcuts

Source code for agentlego.tools.speech_text.speech_to_text

from agentlego.types import AudioIO
from agentlego.utils import apply_to, is_package_available, load_or_build_object, require
from ..base import BaseTool

if is_package_available('torch'):
    import torch

if is_package_available('torchaudio'):
    import torchaudio


def resampling_audio(audio: AudioIO, new_rate):
    tensor, ori_sampling_rate = audio.to_tensor(), audio.sampling_rate
    tensor = torchaudio.functional.resample(tensor, ori_sampling_rate, new_rate)
    return AudioIO(tensor, sampling_rate=new_rate)


[docs]class SpeechToText(BaseTool): """A tool to recognize speech and convert to text. Args: model (str): The model name used to inference. Which can be found in the ``HuggingFace`` model page. Defaults to ``openai/whisper-base``. device (str): The device to load the model. Defaults to 'cpu'. toolmeta (None | dict | ToolMeta): The additional info of the tool. Defaults to None. """ default_desc = 'The tool can translate spoken language audio into text.' @require(('torch', 'transformers', 'torchaudio')) def __init__(self, model='openai/whisper-base', device='cuda', toolmeta=None): super().__init__(toolmeta) self.model_name = model self.device = device def setup(self) -> None: from transformers.models.whisper import (WhisperForConditionalGeneration, WhisperProcessor) self.processor = load_or_build_object(WhisperProcessor.from_pretrained, self.model_name) self.model = load_or_build_object( WhisperForConditionalGeneration.from_pretrained, self.model_name).to(self.device) def apply(self, audio: AudioIO) -> str: target_sampling_rate = self.processor.feature_extractor.sampling_rate if target_sampling_rate != audio.sampling_rate: audio = resampling_audio(audio, target_sampling_rate) encoded_inputs = self.processor( audio.to_tensor().numpy().reshape(-1), return_tensors='pt', sampling_rate=target_sampling_rate).input_features encoded_inputs = apply_to(encoded_inputs, lambda x: isinstance(x, torch.Tensor), lambda x: x.to(self.device)) outputs = self.model.generate(inputs=encoded_inputs) outputs = apply_to(outputs, lambda x: isinstance(x, torch.Tensor), lambda x: x.to('cpu')) return self.processor.batch_decode(outputs, skip_special_tokens=True)[0]