lfoppiano commited on
Commit
a2bcc71
1 Parent(s): d2299da

disable conversational memory with zephyr

Browse files
Files changed (1) hide show
  1. streamlit_app.py +12 -7
streamlit_app.py CHANGED
@@ -54,7 +54,7 @@ if 'uploaded' not in st.session_state:
54
  st.session_state['uploaded'] = False
55
 
56
  if 'memory' not in st.session_state:
57
- st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
58
 
59
  if 'binary' not in st.session_state:
60
  st.session_state['binary'] = None
@@ -117,12 +117,14 @@ def clear_memory():
117
  def init_qa(model, api_key=None):
118
  ## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])])
119
  if model == 'chatgpt-3.5-turbo':
 
120
  if api_key:
121
  chat = ChatOpenAI(model_name="gpt-3.5-turbo",
122
  temperature=0,
123
  openai_api_key=api_key,
124
  frequency_penalty=0.1)
125
  embeddings = OpenAIEmbeddings(openai_api_key=api_key)
 
126
  else:
127
  chat = ChatOpenAI(model_name="gpt-3.5-turbo",
128
  temperature=0,
@@ -134,11 +136,13 @@ def init_qa(model, api_key=None):
134
  model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048})
135
  embeddings = HuggingFaceEmbeddings(
136
  model_name="all-MiniLM-L6-v2")
 
137
 
138
  elif model == 'zephyr-7b-beta':
139
  chat = HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta",
140
  model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048})
141
  embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
 
142
  else:
143
  st.error("The model was not loaded properly. Try reloading. ")
144
  st.stop()
@@ -255,7 +259,8 @@ with st.sidebar:
255
  'Reset chat memory.',
256
  key="reset-memory-button",
257
  on_click=clear_memory,
258
- help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.")
 
259
 
260
  left_column, right_column = st.columns([1, 1])
261
 
@@ -267,8 +272,8 @@ with right_column:
267
  ":warning: Do not upload sensitive data. We **temporarily** store text from the uploaded PDF documents solely for the purpose of processing your request, and we **do not assume responsibility** for any subsequent use or handling of the data submitted to third parties LLMs.")
268
 
269
  uploaded_file = st.file_uploader("Upload an article",
270
- type=("pdf", "txt"),
271
- on_change=new_file,
272
  disabled=st.session_state['model'] is not None and st.session_state['model'] not in
273
  st.session_state['api_keys'],
274
  help="The full-text is extracted using Grobid. ")
@@ -335,8 +340,8 @@ if uploaded_file and not st.session_state.loaded_embeddings:
335
 
336
  st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
337
  chunk_size=chunk_size,
338
- perc_overlap=0.1,
339
- include_biblio=True)
340
  st.session_state['loaded_embeddings'] = True
341
  st.session_state.messages = []
342
 
@@ -389,7 +394,7 @@ with right_column:
389
  elif mode == "LLM":
390
  with st.spinner("Generating response..."):
391
  _, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
392
- context_size=context_size)
393
 
394
  if not text_response:
395
  st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
 
54
  st.session_state['uploaded'] = False
55
 
56
  if 'memory' not in st.session_state:
57
+ st.session_state['memory'] = None
58
 
59
  if 'binary' not in st.session_state:
60
  st.session_state['binary'] = None
 
117
  def init_qa(model, api_key=None):
118
  ## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])])
119
  if model == 'chatgpt-3.5-turbo':
120
+ st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
121
  if api_key:
122
  chat = ChatOpenAI(model_name="gpt-3.5-turbo",
123
  temperature=0,
124
  openai_api_key=api_key,
125
  frequency_penalty=0.1)
126
  embeddings = OpenAIEmbeddings(openai_api_key=api_key)
127
+
128
  else:
129
  chat = ChatOpenAI(model_name="gpt-3.5-turbo",
130
  temperature=0,
 
136
  model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048})
137
  embeddings = HuggingFaceEmbeddings(
138
  model_name="all-MiniLM-L6-v2")
139
+ st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
140
 
141
  elif model == 'zephyr-7b-beta':
142
  chat = HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta",
143
  model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048})
144
  embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
145
+ st.session_state['memory'] = None
146
  else:
147
  st.error("The model was not loaded properly. Try reloading. ")
148
  st.stop()
 
259
  'Reset chat memory.',
260
  key="reset-memory-button",
261
  on_click=clear_memory,
262
+ help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.",
263
+ disabled=model in st.session_state['rqa'] and st.session_state['rqa'][model].memory is None)
264
 
265
  left_column, right_column = st.columns([1, 1])
266
 
 
272
  ":warning: Do not upload sensitive data. We **temporarily** store text from the uploaded PDF documents solely for the purpose of processing your request, and we **do not assume responsibility** for any subsequent use or handling of the data submitted to third parties LLMs.")
273
 
274
  uploaded_file = st.file_uploader("Upload an article",
275
+ type=("pdf", "txt"),
276
+ on_change=new_file,
277
  disabled=st.session_state['model'] is not None and st.session_state['model'] not in
278
  st.session_state['api_keys'],
279
  help="The full-text is extracted using Grobid. ")
 
340
 
341
  st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
342
  chunk_size=chunk_size,
343
+ perc_overlap=0.1,
344
+ include_biblio=True)
345
  st.session_state['loaded_embeddings'] = True
346
  st.session_state.messages = []
347
 
 
394
  elif mode == "LLM":
395
  with st.spinner("Generating response..."):
396
  _, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
397
+ context_size=context_size)
398
 
399
  if not text_response:
400
  st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")