convosim-ui / pages /comparisor.py
ivnban27-ctl's picture
first commit openai simulators
5832f57
raw
history blame
No virus
5.13 kB
import openai
import os
import streamlit as st
from langchain.schema.messages import HumanMessage
import logging
from utils import create_memory_add_initial_message, clear_memory, get_chain
openai_api_key = os.environ['OPENAI_API_KEY']
memories = ['memoryA', 'memoryB', 'commonMemory']
def delete_last_message(memory):
last_prompt = memory.chat_memory.messages[-2].content
memory.chat_memory.messages = memory.chat_memory.messages[:-2]
return last_prompt
def replace_last_message(memory, new_message):
memory.chat_memory.messages = memory.chat_memory.messages[:-1]
memory.chat_memory.add_ai_message(new_message)
def regenerateA():
last_prompt = delete_last_message(st.session_state[memories[0]])
new_response = llm_chainA.predict(input=last_prompt, stop="helper:")
col1.chat_message("user").write(last_prompt)
col1.chat_message("assistant").write(new_response)
def regenerateB():
last_prompt = delete_last_message(st.session_state[memories[1]])
new_response = llm_chainB.predict(input=last_prompt, stop="helper:")
col2.chat_message("user").write(last_prompt)
col2.chat_message("assistant").write(new_response)
def replaceA():
last_prompt = st.session_state[memories[1]].chat_memory.messages[-2].content
new_message = st.session_state[memories[1]].chat_memory.messages[-1].content
replace_last_message(st.session_state[memories[0]], new_message)
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
def replaceB():
last_prompt = st.session_state[memories[0]].chat_memory.messages[-2].content
new_message = st.session_state[memories[0]].chat_memory.messages[-1].content
replace_last_message(st.session_state[memories[1]], new_message)
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
def regenerateBoth():
regenerateA()
regenerateB()
def bothGood():
if len(st.session_state['memoryA'].buffer_as_messages) == 1:
pass
else:
last_prompt = st.session_state[memories[0]].chat_memory.messages[-2].content
last_reponse = st.session_state[memories[0]].chat_memory.messages[-1].content
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":last_reponse})
with st.sidebar:
issue = st.selectbox("Select an Issue", ['Anxiety','Suicide'], index=0,
on_change=clear_memory, args=(memories,)
)
supported_languages = ['English', "Spanish"] if issue == "Anxiety" else ['English']
language = st.selectbox("Select a Language", supported_languages, index=0,
on_change=clear_memory, args=(memories,)
)
with st.expander("Model A"):
temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1)
sourceA = st.selectbox("Select a source Model A", ['OpenAI GPT3.5','Finetuned OpenAI'], index=0,
on_change=clear_memory, args=(memories,)
)
with st.expander("Model B"):
temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1)
sourceB = st.selectbox("Select a source Model B", ['OpenAI GPT3.5','Finetuned OpenAI'], index=1,
on_change=clear_memory, args=(memories,)
)
sbcol1, sbcol2 = st.columns(2)
beta = sbcol1.button("A is better", on_click=replaceB)
betb = sbcol2.button("B is better", on_click=replaceA)
same = sbcol1.button("Tie", on_click=bothGood)
bbad = sbcol2.button("Both are bad", on_click=regenerateBoth)
# regenA = sbcol1.button("Regenerate A", on_click=regenerateA)
# regenB = sbcol2.button("Regenerate B", on_click=regenerateB)
clear = st.button("Clear History", on_click=clear_memory, args=(memories,))
create_memory_add_initial_message(memories, language)
llm_chainA = get_chain(issue, language, sourceA, st.session_state[memories[0]], temperatureA)
llm_chainB = get_chain(issue, language, sourceB, st.session_state[memories[1]], temperatureB)
st.title(f"πŸ’¬ History")
for msg in st.session_state['commonMemory'].buffer_as_messages:
role = "user" if type(msg) == HumanMessage else "assistant"
st.chat_message(role).write(msg.content)
col1, col2 = st.columns(2)
col1.title(f"πŸ’¬ Simulator A")
col2.title(f"πŸ’¬ Simulator B")
def reset_buttons():
buttons = [beta, betb, same, bbad,
#regenA, regenB
]
for but in buttons:
but = False
def disable_chat():
buttons = [beta, betb, same, bbad]
if any(buttons):
return False
else:
return True
if prompt := st.chat_input(disabled=disable_chat()):
col1.chat_message("user").write(prompt)
col2.chat_message("user").write(prompt)
responseA = llm_chainA.predict(input=prompt, stop="helper:")
responseB = llm_chainB.predict(input=prompt, stop="helper:")
col1.chat_message("assistant").write(responseA)
col2.chat_message("assistant").write(responseB)
reset_buttons()