|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ..models.whisper import WhisperForConditionalGeneration, WhisperProcessor |
|
from .base import PipelineTool |
|
|
|
|
|
class SpeechToTextTool(PipelineTool): |
|
default_checkpoint = "openai/whisper-base" |
|
description = ( |
|
"This is a tool that transcribes an audio into text. It takes an input named `audio` and returns the " |
|
"transcribed text." |
|
) |
|
name = "transcriber" |
|
pre_processor_class = WhisperProcessor |
|
model_class = WhisperForConditionalGeneration |
|
|
|
inputs = ["audio"] |
|
outputs = ["text"] |
|
|
|
def encode(self, audio): |
|
return self.pre_processor(audio, return_tensors="pt").input_features |
|
|
|
def forward(self, inputs): |
|
return self.model.generate(inputs=inputs) |
|
|
|
def decode(self, outputs): |
|
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0] |
|
|