Spaces:
Running
Running
import os | |
import sys | |
from typing import Any, Dict, List, Sequence, Union | |
from langchain.memory.chat_memory import BaseChatMemory | |
from langchain.schema.messages import ( | |
AIMessage, | |
BaseMessage, | |
ChatMessage, | |
FunctionMessage, | |
HumanMessage, | |
SystemMessage, | |
) | |
# Hacky way of importing other directory modules. | |
# This is done this way instead of src.evaluation.postprocessing_utils | |
# Because this file is used for inference where there is not the same folder structure | |
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "evaluation")) | |
def get_buffer_string_inst( | |
messages: Sequence[BaseMessage], | |
human_prefix: str = "[INST]", | |
ai_prefix: str = "[/INST]", | |
) -> str: | |
"""Convert sequence of Messages to strings and concatenate them into one string. | |
Args: | |
messages: Messages to be converted to strings. | |
human_prefix: The prefix to prepend to contents of HumanMessages. | |
ai_prefix: THe prefix to prepend to contents of AIMessages. | |
Returns: | |
A single string concatenation of all input messages. | |
Example: | |
.. code-block:: python | |
from langchain.schema import AIMessage, HumanMessage | |
messages = [ | |
HumanMessage(content="Hi, how are you?"), | |
AIMessage(content="Good, how are you?"), | |
] | |
get_buffer_string(messages) | |
# -> "Human: Hi, how are you?\nAI: Good, how are you?" | |
""" | |
string_messages = [] | |
for m in messages: | |
if isinstance(m, HumanMessage): | |
role = human_prefix | |
elif isinstance(m, AIMessage): | |
role = ai_prefix | |
elif isinstance(m, SystemMessage): | |
role = "System" | |
elif isinstance(m, FunctionMessage): | |
role = "Function" | |
elif isinstance(m, ChatMessage): | |
role = m.role | |
else: | |
raise ValueError(f"Got unsupported message type: {m}") | |
message = f"{role} {m.content}" | |
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs: | |
message += f"{m.additional_kwargs['function_call']}" | |
string_messages.append(message) | |
return " ".join(string_messages) | |
class CustomBufferInstructionMemory(BaseChatMemory): | |
memory_key: str = "chat_history" | |
input_key: str = "input" | |
ai_prefix: str = "[/INST]" | |
human_prefix: str = "[INST]" | |
name: str = "Kit" | |
def buffer(self) -> Union[str, List[BaseMessage]]: | |
"""String buffer of memory.""" | |
return self.buffer_as_messages if self.return_messages else self.buffer_as_str | |
def buffer_as_str(self) -> str: | |
"""Exposes the buffer as a string in case return_messages is True.""" | |
messages = self.chat_memory.messages | |
return get_buffer_string_inst( | |
messages, | |
human_prefix=self.human_prefix, | |
ai_prefix=self.ai_prefix, | |
) | |
def buffer_as_messages(self) -> List[BaseMessage]: | |
"""Exposes the buffer as a list of messages in case return_messages is False.""" | |
return self.chat_memory.messages | |
def memory_variables(self) -> List[str]: | |
"""Define the variables we are providing to the prompt.""" | |
return [self.memory_key] | |
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |
"""Load the memory variables, in this case the entity key.""" | |
# Return combined information about entities to put into context. | |
return {self.memory_key: self.buffer} | |
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: | |
"""Save context from this conversation to buffer.""" | |
input_str, output_str = self._get_input_output(inputs, outputs) | |
self.chat_memory.add_user_message(input_str) | |
self.chat_memory.add_ai_message(output_str) | |
class CustomBufferWindowInstructionMemory(CustomBufferInstructionMemory): | |
k: int = 5 | |
def buffer_as_str(self) -> str: | |
"""Exposes the buffer as a string in case return_messages is True.""" | |
messages = self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else [] | |
return get_buffer_string_inst( | |
messages, | |
human_prefix=self.human_prefix, | |
ai_prefix=self.ai_prefix, | |
) | |
def buffer_as_messages(self) -> List[BaseMessage]: | |
"""Exposes the buffer as a list of messages in case return_messages is False.""" | |
return self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else [] | |