File size: 4,139 Bytes
7bcf8d7 1d571fd 7bcf8d7 408e3fc 7bcf8d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import librosa
from transformers import Wav2Vec2ForCTC, AutoProcessor
import torch
import json
from huggingface_hub import hf_hub_download
from torchaudio.models.decoder import ctc_decoder
ASR_SAMPLING_RATE = 16_000
ASR_LANGUAGES = {}
with open(f"data/asr/all_langs.tsv") as f:
for line in f:
iso, name = line.split(" ", 1)
ASR_LANGUAGES[iso] = name
MODEL_ID = "facebook/mms-1b-all"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
lm_decoding_config = {}
lm_decoding_configfile = hf_hub_download(
repo_id="facebook/mms-cclms",
filename="decoding_config.json",
subfolder="mms-1b-all",
)
with open(lm_decoding_configfile) as f:
lm_decoding_config = json.loads(f.read())
# allow language model decoding for specific languages
lm_decode_isos = ["eng"]
def transcribe(
audio_source=None, microphone=None, file_upload=None, lang="eng (English)"
):
if type(microphone) is dict:
# HACK: microphone variable is a dict when running on examples
microphone = microphone["name"]
audio_fp = (
file_upload if "upload" in str(audio_source or "").lower() else microphone
)
if audio_fp is None:
return "ERROR: You have to either use the microphone or upload an audio file"
audio_samples = librosa.load(audio_fp, sr=ASR_SAMPLING_RATE, mono=True)[0]
lang_code = lang.split()[0]
processor.tokenizer.set_target_lang(lang_code)
model.load_adapter(lang_code)
inputs = processor(
audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
)
# set device
if torch.cuda.is_available():
device = torch.device("cuda")
elif (
hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
and torch.backends.mps.is_built()
):
device = torch.device("mps")
else:
device = torch.device("cpu")
model.to(device)
inputs = inputs.to(device)
with torch.no_grad():
outputs = model(**inputs).logits
if lang_code not in lm_decoding_config or lang_code not in lm_decode_isos:
ids = torch.argmax(outputs, dim=-1)[0]
transcription = processor.decode(ids)
else:
decoding_config = lm_decoding_config[lang_code]
lm_file = hf_hub_download(
repo_id="facebook/mms-cclms",
filename=decoding_config["lmfile"].rsplit("/", 1)[1],
subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
)
token_file = hf_hub_download(
repo_id="facebook/mms-cclms",
filename=decoding_config["tokensfile"].rsplit("/", 1)[1],
subfolder=decoding_config["tokensfile"].rsplit("/", 1)[0],
)
lexicon_file = None
if decoding_config["lexiconfile"] is not None:
lexicon_file = hf_hub_download(
repo_id="facebook/mms-cclms",
filename=decoding_config["lexiconfile"].rsplit("/", 1)[1],
subfolder=decoding_config["lexiconfile"].rsplit("/", 1)[0],
)
beam_search_decoder = ctc_decoder(
lexicon=lexicon_file,
tokens=token_file,
lm=lm_file,
nbest=1,
beam_size=500,
beam_size_token=50,
lm_weight=float(decoding_config["lmweight"]),
word_score=float(decoding_config["wordscore"]),
sil_score=float(decoding_config["silweight"]),
blank_token="<s>",
)
beam_search_result = beam_search_decoder(outputs.to("cpu"))
transcription = " ".join(beam_search_result[0][0].words).strip()
return transcription
ASR_EXAMPLES = [
[None, "assets/english.mp3", None, "eng (English)"],
# [None, "assets/tamil.mp3", None, "tam (Tamil)"],
# [None, "assets/burmese.mp3", None, "mya (Burmese)"],
]
ASR_NOTE = """
The above demo uses beam-search decoding with LM for English and greedy decoding results for all other languages.
Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for other languages.
"""
|