Afrinetwork7 commited on
Commit
488f2ba
1 Parent(s): 521a4ba

Update asr.py

Browse files
Files changed (1) hide show
  1. asr.py +38 -97
asr.py CHANGED
@@ -3,76 +3,19 @@ from transformers import Wav2Vec2ForCTC, AutoProcessor
3
  import torch
4
  import numpy as np
5
  from pathlib import Path
6
-
7
- from huggingface_hub import hf_hub_download
8
- from torchaudio.models.decoder import ctc_decoder
9
 
10
  ASR_SAMPLING_RATE = 16_000
11
-
12
- ASR_LANGUAGES = {}
13
- with open(f"data/asr/all_langs.tsv") as f:
14
- for line in f:
15
- iso, name = line.split(" ", 1)
16
- ASR_LANGUAGES[iso.strip()] = name.strip()
17
 
18
  MODEL_ID = "facebook/mms-1b-all"
19
 
20
  processor = AutoProcessor.from_pretrained(MODEL_ID)
21
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
22
 
23
-
24
- # lm_decoding_config = {}
25
- # lm_decoding_configfile = hf_hub_download(
26
- # repo_id="facebook/mms-cclms",
27
- # filename="decoding_config.json",
28
- # subfolder="mms-1b-all",
29
- # )
30
-
31
- # with open(lm_decoding_configfile) as f:
32
- # lm_decoding_config = json.loads(f.read())
33
-
34
- # # allow language model decoding for "eng"
35
-
36
- # decoding_config = lm_decoding_config["eng"]
37
-
38
- # lm_file = hf_hub_download(
39
- # repo_id="facebook/mms-cclms",
40
- # filename=decoding_config["lmfile"].rsplit("/", 1)[1],
41
- # subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
42
- # )
43
- # token_file = hf_hub_download(
44
- # repo_id="facebook/mms-cclms",
45
- # filename=decoding_config["tokensfile"].rsplit("/", 1)[1],
46
- # subfolder=decoding_config["tokensfile"].rsplit("/", 1)[0],
47
- # )
48
- # lexicon_file = None
49
- # if decoding_config["lexiconfile"] is not None:
50
- # lexicon_file = hf_hub_download(
51
- # repo_id="facebook/mms-cclms",
52
- # filename=decoding_config["lexiconfile"].rsplit("/", 1)[1],
53
- # subfolder=decoding_config["lexiconfile"].rsplit("/", 1)[0],
54
- # )
55
-
56
- # beam_search_decoder = ctc_decoder(
57
- # lexicon=lexicon_file,
58
- # tokens=token_file,
59
- # lm=lm_file,
60
- # nbest=1,
61
- # beam_size=500,
62
- # beam_size_token=50,
63
- # lm_weight=float(decoding_config["lmweight"]),
64
- # word_score=float(decoding_config["wordscore"]),
65
- # sil_score=float(decoding_config["silweight"]),
66
- # blank_token="<s>",
67
- # )
68
-
69
-
70
- def transcribe(audio_data=None, lang="eng (English)"):
71
- if audio_data is None or (isinstance(audio_data, np.ndarray) and audio_data.size == 0):
72
- return "<<ERROR: Empty Audio Input>>"
73
-
74
  if isinstance(audio_data, tuple):
75
- # microphone
76
  sr, audio_samples = audio_data
77
  audio_samples = (audio_samples / 32768.0).astype(np.float32)
78
  if sr != ASR_SAMPLING_RATE:
@@ -80,59 +23,57 @@ def transcribe(audio_data=None, lang="eng (English)"):
80
  audio_samples, orig_sr=sr, target_sr=ASR_SAMPLING_RATE
81
  )
82
  elif isinstance(audio_data, np.ndarray):
83
- # Assuming audio_data is already in the correct format
84
  audio_samples = audio_data
85
  elif isinstance(audio_data, str):
86
- # file upload
87
  audio_samples = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)[0]
88
  else:
89
- return f"<<ERROR: Invalid Audio Input Instance: {type(audio_data)}>>"
90
- audio_samples = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  lang_code = lang.split()[0]
93
  processor.tokenizer.set_target_lang(lang_code)
94
  model.load_adapter(lang_code)
95
 
96
- inputs = processor(
97
- audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
98
- )
99
-
100
- # set device
101
- if torch.cuda.is_available():
102
- device = torch.device("cuda")
103
- elif (
104
- hasattr(torch.backends, "mps")
105
- and torch.backends.mps.is_available()
106
- and torch.backends.mps.is_built()
107
- ):
108
- device = torch.device("mps")
109
- else:
110
- device = torch.device("cpu")
111
-
112
  model.to(device)
