Spaces:
Running
Running
ivnban27-ctl
commited on
Commit
•
20b3b4a
1
Parent(s):
18f6362
v0.5 (#4)
Browse files- v0.5 utilities (87cfe08d079faab66e43edd83437c18cd8f6d2f7)
- app_config.py +8 -1
- convosim.py +9 -5
- models/custom_parsers.py +43 -3
- models/databricks/instruction_memory.py +129 -0
- models/databricks/scenario_sim.py +91 -0
- models/databricks/scenario_sim_biz.py +13 -8
- models/model_seeds.py +62 -77
- requirements.txt +2 -1
- utils/chain_utils.py +24 -2
- utils/memory_utils.py +3 -0
app_config.py
CHANGED
@@ -4,14 +4,21 @@ from models.model_seeds import seeds, seed2str
|
|
4 |
ISSUES = [k for k,_ in seeds.items()]
|
5 |
SOURCES = [
|
6 |
"CTL_llama2",
|
|
|
7 |
'OA_rolemodel',
|
8 |
# 'OA_finetuned',
|
9 |
]
|
10 |
SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT3.5',
|
11 |
"OA_finetuned":'Finetuned OpenAI',
|
12 |
-
"CTL_llama2": "
|
|
|
13 |
}
|
14 |
|
|
|
|
|
|
|
|
|
|
|
15 |
def source2label(source):
|
16 |
return SOURCES_LAB[source]
|
17 |
|
|
|
4 |
ISSUES = [k for k,_ in seeds.items()]
|
5 |
SOURCES = [
|
6 |
"CTL_llama2",
|
7 |
+
"CTL_mistral",
|
8 |
'OA_rolemodel',
|
9 |
# 'OA_finetuned',
|
10 |
]
|
11 |
SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT3.5',
|
12 |
"OA_finetuned":'Finetuned OpenAI',
|
13 |
+
"CTL_llama2": "Llama",
|
14 |
+
"CTL_mistral": "Mistral",
|
15 |
}
|
16 |
|
17 |
+
ENDPOINT_NAMES = {
|
18 |
+
'CTL_llama2': "llama2_convo_sim",
|
19 |
+
"CTL_mistral": "convo_sim_mistral"
|
20 |
+
}
|
21 |
+
|
22 |
def source2label(source):
|
23 |
return SOURCES_LAB[source]
|
24 |
|
convosim.py
CHANGED
@@ -5,11 +5,13 @@ from langchain.schema.messages import HumanMessage
|
|
5 |
from utils.mongo_utils import get_db_client
|
6 |
from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF
|
7 |
from utils.memory_utils import clear_memory, push_convo2db
|
8 |
-
from utils.chain_utils import get_chain
|
9 |
from app_config import ISSUES, SOURCES, source2label, issue2label
|
10 |
|
11 |
logger = get_logger(__name__)
|
12 |
openai_api_key = os.environ['OPENAI_API_KEY']
|
|
|
|
|
13 |
|
14 |
if "sent_messages" not in st.session_state:
|
15 |
st.session_state['sent_messages'] = 0
|
@@ -28,8 +30,8 @@ if 'texter_name' not in st.session_state:
|
|
28 |
memories = {'memory':{"issue": st.session_state['issue'], "source": st.session_state['previous_source']}}
|
29 |
|
30 |
with st.sidebar:
|
31 |
-
username = st.text_input("Username", value='ivnban-ctl', max_chars=30)
|
32 |
-
temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
|
33 |
issue = st.selectbox("Select a Scenario", ISSUES, index=0, format_func=issue2label,
|
34 |
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
|
35 |
)
|
@@ -76,6 +78,8 @@ if prompt := st.chat_input():
|
|
76 |
push_convo2db(memories, username, language)
|
77 |
|
78 |
st.chat_message("user").write(prompt)
|
79 |
-
|
|
|
80 |
# response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature)
|
81 |
-
|
|
|
|
5 |
from utils.mongo_utils import get_db_client
|
6 |
from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF
|
7 |
from utils.memory_utils import clear_memory, push_convo2db
|
8 |
+
from utils.chain_utils import get_chain, custom_chain_predict
|
9 |
from app_config import ISSUES, SOURCES, source2label, issue2label
|
10 |
|
11 |
logger = get_logger(__name__)
|
12 |
openai_api_key = os.environ['OPENAI_API_KEY']
|
13 |
+
temperature = 0.8
|
14 |
+
username = "barb-chase" #"ivnban-ctl"
|
15 |
|
16 |
if "sent_messages" not in st.session_state:
|
17 |
st.session_state['sent_messages'] = 0
|
|
|
30 |
memories = {'memory':{"issue": st.session_state['issue'], "source": st.session_state['previous_source']}}
|
31 |
|
32 |
with st.sidebar:
|
33 |
+
# username = st.text_input("Username", value='ivnban-ctl', max_chars=30)
|
34 |
+
# temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
|
35 |
issue = st.selectbox("Select a Scenario", ISSUES, index=0, format_func=issue2label,
|
36 |
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
|
37 |
)
|
|
|
78 |
push_convo2db(memories, username, language)
|
79 |
|
80 |
st.chat_message("user").write(prompt)
|
81 |
+
responses = custom_chain_predict(llm_chain, prompt, stopper)
|
82 |
+
# responses = llm_chain.predict(input=prompt, stop=stopper)
|
83 |
# response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature)
|
84 |
+
for response in responses:
|
85 |
+
st.chat_message("assistant").write(response)
|
models/custom_parsers.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
from typing import List
|
|
|
2 |
from langchain.schema import BaseOutputParser
|
|
|
3 |
|
4 |
class CustomStringOutputParser(BaseOutputParser[List[str]]):
|
5 |
"""Parse the output of an LLM call to a list."""
|
@@ -10,8 +12,46 @@ class CustomStringOutputParser(BaseOutputParser[List[str]]):
|
|
10 |
|
11 |
def parse(self, text: str) -> str:
|
12 |
"""Parse the output of an LLM call."""
|
13 |
-
text = text.split("texter:")[0]
|
14 |
text = text.split("helper")[0]
|
|
|
15 |
text = text.rstrip("\n")
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from typing import List
|
2 |
+
import re
|
3 |
from langchain.schema import BaseOutputParser
|
4 |
+
from itertools import chain
|
5 |
|
6 |
class CustomStringOutputParser(BaseOutputParser[List[str]]):
|
7 |
"""Parse the output of an LLM call to a list."""
|
|
|
12 |
|
13 |
def parse(self, text: str) -> str:
|
14 |
"""Parse the output of an LLM call."""
|
|
|
15 |
text = text.split("helper")[0]
|
16 |
+
text = text.split("\nhelper")[0]
|
17 |
text = text.rstrip("\n")
|
18 |
+
text_list = text.split("texter:")
|
19 |
+
text_list = [x.split("\ntexter") for x in text_list]
|
20 |
+
text_list = [x.strip() for x in list(chain.from_iterable(text_list))]
|
21 |
+
return text_list
|
22 |
+
|
23 |
+
class CustomINSTOutputParser(BaseOutputParser[List[str]]):
|
24 |
+
"""Parse the output of an LLM call to a list."""
|
25 |
+
|
26 |
+
name = "Kit"
|
27 |
+
name_rx = re.compile(r""+ name + r":|" + name.lower() + r":")
|
28 |
+
whispers = re.compile((r"([\(]).*?([\)])"))
|
29 |
+
reactions = re.compile(r"([\*]).*?([\*])")
|
30 |
+
double_spaces = re.compile(r" ")
|
31 |
+
quotation_rx = re.compile('"')
|
32 |
+
|
33 |
+
@property
|
34 |
+
def _type(self) -> str:
|
35 |
+
return "str"
|
36 |
+
|
37 |
+
def parse_whispers(self, text: str) -> str:
|
38 |
+
text = self.name_rx.sub("", text).strip()
|
39 |
+
text = self.reactions.sub("", text).strip()
|
40 |
+
text = self.whispers.sub("", text).strip()
|
41 |
+
text = self.double_spaces.sub(r" ", text).strip()
|
42 |
+
text = self.quotation_rx.sub("", text).strip()
|
43 |
+
return text
|
44 |
+
|
45 |
+
def parse_split(self, text: str) -> str:
|
46 |
+
text = text.split("[INST]")[0]
|
47 |
+
text_list = text.split("[/INST]")
|
48 |
+
text_list = [x.split("</s>") for x in text_list]
|
49 |
+
text_list = [x.strip() for x in list(chain.from_iterable(text_list))]
|
50 |
+
text_list = [x.split("\n\n") for x in text_list]
|
51 |
+
text_list = [x.strip().rstrip("\n") for x in list(chain.from_iterable(text_list))]
|
52 |
+
return text_list
|
53 |
+
|
54 |
+
def parse(self, text: str) -> str:
|
55 |
+
text = self.parse_whispers(text)
|
56 |
+
text_list = self.parse_split(text)
|
57 |
+
return text_list
|
models/databricks/instruction_memory.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from typing import Any, Dict, List, Sequence, Union
|
4 |
+
|
5 |
+
from langchain.memory.chat_memory import BaseChatMemory
|
6 |
+
from langchain.schema.messages import (
|
7 |
+
AIMessage,
|
8 |
+
BaseMessage,
|
9 |
+
ChatMessage,
|
10 |
+
FunctionMessage,
|
11 |
+
HumanMessage,
|
12 |
+
SystemMessage,
|
13 |
+
)
|
14 |
+
|
15 |
+
# Hacky way of importing other directory modules.
|
16 |
+
# This is done this way instead of src.evaluation.postprocessing_utils
|
17 |
+
# Because this file is used for inference where there is not the same folder structure
|
18 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "evaluation"))
|
19 |
+
|
20 |
+
def get_buffer_string_inst(
|
21 |
+
messages: Sequence[BaseMessage],
|
22 |
+
human_prefix: str = "[INST]",
|
23 |
+
ai_prefix: str = "[/INST]",
|
24 |
+
) -> str:
|
25 |
+
"""Convert sequence of Messages to strings and concatenate them into one string.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
messages: Messages to be converted to strings.
|
29 |
+
human_prefix: The prefix to prepend to contents of HumanMessages.
|
30 |
+
ai_prefix: THe prefix to prepend to contents of AIMessages.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
A single string concatenation of all input messages.
|
34 |
+
|
35 |
+
Example:
|
36 |
+
.. code-block:: python
|
37 |
+
|
38 |
+
from langchain.schema import AIMessage, HumanMessage
|
39 |
+
|
40 |
+
messages = [
|
41 |
+
HumanMessage(content="Hi, how are you?"),
|
42 |
+
AIMessage(content="Good, how are you?"),
|
43 |
+
]
|
44 |
+
get_buffer_string(messages)
|
45 |
+
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
|
46 |
+
"""
|
47 |
+
string_messages = []
|
48 |
+
for m in messages:
|
49 |
+
if isinstance(m, HumanMessage):
|
50 |
+
role = human_prefix
|
51 |
+
elif isinstance(m, AIMessage):
|
52 |
+
role = ai_prefix
|
53 |
+
elif isinstance(m, SystemMessage):
|
54 |
+
role = "System"
|
55 |
+
elif isinstance(m, FunctionMessage):
|
56 |
+
role = "Function"
|
57 |
+
elif isinstance(m, ChatMessage):
|
58 |
+
role = m.role
|
59 |
+
else:
|
60 |
+
raise ValueError(f"Got unsupported message type: {m}")
|
61 |
+
message = f"{role} {m.content}"
|
62 |
+
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
|
63 |
+
message += f"{m.additional_kwargs['function_call']}"
|
64 |
+
string_messages.append(message)
|
65 |
+
|
66 |
+
return " ".join(string_messages)
|
67 |
+
|
68 |
+
|
69 |
+
class CustomBufferInstructionMemory(BaseChatMemory):
|
70 |
+
memory_key: str = "chat_history"
|
71 |
+
input_key: str = "input"
|
72 |
+
ai_prefix: str = "[/INST]"
|
73 |
+
human_prefix: str = "[INST]"
|
74 |
+
name: str = "Kit"
|
75 |
+
|
76 |
+
@property
|
77 |
+
def buffer(self) -> Union[str, List[BaseMessage]]:
|
78 |
+
"""String buffer of memory."""
|
79 |
+
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
80 |
+
|
81 |
+
@property
|
82 |
+
def buffer_as_str(self) -> str:
|
83 |
+
"""Exposes the buffer as a string in case return_messages is True."""
|
84 |
+
messages = self.chat_memory.messages
|
85 |
+
return get_buffer_string_inst(
|
86 |
+
messages,
|
87 |
+
human_prefix=self.human_prefix,
|
88 |
+
ai_prefix=self.ai_prefix,
|
89 |
+
)
|
90 |
+
|
91 |
+
@property
|
92 |
+
def buffer_as_messages(self) -> List[BaseMessage]:
|
93 |
+
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
94 |
+
return self.chat_memory.messages
|
95 |
+
|
96 |
+
@property
|
97 |
+
def memory_variables(self) -> List[str]:
|
98 |
+
"""Define the variables we are providing to the prompt."""
|
99 |
+
return [self.memory_key]
|
100 |
+
|
101 |
+
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
102 |
+
"""Load the memory variables, in this case the entity key."""
|
103 |
+
# Return combined information about entities to put into context.
|
104 |
+
return {self.memory_key: self.buffer}
|
105 |
+
|
106 |
+
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
107 |
+
"""Save context from this conversation to buffer."""
|
108 |
+
input_str, output_str = self._get_input_output(inputs, outputs)
|
109 |
+
self.chat_memory.add_user_message(input_str)
|
110 |
+
self.chat_memory.add_ai_message(output_str)
|
111 |
+
|
112 |
+
|
113 |
+
class CustomBufferWindowInstructionMemory(CustomBufferInstructionMemory):
|
114 |
+
k: int = 5
|
115 |
+
|
116 |
+
@property
|
117 |
+
def buffer_as_str(self) -> str:
|
118 |
+
"""Exposes the buffer as a string in case return_messages is True."""
|
119 |
+
messages = self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
|
120 |
+
return get_buffer_string_inst(
|
121 |
+
messages,
|
122 |
+
human_prefix=self.human_prefix,
|
123 |
+
ai_prefix=self.ai_prefix,
|
124 |
+
)
|
125 |
+
|
126 |
+
@property
|
127 |
+
def buffer_as_messages(self) -> List[BaseMessage]:
|
128 |
+
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
129 |
+
return self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
|
models/databricks/scenario_sim.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import logging
|
4 |
+
from models.custom_parsers import CustomINSTOutputParser
|
5 |
+
from utils.app_utils import get_random_name
|
6 |
+
from app_config import ENDPOINT_NAMES
|
7 |
+
from langchain.chains import ConversationChain
|
8 |
+
from langchain_community.llms import Databricks
|
9 |
+
from langchain.prompts import PromptTemplate
|
10 |
+
|
11 |
+
from typing import Any, List, Mapping, Optional, Dict
|
12 |
+
|
13 |
+
ISSUE_MAPPING = {
|
14 |
+
"anxiety": "issue_Anxiety",
|
15 |
+
"suicide": "issue_Suicide",
|
16 |
+
"safety_planning": "issue_Suicide",
|
17 |
+
"GCT": "issue_Gral",
|
18 |
+
}
|
19 |
+
|
20 |
+
_EN_INST_TEMPLATE_ = """<s> [INST] The following is a conversation between you and a crisis counselor.
|
21 |
+
{current_issue}
|
22 |
+
You are able to reply with what the character should say. You are able to reply with your character's dialogue inside and nothing else. Do not write explanations.
|
23 |
+
Do not disclose your name unless asked.
|
24 |
+
|
25 |
+
{history} </s> [INST] {input} [/INST]"""
|
26 |
+
|
27 |
+
CURRENT_ISSUE_MAPPING = {
|
28 |
+
"issue_Suicide-en": "Your character, {texter_name}, has suicidal thoughts. Your character has a plan to end his life and has all the means and requirements to do so. {seed}",
|
29 |
+
"issue_Anxiety-en": "Your character, {texter_name}, is experiencing anxiety. Your character has suicide thoughts but no plan. {seed}",
|
30 |
+
"issue_Suicide-es": "Tu personaje, {texter_name}, tiene pensamientos suicidas. Tu personaje tiene un plan para terminar con su vida y tiene todos los medios y requerimientos para hacerlo. {seed}",
|
31 |
+
"issue_Anxiety-es": "Tu personaje, {texter_name}, experimenta ansiedad. Tu personaje tiene pensamientos suicidas pero ningun plan. {seed}",
|
32 |
+
"issue_Gral-en": "Your character {texter_name} is experiencing a mental health crisis. {seed}",
|
33 |
+
"issue_Gral-es": "Tu personaje {texter_name} esta experimentando una crisis de salud mental. {seed}",
|
34 |
+
}
|
35 |
+
|
36 |
+
def get_template_databricks_models(issue: str, language: str, texter_name: str = "", seed="") -> str:
|
37 |
+
"""_summary_
|
38 |
+
|
39 |
+
Args:
|
40 |
+
issue (str): Issue for template, current options are ['issue_Suicide','issue_Anxiety']
|
41 |
+
language (str): Language for the template, current options are ['en','es']
|
42 |
+
texter_name (str): texter to apply to template, defaults to None
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
str: template
|
46 |
+
"""
|
47 |
+
current_issue = CURRENT_ISSUE_MAPPING.get(
|
48 |
+
f"{issue}-{language}", CURRENT_ISSUE_MAPPING[f"issue_Gral-{language}"]
|
49 |
+
)
|
50 |
+
default_name = get_random_name()
|
51 |
+
texter_name=default_name if not texter_name else texter_name
|
52 |
+
current_issue = current_issue.format(
|
53 |
+
texter_name=texter_name,
|
54 |
+
seed = seed
|
55 |
+
)
|
56 |
+
|
57 |
+
if language == "en":
|
58 |
+
template = _EN_INST_TEMPLATE_.format(current_issue=current_issue, history="{history}", input="{input}")
|
59 |
+
else:
|
60 |
+
raise Exception(f"Language not supported for Databricks: {language}")
|
61 |
+
|
62 |
+
return template, texter_name
|
63 |
+
|
64 |
+
def get_databricks_chain(source, template, memory, temperature=0.8, texter_name="Kit"):
|
65 |
+
|
66 |
+
endpoint_name = ENDPOINT_NAMES.get(source, "conversation_simulator")
|
67 |
+
|
68 |
+
PROMPT = PromptTemplate(
|
69 |
+
input_variables=['history', 'input'],
|
70 |
+
template=template
|
71 |
+
)
|
72 |
+
|
73 |
+
def transform_output(response):
|
74 |
+
return response['candidates'][0]['text']
|
75 |
+
|
76 |
+
llm = Databricks(endpoint_name=endpoint_name,
|
77 |
+
transform_output_fn=transform_output,
|
78 |
+
temperature=temperature,
|
79 |
+
max_tokens=256,
|
80 |
+
)
|
81 |
+
|
82 |
+
llm_chain = ConversationChain(
|
83 |
+
llm=llm,
|
84 |
+
prompt=PROMPT,
|
85 |
+
memory=memory,
|
86 |
+
output_parser=CustomINSTOutputParser(name=texter_name, name_rx=re.compile(r""+ texter_name + r":|" + texter_name.lower() + r":")),
|
87 |
+
verbose=True,
|
88 |
+
)
|
89 |
+
|
90 |
+
logging.debug(f"loaded Databricks model")
|
91 |
+
return llm_chain, ["[INST]"]
|
models/databricks/scenario_sim_biz.py
CHANGED
@@ -3,6 +3,7 @@ import json
|
|
3 |
import requests
|
4 |
import logging
|
5 |
from models.custom_parsers import CustomStringOutputParser
|
|
|
6 |
from langchain.chains import ConversationChain
|
7 |
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
8 |
from langchain_core.language_models.llms import LLM
|
@@ -10,16 +11,17 @@ from langchain.prompts import PromptTemplate
|
|
10 |
|
11 |
from typing import Any, List, Mapping, Optional, Dict
|
12 |
|
13 |
-
class
|
14 |
issue:str
|
15 |
language:str
|
16 |
temperature:float = 0.8
|
17 |
-
|
|
|
18 |
headers:Mapping[str,str] = {'Authorization': f'Bearer {os.environ.get("DATABRICKS_TOKEN")}', 'Content-Type': 'application/json'}
|
19 |
|
20 |
@property
|
21 |
def _llm_type(self) -> str:
|
22 |
-
return "
|
23 |
|
24 |
def _call(
|
25 |
self,
|
@@ -32,7 +34,8 @@ class DatabricksCustomLLM(LLM):
|
|
32 |
'prompt': [prompt],
|
33 |
'issue': [self.issue],
|
34 |
'language': [self.language],
|
35 |
-
'temperature': [self.temperature]
|
|
|
36 |
}}
|
37 |
data_json = json.dumps(data_, allow_nan=True)
|
38 |
response = requests.request(method='POST', headers=self.headers, url=self.db_url, data=data_json)
|
@@ -45,16 +48,18 @@ _DATABRICKS_TEMPLATE_ = """{history}
|
|
45 |
helper: {input}
|
46 |
texter:"""
|
47 |
|
48 |
-
def
|
49 |
|
50 |
PROMPT = PromptTemplate(
|
51 |
input_variables=['history', 'input'],
|
52 |
template=_DATABRICKS_TEMPLATE_
|
53 |
)
|
54 |
-
llm =
|
55 |
issue=issue,
|
56 |
language=language,
|
57 |
-
temperature=temperature
|
|
|
|
|
58 |
)
|
59 |
llm_chain = ConversationChain(
|
60 |
llm=llm,
|
@@ -62,5 +67,5 @@ def get_databricks_chain(issue, language, memory, temperature=0.8):
|
|
62 |
memory=memory,
|
63 |
output_parser=CustomStringOutputParser()
|
64 |
)
|
65 |
-
logging.debug(f"loaded Databricks
|
66 |
return llm_chain, "helper:"
|
|
|
3 |
import requests
|
4 |
import logging
|
5 |
from models.custom_parsers import CustomStringOutputParser
|
6 |
+
from app_config import ENDPOINT_NAMES
|
7 |
from langchain.chains import ConversationChain
|
8 |
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
9 |
from langchain_core.language_models.llms import LLM
|
|
|
11 |
|
12 |
from typing import Any, List, Mapping, Optional, Dict
|
13 |
|
14 |
+
class DatabricksCustomBizLLM(LLM):
|
15 |
issue:str
|
16 |
language:str
|
17 |
temperature:float = 0.8
|
18 |
+
max_tokens: int = 128
|
19 |
+
db_url:str
|
20 |
headers:Mapping[str,str] = {'Authorization': f'Bearer {os.environ.get("DATABRICKS_TOKEN")}', 'Content-Type': 'application/json'}
|
21 |
|
22 |
@property
|
23 |
def _llm_type(self) -> str:
|
24 |
+
return "custom_databricks_biz"
|
25 |
|
26 |
def _call(
|
27 |
self,
|
|
|
34 |
'prompt': [prompt],
|
35 |
'issue': [self.issue],
|
36 |
'language': [self.language],
|
37 |
+
'temperature': [self.temperature],
|
38 |
+
'max_tokens': [self.max_tokens],
|
39 |
}}
|
40 |
data_json = json.dumps(data_, allow_nan=True)
|
41 |
response = requests.request(method='POST', headers=self.headers, url=self.db_url, data=data_json)
|
|
|
48 |
helper: {input}
|
49 |
texter:"""
|
50 |
|
51 |
+
def get_databricks_biz_chain(source, issue, language, memory, temperature=0.8):
|
52 |
|
53 |
PROMPT = PromptTemplate(
|
54 |
input_variables=['history', 'input'],
|
55 |
template=_DATABRICKS_TEMPLATE_
|
56 |
)
|
57 |
+
llm = DatabricksCustomBizLLM(
|
58 |
issue=issue,
|
59 |
language=language,
|
60 |
+
temperature=temperature,
|
61 |
+
max_tokens=256,
|
62 |
+
db_url = os.environ['DATABRICKS_URL'].format(endpoint_name=ENDPOINT_NAMES.get(source, "conversation_simulator"))
|
63 |
)
|
64 |
llm_chain = ConversationChain(
|
65 |
llm=llm,
|
|
|
67 |
memory=memory,
|
68 |
output_parser=CustomStringOutputParser()
|
69 |
)
|
70 |
+
logging.debug(f"loaded Databricks Biz model")
|
71 |
return llm_chain, "helper:"
|
models/model_seeds.py
CHANGED
@@ -1,93 +1,77 @@
|
|
1 |
seeds = {
|
2 |
-
"GCT": {
|
3 |
-
"prompt": "",
|
4 |
-
"memory": "texter: Help"
|
5 |
-
},
|
6 |
# "GCT__relationship": {
|
7 |
# "prompt": "Your character is having a hard time becuase a failed relationship.",
|
8 |
-
# "memory": "texter: Hi, I don't know what to do"
|
9 |
-
#
|
10 |
-
"GCT__body_image": {
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
# "GCT__sexuality": {
|
15 |
# "prompt": "Your character has a sexuality identity crisis.",
|
16 |
-
# "memory": "texter: Hi\ntexter:I'm not sure who I am anymore"
|
17 |
-
#
|
18 |
# "GCT__anxiety": {
|
19 |
# "prompt": "Your character is experiencing an anxiety crisis.",
|
20 |
-
# "memory": "texter: help!\ntexter: I'm feeling overwhelmed"
|
21 |
-
#
|
|
|
|
|
22 |
"safety_planning": {
|
23 |
-
"prompt": "
|
24 |
-
"memory": """texter:
|
25 |
-
helper: Hi, my name is {counselor_name} and I'm here to support you. It sounds like you are having a
|
|
|
26 |
texter: nothing makes sense in my life, I see no future.
|
27 |
-
helper: It takes courage to reach out. I'm here with you. Sounds like you are feeling defeated by how things are going in your life
|
28 |
-
texter: I guess
|
29 |
helper: It's really brave of you to talk about this openly. No one deserves to feel like that. I'm wondering how long have you been feeling this way?
|
30 |
-
texter: About
|
31 |
-
helper:
|
32 |
-
texter:
|
33 |
-
helper:
|
|
|
|
|
|
|
|
|
34 |
texter: call me {texter_name}
|
35 |
-
helper: Nice to meet you {texter_name}.
|
36 |
-
texter:
|
37 |
-
helper:
|
38 |
-
texter:
|
39 |
-
helper:
|
40 |
-
texter:
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
texter:
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
texter: I
|
52 |
-
|
53 |
-
|
54 |
-
texter:
|
55 |
-
helper:
|
56 |
-
texter:
|
57 |
-
texter: I'
|
58 |
-
helper:
|
59 |
-
texter:
|
60 |
-
helper:
|
61 |
-
texter:
|
62 |
-
|
63 |
-
helper: I appreciate you telling me this. I know is not easy, especially over text. I hear you are under a lot of pressure from your mom
|
64 |
-
texter: Yeah exactly!
|
65 |
-
helper: Your self-awareness is inspiring. You mentioned earlier you do not want to live anymore. Your safety is my priority (1/2)
|
66 |
-
helper: Do you have thoughts of suicide? (2/2)
|
67 |
-
texter: Yeah constantly, like always always
|
68 |
-
helper: Thanks for sharing that with me. You are very resilient. Do you have a plan to end your life?
|
69 |
-
texter: I used to cut myself a few months ago
|
70 |
-
texter: I still have the razor, sometimes the urge is so hard!
|
71 |
-
helper: I really appreciate your strength in talking about this. I want to help you stay safe today. Just to be clear, are you cutting yourself now?
|
72 |
-
texter: No, not now, but I want to soo bad.
|
73 |
-
helper: Thanks for your honesty. Do you have access to the razor right now?
|
74 |
-
texter: Yeah is in my drawer
|
75 |
-
helper: You've been so strong so far {texter_name}. When do you plan to end your life
|
76 |
-
texter: Today"""
|
77 |
-
},
|
78 |
# "safety_planning__overdose": {
|
79 |
-
# "prompt": "
|
80 |
# "memory": """texter: I want to kms
|
81 |
# helper: Hi there I'm {counselor_name}. I'm here to listen. It sounds like you're dealing with a lot right now. Can you tell me a little more what is going on?
|
82 |
# texter: I feel like nobody loves me, not even me. I don't want to live anymore
|
83 |
# helper: I can tell you are really going through a lot right now. Would you mind sharing a name with me?
|
84 |
# texter: yeah, I'm {texter_name}
|
85 |
# helper: Nice to meet you {texter_name}. Did something happened recently that intensified these feelings?
|
86 |
-
# texter: I
|
87 |
-
# texter: They took my bag and hide all my stuff, they told my crush I was in love with him
|
88 |
-
# texter: I can't deal with all of that
|
89 |
-
# helper: It sounds like you went through a lot. Bullying and pranks can be hurtful. I'm here for you
|
90 |
-
# texter: Thank you it feels good to have someone in your sidw
|
91 |
# helper: I can hear how much pain you are in {texter_name}. You are smart for reaching out. You mentioned don't wanting to live anymore, I want to check in your safety, does this means you have thoughts of suicide?
|
92 |
# texter: Yeah, what else would it be
|
93 |
# helper: Thanks for sharing that with me. It is not easy to accept those feelings specially with a stranger over text. Do you have a plan to end your life?
|
@@ -95,15 +79,16 @@ texter: Today"""
|
|
95 |
# helper: Sounds like you've been contemplating this for a while. Would you mind sharing this plan with me?
|
96 |
# texter: I thought about taking a bunch of benadryll and be done with it
|
97 |
# helper: You've been so forthcoming with all this and I admire your stregth for holding on this long. Do you have those pills right now?
|
98 |
-
# texter: They are at
|
99 |
# helper: You been so strong so far {texter_name}. I'm here for you tonight. Your safety is really important to me. Do you have a date you are going to end your life?
|
100 |
-
# texter: I was thinking tonight"""
|
101 |
-
#
|
102 |
}
|
103 |
|
104 |
seed2str = {
|
105 |
"GCT":"Good Contact Techniques",
|
106 |
-
"
|
|
|
107 |
"safety_planning": "Safety Planning",
|
108 |
-
"safety_planning__selfharm": "SP Self Harm"
|
109 |
}
|
|
|
1 |
seeds = {
|
|
|
|
|
|
|
|
|
2 |
# "GCT__relationship": {
|
3 |
# "prompt": "Your character is having a hard time becuase a failed relationship.",
|
4 |
+
# "memory": "texter: Hi, I don't know what to do",
|
5 |
+
# },
|
6 |
+
# "GCT__body_image": {
|
7 |
+
# "prompt": "Your character has a low steem and struggles with body image.",
|
8 |
+
# "memory": "texter: I feel so dumb\ntexter: nobody loves me",
|
9 |
+
# },
|
10 |
# "GCT__sexuality": {
|
11 |
# "prompt": "Your character has a sexuality identity crisis.",
|
12 |
+
# "memory": "texter: Hi\ntexter:I'm not sure who I am anymore",
|
13 |
+
# },
|
14 |
# "GCT__anxiety": {
|
15 |
# "prompt": "Your character is experiencing an anxiety crisis.",
|
16 |
+
# "memory": "texter: help!\ntexter: I'm feeling overwhelmed",
|
17 |
+
# },
|
18 |
+
"GCT": {"prompt": "You are talkative, and you are open to talk with details about your problems.", "memory": "texter: Help\ntexter: I'm so overwhelmed"},
|
19 |
+
"GCT__seed2": {"prompt": "Your character is experiencing an anxiety crisis. You express openly and detailed about what you're going through.", "memory":"texter: Help\ntexter: I'm so overwhelmed"},
|
20 |
"safety_planning": {
|
21 |
+
"prompt": "You are talkative, and you are open to talk with details about your problems. When you are feeling down you like to listen to classical music. Your favorite composer is Beethoven.",
|
22 |
+
"memory": """texter: Hi, this is pointless
|
23 |
+
helper: Hi, my name is {counselor_name} and I'm here to support you. It sounds like you are having a rough time. Do you want to share what is going on?
|
24 |
+
texter: sure
|
25 |
texter: nothing makes sense in my life, I see no future.
|
26 |
+
helper: It takes courage to reach out when you are im. I'm here with you. Sounds like you are feeling defeated by how things are going in your life
|
27 |
+
texter: Yeah, I guess I'm better off dead
|
28 |
helper: It's really brave of you to talk about this openly. No one deserves to feel like that. I'm wondering how long have you been feeling this way?
|
29 |
+
texter: About 1 week or so
|
30 |
+
helper: You are so strong for dealing with this so long. I really appreciate your openess. Did something happened specifically today?
|
31 |
+
texter: Well, finding a job is impossible, money is tight, nothing goes my way
|
32 |
+
helper: I hear you are frustrated, and you are currently unemployed correct?
|
33 |
+
texter: Yeah
|
34 |
+
helper: Dealing with unemployment is hard and is normal to feel dissapointed
|
35 |
+
texter: thanks I probably needed to hear that
|
36 |
+
helper: If you are comfortable, is there a name I can call you by while we talk
|
37 |
texter: call me {texter_name}
|
38 |
+
helper: Nice to meet you {texter_name}. You mentioned having thoughts of suicide, are you having those thoughts now?
|
39 |
+
texter: Yes
|
40 |
+
helper: I know this is thought to share. I'm wondering is there any plan to end your life?
|
41 |
+
texter: I guess I'll just take lots of pills, that is a calm way to go out
|
42 |
+
helper: I really appreciate your strength in talking about this. I want to help you stay safe today. Do you have the pills right now?
|
43 |
+
texter: not really, I'll have to buy them or something""",
|
44 |
+
},
|
45 |
+
# "safety_planning__selfharm": {
|
46 |
+
# "prompt": "",
|
47 |
+
# "memory": """texter: I need help
|
48 |
+
# texter: I cut myself, I don't want to live anymore
|
49 |
+
# helper: Hi, my name is {counselor_name}. It seems you are going through a lot. Are you self-harming right now?
|
50 |
+
# texter: Not anymore
|
51 |
+
# helper: Your safety is my priority number one. Thanks for being honest with me. Would you like to share a name I can call you?
|
52 |
+
# texter: {texter_name}
|
53 |
+
# helper: Nice to meet you {texter_name}. I'm glad you reach out this shows stregth in you. Would you like to share more on what is going on in your life?
|
54 |
+
# texter: I just can't do it anymore
|
55 |
+
# texter: Finding a job is impossible, money is tight, nothing goes my way
|
56 |
+
# helper: I hear you are frustrated, and you are currently unemployed correct?
|
57 |
+
# texter: Yeah
|
58 |
+
# helper: Dealing with unemployment is hard and is normal to feel dissapointed. How long have you been feeling this way?
|
59 |
+
# texter: a while now
|
60 |
+
# texter: I've been unemployed 6 months
|
61 |
+
# helper: You are so resilient for dealing with this so much time. You mentioned cutting yourself earlier. I want to check in your safety. Do you have suicide thoughts
|
62 |
+
# texter: Definitely
|
63 |
+
# helper: Do you have a plan?
|
64 |
+
# texter: I'll just keep cutting myself""",
|
65 |
+
# },
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
# "safety_planning__overdose": {
|
67 |
+
# "prompt": "",
|
68 |
# "memory": """texter: I want to kms
|
69 |
# helper: Hi there I'm {counselor_name}. I'm here to listen. It sounds like you're dealing with a lot right now. Can you tell me a little more what is going on?
|
70 |
# texter: I feel like nobody loves me, not even me. I don't want to live anymore
|
71 |
# helper: I can tell you are really going through a lot right now. Would you mind sharing a name with me?
|
72 |
# texter: yeah, I'm {texter_name}
|
73 |
# helper: Nice to meet you {texter_name}. Did something happened recently that intensified these feelings?
|
74 |
+
# texter: I dont know I'm just so done with life
|
|
|
|
|
|
|
|
|
75 |
# helper: I can hear how much pain you are in {texter_name}. You are smart for reaching out. You mentioned don't wanting to live anymore, I want to check in your safety, does this means you have thoughts of suicide?
|
76 |
# texter: Yeah, what else would it be
|
77 |
# helper: Thanks for sharing that with me. It is not easy to accept those feelings specially with a stranger over text. Do you have a plan to end your life?
|
|
|
79 |
# helper: Sounds like you've been contemplating this for a while. Would you mind sharing this plan with me?
|
80 |
# texter: I thought about taking a bunch of benadryll and be done with it
|
81 |
# helper: You've been so forthcoming with all this and I admire your stregth for holding on this long. Do you have those pills right now?
|
82 |
+
# texter: They are at the cabinet right now
|
83 |
# helper: You been so strong so far {texter_name}. I'm here for you tonight. Your safety is really important to me. Do you have a date you are going to end your life?
|
84 |
+
# texter: I was thinking tonight""",
|
85 |
+
# },
|
86 |
}
|
87 |
|
88 |
seed2str = {
|
89 |
"GCT":"Good Contact Techniques",
|
90 |
+
"GCT__seed2": "Good Contact Techniques 2",
|
91 |
+
# "GCT__body_image": "GCT Body Image",
|
92 |
"safety_planning": "Safety Planning",
|
93 |
+
# "safety_planning__selfharm": "SP Self Harm"
|
94 |
}
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
scipy==1.11.1
|
2 |
openai==1.7.0
|
3 |
langchain==0.1.0
|
4 |
-
pymongo==4.5.0
|
|
|
|
1 |
scipy==1.11.1
|
2 |
openai==1.7.0
|
3 |
langchain==0.1.0
|
4 |
+
pymongo==4.5.0
|
5 |
+
mlflow==2.9.0
|
utils/chain_utils.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
from models.model_seeds import seeds
|
2 |
from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
|
3 |
from models.openai.role_models import get_role_chain, get_template_role_models
|
4 |
-
from models.databricks.scenario_sim_biz import
|
|
|
5 |
|
6 |
def get_chain(issue, language, source, memory, temperature, texter_name=""):
|
7 |
if source in ("OA_finetuned"):
|
@@ -16,4 +17,25 @@ def get_chain(issue, language, source, memory, temperature, texter_name=""):
|
|
16 |
language = "en"
|
17 |
elif language == "Spanish":
|
18 |
language = "es"
|
19 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from models.model_seeds import seeds
|
2 |
from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
|
3 |
from models.openai.role_models import get_role_chain, get_template_role_models
|
4 |
+
from models.databricks.scenario_sim_biz import get_databricks_biz_chain
|
5 |
+
from models.databricks.scenario_sim import get_databricks_chain, get_template_databricks_models
|
6 |
|
7 |
def get_chain(issue, language, source, memory, temperature, texter_name=""):
|
8 |
if source in ("OA_finetuned"):
|
|
|
17 |
language = "en"
|
18 |
elif language == "Spanish":
|
19 |
language = "es"
|
20 |
+
return get_databricks_biz_chain(source, issue, language, memory, temperature)
|
21 |
+
elif source in ('CTL_mistral'):
|
22 |
+
if language == "English":
|
23 |
+
language = "en"
|
24 |
+
elif language == "Spanish":
|
25 |
+
language = "es"
|
26 |
+
seed = seeds.get(issue, "GCT")['prompt']
|
27 |
+
template, texter_name = get_template_databricks_models(issue, language, texter_name=texter_name, seed=seed)
|
28 |
+
return get_databricks_chain(source, template, memory, temperature, texter_name)
|
29 |
+
|
30 |
+
from typing import cast
|
31 |
+
|
32 |
+
def custom_chain_predict(llm_chain, input, stop):
|
33 |
+
|
34 |
+
inputs = llm_chain.prep_inputs({"input":input, "stop":stop})
|
35 |
+
llm_chain._validate_inputs(inputs)
|
36 |
+
outputs = llm_chain._call(inputs)
|
37 |
+
llm_chain._validate_outputs(outputs)
|
38 |
+
llm_chain.memory.chat_memory.add_user_message(inputs['input'])
|
39 |
+
for out in outputs[llm_chain.output_key]:
|
40 |
+
llm_chain.memory.chat_memory.add_ai_message(out)
|
41 |
+
return outputs[llm_chain.output_key]
|
utils/memory_utils.py
CHANGED
@@ -2,6 +2,7 @@ import streamlit as st
|
|
2 |
from streamlit.logger import get_logger
|
3 |
|
4 |
from langchain.memory import ConversationBufferMemory
|
|
|
5 |
from utils.mongo_utils import new_convo
|
6 |
|
7 |
logger = get_logger(__name__)
|
@@ -24,6 +25,8 @@ def change_memories(memories, language, changed_source=False):
|
|
24 |
logger.info(f"Source for memory {memory} is {source}")
|
25 |
if source in ('OA_rolemodel','OA_finetuned',"CTL_llama2"):
|
26 |
st.session_state[memory] = ConversationBufferMemory(ai_prefix='texter', human_prefix='helper')
|
|
|
|
|
27 |
|
28 |
if ("convo_id" in st.session_state) and changed_source:
|
29 |
del st.session_state['convo_id']
|
|
|
2 |
from streamlit.logger import get_logger
|
3 |
|
4 |
from langchain.memory import ConversationBufferMemory
|
5 |
+
from models.databricks.instruction_memory import CustomBufferInstructionMemory
|
6 |
from utils.mongo_utils import new_convo
|
7 |
|
8 |
logger = get_logger(__name__)
|
|
|
25 |
logger.info(f"Source for memory {memory} is {source}")
|
26 |
if source in ('OA_rolemodel','OA_finetuned',"CTL_llama2"):
|
27 |
st.session_state[memory] = ConversationBufferMemory(ai_prefix='texter', human_prefix='helper')
|
28 |
+
elif source in ('CTL_mistral'):
|
29 |
+
st.session_state[memory] = CustomBufferInstructionMemory(human_prefix="</s> [INST]", memory_key="history")
|
30 |
|
31 |
if ("convo_id" in st.session_state) and changed_source:
|
32 |
del st.session_state['convo_id']
|