ivnban27-ctl commited on
Commit
0f381ca
1 Parent(s): 39d66f7

Added MongoDB functionality

Browse files
.gitignore CHANGED
@@ -177,7 +177,7 @@ cython_debug/
177
 
178
  # Jupyter NB Checkpoints
179
  .ipynb_checkpoints/
180
-
181
  # exclude data from source control by default
182
  /data/
183
 
 
177
 
178
  # Jupyter NB Checkpoints
179
  .ipynb_checkpoints/
180
+ *.ipynb
181
  # exclude data from source control by default
182
  /data/
183
 
app_config.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ISSUES = ['Anxiety','Suicide']
2
+ SOURCES = ['OA_rolemodel', 'OA_finetuned']
3
+ SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT3.5',
4
+ "OA_finetuned":'Finetuned OpenAI'}
5
+
6
+ def source2label(source):
7
+ return SOURCES_LAB[source]
8
+
9
+ ENVIRON = "dev"
10
+
11
+ DB_SCHEMA = 'prod_db' if ENVIRON == 'prod' else 'test_db'
12
+ DB_CONVOS = 'conversations'
13
+ DB_COMPLETIONS = 'comparison_completions'
14
+ DB_BATTLES = 'battles'
15
+ DB_ERRORS = 'completion_errors'
app_utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime as dt
2
+ import streamlit as st
3
+ from streamlit.logger import get_logger
4
+ import langchain
5
+ from langchain.memory import ConversationBufferMemory
6
+
7
+ from app_config import ENVIRON
8
+ from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
9
+ from models.openai.role_models import get_role_chain, role_templates
10
+ from mongo_utils import new_convo
11
+
12
+ langchain.verbose = ENVIRON=="dev"
13
+ logger = get_logger(__name__)
14
+
15
+ def add_initial_message(model_name, memory):
16
+ if "Spanish" in model_name:
17
+ memory.chat_memory.add_ai_message("Hola necesito ayuda")
18
+ else:
19
+ memory.chat_memory.add_ai_message("Hi I need help")
20
+
21
+
22
+ def push_convo2db(memories, username, language):
23
+ if len(memories) == 1:
24
+ issue = memories['memory']['issue']
25
+ model_one = memories['memory']['source']
26
+ new_convo(st.session_state['db_client'], issue, language, username, False, model_one)
27
+ else:
28
+ issue = memories['commonMemory']['issue']
29
+ model_one = memories['memoryA']['source']
30
+ model_two = memories['memoryB']['source']
31
+ new_convo(st.session_state['db_client'], issue, language, username, True, model_one, model_two)
32
+
33
+ def change_memories(memories, username, language, changed_source=False):
34
+ for memory, params in memories.items():
35
+ if (memory not in st.session_state) or changed_source:
36
+ source = params['source']
37
+ logger.info(f"Source for memory {memory} is {source}")
38
+ if source in ('OA_rolemodel','OA_finetuned'):
39
+ st.session_state[memory] = ConversationBufferMemory(ai_prefix='texter', human_prefix='helper')
40
+
41
+ if ("convo_id" in st.session_state) and changed_source:
42
+ del st.session_state['convo_id']
43
+
44
+
45
+ def clear_memory(memories, username, language):
46
+ for memory, _ in memories.items():
47
+ st.session_state[memory].clear()
48
+
49
+ if "convo_id" in st.session_state:
50
+ del st.session_state['convo_id']
51
+
52
+
53
+ def create_memory_add_initial_message(memories, username, language, changed_source=False):
54
+ change_memories(memories, username, language, changed_source=changed_source)
55
+ for memory, _ in memories.items():
56
+ if len(st.session_state[memory].buffer_as_messages) < 1:
57
+ add_initial_message(language, st.session_state[memory])
58
+
59
+
60
+ def get_chain(issue, language, source, memory, temperature):
61
+ if source in ("OA_finetuned"):
62
+ OA_engine = finetuned_models[f"{issue}-{language}"]
63
+ return get_finetuned_chain(OA_engine, memory, temperature)
64
+ elif source in ('OA_rolemodel'):
65
+ template = role_templates[f"{issue}-{language}"]
66
+ return get_role_chain(template, memory, temperature)
convosim.py CHANGED
@@ -1,38 +1,53 @@
1
- import openai
2
  import os
