Spaces:
Edmond98
/
Running on A100

sts / asr.py
Afrinetwork7's picture
Update asr.py
10fc892 verified
raw
history blame contribute delete
No virus
2.79 kB
import librosa
from transformers import Wav2Vec2ForCTC, AutoProcessor
import torch
import numpy as np
from pathlib import Path
import concurrent.futures
ASR_SAMPLING_RATE = 16_000
CHUNK_LENGTH_S = 60 # Increased to 60 seconds per chunk
MAX_CONCURRENT_CHUNKS = 4 # Adjust based on VRAM availability
ASR_LANGUAGES = {}
with open(f"data/asr/all_langs.tsv") as f:
for line in f:
iso, name = line.split(" ", 1)
ASR_LANGUAGES[iso.strip()] = name.strip()
MODEL_ID = "facebook/mms-1b-all"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
def load_audio(audio_data):
if isinstance(audio_data, tuple):
sr, audio_samples = audio_data
audio_samples = (audio_samples / 32768.0).astype(np.float32)
if sr != ASR_SAMPLING_RATE:
audio_samples = librosa.resample(
audio_samples, orig_sr=sr, target_sr=ASR_SAMPLING_RATE
)
elif isinstance(audio_data, np.ndarray):
audio_samples = audio_data
elif isinstance(audio_data, str):
audio_samples = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)[0]
else:
raise ValueError(f"Invalid Audio Input Instance: {type(audio_data)}")
return audio_samples
def process_chunk(chunk, device):
inputs = processor(chunk, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs).logits
ids = torch.argmax(outputs, dim=-1)[0]
return processor.decode(ids)
def transcribe(audio_data=None, lang="eng (English)"):
if audio_data is None or (isinstance(audio_data, np.ndarray) and audio_data.size == 0):
return "<<ERROR: Empty Audio Input>>"
try:
audio_samples = load_audio(audio_data)
except Exception as e:
return f"<<ERROR: {str(e)}>>"
lang_code = lang.split()[0]
processor.tokenizer.set_target_lang(lang_code)
model.load_adapter(lang_code)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
chunk_length = int(CHUNK_LENGTH_S * ASR_SAMPLING_RATE)
chunks = [audio_samples[i:i+chunk_length] for i in range(0, len(audio_samples), chunk_length)]
transcriptions = []
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CHUNKS) as executor:
future_to_chunk = {executor.submit(process_chunk, chunk, device): chunk for chunk in chunks}
for future in concurrent.futures.as_completed(future_to_chunk):
transcriptions.append(future.result())
return " ".join(transcriptions)
# Example usage
ASR_EXAMPLES = [
["upload/english.mp3", "eng (English)"],
# ["upload/tamil.mp3", "tam (Tamil)"],
# ["upload/burmese.mp3", "mya (Burmese)"],
]