File size: 2,420 Bytes
688951f
 
 
e2374dd
688951f
e2374dd
688951f
 
 
e2374dd
688951f
 
 
e2374dd
688951f
 
 
 
39cf578
 
 
 
688951f
39cf578
 
 
 
 
 
688951f
e2374dd
39cf578
688951f
39cf578
 
 
688951f
39cf578
 
 
 
688951f
39cf578
 
688951f
 
 
1fc799c
688951f
 
 
e2374dd
688951f
 
 
 
39cf578
688951f
 
 
 
 
 
e2374dd
688951f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoTokenizer, BartForConditionalGeneration
import torch
import torchaudio  

# Load BART 
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
summarizer = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")

# Load Wav2Vec2 
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

# Check for CUDA
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
summarizer.to(device)

model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
summarizer = torch.quantization.quantize_dynamic(summarizer, {torch.nn.Linear}, dtype=torch.qint8)


def transcribe_and_summarize(audioFile):
    audio, sampling_rate = torchaudio.load(audioFile)
    
    if sampling_rate != 16000:
        resample_transform = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
        audio = resample_transform(audio)
    audio = audio.squeeze()

    chunk_size = int(16000 * 30)  
    transcription = ""

    for i in range(0, len(audio), chunk_size):
        chunk = audio[i:i+chunk_size].numpy()
        inputs = processor(chunk, sampling_rate=16000, return_tensors="pt").input_values.to(device)

        with torch.no_grad():
            logits = model(inputs).logits
            predicted_ids = torch.argmax(logits, dim=-1)
            transcription += processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] + " "

    inputs = tokenizer(transcription, return_tensors="pt", truncation=True, max_length=1024).to(device)
    
    result = summarizer.generate(
        inputs["input_ids"],
        min_length=10,
        max_length=1024,
        no_repeat_ngram_size=2,
        encoder_no_repeat_ngram_size=2,
        repetition_penalty=2.0,
        num_beams=2, 
        early_stopping=True,
    )
    summary = tokenizer.decode(result[0], skip_special_tokens=True)

    return transcription.strip(), summary.strip()

iface = gr.Interface(
    fn=transcribe_and_summarize,
    inputs=gr.Audio(type="filepath", label="Upload Audio"),
    outputs=[gr.Textbox(label="Transcription"), gr.Textbox(label="Summary")], 
    title="Audio Transcription and Summarization",
    description="Transcribe and summarize audio using Audio Summarizer.",
)

iface.launch()