3
  import streamlit as st
 
4
  from langchain.schema.messages import HumanMessage
 
 
 
5
 
6
- from utils import create_memory_add_initial_message, clear_memory, get_chain
7
-
8
  openai_api_key = os.environ['OPENAI_API_KEY']
9
- memories = ['memory']
 
 
 
 
 
10
 
11
  with st.sidebar:
 
12
  temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
13
- issue = st.selectbox("Select an Issue", ['Anxiety','Suicide'], index=0,
14
- on_change=clear_memory, args=(memories,)
15
  )
16
  supported_languages = ['English', "Spanish"] if issue == "Anxiety" else ['English']
17
  language = st.selectbox("Select a Language", supported_languages, index=0,
18
- on_change=clear_memory, args=(memories,)
19
  )
20
 
21
- source = st.selectbox("Select a source Model A", ['OpenAI GPT3.5','Finetuned OpenAI'], index=1,
22
- on_change=clear_memory, args=(memories,)
23
  )
24
 
25
- create_memory_add_initial_message(memories, language)
26
- llm_chain = get_chain(issue, language, source, st.session_state[memories[0]], temperature)
 
 
 
 
27
 
28
  st.title("💬 Simulator")
29
 
30
- for msg in st.session_state[memories[0]].buffer_as_messages:
31
  role = "user" if type(msg) == HumanMessage else "assistant"
32
  st.chat_message(role).write(msg.content)
33
 
34
  if prompt := st.chat_input():
 
 
 
35
  st.chat_message("user").write(prompt)
36
- response = llm_chain.predict(input=prompt, stop="helper:")
37
  # response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature)
38
  st.chat_message("assistant").write(response)
 
 
1
  import os
2
  import streamlit as st
3
+ from streamlit.logger import get_logger
4
  from langchain.schema.messages import HumanMessage
5
+ from mongo_utils import get_db_client
6
+ from app_utils import create_memory_add_initial_message, clear_memory, get_chain, push_convo2db
7
+ from app_config import ISSUES, SOURCES, source2label
8
 
9
+ logger = get_logger(__name__)
 
10
  openai_api_key = os.environ['OPENAI_API_KEY']
11
+ memories = {'memory':{"issue": ISSUES[0], "source": SOURCES[0]}}
12
+
13
+ if 'previous_source' not in st.session_state:
14
+ st.session_state['previous_source'] = SOURCES[0]
15
+ if 'db_client' not in st.session_state:
16
+ st.session_state["db_client"] = get_db_client()
17
 
18
  with st.sidebar:
19
+ username = st.text_input("Username", value='ivnban-ctl', max_chars=30)
20
  temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
21
+ issue = st.selectbox("Select an Issue", ISSUES, index=0,
22
+ on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
23
  )
24
  supported_languages = ['English', "Spanish"] if issue == "Anxiety" else ['English']
25
  language = st.selectbox("Select a Language", supported_languages, index=0,
26
+ on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
27
  )
28
 
29
+ source = st.selectbox("Select a source Model A", SOURCES, index=1,
30
+ format_func=source2label,
31
  )
32
 
33
+ memories = {'memory':{"issue":issue, "source":source}}
34
+ changed_source = st.session_state['previous_source'] != source
35
+ create_memory_add_initial_message(memories, username, language, changed_source=changed_source)
36
+ st.session_state['previous_source'] = source
37
+ memoryA = st.session_state[list(memories.keys())[0]]
38
+ llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature)
39
 
40
  st.title("💬 Simulator")
41
 
42
+ for msg in memoryA.buffer_as_messages:
43
  role = "user" if type(msg) == HumanMessage else "assistant"
44
  st.chat_message(role).write(msg.content)
45
 
46
  if prompt := st.chat_input():
47
+ if 'convo_id' not in st.session_state:
48
+ push_convo2db(memories, username, language)
49
+
50
  st.chat_message("user").write(prompt)
51
+ response = llm_chain.predict(input=prompt, stop=stopper)
52
  # response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature)
53
  st.chat_message("assistant").write(response)
models/openai/finetuned_models.py CHANGED
@@ -1,9 +1,11 @@
1
- import openai
2
  from models.custom_parsers import CustomStringOutputParser
