Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
import whisper
|
3 |
+
from pytube import YouTube
|
4 |
+
import streamlit as st
|
5 |
+
from sentence_transformers import SentenceTransformer, util
|
6 |
+
|
7 |
+
nltk.download('punkt')
|
8 |
+
|
9 |
+
|
10 |
+
@st.experimental_singleton
|
11 |
+
def init_sentence_model(embedding_model):
|
12 |
+
return SentenceTransformer(embedding_model)
|
13 |
+
|
14 |
+
@st.experimental_singleton
|
15 |
+
def init_whisper(whisper_size):
|
16 |
+
return whisper.load_model(whisper_size)
|
17 |
+
|
18 |
+
@st.experimental_memo
|
19 |
+
def inference(link):
|
20 |
+
yt = YouTube(link)
|
21 |
+
path = yt.streams.filter(only_audio=True)[0].download(filename="audio.mp4")
|
22 |
+
options = whisper.DecodingOptions(without_timestamps=True)
|
23 |
+
results = whisper_model.transcribe(path)
|
24 |
+
return results['segments']
|
25 |
+
|
26 |
+
@st.experimental_memo
|
27 |
+
def get_embeddings(segments):
|
28 |
+
return model.encode(segments["text"])
|
29 |
+
|
30 |
+
def format_segments(segments, window=10):
|
31 |
+
new_segments = dict()
|
32 |
+
new_segments['text'] = [" ".join([seg['text'] for seg in segments[i:i+5]]) for i in range(0, len(segments), window)]
|
33 |
+
new_segments['start'] = [segments[i]['start'] for i in range(0, len(segments), window)]
|
34 |
+
|
35 |
+
return new_segments
|
36 |
+
|
37 |
+
with st.form("transcribe"):
|
38 |
+
yt_link = st.text_input("Youtube link")
|
39 |
+
whisper_size = st.selectbox("Whisper model size", ("small", "base", "large"))
|
40 |
+
embedding_model = st.text_input("Embedding model name", value='all-mpnet-base-v2')
|
41 |
+
top_k = st.number_input("Number of query results", value=5)
|
42 |
+
window = st.number_input("Number of segments per result", value=10)
|
43 |
+
|
44 |
+
transcribe_submit = st.form_submit_button("Submit")
|
45 |
+
|
46 |
+
if transcribe_submit and 'start_search' not in st.session_state:
|
47 |
+
st.session_state.start_search = True
|
48 |
+
|
49 |
+
if 'start_search' in st.session_state:
|
50 |
+
model = init_sentence_model(embedding_model)
|
51 |
+
|
52 |
+
whisper_model = init_whisper(whisper_size)
|
53 |
+
|
54 |
+
segments = inference(yt_link)
|
55 |
+
|
56 |
+
segments = format_segments(segments, window)
|
57 |
+
|
58 |
+
embeddings = get_embeddings(segments)
|
59 |
+
|
60 |
+
query = st.text_input('Enter a query')
|
61 |
+
|
62 |
+
if query:
|
63 |
+
query_embedding = model.encode(query)
|
64 |
+
results = util.semantic_search(query_embedding, embeddings, top_k=top_k)
|
65 |
+
st.markdown("\n\n".join([segments['text'][result['corpus_id']]+f"... [Watch at timestamp]({yt_link}&t={segments['start'][result['corpus_id']]}s)" for result in results[0]]), unsafe_allow_html=True)
|