|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor |
|
from ..utils import is_datasets_available |
|
from .base import PipelineTool |
|
|
|
|
|
if is_datasets_available(): |
|
from datasets import load_dataset |
|
|
|
|
|
class TextToSpeechTool(PipelineTool): |
|
default_checkpoint = "microsoft/speecht5_tts" |
|
description = ( |
|
"This is a tool that reads an English text out loud. It takes an input named `text` which should contain the " |
|
"text to read (in English) and returns a waveform object containing the sound." |
|
) |
|
name = "text_reader" |
|
pre_processor_class = SpeechT5Processor |
|
model_class = SpeechT5ForTextToSpeech |
|
post_processor_class = SpeechT5HifiGan |
|
|
|
inputs = ["text"] |
|
outputs = ["audio"] |
|
|
|
def setup(self): |
|
if self.post_processor is None: |
|
self.post_processor = "microsoft/speecht5_hifigan" |
|
super().setup() |
|
|
|
def encode(self, text, speaker_embeddings=None): |
|
inputs = self.pre_processor(text=text, return_tensors="pt", truncation=True) |
|
|
|
if speaker_embeddings is None: |
|
if not is_datasets_available(): |
|
raise ImportError("Datasets needs to be installed if not passing speaker embeddings.") |
|
|
|
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") |
|
speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0) |
|
|
|
return {"input_ids": inputs["input_ids"], "speaker_embeddings": speaker_embeddings} |
|
|
|
def forward(self, inputs): |
|
with torch.no_grad(): |
|
return self.model.generate_speech(**inputs) |
|
|
|
def decode(self, outputs): |
|
with torch.no_grad(): |
|
return self.post_processor(outputs).cpu().detach() |
|
|