import pinecone import streamlit as st import streamlit_scrollable_textbox as stx import openai from utils import ( get_data, get_mpnet_embedding_model, get_sgpt_embedding_model, get_flan_t5_model, get_t5_model, save_key, ) from utils import ( retrieve_transcript, query_pinecone, format_query, sentence_id_combine, text_lookup, gpt3, ) st.title("Abstractive Question Answering") st.write( "The app uses the quarterly earnings call transcripts for 10 companies (Apple, AMD, Amazon, Cisco, Google, Microsoft, Nvidia, ASML, Intel, Micron) for the years 2016 to 2020." ) query_text = st.text_input("Input Query", value="Who is the CEO of Apple?") years_choice = ["2016", "2017", "2018", "2019", "2020"] year = st.selectbox("Year", years_choice) quarter = st.selectbox("Quarter", ["Q1", "Q2", "Q3", "Q4"]) ticker_choice = [ "AAPL", "CSCO", "MSFT", "ASML", "NVDA", "GOOGL", "MU", "INTC", "AMZN", "AMD", ] ticker = st.selectbox("Company", ticker_choice) num_results = int(st.number_input("Number of Results to query", 1, 5, value=3)) # Choose encoder model encoder_models_choice = ["SGPT", "MPNET"] encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice) # Choose decoder model decoder_models_choice = ["FLAN-T5", "T5", "GPT3 - (text-davinci-003)"] decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice) if encoder_model == "MPNET": # Connect to pinecone environment pinecone.init(api_key=st.secrets["pinecone_mpnet"], environment="us-east1-gcp") pinecone_index_name = "week2-all-mpnet-base" pinecone_index = pinecone.Index(pinecone_index_name) retriever_model = get_mpnet_embedding_model() elif encoder_model == "SGPT": # Connect to pinecone environment pinecone.init(api_key=st.secrets["pinecone_sgpt"], environment="us-east1-gcp") pinecone_index_name = "week2-sgpt-125m" pinecone_index = pinecone.Index(pinecone_index_name) retriever_model = get_sgpt_embedding_model() window = int(st.number_input("Sentence Window Size", 0, 3, value=0)) threshold = float( st.number_input( label="Similarity Score Threshold", step=0.05, format="%.2f", value=0.55 ) ) data = get_data() query_results = query_pinecone( query_text, num_results, retriever_model, pinecone_index, year, quarter, ticker, threshold, ) if threshold <= 0.60: context_list = sentence_id_combine(data, query_results, lag=window) else: context_list = format_query(query_results) st.subheader("Answer:") if decoder_model == "GPT3 (summary_davinci)": openai_key = st.text_input( "Enter OpenAI key", value=st.secrets["openai_key"], type="password", ) api_key = save_key(openai_key) openai.api_key = api_key generated_text = gpt3(query_text, context_list) st.write(gpt3(generated_text)) elif decoder_model == "T5": t5_pipeline = get_t5_model() output_text = [] for context_text in context_list: output_text.append(t5_pipeline(context_text)[0]["summary_text"]) generated_text = ". ".join(output_text) st.write(t5_pipeline(generated_text)[0]["summary_text"]) elif decoder_model == "FLAN-T5": flan_t5_pipeline = get_flan_t5_model() output_text = [] for context_text in context_list: output_text.append(flan_t5_pipeline(context_text)[0]["summary_text"]) generated_text = ". ".join(output_text) st.write(flan_t5_pipeline(generated_text)[0]["summary_text"]) with st.expander("See Retrieved Text"): for context_text in context_list: st.markdown(f"- {context_text}") file_text = retrieve_transcript(data, year, quarter, ticker) with st.expander("See Transcript"): stx.scrollableTextbox(file_text, height=700, border=False, fontFamily="Helvetica")