jasspier commited on
Commit
71b0e8e
1 Parent(s): 7cc4a72
Files changed (1) hide show
  1. app.py +19 -11
app.py CHANGED
@@ -1,22 +1,30 @@
1
  import gradio as gr
2
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
3
  import torch
4
- import librosa
 
5
 
6
- # 加载模型和处理器
7
- model_name = "Tele-AI/TeleSpeech-ASR1.0"
8
- processor = Wav2Vec2Processor.from_pretrained(model_name)
9
- model = Wav2Vec2ForCTC.from_pretrained(model_name)
 
 
 
 
 
10
 
11
  # 定义处理函数
12
  def transcribe(audio):
13
- waveform, rate = librosa.load(audio, sr=16000)
14
- input_values = processor(waveform, return_tensors="pt", padding="longest").input_values
 
 
 
15
  with torch.no_grad():
16
- logits = model(input_values).logits
17
  predicted_ids = torch.argmax(logits, dim=-1)
18
- transcription = processor.batch_decode(predicted_ids)
19
- return transcription[0]
20
 
21
  # 创建 Gradio 界面
22
  iface = gr.Interface(
 
1
  import gradio as gr
 
2
  import torch
3
+ import torchaudio
4
+ from torchaudio.transforms import Resample
5
 
6
+ # 定义模型路径
7
+ model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/large.pt"
8
+
9
+ # 下载模型文件
10
+ torch.hub.download_url_to_file(model_path, 'large.pt')
11
+
12
+ # 加载模型
13
+ model = torch.jit.load('large.pt')
14
+ model.eval()
15
 
16
  # 定义处理函数
17
  def transcribe(audio):
18
+ waveform, sample_rate = torchaudio.load(audio)
19
+ resample = Resample(orig_freq=sample_rate, new_freq=16000)
20
+ waveform = resample(waveform)
21
+
22
+ input_values = waveform.unsqueeze(0)
23
  with torch.no_grad():
24
+ logits = model(input_values)
25
  predicted_ids = torch.argmax(logits, dim=-1)
26
+ transcription = tokenizer.decode(predicted_ids[0])
27
+ return transcription
28
 
29
  # 创建 Gradio 界面
30
  iface = gr.Interface(