Spaces:
Running
Running
import gradio as gr | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoTokenizer, BartForConditionalGeneration | |
import torch | |
import torchaudio | |
# Load BART | |
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") | |
summarizer = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") | |
# Load Wav2Vec2 | |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") | |
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") | |
# Check for CUDA | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
summarizer.to(device) | |
model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) | |
summarizer = torch.quantization.quantize_dynamic(summarizer, {torch.nn.Linear}, dtype=torch.qint8) | |
def transcribe_and_summarize(audioFile): | |
audio, sampling_rate = torchaudio.load(audioFile) | |
if sampling_rate != 16000: | |
resample_transform = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000) | |
audio = resample_transform(audio) | |
audio = audio.squeeze() | |
chunk_size = int(16000 * 30) | |
transcription = "" | |
for i in range(0, len(audio), chunk_size): | |
chunk = audio[i:i+chunk_size].numpy() | |
inputs = processor(chunk, sampling_rate=16000, return_tensors="pt").input_values.to(device) | |
with torch.no_grad(): | |
logits = model(inputs).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription += processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] + " " | |
inputs = tokenizer(transcription, return_tensors="pt", truncation=True, max_length=1024).to(device) | |
result = summarizer.generate( | |
inputs["input_ids"], | |
min_length=10, | |
max_length=1024, | |
no_repeat_ngram_size=2, | |
encoder_no_repeat_ngram_size=2, | |
repetition_penalty=2.0, | |
num_beams=2, | |
early_stopping=True, | |
) | |
summary = tokenizer.decode(result[0], skip_special_tokens=True) | |
return transcription.strip(), summary.strip() | |
iface = gr.Interface( | |
fn=transcribe_and_summarize, | |
inputs=gr.Audio(type="filepath", label="Upload Audio"), | |
outputs=[gr.Textbox(label="Transcription"), gr.Textbox(label="Summary")], | |
title="Audio Transcription and Summarization", | |
description="Transcribe and summarize audio using Audio Summarizer.", | |
) | |
iface.launch() | |