import gradio as gr import torch import torchaudio from torchaudio.transforms import Resample # 定义一个假设的 ASR 模型结构 class ASRModel(torch.nn.Module): def __init__(self): super(ASRModel, self).__init__() self.lstm = torch.nn.LSTM(input_size=160, hidden_size=256, num_layers=3, batch_first=True) self.linear = torch.nn.Linear(256, 29) # 假设有 29 个输出类用于字符 def forward(self, x): x, _ = self.lstm(x) x = self.linear(x) return x # 定义模型路径 model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/finetune_large_kespeech.pt" # 下载模型文件 print("Downloading model file...") torch.hub.download_url_to_file(model_path, 'large.pt') print("Model file downloaded.") # 初始化模型 model = ASRModel() # 加载模型参数 print("Loading model checkpoint...") checkpoint = torch.load('large.pt', map_location=torch.device('cpu')) print("Checkpoint keys:", checkpoint.keys()) # 打印模型参数中的键 if 'model' in checkpoint: state_dict = checkpoint['model'] print("Model state_dict keys:", state_dict.keys()) else: print("Key 'model' not found in checkpoint.") state_dict = checkpoint # 加载模型状态字典 try: model.load_state_dict(state_dict) print("Model state_dict loaded successfully.") except Exception as e: print("Error loading model state_dict:", str(e)) model.eval() # 定义处理函数 def transcribe(audio): print("Transcribing audio...") waveform, sample_rate = torchaudio.load(audio) resample = Resample(orig_freq=sample_rate, new_freq=16000) waveform = resample(waveform).squeeze() # 将输入数据转换为符合模型预期的形状 num_frames = waveform.size(0) if num_frames % 160 != 0: # 如果样本数量不是160的倍数,则填充样本 num_frames_padded = ((num_frames // 160) + 1) * 160 padding = num_frames_padded - num_frames waveform = torch.nn.functional.pad(waveform, (0, padding)) input_values = waveform.view(-1, 160).unsqueeze(0) # 确保输入形状为 (batch_size, seq_len, input_size) with torch.no_grad(): logits = model(input_values) predicted_ids = torch.argmax(logits, dim=-1) transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符 print("Transcription:", transcription) return transcription # 创建 Gradio 界面 iface = gr.Interface( fn=transcribe, inputs=gr.Audio(type="filepath"), outputs="text", title="TeleSpeech ASR", description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model." ) print("Launching Gradio interface...") iface.launch()