Spaces:
Build error
Build error
import pinecone | |
import streamlit as st | |
st.set_page_config(layout="wide") | |
import streamlit_scrollable_textbox as stx | |
import openai | |
from utils import ( | |
get_data, | |
get_mpnet_embedding_model, | |
get_sgpt_embedding_model, | |
get_flan_t5_model, | |
get_t5_model, | |
save_key, | |
) | |
from utils import ( | |
retrieve_transcript, | |
query_pinecone, | |
format_query, | |
sentence_id_combine, | |
text_lookup, | |
generate_prompt, | |
gpt_model, | |
) | |
st.title("Abstractive Question Answering") | |
st.write( | |
"The app uses the quarterly earnings call transcripts for 10 companies (Apple, AMD, Amazon, Cisco, Google, Microsoft, Nvidia, ASML, Intel, Micron) for the years 2016 to 2020." | |
) | |
col1, col2 = st.columns([3, 3], gap="medium") | |
with col1: | |
st.subheader("Question") | |
query_text = st.text_input( | |
"Input Query", | |
value="What was discussed regarding Wearables revenue performance?", | |
) | |
with col1: | |
years_choice = ["2020", "2019", "2018", "2017", "2016", "All"] | |
with col1: | |
year = st.selectbox("Year", years_choice) | |
with col1: | |
quarter = st.selectbox("Quarter", ["Q1", "Q2", "Q3", "Q4", "All"]) | |
with col1: | |
participant_type = st.selectbox("Speaker", ["Company Speaker", "Analyst"]) | |
ticker_choice = [ | |
"AAPL", | |
"CSCO", | |
"MSFT", | |
"ASML", | |
"NVDA", | |
"GOOGL", | |
"MU", | |
"INTC", | |
"AMZN", | |
"AMD", | |
] | |
with col1: | |
ticker = st.selectbox("Company", ticker_choice) | |
with st.sidebar: | |
st.subheader("Select Options:") | |
with st.sidebar: | |
num_results = int(st.number_input("Number of Results to query", 1, 15, value=6)) | |
# Choose encoder model | |
encoder_models_choice = ["MPNET", "SGPT"] | |
with st.sidebar: | |
encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice) | |
# Choose decoder model | |
decoder_models_choice = [ | |
"GPT3 - (text-davinci-003)", | |
"T5", | |
"FLAN-T5", | |
"GPT-J" | |
] | |
with st.sidebar: | |
decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice) | |
if encoder_model == "MPNET": | |
# Connect to pinecone environment | |
pinecone.init(api_key=st.secrets["pinecone_mpnet"], environment="us-east1-gcp") | |
pinecone_index_name = "week2-all-mpnet-base" | |
pinecone_index = pinecone.Index(pinecone_index_name) | |
retriever_model = get_mpnet_embedding_model() | |
elif encoder_model == "SGPT": | |
# Connect to pinecone environment | |
pinecone.init(api_key=st.secrets["pinecone_sgpt"], environment="us-east1-gcp") | |
pinecone_index_name = "week2-sgpt-125m" | |
pinecone_index = pinecone.Index(pinecone_index_name) | |
retriever_model = get_sgpt_embedding_model() | |
with st.sidebar: | |
window = int(st.number_input("Sentence Window Size", 0, 10, value=1)) | |
with st.sidebar: | |
threshold = float( | |
st.number_input( | |
label="Similarity Score Threshold", step=0.05, format="%.2f", value=0.25 | |
) | |
) | |
data = get_data() | |
query_results = query_pinecone( | |
query_text, | |
num_results, | |
retriever_model, | |
pinecone_index, | |
year, | |
quarter, | |
ticker, | |
participant_type, | |
threshold, | |
) | |
if threshold <= 0.90: | |
context_list = sentence_id_combine(data, query_results, lag=window) | |
else: | |
context_list = format_query(query_results) | |
prompt = generate_prompt(query_text, context_list) | |
if decoder_model == "GPT3 - (text-davinci-003)": | |
with col2: | |
with st.form("my_form"): | |
edited_prompt = st.text_area(label="Model Prompt", value=prompt, height=270) | |
openai_key = st.text_input( | |
"Enter OpenAI key", | |
value="", | |
type="password", | |
) | |
submitted = st.form_submit_button("Submit") | |
if submitted: | |
api_key = save_key(openai_key) | |
openai.api_key = api_key | |
generated_text = gpt_model(edited_prompt) | |
with col2: | |
st.subheader("Answer:") | |
st.write(generated_text) | |
elif decoder_model == "T5": | |
t5_pipeline = get_t5_model() | |
output_text = [] | |
for context_text in context_list: | |
output_text.append(t5_pipeline(context_text)[0]["summary_text"]) | |
generated_text = ". ".join(output_text) | |
with col2: | |
st.subheader("Answer:") | |
st.write(t5_pipeline(generated_text)[0]["summary_text"]) | |
elif decoder_model == "FLAN-T5": | |
flan_t5_pipeline = get_flan_t5_model() | |
output_text = [] | |
for context_text in context_list: | |
output_text.append(flan_t5_pipeline(context_text)[0]["summary_text"]) | |
generated_text = ". ".join(output_text) | |
with col2: | |
st.subheader("Answer:") | |
st.write(flan_t5_pipeline(generated_text)[0]["summary_text"]) | |
elif decoder_model == "GPTJ": | |
flan_t5_pipeline = get_gptj_model() | |
output_text = [] | |
for context_text in context_list: | |
output_text.append(gptj_pipline(context_text)[0]["summary_text"]) | |
generated_text = ". ".join(output_text) | |
with col2: | |
st.subheader("Answer:") | |
st.write(flan_t5_pipeline(generated_text)[0]["summary_text"]) | |
with col1: | |
with st.expander("See Retrieved Text"): | |
for context_text in context_list: | |
st.markdown(f"- {context_text}") | |
file_text = retrieve_transcript(data, year, quarter, ticker) | |
with col1: | |
with st.expander("See Transcript"): | |
stx.scrollableTextbox( | |
file_text, height=700, border=False, fontFamily="Helvetica" | |
) | |