Spaces:
Running
Running
import os | |
import random | |
import datetime as dt | |
import streamlit as st | |
from streamlit.logger import get_logger | |
from langchain.schema.messages import HumanMessage | |
from utils.mongo_utils import get_db_client, new_comparison, new_battle_result | |
from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF | |
from utils.memory_utils import clear_memory, push_convo2db | |
from utils.chain_utils import get_chain | |
from app_config import ISSUES, SOURCES, source2label | |
logger = get_logger(__name__) | |
openai_api_key = os.environ['OPENAI_API_KEY'] | |
if "sent_messages" not in st.session_state: | |
st.session_state['sent_messages'] = 0 | |
logger.info(f'sent messages {st.session_state["sent_messages"]}') | |
if "issue" not in st.session_state: | |
st.session_state['issue'] = ISSUES[0] | |
if 'previous_sourceA' not in st.session_state: | |
st.session_state['previous_sourceA'] = SOURCES[0] | |
if 'previous_sourceB' not in st.session_state: | |
st.session_state['previous_sourceB'] = SOURCES[0] | |
memories = { | |
'memoryA': {"issue": st.session_state['issue'], "source": st.session_state['previous_sourceA']}, | |
'memoryB': {"issue": st.session_state['issue'], "source": st.session_state['previous_sourceB']}, | |
'commonMemory': {"issue": st.session_state['issue'], "source": SOURCES[0]} | |
} | |
if 'db_client' not in st.session_state: | |
st.session_state["db_client"] = get_db_client() | |
if 'counselor_name' not in st.session_state: | |
st.session_state["counselor_name"] = get_random_name(names_df=DEFAULT_NAMES_DF) | |
if 'texter_name' not in st.session_state: | |
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF) | |
def delete_last_message(memory): | |
last_prompt = memory.chat_memory.messages[-2].content | |
memory.chat_memory.messages = memory.chat_memory.messages[:-2] | |
return last_prompt | |
def replace_last_message(memory, new_message): | |
memory.chat_memory.messages = memory.chat_memory.messages[:-1] | |
memory.chat_memory.add_ai_message(new_message) | |
def regenerateA(): | |
last_prompt = delete_last_message(memoryA) | |
new_response = llm_chainA.predict(input=last_prompt, stop=stopperA) | |
col1.chat_message("user").write(last_prompt) | |
col1.chat_message("assistant").write(new_response) | |
return new_response | |
def regenerateB(): | |
last_prompt = delete_last_message(memoryB) | |
new_response = llm_chainB.predict(input=last_prompt, stop=stopperB) | |
col2.chat_message("user").write(last_prompt) | |
col2.chat_message("assistant").write(new_response) | |
return new_response | |
def replaceA(): | |
last_prompt = memoryB.chat_memory.messages[-2].content | |
new_message = memoryB.chat_memory.messages[-1].content | |
replace_last_message(memoryA, new_message) | |
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message}) | |
new_battle_result(st.session_state['db_client'], | |
st.session_state['comparison_id'], | |
st.session_state['convo_id'], | |
username, sourceA, sourceB, winner='model_two' | |
) | |
def replaceB(): | |
last_prompt = memoryA.chat_memory.messages[-2].content | |
new_message = memoryA.chat_memory.messages[-1].content | |
replace_last_message(memoryB, new_message) | |
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message}) | |
new_battle_result(st.session_state['db_client'], | |
st.session_state['comparison_id'], | |
st.session_state['convo_id'], | |
username, sourceA, sourceB, winner='model_one' | |
) | |
def regenerateBoth(): | |
promt_ts = dt.datetime.now(tz=dt.timezone.utc) | |
new_battle_result(st.session_state['db_client'], | |
st.session_state['comparison_id'], | |
st.session_state['convo_id'], | |
username, sourceA, sourceB, winner='both_bad' | |
) | |
responseA = regenerateA() | |
responseB = regenerateB() | |
completion_ts = dt.datetime.now(tz=dt.timezone.utc) | |
new_comparison(st.session_state['db_client'], promt_ts, completion_ts, | |
st.session_state['commonMemory'].buffer_as_str, prompt, responseA, responseB) | |
def bothGood(): | |
if st.session_state['sent_messages'] == 0: | |
pass | |
else: | |
i = random.choice([memoryA, memoryB]) | |
last_prompt = i.chat_memory.messages[-2].content | |
last_reponse = i.chat_memory.messages[-1].content | |
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":last_reponse}) | |
new_battle_result(st.session_state['db_client'], | |
st.session_state['comparison_id'], | |
st.session_state['convo_id'], | |
username, sourceA, sourceB, winner='tie' | |
) | |
with st.sidebar: | |
username = st.text_input("Username", value='ivnban-ctl', max_chars=30) | |
issue = st.selectbox("Select an Issue", ISSUES, index=0, | |
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"} | |
) | |
supported_languages = ['en', "es"] if issue == "Anxiety" else ['en'] | |
language = st.selectbox("Select a Language", supported_languages, index=0, | |
format_func=lambda x: "English" if x=="en" else "Spanish", | |
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"} | |
) | |
with st.expander("Model A"): | |
temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1) | |
sourceA = st.selectbox("Select a source Model A", SOURCES, index=0, | |
format_func=source2label | |
) | |
with st.expander("Model B"): | |
temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1) | |
sourceB = st.selectbox("Select a source Model B", SOURCES, index=0, | |
format_func=source2label | |
) | |
st.markdown(f"### Previous Prompt Count: :red[**{st.session_state['sent_messages']}**]") | |
sbcol1, sbcol2 = st.columns(2) | |
beta = sbcol1.button("A is better", on_click=replaceB) | |
betb = sbcol2.button("B is better", on_click=replaceA) | |
same = sbcol1.button("Tie", on_click=bothGood) | |
bbad = sbcol2.button("Both are bad", on_click=regenerateBoth) | |
# regenA = sbcol1.button("Regenerate A", on_click=regenerateA) | |
# regenB = sbcol2.button("Regenerate B", on_click=regenerateB) | |
clear = st.button("Clear History", on_click=clear_memory, kwargs={"memories":memories, "username":username, "language":language}) | |
changed_source = any([ | |
st.session_state['previous_sourceA'] != sourceA, | |
st.session_state['previous_sourceB'] != sourceB, | |
st.session_state['issue'] != issue | |
]) | |
if changed_source: | |
print("changed something") | |
st.session_state["counselor_name"] = get_random_name(names_df=DEFAULT_NAMES_DF) | |
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF) | |
st.session_state['previous_sourceA'] = sourceA | |
st.session_state['previous_sourceB'] = sourceB | |
st.session_state['issue'] = issue | |
st.session_state['sent_messages'] = 0 | |
create_memory_add_initial_message(memories, | |
issue, | |
language, | |
changed_source=changed_source, | |
counselor_name=st.session_state["counselor_name"], | |
texter_name=st.session_state["texter_name"]) | |
memoryA = st.session_state[list(memories.keys())[0]] | |
memoryB = st.session_state[list(memories.keys())[1]] | |
llm_chainA, stopperA = get_chain(issue, language, sourceA, memoryA, temperatureA, texter_name=st.session_state["texter_name"]) | |
llm_chainB, stopperB = get_chain(issue, language, sourceB, memoryB, temperatureB, texter_name=st.session_state["texter_name"]) | |
st.title(f"💬 History") | |
for msg in st.session_state['commonMemory'].buffer_as_messages: | |
role = "user" if type(msg) == HumanMessage else "assistant" | |
st.chat_message(role).write(msg.content) | |
col1, col2 = st.columns(2) | |
col1.title(f"💬 Simulator A") | |
col2.title(f"💬 Simulator B") | |
def reset_buttons(): | |
buttons = [beta, betb, same, bbad, | |
#regenA, regenB | |
] | |
for but in buttons: | |
but = False | |
def disable_chat(): | |
buttons = [beta, betb, same, bbad] | |
if any(buttons): | |
return False | |
else: | |
return True | |
if prompt := st.chat_input(disabled=disable_chat()): | |
st.session_state['sent_messages'] += 1 | |
if 'convo_id' not in st.session_state: | |
push_convo2db(memories, username, language) | |
promt_ts = dt.datetime.now(tz=dt.timezone.utc) | |
col1.chat_message("user").write(prompt) | |
col2.chat_message("user").write(prompt) | |
responseA = llm_chainA.predict(input=prompt, stop=stopperA) | |
responseB = llm_chainB.predict(input=prompt, stop=stopperB) | |
completion_ts = dt.datetime.now(tz=dt.timezone.utc) | |
new_comparison(st.session_state['db_client'], promt_ts, completion_ts, | |
st.session_state['commonMemory'].buffer_as_str, prompt, responseA, responseB) | |
col1.chat_message("assistant").write(responseA) | |
col2.chat_message("assistant").write(responseB) | |
reset_buttons() |