erastorgueva-nv commited on
Commit
a28041c
1 Parent(s): 9379375

allow buffered inference up to 10 mins

Browse files
Files changed (2) hide show
  1. app.py +53 -18
  2. requirements.txt +1 -1
app.py CHANGED
@@ -6,16 +6,39 @@ import soundfile as sf
6
  import tempfile
7
  import uuid
8
 
 
 
9
  from nemo.collections.asr.models import ASRModel
 
 
10
 
11
  SAMPLE_RATE = 16000 # Hz
 
12
 
13
  model = ASRModel.from_pretrained("nvidia/canary-1b")
14
  model.eval()
15
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- MAX_AUDIO_SECONDS = 40
 
 
 
 
 
18
 
 
19
 
20
  def convert_audio(audio_filepath, tmpdir, utt_id):
21
  """
@@ -24,21 +47,20 @@ def convert_audio(audio_filepath, tmpdir, utt_id):
24
  Returns output filename and duration.
25
  """
26
 
27
- data, sr = librosa.load(audio_filepath)
28
 
29
  duration = librosa.get_duration(y=data, sr=sr)
30
 
31
- if duration > MAX_AUDIO_SECONDS:
32
  raise gr.Error(
33
- f"This demo can transcribe up to {MAX_AUDIO_SECONDS} seconds of audio."
 
 
34
  )
35
 
36
  if sr != SAMPLE_RATE:
37
  data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
38
 
39
- # monochannel
40
- data = librosa.to_mono(data)
41
-
42
  out_filename = os.path.join(tmpdir, utt_id + '.wav')
43
 
44
  # save output audio
@@ -54,7 +76,6 @@ def transcribe(audio_filepath, src_lang, tgt_lang, pnc):
54
 
55
  utt_id = uuid.uuid4()
56
  with tempfile.TemporaryDirectory() as tmpdir:
57
-
58
  converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id))
59
 
60
  # map src_lang and tgt_lang from long versions to short
@@ -102,9 +123,23 @@ def transcribe(audio_filepath, src_lang, tgt_lang, pnc):
102
  fout.write(line + '\n')
103
 
104
  # call transcribe, passing in manifest filepath
105
- model_output = model.transcribe(manifest_filepath)
106
-
107
- return model_output[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  # add logic to make sure dropdown menus only suggest valid combos
110
  def on_src_or_tgt_lang_change(src_lang_value, tgt_lang_value, pnc_value):
@@ -124,15 +159,15 @@ def on_src_or_tgt_lang_change(src_lang_value, tgt_lang_value, pnc_value):
124
  tgt_lang_value, and then which states you can go to from there.
125
 
126
  tgt lang
127
- - |EN |ES |FR |DE
128
- ------------------
129
- EN| Y | Y | Y | Y
130
- ------------------
131
  src ES| Y | Y | |
132
  lang ------------------
133
- FR| Y | | Y |
134
- ------------------
135
- DE| Y | | | Y
136
  """
137
 
138
  if src_lang_value == "English" and tgt_lang_value == "English":
 
6
  import tempfile
7
  import uuid
8
 
9
+ import torch
10
+
11
  from nemo.collections.asr.models import ASRModel
12
+ from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED
13
+ from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED
14
 
15
  SAMPLE_RATE = 16000 # Hz
16
+ MAX_AUDIO_MINUTES = 10 # wont try to transcribe if longer than this
17
 
18
  model = ASRModel.from_pretrained("nvidia/canary-1b")
19
  model.eval()
20
 
21
+ # make sure beam size always 1 for consistency
22
+ model.change_decoding_strategy(None)
23
+ decoding_cfg = model.cfg.decoding
24
+ decoding_cfg.beam.beam_size = 1
25
+ model.change_decoding_strategy(decoding_cfg)
26
+
27
+ # setup for buffered inference
28
+ model.cfg.preprocessor.dither = 0.0
29
+ model.cfg.preprocessor.pad_to = 0
30
+
31
+ feature_stride = model.cfg.preprocessor['window_stride']
32
+ model_stride_in_secs = feature_stride * 8 # 8 = model stride, which is 8 for FastConformer
33
 
34
+ frame_asr = FrameBatchMultiTaskAED(
35
+ asr_model=model,
36
+ frame_len=40.0,
37
+ total_buffer=40.0,
38
+ batch_size=16,
39
+ )
40
 
41
+ amp_dtype = torch.float16
42
 
43
  def convert_audio(audio_filepath, tmpdir, utt_id):
44
  """
 
47
  Returns output filename and duration.
48
  """
49
 
50
+ data, sr = librosa.load(audio_filepath, sr=None, mono=False)
51
 
52
  duration = librosa.get_duration(y=data, sr=sr)
53
 
54
+ if duration / 60.0 > MAX_AUDIO_MINUTES:
55
  raise gr.Error(
56
+ f"This demo can transcribe up to {MAX_AUDIO_MINUTES} minutes of audio. "
57
+ "If you wish, you may trim the audio using the Audio viewer in Step 1 "
58
+ "(click on the scissors icon to start trimming audio)."
59
  )
60
 
61
  if sr != SAMPLE_RATE:
62
  data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
63
 
 
 
 
64
  out_filename = os.path.join(tmpdir, utt_id + '.wav')
65
 
66
  # save output audio
 
76
 
77
  utt_id = uuid.uuid4()
78
  with tempfile.TemporaryDirectory() as tmpdir:
 
79
  converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id))
80
 
81
  # map src_lang and tgt_lang from long versions to short
 
123
  fout.write(line + '\n')
124
 
125
  # call transcribe, passing in manifest filepath
126
+ if duration < 40:
127
+ output_text = model.transcribe(manifest_filepath)[0]
128
+ else: # do buffered inference
129
+ with torch.cuda.amp.autocast(dtype=amp_dtype): # TODO: make it work if no cuda
130
+ with torch.no_grad():
131
+ hyps = get_buffered_pred_feat_multitaskAED(
132
+ frame_asr,
133
+ model.cfg.preprocessor,
134
+ model_stride_in_secs,
135
+ model.device,
136
+ manifest=manifest_filepath,
137
+ filepaths=None,
138
+ )
139
+
140
+ output_text = hyps[0].text
141
+
142
+ return output_text
143
 
144
  # add logic to make sure dropdown menus only suggest valid combos
145
  def on_src_or_tgt_lang_change(src_lang_value, tgt_lang_value, pnc_value):
 
159
  tgt_lang_value, and then which states you can go to from there.
160
 
161
  tgt lang
162
+ - |EN |ES |FR |DE
163
+ ------------------
164
+ EN| Y | Y | Y | Y
165
+ ------------------
166
  src ES| Y | Y | |
167
  lang ------------------
168
+ FR| Y | | Y |
169
+ ------------------
170
+ DE| Y | | | Y
171
  """
172
 
173
  if src_lang_value == "English" and tgt_lang_value == "English":
requirements.txt CHANGED
@@ -1 +1 @@
1
- git+https://github.com/NVIDIA/NeMo.git@r1.23.0#egg=nemo_toolkit[all]
 
1
+ git+https://github.com/NVIDIA/NeMo.git@61325fe0c70ef4294d8562991f6841d26b238e85#egg=nemo_toolkit[all] # commit from canary_buffer_infer branch