3
  from langchain.chains import LLMChain
4
  from langchain.llms import OpenAI
5
  from langchain.prompts import PromptTemplate
6
- import logging
 
 
7
 
8
  finetuned_models = {
9
  # "olivia_babbage_engine": "babbage:ft-crisis-text-line:exp-olivia-babbage-2023-02-23-19-57-19",
@@ -67,9 +69,8 @@ def get_finetuned_chain(model_name, memory, temperature=0.8):
67
  llm_chain = LLMChain(
68
  llm=llm,
69
  prompt=PROMPT,
70
- verbose=True,
71
  memory=memory,
72
  output_parser = CustomStringOutputParser()
73
  )
74
- logging.debug(f"loaded fine tuned model {model_name}")
75
- return llm_chain
 
1
+ # from streamlit.logger import get_logger
2
  from models.custom_parsers import CustomStringOutputParser
3
  from langchain.chains import LLMChain
4
  from langchain.llms import OpenAI
5
  from langchain.prompts import PromptTemplate
6
+
7
+ # logger = get_logger(__name__)
8
+ # logger.debug("START APP")
9
 
10
  finetuned_models = {
11
  # "olivia_babbage_engine": "babbage:ft-crisis-text-line:exp-olivia-babbage-2023-02-23-19-57-19",
 
69
  llm_chain = LLMChain(
70
  llm=llm,
71
  prompt=PROMPT,
 
72
  memory=memory,
73
  output_parser = CustomStringOutputParser()
74
  )
75
+ # logger.debug(f"{__name__}: loaded fine tuned model {model_name}")
76
+ return llm_chain, "helper:"
models/openai/role_models.py CHANGED
@@ -49,9 +49,8 @@ def get_role_chain(template, memory, temperature=0.8):
49
  llm_chain = ConversationChain(
50
  llm=llm,
51
  prompt=PROMPT,
52
- verbose=True,
53
  memory=memory,
54
  output_parser=CustomStringOutputParser()
55
  )
56
  logging.debug(f"loaded GPT3.5 model")
57
- return llm_chain
 
49
  llm_chain = ConversationChain(
50
  llm=llm,
51
  prompt=PROMPT,
 
52
  memory=memory,
53
  output_parser=CustomStringOutputParser()
54
  )
55
  logging.debug(f"loaded GPT3.5 model")
56
+ return llm_chain, "helper:"
mongo_utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import datetime as dt
3
+ import streamlit as st
4
+ from streamlit.logger import get_logger
5
+ from pymongo.mongo_client import MongoClient
6
+ from pymongo.server_api import ServerApi
7
+ from app_config import DB_SCHEMA, DB_COMPLETIONS, DB_CONVOS, DB_BATTLES, DB_ERRORS
8
+
9
+ DB_URL = os.environ['MONGO_URL']
10
+ DB_USR = os.environ['MONGO_USR']
11
+ DB_PWD = os.environ['MONGO_PWD']
12
+
13
+ logger = get_logger(__name__)
14
+
15
+ def get_db_client():
16
+ uri = f"mongodb+srv://{DB_USR}:{DB_PWD}@{DB_URL}/?retryWrites=true&w=majority"
17
+ # Create a new client and connect to the server
18
+ client = MongoClient(uri, server_api=ServerApi('1'))
19
+ # Send a ping to confirm a successful connection
20
+ try:
21
+ client.admin.command('ping')
22
+ logger.info(f"DBUTILS: Pinged your deployment. You successfully connected to MongoDB!")
23
+ return client
24
+ except Exception as e:
25
+ logger.error(e)
26
+
27
+ def new_convo(client, issue, language, username, is_comparison, model_one, model_two=None):
28
+ convo = {
29
+ "start_timestamp": dt.datetime.now(tz=dt.timezone.utc),
30
+ "issue": issue,
31
+ "language": language,
32
+ "username": username,
33
+ "is_comparison": is_comparison,
34
+ "model_one": model_one,
35
+ "model_two": model_two,
36
+ }
37
+
38
+ db = client[DB_SCHEMA]
39
+ convos = db[DB_CONVOS]
40
+ convo_id = convos.insert_one(convo).inserted_id
41
+ logger.info(f"DBUTILS: new convo id is {convo_id}")
42
+ st.session_state['convo_id'] = convo_id
43
+
44
+ def new_comparison(client, prompt_timestamp, completion_timestamp,
45
+ chat_history, prompt, completionA, completionB,
46
+ source="webapp", subset=None
47
+ ):
48
+ comparison = {
49
+ "prompt_timestamp": prompt_timestamp,
50
+ "completion_timestamp": completion_timestamp,
51
+ "source": source,
52
+ "subset": subset,
53
+ "model_one_args": {
54
+ 'temperature':0.8
55
+ },
56
+ "model_two_args": {
57
+ 'temperature':0.8
58
+ },
59
+ "convo_id": st.session_state['convo_id'],
60
+ "chat_history": chat_history,
61
+ "prompt": prompt,
62
+ "compeltion_model_one": completionA,
63
+ "compeltion_model_two": completionB,
64
+ }
65
+
66
+ db = client[DB_SCHEMA]
67
+ comparisons = db[DB_COMPLETIONS]
68
+ comparison_id = comparisons.insert_one(comparison).inserted_id
69
+ logger.info(f"DBUTILS: new comparison id is {comparison_id}")
70
+ st.session_state['comparison_id'] = comparison_id
71
+
72
+ def new_battle_result(client, comparison_id, convo_id, username, model_one, model_two, winner):
73
+ battle = {
74
+ "battle_timestamp": dt.datetime.now(tz=dt.timezone.utc),
75
+ "comparison_id": comparison_id,
76
+ "convo_id": convo_id,
77
+ "username": username,
78
+ "model_one": model_one,
79
+ "model_two": model_two,
80
+ "winner": winner,
81
+
82
+ }
83
+
84
+ db = client[DB_SCHEMA]
85
+ battles = db[DB_BATTLES]
86
+ battle_id = battles.insert_one(battle).inserted_id
87
+ logger.info(f"DBUTILS: new battle id is {battle_id}")
88
+
89
+ def new_completion_error(client, comparison_id, username, model):
90
+ error = {
91
+ "error_timestamp": dt.datetime.now(tz=dt.timezone.utc),
92
+ "comparison_id": comparison_id,
93
+ "username": username,
94
+ "model": model,
95
+ }
96
+
97
+ db = client[DB_SCHEMA]
98
+ errors = db[DB_ERRORS]
99
+ error_id = errors.insert_one(error).inserted_id
100
+ logger.info(f"DBUTILS: new error id is {error_id}")
101
+
102
+ def get_non_assesed_comparison(client, username):
103
+ from bson.son import SON
104
+ pipeline = [
105
+ {'$lookup': {
106
+ 'from': DB_BATTLES,
107
+ 'localField': '_id',
108
+ 'foreignField': 'comparison_id',
109
+ "pipeline": [
110
+ {"$match": {"username":username}},
111
+ ],
112
+ 'as': 'battles'
113
+ }},
114
+ {'$lookup': {
115
+ 'from': DB_CONVOS,
116
+ 'localField': 'convo_id',
117
+ 'foreignField': '_id',
118
+ 'as': 'convo_info'
119
+ }},
120
+ {"$match":{
121
+ "battles": {"$size":0},
122
+ }},
123
+ {"$addFields": {
124
+ "priority": {
125
+ "$cond":[
126
+ {"$eq": ["$source","manual"]},
127
+ 1,
128
+ 0
129
+ ]
130
+ },
131
+ }},
132
+ {"$sort": SON([
133
+ ("priority", -1),
134
+ ("prompt_timestamp", 1),
135
+ ("convo_id", 1),
136
+ ])
137
+ },
138
+ {"$limit": 1}
139
+ ]
140
+
141
+ db = client[DB_SCHEMA]
142
+ return list(db[DB_COMPLETIONS].aggregate(pipeline))
143
+
pages/comparisor.py CHANGED
@@ -1,14 +1,27 @@
1
- import openai
2
  import os
 
 
3
  import streamlit as st
 
4
  from langchain.schema.messages import HumanMessage
5
- import logging
6
-
7
- from utils import create_memory_add_initial_message, clear_memory, get_chain
8
 
 
9
  openai_api_key = os.environ['OPENAI_API_KEY']
10
- memories = ['memoryA', 'memoryB', 'commonMemory']
11
-
 
 
 
 
 
 
 
 
 
12
 
13
  def delete_last_message(memory):
14
  last_prompt = memory.chat_memory.messages[-2].content
@@ -20,59 +33,91 @@ def replace_last_message(memory, new_message):
20
  memory.chat_memory.add_ai_message(new_message)
21
 
22
  def regenerateA():
23
- last_prompt = delete_last_message(st.session_state[memories[0]])
24
- new_response = llm_chainA.predict(input=last_prompt, stop="helper:")
25
  col1.chat_message("user").write(last_prompt)
26
  col1.chat_message("assistant").write(new_response)
 
27
 
28
  def regenerateB():
29
- last_prompt = delete_last_message(st.session_state[memories[1]])
30
- new_response = llm_chainB.predict(input=last_prompt, stop="helper:")
31
  col2.chat_message("user").write(last_prompt)
32
  col2.chat_message("assistant").write(new_response)
 
33
 
34
  def replaceA():
35
- last_prompt = st.session_state[memories[1]].chat_memory.messages[-2].content
36
- new_message = st.session_state[memories[1]].chat_memory.messages[-1].content
37
- replace_last_message(st.session_state[memories[0]], new_message)
38
  st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
39
 
 
 
 
 
 
 
40
  def replaceB():
41
- last_prompt = st.session_state[memories[0]].chat_memory.messages[-2].content
42
- new_message = st.session_state[memories[0]].chat_memory.messages[-1].content
43
- replace_last_message(st.session_state[memories[1]], new_message)
44
  st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
45
 
 
 
 
 
 
 
46
  def regenerateBoth():
47
- regenerateA()
48
- regenerateB()
 
 
 
 
 
 
 
 
 
 
49
 
50
  def bothGood():
51
- if len(st.session_state['memoryA'].buffer_as_messages) == 1:
52
  pass
53
  else:
54
- last_prompt = st.session_state[memories[0]].chat_memory.messages[-2].content
55
- last_reponse = st.session_state[memories[0]].chat_memory.messages[-1].content
 
56
  st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":last_reponse})
 
 
 
 
 
 
57
 
58
  with st.sidebar:
59
- issue = st.selectbox("Select an Issue", ['Anxiety','Suicide'], index=0,
60
- on_change=clear_memory, args=(memories,)
 
61
  )
62
  supported_languages = ['English', "Spanish"] if issue == "Anxiety" else ['English']
63
  language = st.selectbox("Select a Language", supported_languages, index=0,
64
- on_change=clear_memory, args=(memories,)
65
  )
66
 
67
  with st.expander("Model A"):
68
  temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1)
69
- sourceA = st.selectbox("Select a source Model A", ['OpenAI GPT3.5','Finetuned OpenAI'], index=0,
70
- on_change=clear_memory, args=(memories,)
71
  )
72
  with st.expander("Model B"):
73
  temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1)
