asr_arena / app.py
jasspier's picture
Update app.py
ca625d0 verified
raw
history blame
No virus
2.23 kB
import gradio as gr
import torch
import torchaudio
from torchaudio.transforms import Resample
from data2vec2 import Data2VecMultiModel, Data2VecMultiConfig, Modality
# 定义模型路径
model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/finetune_large_kespeech.pt"
# 下载模型文件
print("Downloading model file...")
torch.hub.download_url_to_file(model_path, 'large.pt')
print("Model file downloaded.")
# 加载模型配置和初始化模型
config = Data2VecMultiConfig()
model = Data2VecMultiModel(config, modalities=[Modality.AUDIO])
# 加载模型参数
print("Loading model checkpoint...")
checkpoint = torch.load('large.pt', map_location=torch.device('cpu'))
print("Checkpoint keys:", checkpoint.keys())
# 打印模型参数中的键
if 'model' in checkpoint:
state_dict = checkpoint['model']
print("Model state_dict keys:", state_dict.keys())
else:
print("Key 'model' not found in checkpoint.")
state_dict = checkpoint
# 加载模型状态字典
try:
model.load_state_dict(state_dict)
print("Model state_dict loaded successfully.")
except Exception as e:
print("Error loading model state_dict:", str(e))
model.eval()
# 定义处理函数
def transcribe(audio):
print("Transcribing audio...")
waveform, sample_rate = torchaudio.load(audio)
resample = Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resample(waveform).squeeze()
# 将输入数据转换为符合模型预期的形状
input_values = waveform.unsqueeze(0) # (batch_size, seq_len)
with torch.no_grad():
outputs = model.extract_features(input_values, mode='AUDIO')
logits = outputs["x"]
predicted_ids = torch.argmax(logits, dim=-1)
transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符
print("Transcription:", transcription)
return transcription
# 创建 Gradio 界面
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(type="filepath"),
outputs="text",
title="TeleSpeech ASR",
description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model."
)
print("Launching Gradio interface...")
iface.launch()