convosim-ui / app_utils.py
ivnban27-ctl's picture
Added MongoDB functionality (#1)
975a927
raw
history blame
2.69 kB
import datetime as dt
import streamlit as st
from streamlit.logger import get_logger
import langchain
from langchain.memory import ConversationBufferMemory
from app_config import ENVIRON
from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
from models.openai.role_models import get_role_chain, role_templates
from mongo_utils import new_convo
langchain.verbose = ENVIRON=="dev"
logger = get_logger(__name__)
def add_initial_message(model_name, memory):
if "Spanish" in model_name:
memory.chat_memory.add_ai_message("Hola necesito ayuda")
else:
memory.chat_memory.add_ai_message("Hi I need help")
def push_convo2db(memories, username, language):
if len(memories) == 1:
issue = memories['memory']['issue']
model_one = memories['memory']['source']
new_convo(st.session_state['db_client'], issue, language, username, False, model_one)
else:
issue = memories['commonMemory']['issue']
model_one = memories['memoryA']['source']
model_two = memories['memoryB']['source']
new_convo(st.session_state['db_client'], issue, language, username, True, model_one, model_two)
def change_memories(memories, username, language, changed_source=False):
for memory, params in memories.items():
if (memory not in st.session_state) or changed_source:
source = params['source']
logger.info(f"Source for memory {memory} is {source}")
if source in ('OA_rolemodel','OA_finetuned'):
st.session_state[memory] = ConversationBufferMemory(ai_prefix='texter', human_prefix='helper')
if ("convo_id" in st.session_state) and changed_source:
del st.session_state['convo_id']
def clear_memory(memories, username, language):
for memory, _ in memories.items():
st.session_state[memory].clear()
if "convo_id" in st.session_state:
del st.session_state['convo_id']
def create_memory_add_initial_message(memories, username, language, changed_source=False):
change_memories(memories, username, language, changed_source=changed_source)
for memory, _ in memories.items():
if len(st.session_state[memory].buffer_as_messages) < 1:
add_initial_message(language, st.session_state[memory])
def get_chain(issue, language, source, memory, temperature):
if source in ("OA_finetuned"):
OA_engine = finetuned_models[f"{issue}-{language}"]
return get_finetuned_chain(OA_engine, memory, temperature)
elif source in ('OA_rolemodel'):
template = role_templates[f"{issue}-{language}"]
return get_role_chain(template, memory, temperature)