74
- sourceB = st.selectbox("Select a source Model B", ['OpenAI GPT3.5','Finetuned OpenAI'], index=1,
75
- on_change=clear_memory, args=(memories,)
76
  )
77
 
78
  sbcol1, sbcol2 = st.columns(2)
@@ -86,9 +131,20 @@ with st.sidebar:
86
  # regenB = sbcol2.button("Regenerate B", on_click=regenerateB)
87
  clear = st.button("Clear History", on_click=clear_memory, args=(memories,))
88
 
89
- create_memory_add_initial_message(memories, language)
90
- llm_chainA = get_chain(issue, language, sourceA, st.session_state[memories[0]], temperatureA)
91
- llm_chainB = get_chain(issue, language, sourceB, st.session_state[memories[1]], temperatureB)
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  st.title(f"💬 History")
94
  for msg in st.session_state['commonMemory'].buffer_as_messages:
@@ -115,12 +171,20 @@ def disable_chat():
115
  return True
116
 
117
  if prompt := st.chat_input(disabled=disable_chat()):
 
 
 
 
118
  col1.chat_message("user").write(prompt)
119
  col2.chat_message("user").write(prompt)
120
 
121
- responseA = llm_chainA.predict(input=prompt, stop="helper:")
122
- responseB = llm_chainB.predict(input=prompt, stop="helper:")
 