113
- inputs = inputs.to(device)
114
 
115
- with torch.no_grad():
116
- outputs = model(**inputs).logits
117
 
118
- if lang_code != "eng" or True:
119
- ids = torch.argmax(outputs, dim=-1)[0]
120
- transcription = processor.decode(ids)
121
- else:
122
- assert False
123
- # beam_search_result = beam_search_decoder(outputs.to("cpu"))
124
- # transcription = " ".join(beam_search_result[0][0].words).strip()
125
-
126
- return transcription
127
 
 
128
 
 
129
  ASR_EXAMPLES = [
130
  ["upload/english.mp3", "eng (English)"],
131
  # ["upload/tamil.mp3", "tam (Tamil)"],
132
  # ["upload/burmese.mp3", "mya (Burmese)"],
133
  ]
134
 
135
- ASR_NOTE = """
136
- The above demo doesn't use beam-search decoding using a language model.
137
- Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for better accuracy.
138
- """
 
 
3
  import torch
4
  import numpy as np
5
  from pathlib import Path
6
+ import concurrent.futures
 
 
7
 
8
  ASR_SAMPLING_RATE = 16_000
9
+ CHUNK_LENGTH_S = 60 # Increased to 60 seconds per chunk
10
+ MAX_CONCURRENT_CHUNKS = 4 # Adjust based on VRAM availability
 
 
 
 
11
 
12
  MODEL_ID = "facebook/mms-1b-all"
13
 
14
  processor = AutoProcessor.from_pretrained(MODEL_ID)
15
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
16
 
17
+ def load_audio(audio_data):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  if isinstance(audio_data, tuple):
 
19
  sr, audio_samples = audio_data
20
  audio_samples = (audio_samples / 32768.0).astype(np.float32)
21
  if sr != ASR_SAMPLING_RATE:
 
23
  audio_samples, orig_sr=sr, target_sr=ASR_SAMPLING_RATE
24
  )
25
  elif isinstance(audio_data, np.ndarray):
 
26
  audio_samples = audio_data
27
  elif isinstance(audio_data, str):
 
28
  audio_samples = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)[0]
29
  else:
30
+ raise ValueError(f"Invalid Audio Input Instance: {type(audio_data)}")
31
+ return audio_samples
32
+
33
+ def process_chunk(chunk, device):
34
+ inputs = processor(chunk, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt").to(device)
35
+ with torch.no_grad():
36
+ outputs = model(**inputs).logits
37
+ ids = torch.argmax(outputs, dim=-1)[0]
38
+ return processor.decode(ids)
39
+
40
+ def transcribe(audio_data=None, lang="eng (English)"):
41
+ if audio_data is None or (isinstance(audio_data, np.ndarray) and audio_data.size == 0):
42
+ return "<<ERROR: Empty Audio Input>>"
43
+
44
+ try:
45
+ audio_samples = load_audio(audio_data)
46
+ except Exception as e:
47
+ return f"<<ERROR: {str(e)}>>"
48
 
49
  lang_code = lang.split()[0]
50
  processor.tokenizer.set_target_lang(lang_code)
51
  model.load_adapter(lang_code)
52
 
53
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  model.to(device)
 
55
 
56
+ chunk_length = int(CHUNK_LENGTH_S * ASR_SAMPLING_RATE)
57
+ chunks = [audio_samples[i:i+chunk_length] for i in range(0, len(audio_samples), chunk_length)]
58
 
59
+ transcriptions = []
60
+
61
+ with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CHUNKS) as executor:
62
+ future_to_chunk = {executor.submit(process_chunk, chunk, device): chunk for chunk in chunks}
63
+ for future in concurrent.futures.as_completed(future_to_chunk):
64
+ transcriptions.append(future.result())
 
 
 
65
 
66
+ return " ".join(transcriptions)
67
 
68
+ # Example usage
69
  ASR_EXAMPLES = [
70
  ["upload/english.mp3", "eng (English)"],
71
  # ["upload/tamil.mp3", "tam (Tamil)"],
72
  # ["upload/burmese.mp3", "mya (Burmese)"],
73
  ]
74
 
75
+ if __name__ == "__main__":
76
+ for audio_file, language in ASR_EXAMPLES:
77
+ print(f"Transcribing {audio_file} in {language}")
78
+ transcription = transcribe(audio_file, language)
79
+ print(f"Transcription: {transcription}\n")