Spaces:
Build error
Build error
Upload 2 files
Browse files
app.py
CHANGED
@@ -1,17 +1,22 @@
|
|
1 |
import openai
|
|
|
2 |
import streamlit_scrollable_textbox as stx
|
3 |
|
4 |
-
import pinecone
|
5 |
import streamlit as st
|
6 |
from utils import (
|
|
|
7 |
create_dense_embeddings,
|
8 |
create_sparse_embeddings,
|
|
|
9 |
format_query,
|
10 |
-
|
|
|
|
|
11 |
get_data,
|
12 |
get_flan_t5_model,
|
13 |
get_mpnet_embedding_model,
|
14 |
get_sgpt_embedding_model,
|
|
|
15 |
get_splade_sparse_embedding_model,
|
16 |
get_t5_model,
|
17 |
gpt_model,
|
@@ -24,7 +29,7 @@ from utils import (
|
|
24 |
text_lookup,
|
25 |
)
|
26 |
|
27 |
-
st.set_page_config(layout="wide")
|
28 |
|
29 |
|
30 |
st.title("Abstractive Question Answering")
|
@@ -36,21 +41,31 @@ st.write(
|
|
36 |
|
37 |
col1, col2 = st.columns([3, 3], gap="medium")
|
38 |
|
|
|
|
|
|
|
39 |
with col1:
|
40 |
st.subheader("Question")
|
41 |
query_text = st.text_input(
|
42 |
"Input Query",
|
43 |
-
value="What was discussed regarding Wearables revenue performance?",
|
44 |
)
|
45 |
|
|
|
|
|
|
|
|
|
|
|
46 |
with col1:
|
47 |
years_choice = ["2020", "2019", "2018", "2017", "2016", "All"]
|
48 |
|
49 |
with col1:
|
50 |
-
year = st.selectbox("Year", years_choice)
|
51 |
|
52 |
with col1:
|
53 |
-
quarter = st.selectbox(
|
|
|
|
|
54 |
|
55 |
with col1:
|
56 |
participant_type = st.selectbox("Speaker", ["Company Speaker", "Analyst"])
|
@@ -69,7 +84,7 @@ ticker_choice = [
|
|
69 |
]
|
70 |
|
71 |
with col1:
|
72 |
-
ticker = st.selectbox("Company", ticker_choice)
|
73 |
|
74 |
with st.sidebar:
|
75 |
st.subheader("Select Options:")
|
@@ -189,9 +204,8 @@ else:
|
|
189 |
context_list = format_query(query_results)
|
190 |
|
191 |
|
192 |
-
prompt = generate_prompt(query_text, context_list)
|
193 |
-
|
194 |
if decoder_model == "GPT3 - (text-davinci-003)":
|
|
|
195 |
with col2:
|
196 |
with st.form("my_form"):
|
197 |
edited_prompt = st.text_area(
|
@@ -208,29 +222,57 @@ if decoder_model == "GPT3 - (text-davinci-003)":
|
|
208 |
api_key = save_key(openai_key)
|
209 |
openai.api_key = api_key
|
210 |
generated_text = gpt_model(edited_prompt)
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
|
215 |
elif decoder_model == "T5":
|
|
|
216 |
t5_pipeline = get_t5_model()
|
217 |
output_text = []
|
218 |
-
for context_text in context_list:
|
219 |
-
output_text.append(t5_pipeline(context_text)[0]["summary_text"])
|
220 |
with col2:
|
221 |
-
st.
|
222 |
-
|
223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
elif decoder_model == "FLAN-T5":
|
|
|
226 |
flan_t5_pipeline = get_flan_t5_model()
|
227 |
output_text = []
|
228 |
-
for context_text in context_list:
|
229 |
-
output_text.append(flan_t5_pipeline(context_text)[0]["summary_text"])
|
230 |
with col2:
|
231 |
-
st.
|
232 |
-
|
233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
with col1:
|
236 |
with st.expander("See Retrieved Text"):
|
|
|
1 |
import openai
|
2 |
+
import pinecone
|
3 |
import streamlit_scrollable_textbox as stx
|
4 |
|
|
|
5 |
import streamlit as st
|
6 |
from utils import (
|
7 |
+
clean_entities,
|
8 |
create_dense_embeddings,
|
9 |
create_sparse_embeddings,
|
10 |
+
extract_entities,
|
11 |
format_query,
|
12 |
+
generate_flant5_prompt,
|
13 |
+
generate_gpt_prompt,
|
14 |
+
get_context_list_prompt,
|
15 |
get_data,
|
16 |
get_flan_t5_model,
|
17 |
get_mpnet_embedding_model,
|
18 |
get_sgpt_embedding_model,
|
19 |
+
get_spacy_model,
|
20 |
get_splade_sparse_embedding_model,
|
21 |
get_t5_model,
|
22 |
gpt_model,
|
|
|
29 |
text_lookup,
|
30 |
)
|
31 |
|
32 |
+
st.set_page_config(layout="wide") # isort: skip
|
33 |
|
34 |
|
35 |
st.title("Abstractive Question Answering")
|
|
|
41 |
|
42 |
col1, col2 = st.columns([3, 3], gap="medium")
|
43 |
|
44 |
+
|
45 |
+
spacy_model = get_spacy_model()
|
46 |
+
|
47 |
with col1:
|
48 |
st.subheader("Question")
|
49 |
query_text = st.text_input(
|
50 |
"Input Query",
|
51 |
+
value="What was discussed regarding Wearables revenue performance in Q1 2020?",
|
52 |
)
|
53 |
|
54 |
+
company_ent, quarter_ent, year_ent = extract_entities(query_text, spacy_model)
|
55 |
+
ticker_index, quarter_index, year_index = clean_entities(
|
56 |
+
company_ent, quarter_ent, year_ent
|
57 |
+
)
|
58 |
+
|
59 |
with col1:
|
60 |
years_choice = ["2020", "2019", "2018", "2017", "2016", "All"]
|
61 |
|
62 |
with col1:
|
63 |
+
year = st.selectbox("Year", years_choice, index=year_index)
|
64 |
|
65 |
with col1:
|
66 |
+
quarter = st.selectbox(
|
67 |
+
"Quarter", ["Q1", "Q2", "Q3", "Q4", "All"], index=quarter_index
|
68 |
+
)
|
69 |
|
70 |
with col1:
|
71 |
participant_type = st.selectbox("Speaker", ["Company Speaker", "Analyst"])
|
|
|
84 |
]
|
85 |
|
86 |
with col1:
|
87 |
+
ticker = st.selectbox("Company", ticker_choice, ticker_index)
|
88 |
|
89 |
with st.sidebar:
|
90 |
st.subheader("Select Options:")
|
|
|
204 |
context_list = format_query(query_results)
|
205 |
|
206 |
|
|
|
|
|
207 |
if decoder_model == "GPT3 - (text-davinci-003)":
|
208 |
+
prompt = generate_gpt_prompt(query_text, context_list)
|
209 |
with col2:
|
210 |
with st.form("my_form"):
|
211 |
edited_prompt = st.text_area(
|
|
|
222 |
api_key = save_key(openai_key)
|
223 |
openai.api_key = api_key
|
224 |
generated_text = gpt_model(edited_prompt)
|
225 |
+
st.subheader("Answer:")
|
226 |
+
st.write(generated_text)
|
227 |
+
|
228 |
|
229 |
elif decoder_model == "T5":
|
230 |
+
prompt = generate_flant5_prompt(query_text, context_list)
|
231 |
t5_pipeline = get_t5_model()
|
232 |
output_text = []
|
|
|
|
|
233 |
with col2:
|
234 |
+
with st.form("my_form"):
|
235 |
+
edited_prompt = st.text_area(
|
236 |
+
label="Model Prompt", value=prompt, height=270
|
237 |
+
)
|
238 |
+
context_list = get_context_list_prompt(edited_prompt)
|
239 |
+
submitted = st.form_submit_button("Submit")
|
240 |
+
if submitted:
|
241 |
+
for context_text in context_list:
|
242 |
+
output_text.append(
|
243 |
+
t5_pipeline(context_text)[0]["summary_text"]
|
244 |
+
)
|
245 |
+
st.subheader("Answer:")
|
246 |
+
for text in output_text:
|
247 |
+
st.markdown(f"- {text}")
|
248 |
|
249 |
elif decoder_model == "FLAN-T5":
|
250 |
+
prompt = generate_flant5_prompt(query_text, context_list)
|
251 |
flan_t5_pipeline = get_flan_t5_model()
|
252 |
output_text = []
|
|
|
|
|
253 |
with col2:
|
254 |
+
with st.form("my_form"):
|
255 |
+
edited_prompt = st.text_area(
|
256 |
+
label="Model Prompt", value=prompt, height=270
|
257 |
+
)
|
258 |
+
context_list = get_context_list_prompt(edited_prompt)
|
259 |
+
submitted = st.form_submit_button("Submit")
|
260 |
+
if submitted:
|
261 |
+
for context_text in context_list:
|
262 |
+
output_text.append(
|
263 |
+
flan_t5_pipeline(
|
264 |
+
"Question:"
|
265 |
+
+ query_text
|
266 |
+
+ "\nContext:"
|
267 |
+
+ context_text
|
268 |
+
+ "\nAnswer?"
|
269 |
+
)[0]["summary_text"]
|
270 |
+
)
|
271 |
+
st.subheader("Answer:")
|
272 |
+
for text in output_text:
|
273 |
+
if "(iii)" not in text:
|
274 |
+
st.markdown(f"- {text}")
|
275 |
+
|
276 |
|
277 |
with col1:
|
278 |
with st.expander("See Retrieved Text"):
|
utils.py
CHANGED
@@ -1,5 +1,9 @@
|
|
|
|
|
|
1 |
import openai
|
2 |
import pandas as pd
|
|
|
|
|
3 |
import streamlit_scrollable_textbox as stx
|
4 |
import torch
|
5 |
from sentence_transformers import SentenceTransformer
|
@@ -11,7 +15,6 @@ from transformers import (
|
|
11 |
pipeline,
|
12 |
)
|
13 |
|
14 |
-
import pinecone
|
15 |
import streamlit as st
|
16 |
|
17 |
|
@@ -21,6 +24,14 @@ def get_data():
|
|
21 |
return data
|
22 |
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
# Initialize models from HuggingFace
|
25 |
|
26 |
|
@@ -33,8 +44,8 @@ def get_t5_model():
|
|
33 |
def get_flan_t5_model():
|
34 |
return pipeline(
|
35 |
"summarization",
|
36 |
-
model="google/flan-t5-
|
37 |
-
tokenizer="google/flan-t5-
|
38 |
max_length=512,
|
39 |
# length_penalty = 0
|
40 |
)
|
@@ -320,7 +331,7 @@ def text_lookup(data, sentence_ids):
|
|
320 |
return context
|
321 |
|
322 |
|
323 |
-
def
|
324 |
context = " ".join(context_list)
|
325 |
prompt = f"""Answer the question in 6 long detailed points as accurately as possible using the provided context. Include as many key details as possible.
|
326 |
Context: {context}
|
@@ -329,7 +340,7 @@ Answer:"""
|
|
329 |
return prompt
|
330 |
|
331 |
|
332 |
-
def
|
333 |
context = " ".join(context_list)
|
334 |
prompt = f"""
|
335 |
Context information is below:
|
@@ -342,6 +353,24 @@ def generate_prompt_2(query_text, context_list):
|
|
342 |
return prompt
|
343 |
|
344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
def gpt_model(prompt):
|
346 |
response = openai.Completion.create(
|
347 |
model="text-davinci-003",
|
@@ -355,6 +384,98 @@ def gpt_model(prompt):
|
|
355 |
return response.choices[0].text
|
356 |
|
357 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
# Transcript Retrieval
|
359 |
|
360 |
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
import openai
|
4 |
import pandas as pd
|
5 |
+
import pinecone
|
6 |
+
import spacy
|
7 |
import streamlit_scrollable_textbox as stx
|
8 |
import torch
|
9 |
from sentence_transformers import SentenceTransformer
|
|
|
15 |
pipeline,
|
16 |
)
|
17 |
|
|
|
18 |
import streamlit as st
|
19 |
|
20 |
|
|
|
24 |
return data
|
25 |
|
26 |
|
27 |
+
# Initialize Spacy Model
|
28 |
+
|
29 |
+
|
30 |
+
@st.experimental_singleton
|
31 |
+
def get_spacy_model():
|
32 |
+
return spacy.load("en_core_web_sm")
|
33 |
+
|
34 |
+
|
35 |
# Initialize models from HuggingFace
|
36 |
|
37 |
|
|
|
44 |
def get_flan_t5_model():
|
45 |
return pipeline(
|
46 |
"summarization",
|
47 |
+
model="google/flan-t5-xl",
|
48 |
+
tokenizer="google/flan-t5-xl",
|
49 |
max_length=512,
|
50 |
# length_penalty = 0
|
51 |
)
|
|
|
331 |
return context
|
332 |
|
333 |
|
334 |
+
def generate_gpt_prompt(query_text, context_list):
|
335 |
context = " ".join(context_list)
|
336 |
prompt = f"""Answer the question in 6 long detailed points as accurately as possible using the provided context. Include as many key details as possible.
|
337 |
Context: {context}
|
|
|
340 |
return prompt
|
341 |
|
342 |
|
343 |
+
def generate_gpt_prompt_2(query_text, context_list):
|
344 |
context = " ".join(context_list)
|
345 |
prompt = f"""
|
346 |
Context information is below:
|
|
|
353 |
return prompt
|
354 |
|
355 |
|
356 |
+
def generate_flant5_prompt(query_text, context_list):
|
357 |
+
context = " \n".join(context_list)
|
358 |
+
prompt = f"""Given the context information and prior knowledge, answer this question:
|
359 |
+
{query_text}
|
360 |
+
Context information is below:
|
361 |
+
---------------------
|
362 |
+
{context}
|
363 |
+
---------------------"""
|
364 |
+
return prompt
|
365 |
+
|
366 |
+
|
367 |
+
def get_context_list_prompt(prompt):
|
368 |
+
prompt_list = prompt.split("---------------------")
|
369 |
+
context = prompt_list[-2].strip()
|
370 |
+
context_list = context.split(" \n")
|
371 |
+
return context_list
|
372 |
+
|
373 |
+
|
374 |
def gpt_model(prompt):
|
375 |
response = openai.Completion.create(
|
376 |
model="text-davinci-003",
|
|
|
384 |
return response.choices[0].text
|
385 |
|
386 |
|
387 |
+
# Entity Extraction
|
388 |
+
|
389 |
+
|
390 |
+
def extract_quarter_year(string):
|
391 |
+
# Extract year from string
|
392 |
+
year_match = re.search(r"\d{4}", string)
|
393 |
+
if year_match:
|
394 |
+
year = year_match.group()
|
395 |
+
else:
|
396 |
+
return None, None
|
397 |
+
|
398 |
+
# Extract quarter from string
|
399 |
+
quarter_match = re.search(r"Q\d", string)
|
400 |
+
if quarter_match:
|
401 |
+
quarter = "Q" + quarter_match.group()[1]
|
402 |
+
else:
|
403 |
+
return None, None
|
404 |
+
|
405 |
+
return quarter, year
|
406 |
+
|
407 |
+
|
408 |
+
def extract_entities(query, model):
|
409 |
+
doc = model(query)
|
410 |
+
entities = {ent.label_: ent.text for ent in doc.ents}
|
411 |
+
if "ORG" in entities.keys():
|
412 |
+
company = entities["ORG"].lower()
|
413 |
+
if "DATE" in entities.keys():
|
414 |
+
quarter, year = extract_quarter_year(entities["DATE"])
|
415 |
+
return company, quarter, year
|
416 |
+
else:
|
417 |
+
return company, None, None
|
418 |
+
else:
|
419 |
+
if "DATE" in entities.keys():
|
420 |
+
quarter, year = extract_quarter_year(entities["DATE"])
|
421 |
+
return None, quarter, year
|
422 |
+
else:
|
423 |
+
return None, None, None
|
424 |
+
|
425 |
+
|
426 |
+
def clean_entities(company, quarter, year):
|
427 |
+
company_ticker_map = {
|
428 |
+
"apple": "AAPL",
|
429 |
+
"amd": "AMD",
|
430 |
+
"amazon": "AMZN",
|
431 |
+
"cisco": "CSCO",
|
432 |
+
"google": "GOOGL",
|
433 |
+
"microsoft": "MSFT",
|
434 |
+
"nvidia": "NVDA",
|
435 |
+
"asml": "ASML",
|
436 |
+
"intel": "INTC",
|
437 |
+
"micron": "MU",
|
438 |
+
}
|
439 |
+
|
440 |
+
ticker_choice = [
|
441 |
+
"AAPL",
|
442 |
+
"CSCO",
|
443 |
+
"MSFT",
|
444 |
+
"ASML",
|
445 |
+
"NVDA",
|
446 |
+
"GOOGL",
|
447 |
+
"MU",
|
448 |
+
"INTC",
|
449 |
+
"AMZN",
|
450 |
+
"AMD",
|
451 |
+
]
|
452 |
+
year_choice = ["2020", "2019", "2018", "2017", "2016", "All"]
|
453 |
+
quarter_choice = ["Q1", "Q2", "Q3", "Q4", "All"]
|
454 |
+
if company is not None:
|
455 |
+
if company in company_ticker_map.keys():
|
456 |
+
ticker = company_ticker_map[company]
|
457 |
+
ticker_index = ticker_choice.index(ticker)
|
458 |
+
else:
|
459 |
+
ticker_index = 0
|
460 |
+
else:
|
461 |
+
ticker_index = 0
|
462 |
+
if quarter is not None:
|
463 |
+
if quarter in quarter_choice:
|
464 |
+
quarter_index = quarter_choice.index(quarter)
|
465 |
+
else:
|
466 |
+
quarter_index = len(quarter_choice) - 1
|
467 |
+
else:
|
468 |
+
quarter_index = len(quarter_choice) - 1
|
469 |
+
if year is not None:
|
470 |
+
if year in year_choice:
|
471 |
+
year_index = year_choice.index(year)
|
472 |
+
else:
|
473 |
+
year_index = len(year_choice) - 1
|
474 |
+
else:
|
475 |
+
year_index = len(year_choice) - 1
|
476 |
+
return ticker_index, quarter_index, year_index
|
477 |
+
|
478 |
+
|
479 |
# Transcript Retrieval
|
480 |
|
481 |
|