123
 
 
 
 
124
  col1.chat_message("assistant").write(responseA)
125
  col2.chat_message("assistant").write(responseB)
126
 
 
1
+
2
  import os
3
+ import random
4
+ import datetime as dt
5
  import streamlit as st
6
+ from streamlit.logger import get_logger
7
  from langchain.schema.messages import HumanMessage
8
+ from mongo_utils import get_db_client, new_comparison, new_battle_result
9
+ from app_utils import create_memory_add_initial_message, clear_memory, get_chain, push_convo2db
10
+ from app_config import ISSUES, SOURCES, source2label
11
 
12
+ logger = get_logger(__name__)
13
  openai_api_key = os.environ['OPENAI_API_KEY']
14
+ memories = {
15
+ 'memoryA': {"issue": ISSUES[0], "source": SOURCES[0]},
16
+ 'memoryB': {"issue": ISSUES[0], "source": SOURCES[1]},
17
+ 'commonMemory': {"issue": ISSUES[0], "source": SOURCES[0]}
18
+ }
19
+ if 'db_client' not in st.session_state:
20
+ st.session_state["db_client"] = get_db_client()
21
+ if 'previous_sourceA' not in st.session_state:
22
+ st.session_state['previous_sourceA'] = SOURCES[0]
23
+ if 'previous_sourceB' not in st.session_state:
24
+ st.session_state['previous_sourceB'] = SOURCES[1]
25
 
