Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torchaudio | |
from torchaudio.transforms import Resample | |
import importlib.util | |
# Function to dynamically import wav2vec2 module and avoid duplicate registration | |
def import_wav2vec2(): | |
if 'wav2vec2' not in sys.modules: | |
spec = importlib.util.spec_from_file_location("wav2vec2", "wav2vec2.py") | |
wav2vec2 = importlib.util.module_from_spec(spec) | |
sys.modules['wav2vec2'] = wav2vec2 | |
spec.loader.exec_module(wav2vec2) | |
else: | |
wav2vec2 = sys.modules['wav2vec2'] | |
Wav2Vec2Model = wav2vec2.Wav2Vec2Model | |
Wav2Vec2Config = wav2vec2.Wav2Vec2Config | |
return Wav2Vec2Model, Wav2Vec2Config | |
Wav2Vec2Model, Wav2Vec2Config = import_wav2vec2() | |
# 定义模型路径 | |
model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/finetune_large_kespeech.pt" | |
# 下载模型文件 | |
print("Downloading model file...") | |
torch.hub.download_url_to_file(model_path, 'large.pt') | |
print("Model file downloaded.") | |
# 加载模型配置和初始化模型 | |
config = Wav2Vec2Config() | |
model = Wav2Vec2Model.build_model(config) | |
# 加载模型参数 | |
print("Loading model checkpoint...") | |
checkpoint = torch.load('large.pt', map_location=torch.device('cpu')) | |
print("Checkpoint keys:", checkpoint.keys()) | |
# 打印模型参数中的键 | |
if 'model' in checkpoint: | |
state_dict = checkpoint['model'] | |
print("Model state_dict keys:", state_dict.keys()) | |
else: | |
print("Key 'model' not found in checkpoint.") | |
state_dict = checkpoint | |
# 加载模型状态字典 | |
try: | |
model.load_state_dict(state_dict) | |
print("Model state_dict loaded successfully.") | |
except Exception as e: | |
print("Error loading model state_dict:", str(e)) | |
model.eval() | |
# 定义处理函数 | |
def transcribe(audio): | |
print("Transcribing audio...") | |
waveform, sample_rate = torchaudio.load(audio) | |
if sample_rate != 16000: | |
resample = Resample(orig_freq=sample_rate, new_freq=16000) | |
waveform = resample(waveform).squeeze() | |
else: | |
waveform = waveform.squeeze() | |
# 将输入数据转换为符合模型预期的形状 | |
input_values = waveform.unsqueeze(0) # (batch_size, seq_len) | |
with torch.no_grad(): | |
outputs = model.extract_features(input_values, padding_mask=None) | |
logits = outputs["x"] | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符 | |
print("Transcription:", transcription) | |
return transcription | |
# 创建 Gradio 界面 | |
iface = gr.Interface( | |
fn=transcribe, | |
inputs=gr.Audio(type="filepath"), | |
outputs="text", | |
title="TeleSpeech ASR", | |
description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model." | |
) | |
print("Launching Gradio interface...") | |
iface.launch() | |