jasspier commited on
Commit
e2bcfc6
1 Parent(s): 12ff2bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -10
app.py CHANGED
@@ -3,35 +3,55 @@ import torch
3
  import torchaudio
4
  from torchaudio.transforms import Resample
5
 
6
- # 定义一个简化的模型类(假设模型是LSTM架构)
7
- class ASRModel(torch.nn.Module):
8
  def __init__(self):
9
- super(ASRModel, self).__init__()
10
- self.lstm = torch.nn.LSTM(input_size=160, hidden_size=256, num_layers=3, batch_first=True)
11
- self.linear = torch.nn.Linear(256, 29) # 假设29个输出类用于字符
 
12
 
13
  def forward(self, x):
14
- x, _ = self.lstm(x)
15
- x = self.linear(x)
16
  return x
17
 
18
  # 定义模型路径
19
  model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/large.pt"
20
 
21
  # 下载模型文件
 
22
  torch.hub.download_url_to_file(model_path, 'large.pt')
 
23
 
24
  # 初始化模型
25
- model = ASRModel()
26
 
27
  # 加载模型参数
 
28
  checkpoint = torch.load('large.pt', map_location=torch.device('cpu'))
29
- state_dict = checkpoint['model'] # 假设模型权重保存在 'model' 键中
30
- model.load_state_dict(state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  model.eval()
32
 
33
  # 定义处理函数
34
  def transcribe(audio):
 
35
  waveform, sample_rate = torchaudio.load(audio)
36
  resample = Resample(orig_freq=sample_rate, new_freq=16000)
37
  waveform = resample(waveform)
@@ -41,6 +61,7 @@ def transcribe(audio):
41
  logits = model(input_values)
42
  predicted_ids = torch.argmax(logits, dim=-1)
43
  transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符
 
44
  return transcription
45
 
46
  # 创建 Gradio 界面
@@ -52,4 +73,5 @@ iface = gr.Interface(
52
  description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model."
53
  )
54
 
 
55
  iface.launch()
 
3
  import torchaudio
4
  from torchaudio.transforms import Resample
5
 
6
+ # 使用一个假设的 Transformer ASR 模型结构
7
+ class TransformerASRModel(torch.nn.Module):
8
  def __init__(self):
9
+ super(TransformerASRModel, self).__init__()
10
+ # 定义模型架构,这里需要根据实际情况进行调整
11
+ self.encoder = torch.nn.TransformerEncoderLayer(d_model=512, nhead=8)
12
+ self.decoder = torch.nn.Linear(512, 29) # 假设29个输出类用于字符
13
 
14
  def forward(self, x):
15
+ x = self.encoder(x)
16
+ x = self.decoder(x)
17
  return x
18
 
19
  # 定义模型路径
20
  model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/large.pt"
21
 
22
  # 下载模型文件
23
+ print("Downloading model file...")
24
  torch.hub.download_url_to_file(model_path, 'large.pt')
25
+ print("Model file downloaded.")
26
 
27
  # 初始化模型
28
+ model = TransformerASRModel()
29
 
30
  # 加载模型参数
31
+ print("Loading model checkpoint...")
32
  checkpoint = torch.load('large.pt', map_location=torch.device('cpu'))
33
+ print("Checkpoint keys:", checkpoint.keys())
34
+
35
+ # 打印模型参数中的键
36
+ if 'model' in checkpoint:
37
+ state_dict = checkpoint['model']
38
+ print("Model state_dict keys:", state_dict.keys())
39
+ else:
40
+ print("Key 'model' not found in checkpoint.")
41
+ state_dict = checkpoint
42
+
43
+ # 加载模型状态字典
44
+ try:
45
+ model.load_state_dict(state_dict)
46
+ print("Model state_dict loaded successfully.")
47
+ except Exception as e:
48
+ print("Error loading model state_dict:", str(e))
49
+
50
  model.eval()
51
 
52
  # 定义处理函数
53
  def transcribe(audio):
54
+ print("Transcribing audio...")
55
  waveform, sample_rate = torchaudio.load(audio)
56
  resample = Resample(orig_freq=sample_rate, new_freq=16000)
57
  waveform = resample(waveform)
 
61
  logits = model(input_values)
62
  predicted_ids = torch.argmax(logits, dim=-1)
63
  transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符
64
+ print("Transcription:", transcription)
65
  return transcription
66
 
67
  # 创建 Gradio 界面
 
73
  description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model."
74
  )
75
 
76
+ print("Launching Gradio interface...")
77
  iface.launch()