whisperX / app.py
ashhadahsan's picture
Update app.py
4751966
raw
history blame
14.1 kB
import numpy as np
import streamlit as st
from constants import WHISPER_MODELS, language_dict
import streamlit as st
from utils import (
translate_to_english,
detect_language,
write,
read,
get_key,
)
import whisperx as whisper
import json
import pandas as pd
from pydub import AudioSegment
import os
import uuid
if "btn1" not in st.session_state:
st.session_state["btn1"] = False
if "btn2" not in st.session_state:
st.session_state["btn2"] = False
class ByteEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, bytes):
return obj.hex()
return json.JSONEncoder.default(self, obj)
def disable_btn2():
st.session_state["btn2"] = True
def disable_btn1():
st.session_state["btn1"] = True
st.set_page_config(page_title="Whisper-X", layout="wide")
import torch
if torch.cuda.is_available():
device = "gpu"
else:
device = "cpu"
input, output = st.columns(2, gap="medium")
with input:
st.header("Input")
audio_file = open("audio.wav", "rb")
audio_bytes = audio_file.read()
# st.markdown("""**sample audio**""", unsafe_allow_html=True)
st.audio(audio_bytes, format="audio/wav")
# st.markdown("""**your audio file**""", unsafe_allow_html=True)
audio_uploaded = st.file_uploader(
label="Upload your file",
type=["mp3", "wav"],
help="Your input file",
# on_change=disable_btn2,
# disabled=st.session_state["btn1"],
)
text_json = st.file_uploader(
label="Aligned JSON",
type=["json"],
help="Your aligned json file (Only if you need to skip transcribe)",
# disabled=st.session_state["btn2"],
# on_change=disable_btn1,
)
# text_json = None
# st.markdown("""**model**""", unsafe_allow_html=True)
model_name = st.selectbox(
label="Choose your model",
options=WHISPER_MODELS,
help="Choose a Whisper model.",
)
model_name = "base" if model_name == "" else model_name
# st.markdown("**transcription**", unsafe_allow_html=True)
transcription = st.selectbox(
"transcription",
options=["plain text", "srt", "vtt", "ass", "tsv"],
help="Choose the format for the transcription",
)
translate = st.checkbox(
"translate", help="Translate the text to English when set to True"
)
language = st.selectbox(
label="language",
options=list(language_dict.keys()) + list(language_dict.values()),
help="Translate the text to English when set to True",
)
patience = st.number_input(
label="patience",
step=0.01,
value=1.0,
help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search",
)
temperature = st.number_input(
label="temperature",
step=0.01,
value=1.0,
help="temperature to use for sampling",
)
suppress_tokens = st.text_input(
"suppress_tokens",
value="-1",
help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations",
)
initial_prompt = st.text_area(
label="initial_prompt",
help="optional text to provide as a prompt for the first window.",
)
condition_on_previous_text = st.checkbox(
"condition_on_previous_text",
help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop",
)
temperature_increment_on_fallback = st.number_input(
label="temperature_increment_on_fallback",
step=0.01,
value=0.2,
help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below",
)
compression_ratio_threshold = st.number_input(
label="compression_ratio_threshold",
value=2.4,
step=0.01,
help="if the gzip compression ratio is higher than this value, treat the decoding as failed",
)
logprob_threshold = st.number_input(
label="logprob_threshold",
value=-1.0,
step=0.01,
help="if the average log probability is lower than this value, treat the decoding as failed",
)
no_speech_threshold = st.number_input(
label="no_speech_threshold",
value=0.6,
step=0.01,
help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence",
)
if temperature_increment_on_fallback is not None:
temperature = tuple(
np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback)
)
else:
temperature = [temperature]
submit = st.button("Submit", type="primary")
with output:
st.header("Output")
segments_pre = st.empty()
segments_post = st.empty()
segments_post_json = st.empty()
segments_post2 = st.empty()
trans = st.empty()
lang = st.empty()
name = str(uuid.uuid1())
if submit:
if audio_uploaded is None:
# st.audio(audio_bytes, format="audio/wav")
audio_uploaded = audio_file
if audio_uploaded is not None:
if audio_uploaded.name.endswith(".wav"):
temp = AudioSegment.from_wav(audio_uploaded)
temp.export(f"{name}.wav")
if audio_uploaded.name.endswith(".mp3"):
try:
temp = AudioSegment.from_file(audio_uploaded, format="mp3")
temp.export(f"{name}.wav")
except:
temp = AudioSegment.from_file(audio_uploaded, format="mp4")
temp.export(f"{name}.wav")
if language == "":
model = whisper.load_model(model_name)
with st.spinner("Detecting language..."):
detection = detect_language(f"{name}.wav", model)
language = detection.get("detected_language")
del model
if len(language) > 2:
language = get_key(language)
if text_json is None:
with st.spinner("Running ... "):
decode = {"suppress_tokens": suppress_tokens, "beam_size": 5}
model = whisper.load_model(model_name)
with st.container():
with st.spinner(f"Running with {model_name} model"):
result = model.transcribe(
f"{name}.wav",
language=language,
patience=patience,
initial_prompt=initial_prompt,
condition_on_previous_text=condition_on_previous_text,
temperature=temperature,
compression_ratio_threshold=compression_ratio_threshold,
logprob_threshold=logprob_threshold,
no_speech_threshold=no_speech_threshold,
**decode,
)
if translate:
result = translate_to_english(result, json=False)
with open("transcription.json", "w") as f:
json.dump(result["segments"], f, indent=4, cls=ByteEncoder)
with st.spinner("Running alignment model ..."):
model_a, metadata = whisper.load_align_model(
language_code=result["language"], device=device
)
result_aligned = whisper.align(
result["segments"],
model_a,
metadata,
f"{name}.wav",
device=device,
)
write(
f"{name}.wav",
dtype=transcription,
result_aligned=result_aligned,
)
trans_text = read(f"{name}.wav", transcription)
trans.text_area(
"transcription", trans_text, height=None, max_chars=None, key=None
)
char_segments = []
word_segments = []
for x in range(len(result_aligned["segments"])):
word_segments.append(
{
"word-segments": result_aligned["segments"][x][
"word-segments"
]
.fillna("")
.to_dict(orient="records")
}
)
char_segments.append(
{
"char-segments": result_aligned["segments"][x][
"char-segments"
]
.fillna("")
.to_dict(orient="records")
}
)
for x in range(len(result_aligned["segments"])):
result_aligned["segments"][x]["word-segments"] = word_segments[x]
result_aligned["segments"][x]["char-segments"] = char_segments[x]
segments_pre.text_area(
"Segments before alignment",
result["segments"],
height=None,
max_chars=None,
key=None,
)
segments_post.text_area(
"Word Segments after alignment",
result_aligned["word_segments"],
height=None,
max_chars=None,
key=None,
)
segments_post2.text_area(
"Segments after alignment",
result_aligned["segments"],
height=None,
max_chars=None,
key=None,
)
lang.text_input(
"detected language", language_dict.get(language), disabled=True
)
os.remove(f"{name}.wav")
if text_json is not None:
with st.spinner("Running ... "):
model = whisper.load_model(model_name)
json_filname = str(uuid.uuid1())
data = json.load(text_json)
# Close the uploaded file
text_json.close()
# Write the JSON data to a new file
with open(f"{json_filname}.json", "w") as outfile:
json.dump(data, outfile)
# with open("fold.json", "w", encoding="utf-8") as f:
# json.dump(text_json, f)
with open(f"{json_filname}.json", "r", encoding="utf-8") as f:
cont = json.load(f)
with st.spinner("Running alignment model ..."):
model_a, metadata = whisper.load_align_model(
language_code=language, device=device
)
result_aligned = whisper.align(
cont,
model_a,
metadata,
f"{name}.wav",
device=device,
)
words_segments = result_aligned["word_segments"]
write(
f"{name}.wav",
dtype=transcription,
result_aligned=result_aligned,
)
trans_text = read(f"{name}.wav", transcription)
char_segments = []
word_segments = []
for x in range(len(result_aligned["segments"])):
word_segments.append(
{
"word-segments": result_aligned["segments"][x][
"word-segments"
]
.fillna("")
.to_dict(orient="records")
}
)
char_segments.append(
{
"char-segments": result_aligned["segments"][x][
"char-segments"
]
.fillna("")
.to_dict(orient="records")
}
)
for x in range(len(result_aligned["segments"])):
result_aligned["segments"][x]["word-segments"] = word_segments[x]
result_aligned["segments"][x]["char-segments"] = char_segments[x]
trans.text_area(
"transcription", trans_text, height=None, max_chars=None, key=None
)
segments_pre.text_area(
"Segments before alignment",
cont,
height=None,
max_chars=None,
key=None,
)
segments_post.text_area(
"Word Segments after alignment",
result_aligned["word_segments"],
height=None,
max_chars=None,
key=None,
)
segments_post2.text_area(
"Segments after alignment",
result_aligned["segments"],
expanded=False,
height=None,
max_chars=None,
key=None,
)
lang.text_input(
"detected language", language_dict.get(language), disabled=True
)
os.remove(f"{name}.wav")
os.remove(f"{json_filname}.json")