File size: 2,531 Bytes
1869d7a
 
71b0e8e
31564d0
d3defc4
4ae1112
 
 
 
 
d3defc4
4ae1112
 
 
31564d0
 
ac47f83
31564d0
 
 
 
 
 
ca625d0
960bf82
 
31564d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1869d7a
 
 
e2bcfc6
71b0e8e
960bf82
 
 
 
 
eb24c35
04a460e
ca625d0
960bf82
1869d7a
960bf82
ca625d0
f8ebe93
31564d0
e2bcfc6
71b0e8e
1869d7a
 
 
f8ebe93
31564d0
1869d7a
 
 
 
 
e2bcfc6
1869d7a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import torch
import torchaudio
from torchaudio.transforms import Resample

# Function to import wav2vec2 model, avoiding duplicate registration
def import_wav2vec2():
    import sys
    if 'wav2vec2' not in sys.modules:
        import wav2vec2
    from wav2vec2 import Wav2Vec2Model, Wav2Vec2Config
    return Wav2Vec2Model, Wav2Vec2Config

Wav2Vec2Model, Wav2Vec2Config = import_wav2vec2()

# 定义模型路径
model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/finetune_large_kespeech.pt"

# 下载模型文件
print("Downloading model file...")
torch.hub.download_url_to_file(model_path, 'large.pt')
print("Model file downloaded.")

# 加载模型配置和初始化模型
config = Wav2Vec2Config()
model = Wav2Vec2Model.build_model(config)

# 加载模型参数
print("Loading model checkpoint...")
checkpoint = torch.load('large.pt', map_location=torch.device('cpu'))
print("Checkpoint keys:", checkpoint.keys())

# 打印模型参数中的键
if 'model' in checkpoint:
    state_dict = checkpoint['model']
    print("Model state_dict keys:", state_dict.keys())
else:
    print("Key 'model' not found in checkpoint.")
    state_dict = checkpoint

# 加载模型状态字典
try:
    model.load_state_dict(state_dict)
    print("Model state_dict loaded successfully.")
except Exception as e:
    print("Error loading model state_dict:", str(e))

model.eval()

# 定义处理函数
def transcribe(audio):
    print("Transcribing audio...")
    waveform, sample_rate = torchaudio.load(audio)
    if sample_rate != 16000:
        resample = Resample(orig_freq=sample_rate, new_freq=16000)
        waveform = resample(waveform).squeeze()
    else:
        waveform = waveform.squeeze()

    # 将输入数据转换为符合模型预期的形状
    input_values = waveform.unsqueeze(0)  # (batch_size, seq_len)

    with torch.no_grad():
        outputs = model.extract_features(input_values, padding_mask=None)
        logits = outputs["x"]
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()])  # 解码预测到字符
    print("Transcription:", transcription)
    return transcription

# 创建 Gradio 界面
iface = gr.Interface(
    fn=transcribe,
    inputs=gr.Audio(type="filepath"),
    outputs="text",
    title="TeleSpeech ASR",
    description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model."
)

print("Launching Gradio interface...")
iface.launch()