awinml's picture
Upload 2 files
2a99161
raw
history blame
No virus
3.88 kB
import pinecone
import streamlit as st
import streamlit_scrollable_textbox as stx
import openai
from utils import (
get_data,
get_mpnet_embedding_model,
get_sgpt_embedding_model,
get_flan_t5_model,
get_t5_model,
save_key,
)
from utils import (
retrieve_transcript,
query_pinecone,
format_query,
sentence_id_combine,
text_lookup,
gpt3,
)
st.title("Abstractive Question Answering")
st.write(
"The app uses the quarterly earnings call transcripts for 10 companies (Apple, AMD, Amazon, Cisco, Google, Microsoft, Nvidia, ASML, Intel, Micron) for the years 2016 to 2020."
)
query_text = st.text_input("Input Query", value="Who is the CEO of Apple?")
years_choice = ["2016", "2017", "2018", "2019", "2020"]
year = st.selectbox("Year", years_choice)
quarter = st.selectbox("Quarter", ["Q1", "Q2", "Q3", "Q4"])
ticker_choice = [
"AAPL",
"CSCO",
"MSFT",
"ASML",
"NVDA",
"GOOGL",
"MU",
"INTC",
"AMZN",
"AMD",
]
ticker = st.selectbox("Company", ticker_choice)
num_results = int(st.number_input("Number of Results to query", 1, 5, value=3))
# Choose encoder model
encoder_models_choice = ["SGPT", "MPNET"]
encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
# Choose decoder model
decoder_models_choice = ["FLAN-T5", "T5", "GPT3 - (text-davinci-003)"]
decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice)
if encoder_model == "MPNET":
# Connect to pinecone environment
pinecone.init(api_key=st.secrets["pinecone_mpnet"], environment="us-east1-gcp")
pinecone_index_name = "week2-all-mpnet-base"
pinecone_index = pinecone.Index(pinecone_index_name)
retriever_model = get_mpnet_embedding_model()
elif encoder_model == "SGPT":
# Connect to pinecone environment
pinecone.init(api_key=st.secrets["pinecone_sgpt"], environment="us-east1-gcp")
pinecone_index_name = "week2-sgpt-125m"
pinecone_index = pinecone.Index(pinecone_index_name)
retriever_model = get_sgpt_embedding_model()
window = int(st.number_input("Sentence Window Size", 0, 3, value=0))
threshold = float(
st.number_input(
label="Similarity Score Threshold", step=0.05, format="%.2f", value=0.55
)
)
data = get_data()
query_results = query_pinecone(
query_text,
num_results,
retriever_model,
pinecone_index,
year,
quarter,
ticker,
threshold,
)
if threshold <= 0.60:
context_list = sentence_id_combine(data, query_results, lag=window)
else:
context_list = format_query(query_results)
st.subheader("Answer:")
if decoder_model == "GPT3 - (text-davinci-003)":
openai_key = st.text_input(
"Enter OpenAI key",
value=st.secrets["openai_key"],
type="password",
)
api_key = save_key(openai_key)
openai.api_key = api_key
generated_text = gpt3(query_text, context_list)
st.write(gpt3(generated_text))
elif decoder_model == "T5":
t5_pipeline = get_t5_model()
output_text = []
for context_text in context_list:
output_text.append(t5_pipeline(context_text)[0]["summary_text"])
generated_text = ". ".join(output_text)
st.write(t5_pipeline(generated_text)[0]["summary_text"])
elif decoder_model == "FLAN-T5":
flan_t5_pipeline = get_flan_t5_model()
output_text = []
for context_text in context_list:
output_text.append(flan_t5_pipeline(context_text)[0]["summary_text"])
generated_text = ". ".join(output_text)
st.write(flan_t5_pipeline(generated_text)[0]["summary_text"])
with st.expander("See Retrieved Text"):
for context_text in context_list:
st.markdown(f"- {context_text}")
file_text = retrieve_transcript(data, year, quarter, ticker)
with st.expander("See Transcript"):
stx.scrollableTextbox(file_text, height=700, border=False, fontFamily="Helvetica")