convosim-ui / pages /comparisor.py
ivnban27-ctl's picture
Added MongoDB functionality (#1)
975a927
raw
history blame
No virus
7.83 kB
import os
import random
import datetime as dt
import streamlit as st
from streamlit.logger import get_logger
from langchain.schema.messages import HumanMessage
from mongo_utils import get_db_client, new_comparison, new_battle_result
from app_utils import create_memory_add_initial_message, clear_memory, get_chain, push_convo2db
from app_config import ISSUES, SOURCES, source2label
logger = get_logger(__name__)
openai_api_key = os.environ['OPENAI_API_KEY']
memories = {
'memoryA': {"issue": ISSUES[0], "source": SOURCES[0]},
'memoryB': {"issue": ISSUES[0], "source": SOURCES[1]},
'commonMemory': {"issue": ISSUES[0], "source": SOURCES[0]}
}
if 'db_client' not in st.session_state:
st.session_state["db_client"] = get_db_client()
if 'previous_sourceA' not in st.session_state:
st.session_state['previous_sourceA'] = SOURCES[0]
if 'previous_sourceB' not in st.session_state:
st.session_state['previous_sourceB'] = SOURCES[1]
def delete_last_message(memory):
last_prompt = memory.chat_memory.messages[-2].content
memory.chat_memory.messages = memory.chat_memory.messages[:-2]
return last_prompt
def replace_last_message(memory, new_message):
memory.chat_memory.messages = memory.chat_memory.messages[:-1]
memory.chat_memory.add_ai_message(new_message)
def regenerateA():
last_prompt = delete_last_message(memoryA)
new_response = llm_chainA.predict(input=last_prompt, stop=stopperA)
col1.chat_message("user").write(last_prompt)
col1.chat_message("assistant").write(new_response)
return new_response
def regenerateB():
last_prompt = delete_last_message(memoryB)
new_response = llm_chainB.predict(input=last_prompt, stop=stopperB)
col2.chat_message("user").write(last_prompt)
col2.chat_message("assistant").write(new_response)
return new_response
def replaceA():
last_prompt = memoryB.chat_memory.messages[-2].content
new_message = memoryB.chat_memory.messages[-1].content
replace_last_message(memoryA, new_message)
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='model_two'
)
def replaceB():
last_prompt = memoryA.chat_memory.messages[-2].content
new_message = memoryA.chat_memory.messages[-1].content
replace_last_message(memoryB, new_message)
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='model_one'
)
def regenerateBoth():
promt_ts = dt.datetime.now(tz=dt.timezone.utc)
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='both_bad'
)
responseA = regenerateA()
responseB = regenerateB()
completion_ts = dt.datetime.now(tz=dt.timezone.utc)
new_comparison(st.session_state['db_client'], promt_ts, completion_ts,
st.session_state['commonMemory'].buffer_as_str, prompt, responseA, responseB)
def bothGood():
if len(memoryA.buffer_as_messages) == 1:
pass
else:
i = random.choice([memoryA, memoryB])
last_prompt = i.chat_memory.messages[-2].content
last_reponse = i.chat_memory.messages[-1].content
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":last_reponse})
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='tie'
)
with st.sidebar:
username = st.text_input("Username", value='ivnban-ctl', max_chars=30)
issue = st.selectbox("Select an Issue", ISSUES, index=0,
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
)
supported_languages = ['English', "Spanish"] if issue == "Anxiety" else ['English']
language = st.selectbox("Select a Language", supported_languages, index=0,
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
)
with st.expander("Model A"):
temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1)
sourceA = st.selectbox("Select a source Model A", SOURCES, index=0,
format_func=source2label
)
with st.expander("Model B"):
temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1)
sourceB = st.selectbox("Select a source Model B", SOURCES, index=1,
format_func=source2label
)
sbcol1, sbcol2 = st.columns(2)
beta = sbcol1.button("A is better", on_click=replaceB)
betb = sbcol2.button("B is better", on_click=replaceA)
same = sbcol1.button("Tie", on_click=bothGood)
bbad = sbcol2.button("Both are bad", on_click=regenerateBoth)
# regenA = sbcol1.button("Regenerate A", on_click=regenerateA)
# regenB = sbcol2.button("Regenerate B", on_click=regenerateB)
clear = st.button("Clear History", on_click=clear_memory, args=(memories,))
memories = {
'memoryA': {"issue": issue, "source": sourceA},
'memoryB': {"issue": issue, "source": sourceB},
'commonMemory': {"issue": issue, "source": SOURCES[0]}
}
changed_source = any([
st.session_state['previous_sourceA'] != sourceA,
st.session_state['previous_sourceB'] != sourceB
])
create_memory_add_initial_message(memories, username, language, changed_source=changed_source)
memoryA = st.session_state[list(memories.keys())[0]]
memoryB = st.session_state[list(memories.keys())[1]]
llm_chainA, stopperA = get_chain(issue, language, sourceA, memoryA, temperatureA)
llm_chainB, stopperB = get_chain(issue, language, sourceB, memoryB, temperatureB)
st.title(f"πŸ’¬ History")
for msg in st.session_state['commonMemory'].buffer_as_messages:
role = "user" if type(msg) == HumanMessage else "assistant"
st.chat_message(role).write(msg.content)
col1, col2 = st.columns(2)
col1.title(f"πŸ’¬ Simulator A")
col2.title(f"πŸ’¬ Simulator B")
def reset_buttons():
buttons = [beta, betb, same, bbad,
#regenA, regenB
]
for but in buttons:
but = False
def disable_chat():
buttons = [beta, betb, same, bbad]
if any(buttons):
return False
else:
return True
if prompt := st.chat_input(disabled=disable_chat()):
if 'convo_id' not in st.session_state:
push_convo2db(memories, username, language)
promt_ts = dt.datetime.now(tz=dt.timezone.utc)
col1.chat_message("user").write(prompt)
col2.chat_message("user").write(prompt)
responseA = llm_chainA.predict(input=prompt, stop=stopperA)
responseB = llm_chainB.predict(input=prompt, stop=stopperB)
completion_ts = dt.datetime.now(tz=dt.timezone.utc)
new_comparison(st.session_state['db_client'], promt_ts, completion_ts,
st.session_state['commonMemory'].buffer_as_str, prompt, responseA, responseB)
col1.chat_message("assistant").write(responseA)
col2.chat_message("assistant").write(responseB)
reset_buttons()