Spaces:
Sleeping
Sleeping
import openai | |
import os | |
import streamlit as st | |
from langchain.schema.messages import HumanMessage | |
import logging | |
from utils import create_memory_add_initial_message, clear_memory, get_chain | |
openai_api_key = os.environ['OPENAI_API_KEY'] | |
memories = ['memoryA', 'memoryB', 'commonMemory'] | |
def delete_last_message(memory): | |
last_prompt = memory.chat_memory.messages[-2].content | |
memory.chat_memory.messages = memory.chat_memory.messages[:-2] | |
return last_prompt | |
def replace_last_message(memory, new_message): | |
memory.chat_memory.messages = memory.chat_memory.messages[:-1] | |
memory.chat_memory.add_ai_message(new_message) | |
def regenerateA(): | |
last_prompt = delete_last_message(st.session_state[memories[0]]) | |
new_response = llm_chainA.predict(input=last_prompt, stop="helper:") | |
col1.chat_message("user").write(last_prompt) | |
col1.chat_message("assistant").write(new_response) | |
def regenerateB(): | |
last_prompt = delete_last_message(st.session_state[memories[1]]) | |
new_response = llm_chainB.predict(input=last_prompt, stop="helper:") | |
col2.chat_message("user").write(last_prompt) | |
col2.chat_message("assistant").write(new_response) | |
def replaceA(): | |
last_prompt = st.session_state[memories[1]].chat_memory.messages[-2].content | |
new_message = st.session_state[memories[1]].chat_memory.messages[-1].content | |
replace_last_message(st.session_state[memories[0]], new_message) | |
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message}) | |
def replaceB(): | |
last_prompt = st.session_state[memories[0]].chat_memory.messages[-2].content | |
new_message = st.session_state[memories[0]].chat_memory.messages[-1].content | |
replace_last_message(st.session_state[memories[1]], new_message) | |
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message}) | |
def regenerateBoth(): | |
regenerateA() | |
regenerateB() | |
def bothGood(): | |
if len(st.session_state['memoryA'].buffer_as_messages) == 1: | |
pass | |
else: | |
last_prompt = st.session_state[memories[0]].chat_memory.messages[-2].content | |
last_reponse = st.session_state[memories[0]].chat_memory.messages[-1].content | |
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":last_reponse}) | |
with st.sidebar: | |
issue = st.selectbox("Select an Issue", ['Anxiety','Suicide'], index=0, | |
on_change=clear_memory, args=(memories,) | |
) | |
supported_languages = ['English', "Spanish"] if issue == "Anxiety" else ['English'] | |
language = st.selectbox("Select a Language", supported_languages, index=0, | |
on_change=clear_memory, args=(memories,) | |
) | |
with st.expander("Model A"): | |
temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1) | |
sourceA = st.selectbox("Select a source Model A", ['OpenAI GPT3.5','Finetuned OpenAI'], index=0, | |
on_change=clear_memory, args=(memories,) | |
) | |
with st.expander("Model B"): | |
temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1) | |
sourceB = st.selectbox("Select a source Model B", ['OpenAI GPT3.5','Finetuned OpenAI'], index=1, | |
on_change=clear_memory, args=(memories,) | |
) | |
sbcol1, sbcol2 = st.columns(2) | |
beta = sbcol1.button("A is better", on_click=replaceB) | |
betb = sbcol2.button("B is better", on_click=replaceA) | |
same = sbcol1.button("Tie", on_click=bothGood) | |
bbad = sbcol2.button("Both are bad", on_click=regenerateBoth) | |
# regenA = sbcol1.button("Regenerate A", on_click=regenerateA) | |
# regenB = sbcol2.button("Regenerate B", on_click=regenerateB) | |
clear = st.button("Clear History", on_click=clear_memory, args=(memories,)) | |
create_memory_add_initial_message(memories, language) | |
llm_chainA = get_chain(issue, language, sourceA, st.session_state[memories[0]], temperatureA) | |
llm_chainB = get_chain(issue, language, sourceB, st.session_state[memories[1]], temperatureB) | |
st.title(f"π¬ History") | |
for msg in st.session_state['commonMemory'].buffer_as_messages: | |
role = "user" if type(msg) == HumanMessage else "assistant" | |
st.chat_message(role).write(msg.content) | |
col1, col2 = st.columns(2) | |
col1.title(f"π¬ Simulator A") | |
col2.title(f"π¬ Simulator B") | |
def reset_buttons(): | |
buttons = [beta, betb, same, bbad, | |
#regenA, regenB | |
] | |
for but in buttons: | |
but = False | |
def disable_chat(): | |
buttons = [beta, betb, same, bbad] | |
if any(buttons): | |
return False | |
else: | |
return True | |
if prompt := st.chat_input(disabled=disable_chat()): | |
col1.chat_message("user").write(prompt) | |
col2.chat_message("user").write(prompt) | |
responseA = llm_chainA.predict(input=prompt, stop="helper:") | |
responseB = llm_chainB.predict(input=prompt, stop="helper:") | |
col1.chat_message("assistant").write(responseA) | |
col2.chat_message("assistant").write(responseB) | |
reset_buttons() |