Ar4ikov commited on
Commit
e8122f3
1 Parent(s): 833d68e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -28
app.py CHANGED
@@ -12,15 +12,16 @@ import numpy as np
12
  import subprocess
13
 
14
 
15
- def resample(speech_array, sampling_rate):
16
- resampler = torchaudio.transforms.Resample(sampling_rate)
17
- speech = resampler(speech_array).squeeze().astype("double")
 
18
  return speech
19
 
20
 
21
- def predict(speech_array, sampling_rate):
22
- speech = resample(speech_array, sampling_rate)
23
- inputs = feature_extractor(speech, sampling_rate=SR, return_tensors="pt", padding=True)
24
  inputs = {key: inputs[key].to(device) for key in inputs}
25
 
26
  with torch.no_grad():
@@ -41,27 +42,15 @@ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("Aniemore/wav2vec2-
41
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
 
43
 
44
- def transcribe(audio):
45
- sr, audio = audio[0], audio[1]
46
- return predict(audio, sr)
47
 
48
 
49
- def get_asr_interface():
50
- return gr.Interface(
51
- fn=transcribe,
52
- inputs=[
53
- gr.inputs.Audio(source="upload", type="numpy")
54
- ],
55
- outputs=[
56
- "json"
57
- ])
58
-
59
- interfaces = [
60
- get_asr_interface()
61
- ]
62
-
63
- names = [
64
- "Russian Emotion Recognition"
65
- ]
66
-
67
- gr.TabbedInterface(interfaces, names).launch(server_name = "0.0.0.0", enable_queue=False)
 
12
  import subprocess
13
 
14
 
15
+ def resample(path, sampling_rate):
16
+ speech_array, _sampling_rate = torchaudio.load(path)
17
+ resampler = torchaudio.transforms.Resample(_sampling_rate)
18
+ speech = resampler(speech_array).squeeze().numpy()
19
  return speech
20
 
21
 
22
+ def predict(path, sampling_rate=SR):
23
+ speech = resample(path, sampling_rate)
24
+ inputs = feature_extractor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
25
  inputs = {key: inputs[key].to(device) for key in inputs}
26
 
27
  with torch.no_grad():
 
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
 
44
 
45
+ def recognize(audio_path):
46
+ return predict(audio_path)
 
47
 
48
 
49
+ with gr.Blocks() as blocks:
50
+ audio = gr.Audio(source="microphone", type="filepath", label="Скажите что-нибудь...")
51
+ success_button = gr.Button('Распознать эмоции')
52
+ output = gr.JSON(label="Эмоции")
53
+
54
+ success_button.click(fn=recognize, inputs=[audio], outputs=[output])
55
+
56
+ blocks.launch(enable_queue=True, debug=True)