Spaces:
Running
Running
import datetime as dt | |
import streamlit as st | |
from streamlit.logger import get_logger | |
import langchain | |
from langchain.memory import ConversationBufferMemory | |
from app_config import ENVIRON | |
from models.openai.finetuned_models import finetuned_models, get_finetuned_chain | |
from models.openai.role_models import get_role_chain, role_templates | |
from mongo_utils import new_convo | |
langchain.verbose = ENVIRON=="dev" | |
logger = get_logger(__name__) | |
def add_initial_message(model_name, memory): | |
if "Spanish" in model_name: | |
memory.chat_memory.add_ai_message("Hola necesito ayuda") | |
else: | |
memory.chat_memory.add_ai_message("Hi I need help") | |
def push_convo2db(memories, username, language): | |
if len(memories) == 1: | |
issue = memories['memory']['issue'] | |
model_one = memories['memory']['source'] | |
new_convo(st.session_state['db_client'], issue, language, username, False, model_one) | |
else: | |
issue = memories['commonMemory']['issue'] | |
model_one = memories['memoryA']['source'] | |
model_two = memories['memoryB']['source'] | |
new_convo(st.session_state['db_client'], issue, language, username, True, model_one, model_two) | |
def change_memories(memories, username, language, changed_source=False): | |
for memory, params in memories.items(): | |
if (memory not in st.session_state) or changed_source: | |
source = params['source'] | |
logger.info(f"Source for memory {memory} is {source}") | |
if source in ('OA_rolemodel','OA_finetuned'): | |
st.session_state[memory] = ConversationBufferMemory(ai_prefix='texter', human_prefix='helper') | |
if ("convo_id" in st.session_state) and changed_source: | |
del st.session_state['convo_id'] | |
def clear_memory(memories, username, language): | |
for memory, _ in memories.items(): | |
st.session_state[memory].clear() | |
if "convo_id" in st.session_state: | |
del st.session_state['convo_id'] | |
def create_memory_add_initial_message(memories, username, language, changed_source=False): | |
change_memories(memories, username, language, changed_source=changed_source) | |
for memory, _ in memories.items(): | |
if len(st.session_state[memory].buffer_as_messages) < 1: | |
add_initial_message(language, st.session_state[memory]) | |
def get_chain(issue, language, source, memory, temperature): | |
if source in ("OA_finetuned"): | |
OA_engine = finetuned_models[f"{issue}-{language}"] | |
return get_finetuned_chain(OA_engine, memory, temperature) | |
elif source in ('OA_rolemodel'): | |
template = role_templates[f"{issue}-{language}"] | |
return get_role_chain(template, memory, temperature) |