Spaces:
Running
Running
import os | |
import re | |
import logging | |
from models.custom_parsers import CustomINSTOutputParser | |
from utils.app_utils import get_random_name | |
from app_config import ENDPOINT_NAMES | |
from langchain.chains import ConversationChain | |
from langchain_community.llms import Databricks | |
from langchain.prompts import PromptTemplate | |
from typing import Any, List, Mapping, Optional, Dict | |
ISSUE_MAPPING = { | |
"anxiety": "issue_Anxiety", | |
"suicide": "issue_Suicide", | |
"safety_planning": "issue_Suicide", | |
"GCT": "issue_Gral", | |
} | |
_EN_INST_TEMPLATE_ = """<s> [INST] The following is a conversation between you and a crisis counselor. | |
{current_issue} | |
You are able to reply with what the character should say. You are able to reply with your character's dialogue inside and nothing else. Do not write explanations. | |
Do not disclose your name unless asked. | |
{history} </s> [INST] {input} [/INST]""" | |
CURRENT_ISSUE_MAPPING = { | |
"issue_Suicide-en": "Your character, {texter_name}, has suicidal thoughts. Your character has a plan to end his life and has all the means and requirements to do so. {seed}", | |
"issue_Anxiety-en": "Your character, {texter_name}, is experiencing anxiety. Your character has suicide thoughts but no plan. {seed}", | |
"issue_Suicide-es": "Tu personaje, {texter_name}, tiene pensamientos suicidas. Tu personaje tiene un plan para terminar con su vida y tiene todos los medios y requerimientos para hacerlo. {seed}", | |
"issue_Anxiety-es": "Tu personaje, {texter_name}, experimenta ansiedad. Tu personaje tiene pensamientos suicidas pero ningun plan. {seed}", | |
"issue_Gral-en": "Your character {texter_name} is experiencing a mental health crisis. {seed}", | |
"issue_Gral-es": "Tu personaje {texter_name} esta experimentando una crisis de salud mental. {seed}", | |
} | |
def get_template_databricks_models(issue: str, language: str, texter_name: str = "", seed="") -> str: | |
"""_summary_ | |
Args: | |
issue (str): Issue for template, current options are ['issue_Suicide','issue_Anxiety'] | |
language (str): Language for the template, current options are ['en','es'] | |
texter_name (str): texter to apply to template, defaults to None | |
Returns: | |
str: template | |
""" | |
current_issue = CURRENT_ISSUE_MAPPING.get( | |
f"{issue}-{language}", CURRENT_ISSUE_MAPPING[f"issue_Gral-{language}"] | |
) | |
default_name = get_random_name() | |
texter_name=default_name if not texter_name else texter_name | |
current_issue = current_issue.format( | |
texter_name=texter_name, | |
seed = seed | |
) | |
if language == "en": | |
template = _EN_INST_TEMPLATE_.format(current_issue=current_issue, history="{history}", input="{input}") | |
else: | |
raise Exception(f"Language not supported for Databricks: {language}") | |
return template, texter_name | |
def get_databricks_chain(source, template, memory, temperature=0.8, texter_name="Kit"): | |
endpoint_name = ENDPOINT_NAMES.get(source, "conversation_simulator") | |
PROMPT = PromptTemplate( | |
input_variables=['history', 'input'], | |
template=template | |
) | |
def transform_output(response): | |
return response['candidates'][0]['text'] | |
llm = Databricks(endpoint_name=endpoint_name, | |
transform_output_fn=transform_output, | |
temperature=temperature, | |
max_tokens=256, | |
) | |
llm_chain = ConversationChain( | |
llm=llm, | |
prompt=PROMPT, | |
memory=memory, | |
output_parser=CustomINSTOutputParser(name=texter_name, name_rx=re.compile(r""+ texter_name + r":|" + texter_name.lower() + r":")), | |
verbose=True, | |
) | |
logging.debug(f"loaded Databricks model") | |
return llm_chain, ["[INST]"] |