Yurii Paniv
Fix dtype handling
1040275
raw
history blame
3.72 kB
from io import BytesIO
from typing import Tuple
import wave
import gradio as gr
import numpy as np
from pydub.audio_segment import AudioSegment
import requests
from os.path import exists
from stt import Model
from datetime import datetime
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch
# download model
version = "v0.4"
storage_url = f"https://github.com/robinhad/voice-recognition-ua/releases/download/{version}"
model_name = "uk.tflite"
scorer_name = "kenlm.scorer"
model_link = f"{storage_url}/{model_name}"
scorer_link = f"{storage_url}/{scorer_name}"
model = Wav2Vec2ForCTC.from_pretrained("robinhad/wav2vec2-xls-r-300m-uk")#.to("cuda")
processor = Wav2Vec2Processor.from_pretrained("robinhad/wav2vec2-xls-r-300m-uk")
# TODO: download config.json, pytorch_model.bin, preprocessor_config.json, tokenizer_config.json, vocab.json, added_tokens.json, special_tokens.json
def download(url, file_name):
if not exists(file_name):
print(f"Downloading {file_name}")
r = requests.get(url, allow_redirects=True)
with open(file_name, 'wb') as file:
file.write(r.content)
else:
print(f"Found {file_name}. Skipping download...")
def deepspeech(audio: np.array, use_scorer=False):
ds = Model(model_name)
if use_scorer:
ds.enableExternalScorer("kenlm.scorer")
result = ds.stt(audio)
return result
def wav2vec2(audio: np.array):
input_dict = processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
output = model(input_dict.input_values.float())
logits = output.logits
pred_ids = torch.argmax(logits, dim=-1)[0]
return processor.decode(pred_ids)
def inference(audio: Tuple[int, np.array]):
print("=============================")
print(f"Time: {datetime.utcnow()}.`")
output_audio = _convert_audio(audio[1], audio[0])
fin = wave.open(output_audio, 'rb')
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
fin.close()
transcripts = []
transcripts.append(wav2vec2(audio))
print(f"Wav2Vec2: `{transcripts[-1]}`")
transcripts.append(deepspeech(audio, use_scorer=True))
print(f"Deepspeech with LM: `{transcripts[-1]}`")
transcripts.append(deepspeech(audio))
print(f"Deepspeech: `{transcripts[-1]}`")
return tuple(transcripts)
def _convert_audio(audio_data: np.array, sample_rate: int):
audio_limit = sample_rate * 60 * 2 # limit audio to 2 minutes max
if audio_data.shape[0] > audio_limit:
audio_data = audio_data[0:audio_limit]
source_audio = BytesIO()
source_audio.write(audio_data)
source_audio.seek(0)
output_audio = BytesIO()
wav_file: AudioSegment = AudioSegment.from_raw(
source_audio,
channels=1,
sample_width=audio_data.dtype.itemsize,
frame_rate=sample_rate
)
wav_file.export(output_audio, "wav", codec="pcm_s16le", parameters=["-ar", "16k"])
output_audio.seek(0)
return output_audio
with open("README.md") as file:
article = file.read()
article = article[article.find("---\n", 4) + 5::]
iface = gr.Interface(
fn=inference,
inputs=[
gr.inputs.Audio(type="numpy",
label="Аудіо", optional=False),
],
outputs=[gr.outputs.Textbox(label="Wav2Vec2"), gr.outputs.Textbox(label="DeepSpeech with LM"), gr.outputs.Textbox(label="DeepSpeech")],
title="🇺🇦 Ukrainian Speech-to-Text models",
theme="huggingface",
description="Україномовний🇺🇦 Speech-to-Text за допомогою Coqui STT",
article=article,
)
download(model_link, model_name)
download(scorer_link, scorer_name)
iface.launch()