Spaces:
Running
Running
File size: 5,134 Bytes
5832f57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import openai
import os
import streamlit as st
from langchain.schema.messages import HumanMessage
import logging
from utils import create_memory_add_initial_message, clear_memory, get_chain
openai_api_key = os.environ['OPENAI_API_KEY']
memories = ['memoryA', 'memoryB', 'commonMemory']
def delete_last_message(memory):
last_prompt = memory.chat_memory.messages[-2].content
memory.chat_memory.messages = memory.chat_memory.messages[:-2]
return last_prompt
def replace_last_message(memory, new_message):
memory.chat_memory.messages = memory.chat_memory.messages[:-1]
memory.chat_memory.add_ai_message(new_message)
def regenerateA():
last_prompt = delete_last_message(st.session_state[memories[0]])
new_response = llm_chainA.predict(input=last_prompt, stop="helper:")
col1.chat_message("user").write(last_prompt)
col1.chat_message("assistant").write(new_response)
def regenerateB():
last_prompt = delete_last_message(st.session_state[memories[1]])
new_response = llm_chainB.predict(input=last_prompt, stop="helper:")
col2.chat_message("user").write(last_prompt)
col2.chat_message("assistant").write(new_response)
def replaceA():
last_prompt = st.session_state[memories[1]].chat_memory.messages[-2].content
new_message = st.session_state[memories[1]].chat_memory.messages[-1].content
replace_last_message(st.session_state[memories[0]], new_message)
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
def replaceB():
last_prompt = st.session_state[memories[0]].chat_memory.messages[-2].content
new_message = st.session_state[memories[0]].chat_memory.messages[-1].content
replace_last_message(st.session_state[memories[1]], new_message)
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
def regenerateBoth():
regenerateA()
regenerateB()
def bothGood():
if len(st.session_state['memoryA'].buffer_as_messages) == 1:
pass
else:
last_prompt = st.session_state[memories[0]].chat_memory.messages[-2].content
last_reponse = st.session_state[memories[0]].chat_memory.messages[-1].content
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":last_reponse})
with st.sidebar:
issue = st.selectbox("Select an Issue", ['Anxiety','Suicide'], index=0,
on_change=clear_memory, args=(memories,)
)
supported_languages = ['English', "Spanish"] if issue == "Anxiety" else ['English']
language = st.selectbox("Select a Language", supported_languages, index=0,
on_change=clear_memory, args=(memories,)
)
with st.expander("Model A"):
temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1)
sourceA = st.selectbox("Select a source Model A", ['OpenAI GPT3.5','Finetuned OpenAI'], index=0,
on_change=clear_memory, args=(memories,)
)
with st.expander("Model B"):
temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1)
sourceB = st.selectbox("Select a source Model B", ['OpenAI GPT3.5','Finetuned OpenAI'], index=1,
on_change=clear_memory, args=(memories,)
)
sbcol1, sbcol2 = st.columns(2)
beta = sbcol1.button("A is better", on_click=replaceB)
betb = sbcol2.button("B is better", on_click=replaceA)
same = sbcol1.button("Tie", on_click=bothGood)
bbad = sbcol2.button("Both are bad", on_click=regenerateBoth)
# regenA = sbcol1.button("Regenerate A", on_click=regenerateA)
# regenB = sbcol2.button("Regenerate B", on_click=regenerateB)
clear = st.button("Clear History", on_click=clear_memory, args=(memories,))
create_memory_add_initial_message(memories, language)
llm_chainA = get_chain(issue, language, sourceA, st.session_state[memories[0]], temperatureA)
llm_chainB = get_chain(issue, language, sourceB, st.session_state[memories[1]], temperatureB)
st.title(f"💬 History")
for msg in st.session_state['commonMemory'].buffer_as_messages:
role = "user" if type(msg) == HumanMessage else "assistant"
st.chat_message(role).write(msg.content)
col1, col2 = st.columns(2)
col1.title(f"💬 Simulator A")
col2.title(f"💬 Simulator B")
def reset_buttons():
buttons = [beta, betb, same, bbad,
#regenA, regenB
]
for but in buttons:
but = False
def disable_chat():
buttons = [beta, betb, same, bbad]
if any(buttons):
return False
else:
return True
if prompt := st.chat_input(disabled=disable_chat()):
col1.chat_message("user").write(prompt)
col2.chat_message("user").write(prompt)
responseA = llm_chainA.predict(input=prompt, stop="helper:")
responseB = llm_chainB.predict(input=prompt, stop="helper:")
col1.chat_message("assistant").write(responseA)
col2.chat_message("assistant").write(responseB)
reset_buttons() |