import gradio as gr import torch import torchaudio from torchaudio.transforms import Resample # 定义一个简化的模型类(假设模型是LSTM架构) class ASRModel(torch.nn.Module): def __init__(self): super(ASRModel, self).__init__() self.lstm = torch.nn.LSTM(input_size=160, hidden_size=256, num_layers=3, batch_first=True) self.linear = torch.nn.Linear(256, 29) # 假设29个输出类用于字符 def forward(self, x): x, _ = self.lstm(x) x = self.linear(x) return x # 定义模型路径 model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/large.pt" # 下载模型文件 torch.hub.download_url_to_file(model_path, 'large.pt') # 初始化模型 model = ASRModel() # 加载模型参数 checkpoint = torch.load('large.pt', map_location=torch.device('cpu')) state_dict = checkpoint['model'] # 假设模型权重保存在 'model' 键中 model.load_state_dict(state_dict) model.eval() # 定义处理函数 def transcribe(audio): waveform, sample_rate = torchaudio.load(audio) resample = Resample(orig_freq=sample_rate, new_freq=16000) waveform = resample(waveform) input_values = waveform.unsqueeze(0) with torch.no_grad(): logits = model(input_values) predicted_ids = torch.argmax(logits, dim=-1) transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符 return transcription # 创建 Gradio 界面 iface = gr.Interface( fn=transcribe, inputs=gr.Audio(source="microphone", type="filepath"), outputs="text", title="TeleSpeech ASR", description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model." ) iface.launch()