26
  def delete_last_message(memory):
27
  last_prompt = memory.chat_memory.messages[-2].content
 
33
  memory.chat_memory.add_ai_message(new_message)
34
 
35
  def regenerateA():
36
+ last_prompt = delete_last_message(memoryA)
37
+ new_response = llm_chainA.predict(input=last_prompt, stop=stopperA)
38
  col1.chat_message("user").write(last_prompt)
39
  col1.chat_message("assistant").write(new_response)
40
+ return new_response
41
 
42
  def regenerateB():
43
+ last_prompt = delete_last_message(memoryB)
44
+ new_response = llm_chainB.predict(input=last_prompt, stop=stopperB)
45
  col2.chat_message("user").write(last_prompt)
46
  col2.chat_message("assistant").write(new_response)
47
+ return new_response
48
 
49
  def replaceA():
50
+ last_prompt = memoryB.chat_memory.messages[-2].content
51
+ new_message = memoryB.chat_memory.messages[-1].content
52
+ replace_last_message(memoryA, new_message)
53
  st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
54
 
55
+ new_battle_result(st.session_state['db_client'],
56
+ st.session_state['comparison_id'],
57
+ st.session_state['convo_id'],
58
+ username, sourceA, sourceB, winner='model_two'
59
+ )
60
+
61
  def replaceB():
62
+ last_prompt = memoryA.chat_memory.messages[-2].content
63
+ new_message = memoryA.chat_memory.messages[-1].content
64
+ replace_last_message(memoryB, new_message)
65
  st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
66
 
67
+ new_battle_result(st.session_state['db_client'],
68
+ st.session_state['comparison_id'],
69
+ st.session_state['convo_id'],
70
+ username, sourceA, sourceB, winner='model_one'
71
+ )
72
+
73
  def regenerateBoth():
74
+ promt_ts = dt.datetime.now(tz=dt.timezone.utc)
75
+ new_battle_result(st.session_state['db_client'],
76
+ st.session_state['comparison_id'],
77
+ st.session_state['convo_id'],
78
+ username, sourceA, sourceB, winner='both_bad'
79
+ )
80
+
81
+ responseA = regenerateA()
82
+ responseB = regenerateB()
83
+ completion_ts = dt.datetime.now(tz=dt.timezone.utc)
84
+ new_comparison(st.session_state['db_client'], promt_ts, completion_ts,
85
+ st.session_state['commonMemory'].buffer_as_str, prompt, responseA, responseB)
86
 
87
  def bothGood():
88
+ if len(memoryA.buffer_as_messages) == 1:
89
  pass
90
  else:
