dmcartor's picture
Modifying synthesise function again
f5b2e41 verified
raw
history blame contribute delete
No virus
6.28 kB
from transformers import VitsModel, AutoTokenizer, WhisperForConditionalGeneration, WhisperProcessor, M2M100ForConditionalGeneration, M2M100Tokenizer
import torch, torchaudio
import numpy as np
import gradio as gr
from pydantic import BaseModel
from typing import ClassVar
class Config(BaseModel):
arbitrary_types_allowed: ClassVar[bool] = True
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load models and tokenizers
asr_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").to(device)
asr_processor = WhisperProcessor.from_pretrained("openai/whisper-base")
mt_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M").to(device)
mt_tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
tts_model = VitsModel.from_pretrained("facebook/mms-tts-spa").to(device)
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-spa")
def transcribe_and_detect_lang(audio):
# Prepare input features
input_features = asr_processor(audio, sampling_rate=16000, return_tensors="pt").input_features.to(device)
# Perform inference (language detection + transcription)
generated_ids = asr_model.generate(input_features)
transcription = asr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Attempt to grab the third token to see if it's the language token
# Convert generated IDs to tokens
tokens = asr_processor.tokenizer.convert_ids_to_tokens(generated_ids[0])
# Check the third token
if len(tokens) > 2:
detected_lang_token = tokens[1] # Second token
detected_lang = detected_lang_token.strip('<|>')
else:
detected_lang = "Unknown" # Fallback if there are not enough tokens
return transcription, detected_lang
def translate_text(text, source_lang, target_lang="es"):
"""
Translates text from source language to target language using M2M100.
"""
# Explicitly set source language for the tokenizer
mt_tokenizer.src_lang = source_lang # Ensure this is set properly to detected_lang
# Tokenize input text
encoded_input = mt_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
# Generate translated tokens
generated_tokens = mt_model.generate(
**encoded_input,
forced_bos_token_id=mt_tokenizer.get_lang_id(target_lang) # Corrected tokenizer usage here
)
# Decode the tokens to get the translated text
translated_text = mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
return translated_text
def translate(audio):
# ASR: Transcribe and detect language
transcription, detected_lang = transcribe_and_detect_lang(audio)
print(f"Detected language: {detected_lang}")
print(f"Transcription: {transcription}")
# MT: Translate transcription to target language (Spanish)
translated_text = translate_text(transcription, source_lang=detected_lang)
print(f"Translated Text: {translated_text}")
return translated_text
def synthesise(text):
# Properly tokenize the translated text for the VITS model
inputs = tts_tokenizer(text, return_tensors="pt").to(device)
# Run the model to generate the waveform
with torch.no_grad():
output = tts_model(**inputs)
# Check the output and access the waveform
print(f"TTS Model Output: {output}")
# Access the synthesized waveform from the model output
speech = output.audio # The waveform is stored in the 'audio' key
# Convert to numpy format suitable for audio output
speech_numpy = (speech.squeeze().cpu().numpy() * 32767).astype(np.int16)
return speech_numpy
# Normalize audio
def normalize_audio(audio):
audio = audio.astype(np.float32)
audio /= np.max(np.abs(audio))
return audio
# Preprocess the audio by converting to mono and resampling to 16 kHz.
def preprocess_audio(waveform, sample_rate):
# Convert to mono if it's stereo
if waveform.ndim > 1 and waveform.shape[0] > 1: # Stereo check
waveform = waveform.mean(dim=0)
# Resample to 16 kHz if necessary
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform.unsqueeze(0)) # Ensure it's in the correct shape
waveform = waveform.squeeze(0) # Remove extra batch dimension after resampling
return waveform.numpy()
# Main translation function
def speech_to_speech_translation(audio):
if isinstance(audio, str): # File path
waveform, sample_rate = torchaudio.load(audio)
else: # Microphone input as tuple
sample_rate, waveform = audio
waveform = torch.from_numpy(normalize_audio(waveform)) # Normalize the audio
# Preprocess the audio
processed_audio = preprocess_audio(waveform, sample_rate)
# Step 1: Translate the input audio to text in the target language
translated_text = translate(processed_audio) # Custom translate function
print(f"Translated Text: {translated_text}")
# Step 2: Synthesize speech from the translated text
synthesised_speech = synthesise(translated_text) # Custom synthesis function
print(f"Synthesised Speech Shape: {synthesised_speech.shape}")
# Return synthesized speech in the required format
return 16000, synthesised_speech.astype(np.int16)
# Gradio UI Setup
title = "Cascaded STST"
description = """
Demo for cascaded speech-to-speech translation (STST), mapping from source speech in any language to target speech in English. Demo uses OpenAI's Whisper model for speech translation, and a TTS model for text-to-speech.
"""
demo = gr.Blocks()
mic_translate = gr.Interface(
fn=speech_to_speech_translation,
inputs=gr.Audio(sources="microphone", type="numpy"),
outputs=gr.Audio(label="Generated Speech", type="numpy"),
title=title,
description=description,
)
file_translate = gr.Interface(
fn=speech_to_speech_translation,
inputs=gr.Audio(sources="upload", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy"),
title=title,
description=description,
)
with demo:
gr.TabbedInterface([mic_translate, file_translate], ["Microphone", "Audio File"])
demo.launch()