import gradio as gr from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoTokenizer, BartForConditionalGeneration import torch import torchaudio # Replace librosa for faster audio processing # Load BART tokenizer and model for summarization tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") summarizer = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") # Load Wav2Vec2 processor and model for transcription processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") # Check if CUDA is available device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) summarizer.to(device) model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) summarizer = torch.quantization.quantize_dynamic(summarizer, {torch.nn.Linear}, dtype=torch.qint8) def transcribe_and_summarize(audioFile): # Load audio using torchaudio audio, sampling_rate = torchaudio.load(audioFile) # Resample audio to 16kHz if necessary if sampling_rate != 16000: resample_transform = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000) audio = resample_transform(audio) audio = audio.squeeze() # Process audio in chunks for large files chunk_size = int(16000 * 30) # 10-second chunks transcription = "" for i in range(0, len(audio), chunk_size): chunk = audio[i:i+chunk_size].numpy() inputs = processor(chunk, sampling_rate=16000, return_tensors="pt").input_values.to(device) # Transcription with torch.no_grad(): logits = model(inputs).logits predicted_ids = torch.argmax(logits, dim=-1) transcription += processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] + " " # Summarization inputs = tokenizer(transcription, return_tensors="pt", truncation=True, max_length=1024).to(device) result = summarizer.generate( inputs["input_ids"], min_length=10, max_length=1024, no_repeat_ngram_size=2, encoder_no_repeat_ngram_size=2, repetition_penalty=2.0, num_beams=2, # Reduced beams for faster inference early_stopping=True, ) summary = tokenizer.decode(result[0], skip_special_tokens=True) return transcription.strip(), summary.strip() # Gradio interface iface = gr.Interface( fn=transcribe_and_summarize, inputs=gr.Audio(type="filepath", label="Upload Audio"), outputs=[gr.Textbox(label="Transcription"), gr.Textbox(label="Summary")], title="Audio Transcription and Summarization", description="Transcribe and summarize audio using Wav2Vec2 and BART.", ) iface.launch()