convosim-ui / models /databricks /scenario_sim.py
ivnban27-ctl's picture
v0.5 (#4)
20b3b4a verified
raw
history blame
3.71 kB
import os
import re
import logging
from models.custom_parsers import CustomINSTOutputParser
from utils.app_utils import get_random_name
from app_config import ENDPOINT_NAMES
from langchain.chains import ConversationChain
from langchain_community.llms import Databricks
from langchain.prompts import PromptTemplate
from typing import Any, List, Mapping, Optional, Dict
ISSUE_MAPPING = {
"anxiety": "issue_Anxiety",
"suicide": "issue_Suicide",
"safety_planning": "issue_Suicide",
"GCT": "issue_Gral",
}
_EN_INST_TEMPLATE_ = """<s> [INST] The following is a conversation between you and a crisis counselor.
{current_issue}
You are able to reply with what the character should say. You are able to reply with your character's dialogue inside and nothing else. Do not write explanations.
Do not disclose your name unless asked.
{history} </s> [INST] {input} [/INST]"""
CURRENT_ISSUE_MAPPING = {
"issue_Suicide-en": "Your character, {texter_name}, has suicidal thoughts. Your character has a plan to end his life and has all the means and requirements to do so. {seed}",
"issue_Anxiety-en": "Your character, {texter_name}, is experiencing anxiety. Your character has suicide thoughts but no plan. {seed}",
"issue_Suicide-es": "Tu personaje, {texter_name}, tiene pensamientos suicidas. Tu personaje tiene un plan para terminar con su vida y tiene todos los medios y requerimientos para hacerlo. {seed}",
"issue_Anxiety-es": "Tu personaje, {texter_name}, experimenta ansiedad. Tu personaje tiene pensamientos suicidas pero ningun plan. {seed}",
"issue_Gral-en": "Your character {texter_name} is experiencing a mental health crisis. {seed}",
"issue_Gral-es": "Tu personaje {texter_name} esta experimentando una crisis de salud mental. {seed}",
}
def get_template_databricks_models(issue: str, language: str, texter_name: str = "", seed="") -> str:
"""_summary_
Args:
issue (str): Issue for template, current options are ['issue_Suicide','issue_Anxiety']
language (str): Language for the template, current options are ['en','es']
texter_name (str): texter to apply to template, defaults to None
Returns:
str: template
"""
current_issue = CURRENT_ISSUE_MAPPING.get(
f"{issue}-{language}", CURRENT_ISSUE_MAPPING[f"issue_Gral-{language}"]
)
default_name = get_random_name()
texter_name=default_name if not texter_name else texter_name
current_issue = current_issue.format(
texter_name=texter_name,
seed = seed
)
if language == "en":
template = _EN_INST_TEMPLATE_.format(current_issue=current_issue, history="{history}", input="{input}")
else:
raise Exception(f"Language not supported for Databricks: {language}")
return template, texter_name
def get_databricks_chain(source, template, memory, temperature=0.8, texter_name="Kit"):
endpoint_name = ENDPOINT_NAMES.get(source, "conversation_simulator")
PROMPT = PromptTemplate(
input_variables=['history', 'input'],
template=template
)
def transform_output(response):
return response['candidates'][0]['text']
llm = Databricks(endpoint_name=endpoint_name,
transform_output_fn=transform_output,
temperature=temperature,
max_tokens=256,
)
llm_chain = ConversationChain(
llm=llm,
prompt=PROMPT,
memory=memory,
output_parser=CustomINSTOutputParser(name=texter_name, name_rx=re.compile(r""+ texter_name + r":|" + texter_name.lower() + r":")),
verbose=True,
)
logging.debug(f"loaded Databricks model")
return llm_chain, ["[INST]"]