File size: 5,134 Bytes
5832f57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import openai 
import os
import streamlit as st
from langchain.schema.messages import HumanMessage
import logging

from utils import create_memory_add_initial_message, clear_memory, get_chain

openai_api_key = os.environ['OPENAI_API_KEY'] 
memories = ['memoryA', 'memoryB', 'commonMemory']


def delete_last_message(memory):
    last_prompt = memory.chat_memory.messages[-2].content
    memory.chat_memory.messages = memory.chat_memory.messages[:-2]
    return last_prompt

def replace_last_message(memory, new_message):
    memory.chat_memory.messages = memory.chat_memory.messages[:-1]
    memory.chat_memory.add_ai_message(new_message)

def regenerateA():
    last_prompt = delete_last_message(st.session_state[memories[0]])
    new_response = llm_chainA.predict(input=last_prompt, stop="helper:")
    col1.chat_message("user").write(last_prompt)
    col1.chat_message("assistant").write(new_response)

def regenerateB():
    last_prompt = delete_last_message(st.session_state[memories[1]])
    new_response = llm_chainB.predict(input=last_prompt, stop="helper:")
    col2.chat_message("user").write(last_prompt)
    col2.chat_message("assistant").write(new_response)

def replaceA():
    last_prompt = st.session_state[memories[1]].chat_memory.messages[-2].content
    new_message = st.session_state[memories[1]].chat_memory.messages[-1].content
    replace_last_message(st.session_state[memories[0]], new_message)
    st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})

def replaceB():
    last_prompt = st.session_state[memories[0]].chat_memory.messages[-2].content
    new_message = st.session_state[memories[0]].chat_memory.messages[-1].content
    replace_last_message(st.session_state[memories[1]], new_message)
    st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})

def regenerateBoth():
    regenerateA()
    regenerateB()

def bothGood():
    if len(st.session_state['memoryA'].buffer_as_messages) == 1:
        pass
    else:
        last_prompt = st.session_state[memories[0]].chat_memory.messages[-2].content
        last_reponse = st.session_state[memories[0]].chat_memory.messages[-1].content
        st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":last_reponse})

with st.sidebar:
    issue = st.selectbox("Select an Issue", ['Anxiety','Suicide'], index=0,
                            on_change=clear_memory, args=(memories,)
                        )
    supported_languages = ['English', "Spanish"] if issue == "Anxiety" else ['English']
    language = st.selectbox("Select a Language", supported_languages, index=0,
                            on_change=clear_memory, args=(memories,)
                        )
    
    with st.expander("Model A"):
        temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1)
        sourceA = st.selectbox("Select a source Model A", ['OpenAI GPT3.5','Finetuned OpenAI'], index=0,
                                on_change=clear_memory, args=(memories,)
                            )
    with st.expander("Model B"):
        temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1)
        sourceB = st.selectbox("Select a source Model B", ['OpenAI GPT3.5','Finetuned OpenAI'], index=1,
                                on_change=clear_memory, args=(memories,)
                            )
        
    sbcol1, sbcol2 = st.columns(2)
    beta = sbcol1.button("A is better", on_click=replaceB)
    betb = sbcol2.button("B is better", on_click=replaceA)

    same = sbcol1.button("Tie", on_click=bothGood)
    bbad = sbcol2.button("Both are bad", on_click=regenerateBoth)
    
    # regenA = sbcol1.button("Regenerate A", on_click=regenerateA)
    # regenB = sbcol2.button("Regenerate B", on_click=regenerateB)
    clear = st.button("Clear History", on_click=clear_memory, args=(memories,))

create_memory_add_initial_message(memories, language)
llm_chainA = get_chain(issue, language, sourceA, st.session_state[memories[0]], temperatureA)
llm_chainB = get_chain(issue, language, sourceB, st.session_state[memories[1]], temperatureB)

st.title(f"💬 History") 
for msg in st.session_state['commonMemory'].buffer_as_messages:
    role = "user" if type(msg) == HumanMessage else "assistant"
    st.chat_message(role).write(msg.content)


col1, col2 = st.columns(2)
col1.title(f"💬 Simulator A") 
col2.title(f"💬 Simulator B") 

def reset_buttons():
    buttons = [beta, betb, same, bbad, 
               #regenA, regenB
               ]
    for but in buttons:
        but = False

def disable_chat():
    buttons = [beta, betb, same, bbad]
    if any(buttons):
        return False
    else:
        return True

if prompt := st.chat_input(disabled=disable_chat()):
    col1.chat_message("user").write(prompt)
    col2.chat_message("user").write(prompt)

    responseA = llm_chainA.predict(input=prompt, stop="helper:")
    responseB = llm_chainB.predict(input=prompt, stop="helper:")

    col1.chat_message("assistant").write(responseA)
    col2.chat_message("assistant").write(responseB)

    reset_buttons()