Spaces:
Build error
Build error
import pinecone | |
import streamlit as st | |
import streamlit_scrollable_textbox as stx | |
import openai | |
from utils import ( | |
get_data, | |
get_mpnet_embedding_model, | |
get_sgpt_embedding_model, | |
get_flan_t5_model, | |
get_t5_model, | |
) | |
from utils import ( | |
retrieve_transcript, | |
query_pinecone, | |
format_query, | |
sentence_id_combine, | |
text_lookup, | |
gpt3_qa, | |
gpt3_summary, | |
) | |
st.title("Abstractive Question Answering") | |
st.write( | |
"The app uses the quarterly earnings call transcripts for 10 companies (Apple, AMD, Amazon, Cisco, Google, Microsoft, Nvidia, ASML, Intel, Micron) for the years 2016 to 2020." | |
) | |
query_text = st.text_input("Input Query", value="Who is the CEO of Apple?") | |
years_choice = ["2016", "2017", "2018", "2019", "2020"] | |
year = st.selectbox("Year", years_choice) | |
quarter = st.selectbox("Quarter", ["Q1", "Q2", "Q3", "Q4"]) | |
ticker_choice = [ | |
"AAPL", | |
"CSCO", | |
"MSFT", | |
"ASML", | |
"NVDA", | |
"GOOGL", | |
"MU", | |
"INTC", | |
"AMZN", | |
"AMD", | |
] | |
ticker = st.selectbox("Company", ticker_choice) | |
num_results = int(st.number_input("Number of Results to query", 1, 5, value=3)) | |
# Choose encoder model | |
encoder_models_choice = ["SGPT", "MPNET"] | |
encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice) | |
# Choose decoder model | |
decoder_models_choice = ["FLAN-T5", "T5", "GPT3 (QA_davinci)", "GPT3 (summary_davinci)"] | |
decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice) | |
if encoder_model == "MPNET": | |
# Connect to pinecone environment | |
pinecone.init(api_key=st.secrets["pinecone_mpnet"], environment="us-east1-gcp") | |
pinecone_index_name = "week2-all-mpnet-base" | |
pinecone_index = pinecone.Index(pinecone_index_name) | |
retriever_model = get_mpnet_embedding_model() | |
elif encoder_model == "SGPT": | |
# Connect to pinecone environment | |
pinecone.init(api_key=st.secrets["pinecone_sgpt"], environment="us-east1-gcp") | |
pinecone_index_name = "week2-sgpt-125m" | |
pinecone_index = pinecone.Index(pinecone_index_name) | |
retriever_model = get_sgpt_embedding_model() | |
window = int(st.number_input("Sentence Window Size", 0, 3, value=0)) | |
threshold = float( | |
st.number_input( | |
label="Similarity Score Threshold", step=0.05, format="%.2f", value=0.55 | |
) | |
) | |
data = get_data() | |
query_results = query_pinecone( | |
query_text, | |
num_results, | |
retriever_model, | |
pinecone_index, | |
year, | |
quarter, | |
ticker, | |
threshold, | |
) | |
if threshold <= 0.60: | |
context_list = sentence_id_combine(data, query_results, lag=window) | |
else: | |
context_list = format_query(query_results) | |
st.subheader("Answer:") | |
if decoder_model == "GPT3 (summary_davinci)": | |
openai_key = st.text_input( | |
"Enter OpenAI key", | |
value=st.secrets["openai_key"], | |
type="password", | |
) | |
api_key = save_key(openai_key) | |
openai.api_key = api_key | |
output_text = [] | |
for context_text in context_list: | |
output_text.append(gpt3_summary(context_text)) | |
generated_text = ". ".join(output_text) | |
st.write(gpt3_summary(generated_text)) | |
elif decoder_model == "GPT3 (QA_davinci)": | |
openai_key = st.text_input( | |
"Enter OpenAI key", | |
value=st.secrets["openai_key"], | |
type="password", | |
) | |
api_key = save_key(openai_key) | |
openai.api_key = api_key | |
output_text = [] | |
for context_text in context_list: | |
output_text.append(gpt3_qa(query_text, context_text)) | |
generated_text = ". ".join(output_text) | |
st.write(gpt3_qa(query_text, generated_text)) | |
elif decoder_model == "T5": | |
t5_pipeline = get_t5_model() | |
output_text = [] | |
for context_text in context_list: | |
output_text.append(t5_pipeline(context_text)[0]["summary_text"]) | |
generated_text = ". ".join(output_text) | |
st.write(t5_pipeline(generated_text)[0]["summary_text"]) | |
elif decoder_model == "FLAN-T5": | |
flan_t5_pipeline = get_flan_t5_model() | |
output_text = [] | |
for context_text in context_list: | |
output_text.append(flan_t5_pipeline(context_text)[0]["summary_text"]) | |
generated_text = ". ".join(output_text) | |
st.write(flan_t5_pipeline(generated_text)[0]["summary_text"]) | |
with st.expander("See Retrieved Text"): | |
for context_text in context_list: | |
st.markdown(f"- {context_text}") | |
file_text = retrieve_transcript(data, year, quarter, ticker) | |
with st.expander("See Transcript"): | |
stx.scrollableTextbox(file_text, height=700, border=False, fontFamily="Helvetica") | |