Spaces:
Running
Running
import gradio as gr | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoTokenizer, BartForConditionalGeneration | |
import torch | |
import torchaudio # Replace librosa for faster audio processing | |
# Load BART tokenizer and model for summarization | |
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") | |
summarizer = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") | |
# Load Wav2Vec2 processor and model for transcription | |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") | |
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") | |
# Check if CUDA is available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
summarizer.to(device) | |
model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) | |
summarizer = torch.quantization.quantize_dynamic(summarizer, {torch.nn.Linear}, dtype=torch.qint8) | |
def transcribe_and_summarize(audioFile): | |
# Load audio using torchaudio | |
audio, sampling_rate = torchaudio.load(audioFile) | |
# Resample audio to 16kHz if necessary | |
if sampling_rate != 16000: | |
resample_transform = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000) | |
audio = resample_transform(audio) | |
audio = audio.squeeze() | |
# Process audio in chunks for large files | |
chunk_size = int(16000 * 30) # 10-second chunks | |
transcription = "" | |
for i in range(0, len(audio), chunk_size): | |
chunk = audio[i:i+chunk_size].numpy() | |
inputs = processor(chunk, sampling_rate=16000, return_tensors="pt").input_values.to(device) | |
# Transcription | |
with torch.no_grad(): | |
logits = model(inputs).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription += processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] + " " | |
# Summarization | |
inputs = tokenizer(transcription, return_tensors="pt", truncation=True, max_length=1024).to(device) | |
result = summarizer.generate( | |
inputs["input_ids"], | |
min_length=10, | |
max_length=1024, | |
no_repeat_ngram_size=2, | |
encoder_no_repeat_ngram_size=2, | |
repetition_penalty=2.0, | |
num_beams=2, # Reduced beams for faster inference | |
early_stopping=True, | |
) | |
summary = tokenizer.decode(result[0], skip_special_tokens=True) | |
return transcription.strip(), summary.strip() | |
# Gradio interface | |
iface = gr.Interface( | |
fn=transcribe_and_summarize, | |
inputs=gr.Audio(type="filepath", label="Upload Audio"), | |
outputs=[gr.Textbox(label="Transcription"), gr.Textbox(label="Summary")], | |
title="Audio Transcription and Summarization", | |
description="Transcribe and summarize audio using Wav2Vec2 and BART.", | |
) | |
iface.launch() | |