JPBianchi commited on
Commit
88b4a61
1 Parent(s): a6d45c6

fixed rewritten query error

Browse files
Files changed (2) hide show
  1. app.py +18 -13
  2. finetune_backend.py +4 -2
app.py CHANGED
@@ -462,15 +462,23 @@ def main():
462
  # with st.spinner('Generating Response...'):
463
 
464
  with col1:
465
-
466
- # let's use Llama2 here
467
- reworded_query = reword_query(query, guest,
468
- model_name='llama2-13b-chat')
469
- new_query = reworded_query['rewritten_question']
470
- if guest.split(' ')[1] not in new_query and guest.split(' ')[0] not in new_query:
471
- # if the guest name is not in the rewritten question, we add it
472
- new_query = f"About {guest}, " + new_query
473
- query = new_query
 
 
 
 
 
 
 
 
474
 
475
  # we can arrive here only if a guest was selected
476
  where_filter = WhereFilter(path=['guest'], operator='Equal', valueText=guest).todict() \
@@ -509,10 +517,7 @@ def main():
509
 
510
 
511
  # I jump out of col1 to get all page width, so need to retest query
512
- if query is not None and reworded_query['status'] != 'error':
513
- show_query = st.toggle('Show rewritten query', True)
514
- if show_query: # or reworded_query['changed']:
515
- st.write(f"Rewritten query: {query}")
516
 
517
  # creates container for LLM response to position it above search results
518
  chat_container, response_box = [], st.empty()
 
462
  # with st.spinner('Generating Response...'):
463
 
464
  with col1:
465
+
466
+ use_reworded_query = st.toggle('Use rewritten query', True)
467
+ if use_reworded_query:
468
+
469
+ # let's use Llama2, and fall back on GPT3.5 if it fails
470
+ reworded_query = reword_query(query, guest,
471
+ model_name='llama2-13b-chat')
472
+ new_query = reworded_query['rewritten_question']
473
+
474
+ if reworded_query['status'] != 'error': # or reworded_query['changed']:
475
+ guest_lastname = guest.split(' ')[1]
476
+ if guest_lastname not in new_query:
477
+ # if the guest name is not in the rewritten question, we add it
478
+ new_query = f"About {guest}, " + new_query
479
+
480
+ query = new_query
481
+ st.write(f"Rewritten query: {query}")
482
 
483
  # we can arrive here only if a guest was selected
484
  where_filter = WhereFilter(path=['guest'], operator='Equal', valueText=guest).todict() \
 
517
 
518
 
519
  # I jump out of col1 to get all page width, so need to retest query
520
+ if query:
 
 
 
521
 
522
  # creates container for LLM response to position it above search results
523
  chat_container, response_box = [], st.empty()
finetune_backend.py CHANGED
@@ -16,7 +16,9 @@ valid_path = 'data/validation_data_100.json'
16
  training_set = EmbeddingQAFinetuneDataset.from_json(training_path)
17
  valid_set = EmbeddingQAFinetuneDataset.from_json(valid_path)
18
 
19
- def finetune(model='all-mpnet-base-v2', savemodel=False, outpath='.'):
 
 
20
  """ Finetunes a model on Modal GPU A100.
21
  The model is saved in /root/models on a Modal volume
22
  and can be stored locally.
@@ -34,7 +36,7 @@ def finetune(model='all-mpnet-base-v2', savemodel=False, outpath='.'):
34
  model = model.replace('/','')
35
  model = f"sentence-transformers/{model}"
36
 
37
- fullpath = os.path.join(outpath, f"finetuned-{model}-300")
38
  st.sidebar.write(f"Model will be saved in {fullpath}")
39
 
40
  if os.path.exists(fullpath):
 
16
  training_set = EmbeddingQAFinetuneDataset.from_json(training_path)
17
  valid_set = EmbeddingQAFinetuneDataset.from_json(valid_path)
18
 
19
+ def finetune(model: str='sentence-transformers/all-mpnet-base-v2',
20
+ savemodel: bool =False,
21
+ outpath: str='.'):
22
  """ Finetunes a model on Modal GPU A100.
23
  The model is saved in /root/models on a Modal volume
24
  and can be stored locally.
 
36
  model = model.replace('/','')
37
  model = f"sentence-transformers/{model}"
38
 
39
+ fullpath = os.path.join(outpath, f"finetuned-{model.strip('/')[-1]}-300")
40
  st.sidebar.write(f"Model will be saved in {fullpath}")
41
 
42
  if os.path.exists(fullpath):