Spaces:
Runtime error
Runtime error
File size: 2,373 Bytes
1869d7a 71b0e8e 31564d0 1869d7a 31564d0 1869d7a e2bcfc6 71b0e8e 31564d0 eb24c35 31564d0 1869d7a 31564d0 f8ebe93 31564d0 e2bcfc6 71b0e8e 1869d7a f8ebe93 31564d0 1869d7a e2bcfc6 1869d7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
import torch
import torchaudio
from torchaudio.transforms import Resample
# 定义一个假设的 ASR 模型结构
class ASRModel(torch.nn.Module):
def __init__(self):
super(ASRModel, self).__init__()
# 这里假设模型架构是一个简单的 LSTM
self.lstm = torch.nn.LSTM(input_size=160, hidden_size=256, num_layers=3, batch_first=True)
self.linear = torch.nn.Linear(256, 29) # 假设有 29 个输出类用于字符
def forward(self, x):
x, _ = self.lstm(x)
x = self.linear(x)
return x
# 定义模型路径
model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/base.pt"
# 下载模型文件
print("Downloading model file...")
torch.hub.download_url_to_file(model_path, 'large.pt')
print("Model file downloaded.")
# 初始化模型
model = ASRModel()
# 加载模型参数
print("Loading model checkpoint...")
checkpoint = torch.load('large.pt', map_location=torch.device('cpu'))
print("Checkpoint keys:", checkpoint.keys())
# 打印模型参数中的键
if 'model' in checkpoint:
state_dict = checkpoint['model']
print("Model state_dict keys:", state_dict.keys())
else:
print("Key 'model' not found in checkpoint.")
state_dict = checkpoint
# 加载模型状态字典
try:
model.load_state_dict(state_dict)
print("Model state_dict loaded successfully.")
except Exception as e:
print("Error loading model state_dict:", str(e))
model.eval()
# 定义处理函数
def transcribe(audio):
print("Transcribing audio...")
waveform, sample_rate = torchaudio.load(audio)
resample = Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resample(waveform).squeeze()
input_values = waveform.unsqueeze(0)
with torch.no_grad():
logits = model(input_values)
predicted_ids = torch.argmax(logits, dim=-1)
transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符
print("Transcription:", transcription)
return transcription
# 创建 Gradio 界面
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(type="filepath"),
outputs="text",
title="TeleSpeech ASR",
description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model."
)
print("Launching Gradio interface...")
iface.launch()
|