import os import streamlit as st from streamlit.logger import get_logger from langchain.schema.messages import HumanMessage from utils.mongo_utils import get_db_client from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF from utils.memory_utils import clear_memory, push_convo2db from utils.chain_utils import get_chain, custom_chain_predict from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT logger = get_logger(__name__) openai_api_key = os.environ['OPENAI_API_KEY'] temperature = 0.8 # username = "barb-chase" #"ivnban-ctl" if "sent_messages" not in st.session_state: st.session_state['sent_messages'] = 0 if "total_messages" not in st.session_state: st.session_state['total_messages'] = 0 if "issue" not in st.session_state: st.session_state['issue'] = ISSUES[0] if 'previous_source' not in st.session_state: st.session_state['previous_source'] = SOURCES[0] if 'db_client' not in st.session_state: st.session_state["db_client"] = get_db_client() if 'texter_name' not in st.session_state: st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF) logger.debug(f"texter name is {st.session_state['texter_name']}") memories = {'memory':{"issue": st.session_state['issue'], "source": st.session_state['previous_source']}} with st.sidebar: username = st.text_input("Username", value='Dani', max_chars=30) if 'counselor_name' not in st.session_state: st.session_state["counselor_name"] = username #get_random_name(names_df=DEFAULT_NAMES_DF) # temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1) issue = st.selectbox("Select a Scenario", ISSUES, index=0, format_func=issue2label, on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"} ) supported_languages = ['en', "es"] if issue == "Anxiety" else ['en'] language = st.selectbox("Select a Language", supported_languages, index=0, format_func=lambda x: "English" if x=="en" else "Spanish", on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"} ) source = st.selectbox("Select a source Model A", SOURCES, index=0, format_func=source2label, ) changed_source = any([ st.session_state['previous_source'] != source, st.session_state['issue'] != issue, st.session_state['counselor_name'] != username, ]) if changed_source: st.session_state["counselor_name"] = username st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF) logger.debug(f"texter name is {st.session_state['texter_name']}") st.session_state['previous_source'] = source st.session_state['issue'] = issue st.session_state['sent_messages'] = 0 st.session_state['total_messages'] = 0 create_memory_add_initial_message(memories, issue, language, changed_source=changed_source, counselor_name=st.session_state["counselor_name"], texter_name=st.session_state["texter_name"]) st.session_state['previous_source'] = source memoryA = st.session_state[list(memories.keys())[0]] # issue only without "." marker for model compatibility llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"]) st.title("💬 Simulator") st.session_state['total_messages'] = len(memoryA.chat_memory.messages) for msg in memoryA.buffer_as_messages: role = "user" if type(msg) == HumanMessage else "assistant" st.chat_message(role).write(msg.content) if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG_COUNT - 4): #account for next interaction st.session_state['sent_messages'] += 1 st.chat_message("user").write(prompt) if 'convo_id' not in st.session_state: push_convo2db(memories, username, language) responses = custom_chain_predict(llm_chain, prompt, stopper) # responses = llm_chain.predict(input=prompt, stop=stopper) # response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature) for response in responses: st.chat_message("assistant").write(response) st.session_state['total_messages'] = len(memoryA.chat_memory.messages) if st.session_state['total_messages'] >= MAX_MSG_COUNT: st.toast(f"Total of {MAX_MSG_COUNT} Messages reached. Conversation Ended", icon=":material/verified:") elif st.session_state['total_messages'] >= WARN_MSG_COUT: st.toast(f"The conversation will end at {MAX_MSG_COUNT} Total Messages ", icon=":material/warning:") with st.sidebar: st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]") st.markdown(f"### Total Messages: :red[**{st.session_state['total_messages']}**]")