import openai import os import streamlit as st from langchain.schema.messages import HumanMessage import logging from utils import create_memory_add_initial_message, clear_memory, get_chain openai_api_key = os.environ['OPENAI_API_KEY'] memories = ['memoryA', 'memoryB', 'commonMemory'] def delete_last_message(memory): last_prompt = memory.chat_memory.messages[-2].content memory.chat_memory.messages = memory.chat_memory.messages[:-2] return last_prompt def replace_last_message(memory, new_message): memory.chat_memory.messages = memory.chat_memory.messages[:-1] memory.chat_memory.add_ai_message(new_message) def regenerateA(): last_prompt = delete_last_message(st.session_state[memories[0]]) new_response = llm_chainA.predict(input=last_prompt, stop="helper:") col1.chat_message("user").write(last_prompt) col1.chat_message("assistant").write(new_response) def regenerateB(): last_prompt = delete_last_message(st.session_state[memories[1]]) new_response = llm_chainB.predict(input=last_prompt, stop="helper:") col2.chat_message("user").write(last_prompt) col2.chat_message("assistant").write(new_response) def replaceA(): last_prompt = st.session_state[memories[1]].chat_memory.messages[-2].content new_message = st.session_state[memories[1]].chat_memory.messages[-1].content replace_last_message(st.session_state[memories[0]], new_message) st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message}) def replaceB(): last_prompt = st.session_state[memories[0]].chat_memory.messages[-2].content new_message = st.session_state[memories[0]].chat_memory.messages[-1].content replace_last_message(st.session_state[memories[1]], new_message) st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message}) def regenerateBoth(): regenerateA() regenerateB() def bothGood(): if len(st.session_state['memoryA'].buffer_as_messages) == 1: pass else: last_prompt = st.session_state[memories[0]].chat_memory.messages[-2].content last_reponse = st.session_state[memories[0]].chat_memory.messages[-1].content st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":last_reponse}) with st.sidebar: issue = st.selectbox("Select an Issue", ['Anxiety','Suicide'], index=0, on_change=clear_memory, args=(memories,) ) supported_languages = ['English', "Spanish"] if issue == "Anxiety" else ['English'] language = st.selectbox("Select a Language", supported_languages, index=0, on_change=clear_memory, args=(memories,) ) with st.expander("Model A"): temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1) sourceA = st.selectbox("Select a source Model A", ['OpenAI GPT3.5','Finetuned OpenAI'], index=0, on_change=clear_memory, args=(memories,) ) with st.expander("Model B"): temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1) sourceB = st.selectbox("Select a source Model B", ['OpenAI GPT3.5','Finetuned OpenAI'], index=1, on_change=clear_memory, args=(memories,) ) sbcol1, sbcol2 = st.columns(2) beta = sbcol1.button("A is better", on_click=replaceB) betb = sbcol2.button("B is better", on_click=replaceA) same = sbcol1.button("Tie", on_click=bothGood) bbad = sbcol2.button("Both are bad", on_click=regenerateBoth) # regenA = sbcol1.button("Regenerate A", on_click=regenerateA) # regenB = sbcol2.button("Regenerate B", on_click=regenerateB) clear = st.button("Clear History", on_click=clear_memory, args=(memories,)) create_memory_add_initial_message(memories, language) llm_chainA = get_chain(issue, language, sourceA, st.session_state[memories[0]], temperatureA) llm_chainB = get_chain(issue, language, sourceB, st.session_state[memories[1]], temperatureB) st.title(f"💬 History") for msg in st.session_state['commonMemory'].buffer_as_messages: role = "user" if type(msg) == HumanMessage else "assistant" st.chat_message(role).write(msg.content) col1, col2 = st.columns(2) col1.title(f"💬 Simulator A") col2.title(f"💬 Simulator B") def reset_buttons(): buttons = [beta, betb, same, bbad, #regenA, regenB ] for but in buttons: but = False def disable_chat(): buttons = [beta, betb, same, bbad] if any(buttons): return False else: return True if prompt := st.chat_input(disabled=disable_chat()): col1.chat_message("user").write(prompt) col2.chat_message("user").write(prompt) responseA = llm_chainA.predict(input=prompt, stop="helper:") responseB = llm_chainB.predict(input=prompt, stop="helper:") col1.chat_message("assistant").write(responseA) col2.chat_message("assistant").write(responseB) reset_buttons()