Spaces:
Build error
Build error
Upload app.py
Browse files
app.py
CHANGED
@@ -59,11 +59,17 @@ def save_key(api_key):
|
|
59 |
return api_key
|
60 |
|
61 |
|
62 |
-
def query_pinecone(query, top_k, model, index):
|
63 |
# generate embeddings for the query
|
64 |
xq = model.encode([query]).tolist()
|
65 |
# search pinecone index for context passage with the answer
|
66 |
xc = index.query(xq, top_k=top_k, include_metadata=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
return xc
|
68 |
|
69 |
|
@@ -127,19 +133,19 @@ st.title("Abstractive Question Answering - APPL")
|
|
127 |
|
128 |
query_text = st.text_input("Input Query", value="Who is the CEO of Apple?")
|
129 |
|
130 |
-
num_results = int(st.number_input("Number of Results to query", 1, 5, value=
|
131 |
|
132 |
|
133 |
# Choose encoder model
|
134 |
|
135 |
-
encoder_models_choice = ["
|
136 |
|
137 |
encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
|
138 |
|
139 |
|
140 |
# Choose decoder model
|
141 |
|
142 |
-
decoder_models_choice = ["GPT3 (QA_davinci)", "GPT3 (
|
143 |
|
144 |
decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice)
|
145 |
|
@@ -163,23 +169,33 @@ elif encoder_model == "SGPT":
|
|
163 |
retriever_model = get_sgpt_embedding_model()
|
164 |
|
165 |
|
166 |
-
query_results = query_pinecone(query_text, num_results, retriever_model, pinecone_index)
|
167 |
-
|
168 |
window = int(st.number_input("Sentence Window Size", 1, 3, value=1))
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
data = get_data()
|
171 |
|
172 |
-
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
|
176 |
st.subheader("Answer:")
|
177 |
|
178 |
|
179 |
-
if decoder_model == "GPT3 (
|
180 |
openai_key = st.text_input(
|
181 |
"Enter OpenAI key",
|
182 |
-
value="sk-
|
183 |
type="password",
|
184 |
)
|
185 |
api_key = save_key(openai_key)
|
@@ -193,7 +209,7 @@ if decoder_model == "GPT3 (text_davinci)":
|
|
193 |
elif decoder_model == "GPT3 (QA_davinci)":
|
194 |
openai_key = st.text_input(
|
195 |
"Enter OpenAI key",
|
196 |
-
value="sk-
|
197 |
type="password",
|
198 |
)
|
199 |
api_key = save_key(openai_key)
|
|
|
59 |
return api_key
|
60 |
|
61 |
|
62 |
+
def query_pinecone(query, top_k, model, index, threshold=0.5):
|
63 |
# generate embeddings for the query
|
64 |
xq = model.encode([query]).tolist()
|
65 |
# search pinecone index for context passage with the answer
|
66 |
xc = index.query(xq, top_k=top_k, include_metadata=True)
|
67 |
+
# filter the context passages based on the score threshold
|
68 |
+
filtered_matches = []
|
69 |
+
for match in xc["matches"]:
|
70 |
+
if match["score"] >= threshold:
|
71 |
+
filtered_matches.append(match)
|
72 |
+
xc["matches"] = filtered_matches
|
73 |
return xc
|
74 |
|
75 |
|
|
|
133 |
|
134 |
query_text = st.text_input("Input Query", value="Who is the CEO of Apple?")
|
135 |
|
136 |
+
num_results = int(st.number_input("Number of Results to query", 1, 5, value=3))
|
137 |
|
138 |
|
139 |
# Choose encoder model
|
140 |
|
141 |
+
encoder_models_choice = ["SGPT", "MPNET"]
|
142 |
|
143 |
encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
|
144 |
|
145 |
|
146 |
# Choose decoder model
|
147 |
|
148 |
+
decoder_models_choice = ["GPT3 (QA_davinci)", "GPT3 (summary_davinci)", "T5", "FLAN-T5"]
|
149 |
|
150 |
decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice)
|
151 |
|
|
|
169 |
retriever_model = get_sgpt_embedding_model()
|
170 |
|
171 |
|
|
|
|
|
172 |
window = int(st.number_input("Sentence Window Size", 1, 3, value=1))
|
173 |
|
174 |
+
threshold = float(
|
175 |
+
st.number_input(
|
176 |
+
label="Similarity Score Threshold", step=0.05, format="%.2f", value=0.55
|
177 |
+
)
|
178 |
+
)
|
179 |
+
|
180 |
data = get_data()
|
181 |
|
182 |
+
query_results = query_pinecone(
|
183 |
+
query_text, num_results, retriever_model, pinecone_index, threshold
|
184 |
+
)
|
185 |
+
|
186 |
+
if threshold <= 0.65:
|
187 |
+
context_list = sentence_id_combine(data, query_results, lag=window)
|
188 |
+
else:
|
189 |
+
context_list = format_query(query_results)
|
190 |
|
191 |
|
192 |
st.subheader("Answer:")
|
193 |
|
194 |
|
195 |
+
if decoder_model == "GPT3 (summary_davinci)":
|
196 |
openai_key = st.text_input(
|
197 |
"Enter OpenAI key",
|
198 |
+
value="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",
|
199 |
type="password",
|
200 |
)
|
201 |
api_key = save_key(openai_key)
|
|
|
209 |
elif decoder_model == "GPT3 (QA_davinci)":
|
210 |
openai_key = st.text_input(
|
211 |
"Enter OpenAI key",
|
212 |
+
value="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",
|
213 |
type="password",
|
214 |
)
|
215 |
api_key = save_key(openai_key)
|