91
+ i = random.choice([memoryA, memoryB])
92
+ last_prompt = i.chat_memory.messages[-2].content
93
+ last_reponse = i.chat_memory.messages[-1].content
94
  st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":last_reponse})
95
+
96
+ new_battle_result(st.session_state['db_client'],
97
+ st.session_state['comparison_id'],
98
+ st.session_state['convo_id'],
99
+ username, sourceA, sourceB, winner='tie'
100
+ )
101
 
102
  with st.sidebar:
103
+ username = st.text_input("Username", value='ivnban-ctl', max_chars=30)
104
+ issue = st.selectbox("Select an Issue", ISSUES, index=0,
105
+ on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
106
  )
107
  supported_languages = ['English', "Spanish"] if issue == "Anxiety" else ['English']
108
  language = st.selectbox("Select a Language", supported_languages, index=0,
109
+ on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
110
  )
111
 
112
  with st.expander("Model A"):
113
  temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1)
114
+ sourceA = st.selectbox("Select a source Model A", SOURCES, index=0,
115
+ format_func=source2label
116
  )
117
  with st.expander("Model B"):
118
  temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1)
119
+ sourceB = st.selectbox("Select a source Model B", SOURCES, index=1,
120
+ format_func=source2label
121
  )
122
 
123
  sbcol1, sbcol2 = st.columns(2)
 
131
  # regenB = sbcol2.button("Regenerate B", on_click=regenerateB)
132
  clear = st.button("Clear History", on_click=clear_memory, args=(memories,))
133
 
134
+ memories = {
135
+ 'memoryA': {"issue": issue, "source": sourceA},
136
+ 'memoryB': {"issue": issue, "source": sourceB},
137
+ 'commonMemory': {"issue": issue, "source": SOURCES[0]}
138
+ }
139
+ changed_source = any([
140
+ st.session_state['previous_sourceA'] != sourceA,
141
+ st.session_state['previous_sourceB'] != sourceB
142
+ ])
143
+ create_memory_add_initial_message(memories, username, language, changed_source=changed_source)
144
+ memoryA = st.session_state[list(memories.keys())[0]]
145
+ memoryB = st.session_state[list(memories.keys())[1]]
146
+ llm_chainA, stopperA = get_chain(issue, language, sourceA, memoryA, temperatureA)
147
+ llm_chainB, stopperB = get_chain(issue, language, sourceB, memoryB, temperatureB)
148
 
149
  st.title(f"💬 History")
150
  for msg in st.session_state['commonMemory'].buffer_as_messages:
 
171
  return True
172
 
173
  if prompt := st.chat_input(disabled=disable_chat()):
174
+ if 'convo_id' not in st.session_state:
175
+ push_convo2db(memories, username, language)
176
+
177
+ promt_ts = dt.datetime.now(tz=dt.timezone.utc)
178
  col1.chat_message("user").write(prompt)
179
  col2.chat_message("user").write(prompt)
180
 
181
+ responseA = llm_chainA.predict(input=prompt, stop=stopperA)
182
+ responseB = llm_chainB.predict(input=prompt, stop=stopperB)
183
+ completion_ts = dt.datetime.now(tz=dt.timezone.utc)
184
 
185
+ new_comparison(st.session_state['db_client'], promt_ts, completion_ts,
186
+ st.session_state['commonMemory'].buffer_as_str, prompt, responseA, responseB)
187
+
188
  col1.chat_message("assistant").write(responseA)
189
  col2.chat_message("assistant").write(responseB)
190
 
