asr_arena / app.py
jasspier's picture
Update app.py
31564d0 verified
raw
history blame
No virus
2.37 kB
import gradio as gr
import torch
import torchaudio
from torchaudio.transforms import Resample
# 定义一个假设的 ASR 模型结构
class ASRModel(torch.nn.Module):
def __init__(self):
super(ASRModel, self).__init__()
# 这里假设模型架构是一个简单的 LSTM
self.lstm = torch.nn.LSTM(input_size=160, hidden_size=256, num_layers=3, batch_first=True)
self.linear = torch.nn.Linear(256, 29) # 假设有 29 个输出类用于字符
def forward(self, x):
x, _ = self.lstm(x)
x = self.linear(x)
return x
# 定义模型路径
model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/base.pt"
# 下载模型文件
print("Downloading model file...")
torch.hub.download_url_to_file(model_path, 'large.pt')
print("Model file downloaded.")
# 初始化模型
model = ASRModel()
# 加载模型参数
print("Loading model checkpoint...")
checkpoint = torch.load('large.pt', map_location=torch.device('cpu'))
print("Checkpoint keys:", checkpoint.keys())
# 打印模型参数中的键
if 'model' in checkpoint:
state_dict = checkpoint['model']
print("Model state_dict keys:", state_dict.keys())
else:
print("Key 'model' not found in checkpoint.")
state_dict = checkpoint
# 加载模型状态字典
try:
model.load_state_dict(state_dict)
print("Model state_dict loaded successfully.")
except Exception as e:
print("Error loading model state_dict:", str(e))
model.eval()
# 定义处理函数
def transcribe(audio):
print("Transcribing audio...")
waveform, sample_rate = torchaudio.load(audio)
resample = Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resample(waveform).squeeze()
input_values = waveform.unsqueeze(0)
with torch.no_grad():
logits = model(input_values)
predicted_ids = torch.argmax(logits, dim=-1)
transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符
print("Transcription:", transcription)
return transcription
# 创建 Gradio 界面
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(type="filepath"),
outputs="text",
title="TeleSpeech ASR",
description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model."
)
print("Launching Gradio interface...")
iface.launch()