import os import random import datetime as dt import streamlit as st from streamlit.logger import get_logger from langchain.schema.messages import HumanMessage from mongo_utils import get_db_client, new_comparison, new_battle_result from app_utils import create_memory_add_initial_message, clear_memory, get_chain, push_convo2db from app_config import ISSUES, SOURCES, source2label logger = get_logger(__name__) openai_api_key = os.environ['OPENAI_API_KEY'] memories = { 'memoryA': {"issue": ISSUES[0], "source": SOURCES[0]}, 'memoryB': {"issue": ISSUES[0], "source": SOURCES[1]}, 'commonMemory': {"issue": ISSUES[0], "source": SOURCES[0]} } if 'db_client' not in st.session_state: st.session_state["db_client"] = get_db_client() if 'previous_sourceA' not in st.session_state: st.session_state['previous_sourceA'] = SOURCES[0] if 'previous_sourceB' not in st.session_state: st.session_state['previous_sourceB'] = SOURCES[1] def delete_last_message(memory): last_prompt = memory.chat_memory.messages[-2].content memory.chat_memory.messages = memory.chat_memory.messages[:-2] return last_prompt def replace_last_message(memory, new_message): memory.chat_memory.messages = memory.chat_memory.messages[:-1] memory.chat_memory.add_ai_message(new_message) def regenerateA(): last_prompt = delete_last_message(memoryA) new_response = llm_chainA.predict(input=last_prompt, stop=stopperA) col1.chat_message("user").write(last_prompt) col1.chat_message("assistant").write(new_response) return new_response def regenerateB(): last_prompt = delete_last_message(memoryB) new_response = llm_chainB.predict(input=last_prompt, stop=stopperB) col2.chat_message("user").write(last_prompt) col2.chat_message("assistant").write(new_response) return new_response def replaceA(): last_prompt = memoryB.chat_memory.messages[-2].content new_message = memoryB.chat_memory.messages[-1].content replace_last_message(memoryA, new_message) st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message}) new_battle_result(st.session_state['db_client'], st.session_state['comparison_id'], st.session_state['convo_id'], username, sourceA, sourceB, winner='model_two' ) def replaceB(): last_prompt = memoryA.chat_memory.messages[-2].content new_message = memoryA.chat_memory.messages[-1].content replace_last_message(memoryB, new_message) st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message}) new_battle_result(st.session_state['db_client'], st.session_state['comparison_id'], st.session_state['convo_id'], username, sourceA, sourceB, winner='model_one' ) def regenerateBoth(): promt_ts = dt.datetime.now(tz=dt.timezone.utc) new_battle_result(st.session_state['db_client'], st.session_state['comparison_id'], st.session_state['convo_id'], username, sourceA, sourceB, winner='both_bad' ) responseA = regenerateA() responseB = regenerateB() completion_ts = dt.datetime.now(tz=dt.timezone.utc) new_comparison(st.session_state['db_client'], promt_ts, completion_ts, st.session_state['commonMemory'].buffer_as_str, prompt, responseA, responseB) def bothGood(): if len(memoryA.buffer_as_messages) == 1: pass else: i = random.choice([memoryA, memoryB]) last_prompt = i.chat_memory.messages[-2].content last_reponse = i.chat_memory.messages[-1].content st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":last_reponse}) new_battle_result(st.session_state['db_client'], st.session_state['comparison_id'], st.session_state['convo_id'], username, sourceA, sourceB, winner='tie' ) with st.sidebar: username = st.text_input("Username", value='ivnban-ctl', max_chars=30) issue = st.selectbox("Select an Issue", ISSUES, index=0, on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"} ) supported_languages = ['English', "Spanish"] if issue == "Anxiety" else ['English'] language = st.selectbox("Select a Language", supported_languages, index=0, on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"} ) with st.expander("Model A"): temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1) sourceA = st.selectbox("Select a source Model A", SOURCES, index=0, format_func=source2label ) with st.expander("Model B"): temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1) sourceB = st.selectbox("Select a source Model B", SOURCES, index=1, format_func=source2label ) sbcol1, sbcol2 = st.columns(2) beta = sbcol1.button("A is better", on_click=replaceB) betb = sbcol2.button("B is better", on_click=replaceA) same = sbcol1.button("Tie", on_click=bothGood) bbad = sbcol2.button("Both are bad", on_click=regenerateBoth) # regenA = sbcol1.button("Regenerate A", on_click=regenerateA) # regenB = sbcol2.button("Regenerate B", on_click=regenerateB) clear = st.button("Clear History", on_click=clear_memory, args=(memories,)) memories = { 'memoryA': {"issue": issue, "source": sourceA}, 'memoryB': {"issue": issue, "source": sourceB}, 'commonMemory': {"issue": issue, "source": SOURCES[0]} } changed_source = any([ st.session_state['previous_sourceA'] != sourceA, st.session_state['previous_sourceB'] != sourceB ]) create_memory_add_initial_message(memories, username, language, changed_source=changed_source) memoryA = st.session_state[list(memories.keys())[0]] memoryB = st.session_state[list(memories.keys())[1]] llm_chainA, stopperA = get_chain(issue, language, sourceA, memoryA, temperatureA) llm_chainB, stopperB = get_chain(issue, language, sourceB, memoryB, temperatureB) st.title(f"💬 History") for msg in st.session_state['commonMemory'].buffer_as_messages: role = "user" if type(msg) == HumanMessage else "assistant" st.chat_message(role).write(msg.content) col1, col2 = st.columns(2) col1.title(f"💬 Simulator A") col2.title(f"💬 Simulator B") def reset_buttons(): buttons = [beta, betb, same, bbad, #regenA, regenB ] for but in buttons: but = False def disable_chat(): buttons = [beta, betb, same, bbad] if any(buttons): return False else: return True if prompt := st.chat_input(disabled=disable_chat()): if 'convo_id' not in st.session_state: push_convo2db(memories, username, language) promt_ts = dt.datetime.now(tz=dt.timezone.utc) col1.chat_message("user").write(prompt) col2.chat_message("user").write(prompt) responseA = llm_chainA.predict(input=prompt, stop=stopperA) responseB = llm_chainB.predict(input=prompt, stop=stopperB) completion_ts = dt.datetime.now(tz=dt.timezone.utc) new_comparison(st.session_state['db_client'], promt_ts, completion_ts, st.session_state['commonMemory'].buffer_as_str, prompt, responseA, responseB) col1.chat_message("assistant").write(responseA) col2.chat_message("assistant").write(responseB) reset_buttons()