Spaces:
Running
Running
ivnban27-ctl
commited on
Commit
•
eda0ce6
1
Parent(s):
59d5667
changes to comparisor on new role models GCT and SP
Browse files- convosim.py +6 -6
- pages/comparisor.py +23 -9
convosim.py
CHANGED
@@ -42,12 +42,12 @@ changed_source = st.session_state['previous_source'] != source
|
|
42 |
if changed_source:
|
43 |
st.session_state["counselor_name"] = get_random_name()
|
44 |
st.session_state["texter_name"] = get_random_name()
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
st.session_state['previous_source'] = source
|
52 |
memoryA = st.session_state[list(memories.keys())[0]]
|
53 |
llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])
|
|
|
42 |
if changed_source:
|
43 |
st.session_state["counselor_name"] = get_random_name()
|
44 |
st.session_state["texter_name"] = get_random_name()
|
45 |
+
create_memory_add_initial_message(memories,
|
46 |
+
issue,
|
47 |
+
language,
|
48 |
+
changed_source=changed_source,
|
49 |
+
counselor_name=st.session_state["counselor_name"],
|
50 |
+
texter_name=st.session_state["texter_name"])
|
51 |
st.session_state['previous_source'] = source
|
52 |
memoryA = st.session_state[list(memories.keys())[0]]
|
53 |
llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])
|
pages/comparisor.py
CHANGED
@@ -6,7 +6,9 @@ import streamlit as st
|
|
6 |
from streamlit.logger import get_logger
|
7 |
from langchain.schema.messages import HumanMessage
|
8 |
from utils.mongo_utils import get_db_client, new_comparison, new_battle_result
|
9 |
-
from utils.app_utils import create_memory_add_initial_message,
|
|
|
|
|
10 |
from app_config import ISSUES, SOURCES, source2label
|
11 |
|
12 |
logger = get_logger(__name__)
|
@@ -21,7 +23,11 @@ if 'db_client' not in st.session_state:
|
|
21 |
if 'previous_sourceA' not in st.session_state:
|
22 |
st.session_state['previous_sourceA'] = SOURCES[0]
|
23 |
if 'previous_sourceB' not in st.session_state:
|
24 |
-
st.session_state['previous_sourceB'] = SOURCES[
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def delete_last_message(memory):
|
27 |
last_prompt = memory.chat_memory.messages[-2].content
|
@@ -104,11 +110,11 @@ with st.sidebar:
|
|
104 |
issue = st.selectbox("Select an Issue", ISSUES, index=0,
|
105 |
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
|
106 |
)
|
107 |
-
supported_languages = ['
|
108 |
language = st.selectbox("Select a Language", supported_languages, index=0,
|
|
|
109 |
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
|
110 |
-
)
|
111 |
-
|
112 |
with st.expander("Model A"):
|
113 |
temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1)
|
114 |
sourceA = st.selectbox("Select a source Model A", SOURCES, index=0,
|
@@ -116,7 +122,7 @@ with st.sidebar:
|
|
116 |
)
|
117 |
with st.expander("Model B"):
|
118 |
temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1)
|
119 |
-
sourceB = st.selectbox("Select a source Model B", SOURCES, index=
|
120 |
format_func=source2label
|
121 |
)
|
122 |
|
@@ -140,11 +146,19 @@ changed_source = any([
|
|
140 |
st.session_state['previous_sourceA'] != sourceA,
|
141 |
st.session_state['previous_sourceB'] != sourceB
|
142 |
])
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
memoryA = st.session_state[list(memories.keys())[0]]
|
145 |
memoryB = st.session_state[list(memories.keys())[1]]
|
146 |
-
llm_chainA, stopperA = get_chain(issue, language, sourceA, memoryA, temperatureA)
|
147 |
-
llm_chainB, stopperB = get_chain(issue, language, sourceB, memoryB, temperatureB)
|
148 |
|
149 |
st.title(f"💬 History")
|
150 |
for msg in st.session_state['commonMemory'].buffer_as_messages:
|
|
|
6 |
from streamlit.logger import get_logger
|
7 |
from langchain.schema.messages import HumanMessage
|
8 |
from utils.mongo_utils import get_db_client, new_comparison, new_battle_result
|
9 |
+
from utils.app_utils import create_memory_add_initial_message, get_random_name
|
10 |
+
from utils.memory_utils import clear_memory, push_convo2db
|
11 |
+
from utils.chain_utils import get_chain
|
12 |
from app_config import ISSUES, SOURCES, source2label
|
13 |
|
14 |
logger = get_logger(__name__)
|
|
|
23 |
if 'previous_sourceA' not in st.session_state:
|
24 |
st.session_state['previous_sourceA'] = SOURCES[0]
|
25 |
if 'previous_sourceB' not in st.session_state:
|
26 |
+
st.session_state['previous_sourceB'] = SOURCES[0]
|
27 |
+
if 'counselor_name' not in st.session_state:
|
28 |
+
st.session_state["counselor_name"] = get_random_name()
|
29 |
+
if 'texter_name' not in st.session_state:
|
30 |
+
st.session_state["texter_name"] = get_random_name()
|
31 |
|
32 |
def delete_last_message(memory):
|
33 |
last_prompt = memory.chat_memory.messages[-2].content
|
|
|
110 |
issue = st.selectbox("Select an Issue", ISSUES, index=0,
|
111 |
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
|
112 |
)
|
113 |
+
supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
|
114 |
language = st.selectbox("Select a Language", supported_languages, index=0,
|
115 |
+
format_func=lambda x: "English" if x=="en" else "Spanish",
|
116 |
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
|
117 |
+
)
|
|
|
118 |
with st.expander("Model A"):
|
119 |
temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1)
|
120 |
sourceA = st.selectbox("Select a source Model A", SOURCES, index=0,
|
|
|
122 |
)
|
123 |
with st.expander("Model B"):
|
124 |
temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1)
|
125 |
+
sourceB = st.selectbox("Select a source Model B", SOURCES, index=0,
|
126 |
format_func=source2label
|
127 |
)
|
128 |
|
|
|
146 |
st.session_state['previous_sourceA'] != sourceA,
|
147 |
st.session_state['previous_sourceB'] != sourceB
|
148 |
])
|
149 |
+
if changed_source:
|
150 |
+
st.session_state["counselor_name"] = get_random_name()
|
151 |
+
st.session_state["texter_name"] = get_random_name()
|
152 |
+
create_memory_add_initial_message(memories,
|
153 |
+
issue,
|
154 |
+
language,
|
155 |
+
changed_source=changed_source,
|
156 |
+
counselor_name=st.session_state["counselor_name"],
|
157 |
+
texter_name=st.session_state["texter_name"])
|
158 |
memoryA = st.session_state[list(memories.keys())[0]]
|
159 |
memoryB = st.session_state[list(memories.keys())[1]]
|
160 |
+
llm_chainA, stopperA = get_chain(issue, language, sourceA, memoryA, temperatureA, texter_name=st.session_state["texter_name"])
|
161 |
+
llm_chainB, stopperB = get_chain(issue, language, sourceB, memoryB, temperatureB, texter_name=st.session_state["texter_name"])
|
162 |
|
163 |
st.title(f"💬 History")
|
164 |
for msg in st.session_state['commonMemory'].buffer_as_messages:
|