Ar4ikov commited on
Commit
1ebf1c8
1 Parent(s): 9585e6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -12,15 +12,15 @@ import numpy as np
12
  import subprocess
13
 
14
 
15
- def speech_file_to_array_fn(path, sampling_rate):
16
- speech_array, _sampling_rate = torchaudio.load(path)
17
- resampler = torchaudio.transforms.Resample(_sampling_rate)
18
  speech = resampler(speech_array).squeeze().numpy()
19
  return speech
20
 
21
 
22
- def predict(speech, sampling_rate):
23
- inputs = feature_extractor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
 
24
  inputs = {key: inputs[key].to(device) for key in inputs}
25
 
26
  with torch.no_grad():
@@ -32,6 +32,7 @@ def predict(speech, sampling_rate):
32
 
33
 
34
  TRUST = True
 
35
 
36
  config = AutoConfig.from_pretrained('Aniemore/wav2vec2-xlsr-53-russian-emotion-recognition', trust_remote_code=TRUST)
37
  model = AutoModel.from_pretrained("Aniemore/wav2vec2-xlsr-53-russian-emotion-recognition", trust_remote_code=TRUST)
@@ -42,8 +43,8 @@ model.to(device)
42
 
43
 
44
  def transcribe(audio):
45
- print(audio)
46
- return predict(audio, 16000)
47
 
48
 
49
  def get_asr_interface():
 
12
  import subprocess
13
 
14
 
15
+ def resample(speech_array, sampling_rate):
16
+ resampler = torchaudio.transforms.Resample(sampling_rate)
 
17
  speech = resampler(speech_array).squeeze().numpy()
18
  return speech
19
 
20
 
21
+ def predict(speech_array, sampling_rate):
22
+ speech = resample(speech_array, sampling_rate)
23
+ inputs = feature_extactor(speech, sampling_rate=SR, return_tensors="pt", padding=True)
24
  inputs = {key: inputs[key].to(device) for key in inputs}
25
 
26
  with torch.no_grad():
 
32
 
33
 
34
  TRUST = True
35
+ SR = 16000
36
 
37
  config = AutoConfig.from_pretrained('Aniemore/wav2vec2-xlsr-53-russian-emotion-recognition', trust_remote_code=TRUST)
38
  model = AutoModel.from_pretrained("Aniemore/wav2vec2-xlsr-53-russian-emotion-recognition", trust_remote_code=TRUST)
 
43
 
44
 
45
  def transcribe(audio):
46
+ sr, audio = audio[0], audio[1]
47
+ return predict(audio, sr)
48
 
49
 
50
  def get_asr_interface():