from transformers import VitsModel, AutoTokenizer, WhisperForConditionalGeneration, WhisperProcessor, M2M100ForConditionalGeneration, M2M100Tokenizer import torch, torchaudio import numpy as np import gradio as gr from pydantic import BaseModel from typing import ClassVar class Config(BaseModel): arbitrary_types_allowed: ClassVar[bool] = True # Set device device = "cuda" if torch.cuda.is_available() else "cpu" # Load models and tokenizers asr_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").to(device) asr_processor = WhisperProcessor.from_pretrained("openai/whisper-base") mt_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M").to(device) mt_tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M") tts_model = VitsModel.from_pretrained("facebook/mms-tts-spa").to(device) tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-spa") def transcribe_and_detect_lang(audio): # Prepare input features input_features = asr_processor(audio, sampling_rate=16000, return_tensors="pt").input_features.to(device) # Perform inference (language detection + transcription) generated_ids = asr_model.generate(input_features) transcription = asr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # Attempt to grab the third token to see if it's the language token # Convert generated IDs to tokens tokens = asr_processor.tokenizer.convert_ids_to_tokens(generated_ids[0]) # Check the third token if len(tokens) > 2: detected_lang_token = tokens[1] # Second token detected_lang = detected_lang_token.strip('<|>') else: detected_lang = "Unknown" # Fallback if there are not enough tokens return transcription, detected_lang def translate_text(text, source_lang, target_lang="es"): """ Translates text from source language to target language using M2M100. """ # Explicitly set source language for the tokenizer mt_tokenizer.src_lang = source_lang # Ensure this is set properly to detected_lang # Tokenize input text encoded_input = mt_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device) # Generate translated tokens generated_tokens = mt_model.generate( **encoded_input, forced_bos_token_id=mt_tokenizer.get_lang_id(target_lang) # Corrected tokenizer usage here ) # Decode the tokens to get the translated text translated_text = mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] return translated_text def translate(audio): # ASR: Transcribe and detect language transcription, detected_lang = transcribe_and_detect_lang(audio) print(f"Detected language: {detected_lang}") print(f"Transcription: {transcription}") # MT: Translate transcription to target language (Spanish) translated_text = translate_text(transcription, source_lang=detected_lang) print(f"Translated Text: {translated_text}") return translated_text def synthesise(text): # Properly tokenize the translated text for the VITS model inputs = tts_tokenizer(text, return_tensors="pt").to(device) # Run the model to generate the waveform with torch.no_grad(): output = tts_model(**inputs) # Check the output and access the waveform print(f"TTS Model Output: {output}") # Access the synthesized waveform from the model output speech = output.audio # The waveform is stored in the 'audio' key # Convert to numpy format suitable for audio output speech_numpy = (speech.squeeze().cpu().numpy() * 32767).astype(np.int16) return speech_numpy # Normalize audio def normalize_audio(audio): audio = audio.astype(np.float32) audio /= np.max(np.abs(audio)) return audio # Preprocess the audio by converting to mono and resampling to 16 kHz. def preprocess_audio(waveform, sample_rate): # Convert to mono if it's stereo if waveform.ndim > 1 and waveform.shape[0] > 1: # Stereo check waveform = waveform.mean(dim=0) # Resample to 16 kHz if necessary if sample_rate != 16000: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) waveform = resampler(waveform.unsqueeze(0)) # Ensure it's in the correct shape waveform = waveform.squeeze(0) # Remove extra batch dimension after resampling return waveform.numpy() # Main translation function def speech_to_speech_translation(audio): if isinstance(audio, str): # File path waveform, sample_rate = torchaudio.load(audio) else: # Microphone input as tuple sample_rate, waveform = audio waveform = torch.from_numpy(normalize_audio(waveform)) # Normalize the audio # Preprocess the audio processed_audio = preprocess_audio(waveform, sample_rate) # Step 1: Translate the input audio to text in the target language translated_text = translate(processed_audio) # Custom translate function print(f"Translated Text: {translated_text}") # Step 2: Synthesize speech from the translated text synthesised_speech = synthesise(translated_text) # Custom synthesis function print(f"Synthesised Speech Shape: {synthesised_speech.shape}") # Return synthesized speech in the required format return 16000, synthesised_speech.astype(np.int16) # Gradio UI Setup title = "Cascaded STST" description = """ Demo for cascaded speech-to-speech translation (STST), mapping from source speech in any language to target speech in English. Demo uses OpenAI's Whisper model for speech translation, and a TTS model for text-to-speech. """ demo = gr.Blocks() mic_translate = gr.Interface( fn=speech_to_speech_translation, inputs=gr.Audio(sources="microphone", type="numpy"), outputs=gr.Audio(label="Generated Speech", type="numpy"), title=title, description=description, ) file_translate = gr.Interface( fn=speech_to_speech_translation, inputs=gr.Audio(sources="upload", type="filepath"), outputs=gr.Audio(label="Generated Speech", type="numpy"), title=title, description=description, ) with demo: gr.TabbedInterface([mic_translate, file_translate], ["Microphone", "Audio File"]) demo.launch()