Spaces:
Build error
Build error
File size: 3,873 Bytes
8cd1f1e f9da573 8cd1f1e f9da573 a64e1a1 f9da573 8cd1f1e f9da573 a64e1a1 f9da573 8cd1f1e a7b0635 b19bb41 8cd1f1e c5f41e6 e514fa8 8cd1f1e e514fa8 8cd1f1e a64e1a1 8cd1f1e b19bb41 8cd1f1e b19bb41 8cd1f1e 0ba41da fbd690d e514fa8 fbd690d e514fa8 c5f41e6 e514fa8 0ba41da e514fa8 8cd1f1e e514fa8 8cd1f1e 12db858 8cd1f1e a64e1a1 8cd1f1e 40eb760 8cd1f1e 40eb760 8cd1f1e f9da573 8cd1f1e f9da573 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import pinecone
import streamlit as st
import streamlit_scrollable_textbox as stx
import openai
from utils import (
get_data,
get_mpnet_embedding_model,
get_sgpt_embedding_model,
get_flan_t5_model,
get_t5_model,
save_key,
)
from utils import (
retrieve_transcript,
query_pinecone,
format_query,
sentence_id_combine,
text_lookup,
gpt3,
)
st.title("Abstractive Question Answering")
st.write(
"The app uses the quarterly earnings call transcripts for 10 companies (Apple, AMD, Amazon, Cisco, Google, Microsoft, Nvidia, ASML, Intel, Micron) for the years 2016 to 2020."
)
query_text = st.text_input("Input Query", value="Who is the CEO of Apple?")
years_choice = ["2016", "2017", "2018", "2019", "2020"]
year = st.selectbox("Year", years_choice)
quarter = st.selectbox("Quarter", ["Q1", "Q2", "Q3", "Q4"])
ticker_choice = [
"AAPL",
"CSCO",
"MSFT",
"ASML",
"NVDA",
"GOOGL",
"MU",
"INTC",
"AMZN",
"AMD",
]
ticker = st.selectbox("Company", ticker_choice)
num_results = int(st.number_input("Number of Results to query", 1, 5, value=3))
# Choose encoder model
encoder_models_choice = ["SGPT", "MPNET"]
encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
# Choose decoder model
decoder_models_choice = ["FLAN-T5", "T5", "GPT3 - (text-davinci-003)"]
decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice)
if encoder_model == "MPNET":
# Connect to pinecone environment
pinecone.init(api_key=st.secrets["pinecone_mpnet"], environment="us-east1-gcp")
pinecone_index_name = "week2-all-mpnet-base"
pinecone_index = pinecone.Index(pinecone_index_name)
retriever_model = get_mpnet_embedding_model()
elif encoder_model == "SGPT":
# Connect to pinecone environment
pinecone.init(api_key=st.secrets["pinecone_sgpt"], environment="us-east1-gcp")
pinecone_index_name = "week2-sgpt-125m"
pinecone_index = pinecone.Index(pinecone_index_name)
retriever_model = get_sgpt_embedding_model()
window = int(st.number_input("Sentence Window Size", 0, 3, value=0))
threshold = float(
st.number_input(
label="Similarity Score Threshold", step=0.05, format="%.2f", value=0.55
)
)
data = get_data()
query_results = query_pinecone(
query_text,
num_results,
retriever_model,
pinecone_index,
year,
quarter,
ticker,
threshold,
)
if threshold <= 0.60:
context_list = sentence_id_combine(data, query_results, lag=window)
else:
context_list = format_query(query_results)
st.subheader("Answer:")
if decoder_model == "GPT3 (summary_davinci)":
openai_key = st.text_input(
"Enter OpenAI key",
value=st.secrets["openai_key"],
type="password",
)
api_key = save_key(openai_key)
openai.api_key = api_key
generated_text = gpt3(query_text, context_list)
st.write(gpt3(generated_text))
elif decoder_model == "T5":
t5_pipeline = get_t5_model()
output_text = []
for context_text in context_list:
output_text.append(t5_pipeline(context_text)[0]["summary_text"])
generated_text = ". ".join(output_text)
st.write(t5_pipeline(generated_text)[0]["summary_text"])
elif decoder_model == "FLAN-T5":
flan_t5_pipeline = get_flan_t5_model()
output_text = []
for context_text in context_list:
output_text.append(flan_t5_pipeline(context_text)[0]["summary_text"])
generated_text = ". ".join(output_text)
st.write(flan_t5_pipeline(generated_text)[0]["summary_text"])
with st.expander("See Retrieved Text"):
for context_text in context_list:
st.markdown(f"- {context_text}")
file_text = retrieve_transcript(data, year, quarter, ticker)
with st.expander("See Transcript"):
stx.scrollableTextbox(file_text, height=700, border=False, fontFamily="Helvetica")
|