File size: 2,373 Bytes
1869d7a
 
71b0e8e
31564d0
1869d7a
31564d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1869d7a
 
 
e2bcfc6
71b0e8e
31564d0
eb24c35
 
31564d0
1869d7a
31564d0
f8ebe93
31564d0
e2bcfc6
71b0e8e
1869d7a
 
 
f8ebe93
31564d0
1869d7a
 
 
 
 
e2bcfc6
1869d7a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
import torch
import torchaudio
from torchaudio.transforms import Resample

# 定义一个假设的 ASR 模型结构
class ASRModel(torch.nn.Module):
    def __init__(self):
        super(ASRModel, self).__init__()
        # 这里假设模型架构是一个简单的 LSTM
        self.lstm = torch.nn.LSTM(input_size=160, hidden_size=256, num_layers=3, batch_first=True)
        self.linear = torch.nn.Linear(256, 29)  # 假设有 29 个输出类用于字符

    def forward(self, x):
        x, _ = self.lstm(x)
        x = self.linear(x)
        return x

# 定义模型路径
model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/base.pt"

# 下载模型文件
print("Downloading model file...")
torch.hub.download_url_to_file(model_path, 'large.pt')
print("Model file downloaded.")

# 初始化模型
model = ASRModel()

# 加载模型参数
print("Loading model checkpoint...")
checkpoint = torch.load('large.pt', map_location=torch.device('cpu'))
print("Checkpoint keys:", checkpoint.keys())

# 打印模型参数中的键
if 'model' in checkpoint:
    state_dict = checkpoint['model']
    print("Model state_dict keys:", state_dict.keys())
else:
    print("Key 'model' not found in checkpoint.")
    state_dict = checkpoint

# 加载模型状态字典
try:
    model.load_state_dict(state_dict)
    print("Model state_dict loaded successfully.")
except Exception as e:
    print("Error loading model state_dict:", str(e))

model.eval()

# 定义处理函数
def transcribe(audio):
    print("Transcribing audio...")
    waveform, sample_rate = torchaudio.load(audio)
    resample = Resample(orig_freq=sample_rate, new_freq=16000)
    waveform = resample(waveform).squeeze()

    input_values = waveform.unsqueeze(0)
    with torch.no_grad():
        logits = model(input_values)
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()])  # 解码预测到字符
    print("Transcription:", transcription)
    return transcription

# 创建 Gradio 界面
iface = gr.Interface(
    fn=transcribe,
    inputs=gr.Audio(type="filepath"),
    outputs="text",
    title="TeleSpeech ASR",
    description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model."
)

print("Launching Gradio interface...")
iface.launch()