JuanjoJ55's picture
doc: removed comments
e2374dd verified
import gradio as gr
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoTokenizer, BartForConditionalGeneration
import torch
import torchaudio
# Load BART
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
summarizer = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
# Load Wav2Vec2
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
# Check for CUDA
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
summarizer.to(device)
model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
summarizer = torch.quantization.quantize_dynamic(summarizer, {torch.nn.Linear}, dtype=torch.qint8)
def transcribe_and_summarize(audioFile):
audio, sampling_rate = torchaudio.load(audioFile)
if sampling_rate != 16000:
resample_transform = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
audio = resample_transform(audio)
audio = audio.squeeze()
chunk_size = int(16000 * 30)
transcription = ""
for i in range(0, len(audio), chunk_size):
chunk = audio[i:i+chunk_size].numpy()
inputs = processor(chunk, sampling_rate=16000, return_tensors="pt").input_values.to(device)
with torch.no_grad():
logits = model(inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription += processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] + " "
inputs = tokenizer(transcription, return_tensors="pt", truncation=True, max_length=1024).to(device)
result = summarizer.generate(
inputs["input_ids"],
min_length=10,
max_length=1024,
no_repeat_ngram_size=2,
encoder_no_repeat_ngram_size=2,
repetition_penalty=2.0,
num_beams=2,
early_stopping=True,
)
summary = tokenizer.decode(result[0], skip_special_tokens=True)
return transcription.strip(), summary.strip()
iface = gr.Interface(
fn=transcribe_and_summarize,
inputs=gr.Audio(type="filepath", label="Upload Audio"),
outputs=[gr.Textbox(label="Transcription"), gr.Textbox(label="Summary")],
title="Audio Transcription and Summarization",
description="Transcribe and summarize audio using Audio Summarizer.",
)
iface.launch()