jasspier commited on
Commit
f7e26e2
1 Parent(s): 746e04b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -2
app.py CHANGED
@@ -3,14 +3,29 @@ import torch
3
  import torchaudio
4
  from torchaudio.transforms import Resample
5
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  # 定义模型路径
7
  model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/large.pt"
8
 
9
  # 下载模型文件
10
  torch.hub.download_url_to_file(model_path, 'large.pt')
11
 
 
 
 
12
  # 加载模型参数
13
- model = torch.load('large.pt', map_location=torch.device('cpu'))
14
  model.eval()
15
 
16
  # 定义处理函数
@@ -23,7 +38,7 @@ def transcribe(audio):
23
  with torch.no_grad():
24
  logits = model(input_values)
25
  predicted_ids = torch.argmax(logits, dim=-1)
26
- transcription = tokenizer.decode(predicted_ids[0])
27
  return transcription
28
 
29
  # 创建 Gradio 界面
 
3
  import torchaudio
4
  from torchaudio.transforms import Resample
5
 
6
+ # 定义一个简化的模型类(假设模型是LSTM架构)
7
+ class ASRModel(torch.nn.Module):
8
+ def __init__(self):
9
+ super(ASRModel, self).__init__()
10
+ self.lstm = torch.nn.LSTM(input_size=160, hidden_size=256, num_layers=3, batch_first=True)
11
+ self.linear = torch.nn.Linear(256, 29) # 假设29个输出类用于字符
12
+
13
+ def forward(self, x):
14
+ x, _ = self.lstm(x)
15
+ x = self.linear(x)
16
+ return x
17
+
18
  # 定义模型路径
19
  model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/large.pt"
20
 
21
  # 下载模型文件
22
  torch.hub.download_url_to_file(model_path, 'large.pt')
23
 
24
+ # 初始化模型
25
+ model = ASRModel()
26
+
27
  # 加载模型参数
28
+ model.load_state_dict(torch.load('large.pt', map_location=torch.device('cpu')))
29
  model.eval()
30
 
31
  # 定义处理函数
 
38
  with torch.no_grad():
39
  logits = model(input_values)
40
  predicted_ids = torch.argmax(logits, dim=-1)
41
+ transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符
42
  return transcription
43
 
44
  # 创建 Gradio 界面