Spaces:
Running
Running
fixed rewritten query error
Browse files- app.py +18 -13
- finetune_backend.py +4 -2
app.py
CHANGED
@@ -462,15 +462,23 @@ def main():
|
|
462 |
# with st.spinner('Generating Response...'):
|
463 |
|
464 |
with col1:
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
new_query =
|
473 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
474 |
|
475 |
# we can arrive here only if a guest was selected
|
476 |
where_filter = WhereFilter(path=['guest'], operator='Equal', valueText=guest).todict() \
|
@@ -509,10 +517,7 @@ def main():
|
|
509 |
|
510 |
|
511 |
# I jump out of col1 to get all page width, so need to retest query
|
512 |
-
if query
|
513 |
-
show_query = st.toggle('Show rewritten query', True)
|
514 |
-
if show_query: # or reworded_query['changed']:
|
515 |
-
st.write(f"Rewritten query: {query}")
|
516 |
|
517 |
# creates container for LLM response to position it above search results
|
518 |
chat_container, response_box = [], st.empty()
|
|
|
462 |
# with st.spinner('Generating Response...'):
|
463 |
|
464 |
with col1:
|
465 |
+
|
466 |
+
use_reworded_query = st.toggle('Use rewritten query', True)
|
467 |
+
if use_reworded_query:
|
468 |
+
|
469 |
+
# let's use Llama2, and fall back on GPT3.5 if it fails
|
470 |
+
reworded_query = reword_query(query, guest,
|
471 |
+
model_name='llama2-13b-chat')
|
472 |
+
new_query = reworded_query['rewritten_question']
|
473 |
+
|
474 |
+
if reworded_query['status'] != 'error': # or reworded_query['changed']:
|
475 |
+
guest_lastname = guest.split(' ')[1]
|
476 |
+
if guest_lastname not in new_query:
|
477 |
+
# if the guest name is not in the rewritten question, we add it
|
478 |
+
new_query = f"About {guest}, " + new_query
|
479 |
+
|
480 |
+
query = new_query
|
481 |
+
st.write(f"Rewritten query: {query}")
|
482 |
|
483 |
# we can arrive here only if a guest was selected
|
484 |
where_filter = WhereFilter(path=['guest'], operator='Equal', valueText=guest).todict() \
|
|
|
517 |
|
518 |
|
519 |
# I jump out of col1 to get all page width, so need to retest query
|
520 |
+
if query:
|
|
|
|
|
|
|
521 |
|
522 |
# creates container for LLM response to position it above search results
|
523 |
chat_container, response_box = [], st.empty()
|
finetune_backend.py
CHANGED
@@ -16,7 +16,9 @@ valid_path = 'data/validation_data_100.json'
|
|
16 |
training_set = EmbeddingQAFinetuneDataset.from_json(training_path)
|
17 |
valid_set = EmbeddingQAFinetuneDataset.from_json(valid_path)
|
18 |
|
19 |
-
def finetune(model='all-mpnet-base-v2',
|
|
|
|
|
20 |
""" Finetunes a model on Modal GPU A100.
|
21 |
The model is saved in /root/models on a Modal volume
|
22 |
and can be stored locally.
|
@@ -34,7 +36,7 @@ def finetune(model='all-mpnet-base-v2', savemodel=False, outpath='.'):
|
|
34 |
model = model.replace('/','')
|
35 |
model = f"sentence-transformers/{model}"
|
36 |
|
37 |
-
fullpath = os.path.join(outpath, f"finetuned-{model}-300")
|
38 |
st.sidebar.write(f"Model will be saved in {fullpath}")
|
39 |
|
40 |
if os.path.exists(fullpath):
|
|
|
16 |
training_set = EmbeddingQAFinetuneDataset.from_json(training_path)
|
17 |
valid_set = EmbeddingQAFinetuneDataset.from_json(valid_path)
|
18 |
|
19 |
+
def finetune(model: str='sentence-transformers/all-mpnet-base-v2',
|
20 |
+
savemodel: bool =False,
|
21 |
+
outpath: str='.'):
|
22 |
""" Finetunes a model on Modal GPU A100.
|
23 |
The model is saved in /root/models on a Modal volume
|
24 |
and can be stored locally.
|
|
|
36 |
model = model.replace('/','')
|
37 |
model = f"sentence-transformers/{model}"
|
38 |
|
39 |
+
fullpath = os.path.join(outpath, f"finetuned-{model.strip('/')[-1]}-300")
|
40 |
st.sidebar.write(f"Model will be saved in {fullpath}")
|
41 |
|
42 |
if os.path.exists(fullpath):
|