pages/manual_comparisor.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import random
4
+ import datetime as dt
5
+ import streamlit as st
6
+ from streamlit.logger import get_logger
7
+ from langchain.schema.messages import HumanMessage
8
+ from mongo_utils import get_db_client, new_battle_result, get_non_assesed_comparison, new_completion_error
9
+ from app_config import ISSUES, SOURCES
10
+
11
+ logger = get_logger(__name__)
12
+ openai_api_key = os.environ['OPENAI_API_KEY']
13
+ if 'db_client' not in st.session_state:
14
+ st.session_state["db_client"] = get_db_client()
15
+
16
+ def disable_buttons():
17
+ return len(comparison) == 0
18
+
19
+ def replaceA():
20
+ new_battle_result(st.session_state['db_client'],
21
+ st.session_state['comparison_id'],
22
+ st.session_state['convo_id'],
23
+ username, sourceA, sourceB, winner='model_two'
24
+ )
25
+
26
+ def replaceB():
27
+ new_battle_result(st.session_state['db_client'],
28
+ st.session_state['comparison_id'],
29
+ st.session_state['convo_id'],
30
+ username, sourceA, sourceB, winner='model_one'
31
+ )
32
+
33
+ def regenerateBoth():
34
+ new_battle_result(st.session_state['db_client'],
35
+ st.session_state['comparison_id'],
36
+ st.session_state['convo_id'],
37
+ username, sourceA, sourceB, winner='both_bad'
38
+ )
39
+
40
+ def bothGood():
41
+ new_battle_result(st.session_state['db_client'],
42
+ st.session_state['comparison_id'],
43
+ st.session_state['convo_id'],
44
+ username, sourceA, sourceB, winner='tie'
45
+ )
46
+
47
+ def error2db(model):
48
+ new_completion_error(st.session_state['db_client'],
49
+ st.session_state['comparison_id'],
50
+ username, model
51
+ )
52
+
53
+ def error2dbA():
54
+ error2db(sourceA)
55
+
56
+ def error2dbA():
57
+ error2db(sourceB)
58
+
59
+ with st.sidebar:
60
+ username = st.text_input("Username", value='ivnban-ctl', max_chars=30)
61
+
62
+ comparison = get_non_assesed_comparison(st.session_state["db_client"], username)
63
+
64
+ with st.sidebar:
65
+
66
+ sbcol1, sbcol2 = st.columns(2)
67
+ beta = sbcol1.button("A is better", on_click=replaceB, disabled=disable_buttons())
68
+ betb = sbcol2.button("B is better", on_click=replaceA, disabled=disable_buttons())
69
+
70
+ same = sbcol1.button("Tie", on_click=bothGood, disabled=disable_buttons())
71
+ bbad = sbcol2.button("Both are bad", on_click=regenerateBoth, disabled=disable_buttons())
72
+
73
+ errorA = sbcol1.button("Error in A", on_click=error2dbA, disabled=disable_buttons())
74
+ errorB = sbcol2.button("Error in B", on_click=error2dbA, disabled=disable_buttons())
75
+
76
+ if len(comparison) > 0:
77
+
78
+ st.session_state['comparison_id'] = comparison[0]["_id"]
79
+ st.session_state['convo_id'] = comparison[0]["convo_id"]
80
+ st.session_state["disabled_buttons"] = False
81
+
82
+ st.sidebar.text_input("Issue", value=comparison[0]['convo_info'][0]['issue'], disabled=True)
83
+
84
+ st.title(f"💬 History")
85
+
86
+ for msg in comparison[0]['chat_history'].split("\n"):
87
+ parts = msg.split(":")
88
+ if len(parts) > 1:
89
+ role = "user" if parts[0] == 'helper' else "assistant"
90
+ st.chat_message(role).write(parts[1])
91
+
92
+ col1, col2 = st.columns(2)
93
+ col1.title(f"💬 Simulator A")
94
+ col2.title(f"💬 Simulator B")
95
+
96
+ selectedA = random.choice(['model_one', 'model_two'])
97
+ selectedB = "model_two" if selectedA == "model_one" else "model_one"
98
+ logger.info(f"selected A is {selectedA} and B is {selectedB}")
99
+ sourceA = comparison[0]['convo_info'][0][selectedA]
100
+ sourceB = comparison[0]['convo_info'][0][selectedB]
101
+ col1.chat_message("user").write(comparison[0]["prompt"])
102
+ col2.chat_message("user").write(comparison[0]["prompt"])
103
+
104
+ col1.chat_message("assistant").write(comparison[0][f"compeltion_{selectedA}"])
105
+ col2.chat_message("assistant").write(comparison[0][f"compeltion_{selectedB}"])
106
+
107
+ else:
108
+ st.write("No Comparisons left to Check")
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  scipy==1.11.1
2
  openai==0.28.0
3
- langchain==0.0.281
 
 
1
  scipy==1.11.1
2
  openai==0.28.0
3
+ langchain==0.0.281
4
+ pymongo==4.5.0