|
from transformers import VitsModel, AutoTokenizer, WhisperForConditionalGeneration, WhisperProcessor, M2M100ForConditionalGeneration, M2M100Tokenizer |
|
import torch, torchaudio |
|
import numpy as np |
|
import gradio as gr |
|
from pydantic import BaseModel |
|
from typing import ClassVar |
|
|
|
class Config(BaseModel): |
|
arbitrary_types_allowed: ClassVar[bool] = True |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
asr_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").to(device) |
|
asr_processor = WhisperProcessor.from_pretrained("openai/whisper-base") |
|
mt_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M").to(device) |
|
mt_tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M") |
|
tts_model = VitsModel.from_pretrained("facebook/mms-tts-spa").to(device) |
|
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-spa") |
|
|
|
def transcribe_and_detect_lang(audio): |
|
|
|
input_features = asr_processor(audio, sampling_rate=16000, return_tensors="pt").input_features.to(device) |
|
|
|
|
|
generated_ids = asr_model.generate(input_features) |
|
transcription = asr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
|
|
|
|
tokens = asr_processor.tokenizer.convert_ids_to_tokens(generated_ids[0]) |
|
|
|
|
|
if len(tokens) > 2: |
|
detected_lang_token = tokens[1] |
|
detected_lang = detected_lang_token.strip('<|>') |
|
else: |
|
detected_lang = "Unknown" |
|
|
|
return transcription, detected_lang |
|
|
|
def translate_text(text, source_lang, target_lang="es"): |
|
""" |
|
Translates text from source language to target language using M2M100. |
|
""" |
|
|
|
mt_tokenizer.src_lang = source_lang |
|
|
|
|
|
encoded_input = mt_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device) |
|
|
|
|
|
generated_tokens = mt_model.generate( |
|
**encoded_input, |
|
forced_bos_token_id=mt_tokenizer.get_lang_id(target_lang) |
|
) |
|
|
|
|
|
translated_text = mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
return translated_text |
|
|
|
def translate(audio): |
|
|
|
transcription, detected_lang = transcribe_and_detect_lang(audio) |
|
print(f"Detected language: {detected_lang}") |
|
print(f"Transcription: {transcription}") |
|
|
|
|
|
translated_text = translate_text(transcription, source_lang=detected_lang) |
|
print(f"Translated Text: {translated_text}") |
|
|
|
return translated_text |
|
|
|
def synthesise(text): |
|
|
|
inputs = tts_tokenizer(text, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
output = tts_model(**inputs) |
|
|
|
|
|
print(f"TTS Model Output: {output}") |
|
|
|
|
|
speech = output.audio |
|
|
|
|
|
speech_numpy = (speech.squeeze().cpu().numpy() * 32767).astype(np.int16) |
|
|
|
return speech_numpy |
|
|
|
|
|
def normalize_audio(audio): |
|
audio = audio.astype(np.float32) |
|
audio /= np.max(np.abs(audio)) |
|
return audio |
|
|
|
|
|
def preprocess_audio(waveform, sample_rate): |
|
|
|
if waveform.ndim > 1 and waveform.shape[0] > 1: |
|
waveform = waveform.mean(dim=0) |
|
|
|
|
|
if sample_rate != 16000: |
|
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) |
|
waveform = resampler(waveform.unsqueeze(0)) |
|
waveform = waveform.squeeze(0) |
|
|
|
return waveform.numpy() |
|
|
|
|
|
def speech_to_speech_translation(audio): |
|
if isinstance(audio, str): |
|
waveform, sample_rate = torchaudio.load(audio) |
|
else: |
|
sample_rate, waveform = audio |
|
waveform = torch.from_numpy(normalize_audio(waveform)) |
|
|
|
|
|
processed_audio = preprocess_audio(waveform, sample_rate) |
|
|
|
|
|
translated_text = translate(processed_audio) |
|
print(f"Translated Text: {translated_text}") |
|
|
|
|
|
synthesised_speech = synthesise(translated_text) |
|
print(f"Synthesised Speech Shape: {synthesised_speech.shape}") |
|
|
|
|
|
return 16000, synthesised_speech.astype(np.int16) |
|
|
|
|
|
title = "Cascaded STST" |
|
description = """ |
|
Demo for cascaded speech-to-speech translation (STST), mapping from source speech in any language to target speech in English. Demo uses OpenAI's Whisper model for speech translation, and a TTS model for text-to-speech. |
|
""" |
|
|
|
demo = gr.Blocks() |
|
|
|
mic_translate = gr.Interface( |
|
fn=speech_to_speech_translation, |
|
inputs=gr.Audio(sources="microphone", type="numpy"), |
|
outputs=gr.Audio(label="Generated Speech", type="numpy"), |
|
title=title, |
|
description=description, |
|
) |
|
|
|
file_translate = gr.Interface( |
|
fn=speech_to_speech_translation, |
|
inputs=gr.Audio(sources="upload", type="filepath"), |
|
outputs=gr.Audio(label="Generated Speech", type="numpy"), |
|
title=title, |
|
description=description, |
|
) |
|
|
|
with demo: |
|
gr.TabbedInterface([mic_translate, file_translate], ["Microphone", "Audio File"]) |
|
|
|
demo.launch() |