Spaces:
Running
Running
import os | |
import random | |
import datetime as dt | |
import streamlit as st | |
from streamlit.logger import get_logger | |
from langchain.schema.messages import HumanMessage | |
from mongo_utils import get_db_client, new_battle_result, get_non_assesed_comparison, new_completion_error | |
from app_config import ISSUES, SOURCES | |
logger = get_logger(__name__) | |
openai_api_key = os.environ['OPENAI_API_KEY'] | |
if 'db_client' not in st.session_state: | |
st.session_state["db_client"] = get_db_client() | |
def disable_buttons(): | |
return len(comparison) == 0 | |
def replaceA(): | |
new_battle_result(st.session_state['db_client'], | |
st.session_state['comparison_id'], | |
st.session_state['convo_id'], | |
username, sourceA, sourceB, winner='model_two' | |
) | |
def replaceB(): | |
new_battle_result(st.session_state['db_client'], | |
st.session_state['comparison_id'], | |
st.session_state['convo_id'], | |
username, sourceA, sourceB, winner='model_one' | |
) | |
def regenerateBoth(): | |
new_battle_result(st.session_state['db_client'], | |
st.session_state['comparison_id'], | |
st.session_state['convo_id'], | |
username, sourceA, sourceB, winner='both_bad' | |
) | |
def bothGood(): | |
new_battle_result(st.session_state['db_client'], | |
st.session_state['comparison_id'], | |
st.session_state['convo_id'], | |
username, sourceA, sourceB, winner='tie' | |
) | |
def error2db(model): | |
logger.info(f"error logged for {model}") | |
new_completion_error(st.session_state['db_client'], | |
st.session_state['comparison_id'], | |
username, model | |
) | |
def error2dbA(): | |
error2db(sourceA) | |
def error2dbB(): | |
error2db(sourceB) | |
with st.sidebar: | |
username = st.text_input("Username", value='ivnban-ctl', max_chars=30) | |
comparison = get_non_assesed_comparison(st.session_state["db_client"], username) | |
with st.sidebar: | |
sbcol1, sbcol2 = st.columns(2) | |
beta = sbcol1.button("A is better", on_click=replaceB, disabled=disable_buttons()) | |
betb = sbcol2.button("B is better", on_click=replaceA, disabled=disable_buttons()) | |
same = sbcol1.button("Tie", on_click=bothGood, disabled=disable_buttons()) | |
bbad = sbcol2.button("Both are bad", on_click=regenerateBoth, disabled=disable_buttons()) | |
errorA = sbcol1.button("Error in A", on_click=error2dbA, disabled=disable_buttons()) | |
errorB = sbcol2.button("Error in B", on_click=error2dbB, disabled=disable_buttons()) | |
if len(comparison) > 0: | |
st.session_state['comparison_id'] = comparison[0]["_id"] | |
st.session_state['convo_id'] = comparison[0]["convo_id"] | |
st.session_state["disabled_buttons"] = False | |
st.sidebar.text_input("Issue", value=comparison[0]['convo_info'][0]['issue'], disabled=True) | |
st.title(f"π¬ History") | |
for msg in comparison[0]['chat_history'].split("\n"): | |
parts = msg.split(":") | |
if len(parts) > 1: | |
role = "user" if parts[0] == 'helper' else "assistant" | |
st.chat_message(role).write(parts[1]) | |
col1, col2 = st.columns(2) | |
col1.title(f"π¬ Simulator A") | |
col2.title(f"π¬ Simulator B") | |
selectedA = random.choice(['model_one', 'model_two']) | |
selectedB = "model_two" if selectedA == "model_one" else "model_one" | |
sourceA = comparison[0]['convo_info'][0][selectedA] | |
sourceB = comparison[0]['convo_info'][0][selectedB] | |
logger.info(f"selected A is {sourceA} and B is {sourceB}") | |
col1.chat_message("user").write(comparison[0]["prompt"]) | |
col2.chat_message("user").write(comparison[0]["prompt"]) | |
col1.chat_message("assistant").write(comparison[0][f"compeltion_{selectedA}"]) | |
col2.chat_message("assistant").write(comparison[0][f"compeltion_{selectedB}"]) | |
else: | |
st.write("No Comparisons left to Check") |