|
from copy import deepcopy |
|
from typing import Dict, List, Any, Optional |
|
|
|
import faiss |
|
|
|
from langchain.docstore import InMemoryDocstore |
|
from langchain.embeddings import OpenAIEmbeddings |
|
from langchain.schema import Document |
|
from langchain.vectorstores import Chroma, FAISS |
|
from langchain.vectorstores.base import VectorStoreRetriever |
|
from aiflows.messages import FlowMessage |
|
from aiflows.base_flows import AtomicFlow |
|
import hydra |
|
|
|
|
|
class VectorStoreFlow(AtomicFlow): |
|
""" A flow that uses the VectorStore model to write and read memories stored in a database (see VectorStoreFlow.yaml for the default configuration) |
|
|
|
*Configuration Parameters*: |
|
|
|
- `name` (str): The name of the flow. Default: "VecotrStoreFlow" |
|
- `description` (str): A description of the flow. This description is used to generate the help message of the flow. |
|
Default: "VectorStoreFlow" |
|
- `backend` (Dict[str, Any]): The configuration of the backend which is used to fetch api keys. Default: LiteLLMBackend with the |
|
default parameters of LiteLLMBackend (see flows.backends.LiteLLMBackend). Except for the following parameter whose default value is overwritten: |
|
- `api_infos` (List[Dict[str, Any]]): The list of api infos. Default: No default value, this parameter is required. |
|
- `model_name` (str): The name of the model. Default: "". In the current implementation, this parameter is not used. |
|
- `type` (str): The type of the vector store. It can be "chroma" or "faiss". Default: "chroma" |
|
- `embedding_size` (int): The size of the embeddings (only for faiss). Default: 1536 |
|
- `retriever_config` (Dict[str, Any]): The configuration of the retriever. Default: empty dictionary |
|
- Other parameters are inherited from the default configuration of AtomicFlow (see AtomicFlow) |
|
|
|
*Input Interface*: |
|
|
|
- `operation` (str): The operation to perform. It can be "write" or "read". |
|
- `content` (str or List[str]): The content to write or read. If operation is "write", it must be a string or a list of strings. If operation is "read", it must be a string. |
|
|
|
*Output Interface*: |
|
|
|
- `retrieved` (str or List[str]): The retrieved content. If operation is "write", it is an empty string. If operation is "read", it is a string or a list of strings. |
|
|
|
:param backend: The backend of the flow (used to retrieve the API key) |
|
:type backend: LiteLLMBackend |
|
:param vector_db: The vector store retriever |
|
:type vector_db: VectorStoreRetriever |
|
:param type: The type of the vector store |
|
:type type: str |
|
:param \**kwargs: Additional arguments to pass to the flow. See :class:`aiflows.base_flows.AtomicFlow` for more details. |
|
""" |
|
REQUIRED_KEYS_CONFIG = ["type"] |
|
|
|
vector_db: VectorStoreRetriever |
|
|
|
def __init__(self, backend,vector_db, **kwargs): |
|
super().__init__(**kwargs) |
|
self.vector_db = vector_db |
|
self.backend = backend |
|
|
|
|
|
@classmethod |
|
def _set_up_backend(cls, config): |
|
""" This instantiates the backend of the flow from a configuration file. |
|
|
|
:param config: The configuration of the backend. |
|
:type config: Dict[str, Any] |
|
:return: The backend of the flow. |
|
:rtype: Dict[str, LiteLLMBackend] |
|
""" |
|
kwargs = {} |
|
|
|
kwargs["backend"] = \ |
|
hydra.utils.instantiate(config['backend'], _convert_="partial") |
|
|
|
return kwargs |
|
|
|
|
|
@classmethod |
|
def _set_up_retriever(cls, api_information,config: Dict[str, Any]) -> Dict[str, Any]: |
|
""" This method sets up the retriever of the vector store retriever. |
|
|
|
:param config: The configuration of the vector store retriever. |
|
:type config: Dict[str, Any] |
|
:param api_information: The api information of the vector store retriever. |
|
:type api_information: ApiInfo |
|
:return: The vector store retriever. |
|
:rtype: Dict[str, VectorStoreRetriever] |
|
""" |
|
|
|
embeddings = OpenAIEmbeddings(openai_api_key=api_information.api_key) |
|
kwargs = {} |
|
|
|
vs_type = config["type"] |
|
|
|
if vs_type == "chroma": |
|
vectorstore = Chroma(config["name"], embedding_function=embeddings) |
|
elif vs_type == "faiss": |
|
index = faiss.IndexFlatL2(config.get("embedding_size", 1536)) |
|
vectorstore = FAISS( |
|
embedding_function=embeddings.embed_query, |
|
index=index, |
|
docstore=InMemoryDocstore({}), |
|
index_to_docstore_id={} |
|
) |
|
else: |
|
raise NotImplementedError(f"Vector store '{vs_type}' not implemented") |
|
|
|
kwargs["vector_db"] = vectorstore.as_retriever(**config.get("retriever_config", {})) |
|
|
|
return kwargs |
|
|
|
@classmethod |
|
def instantiate_from_config(cls, config: Dict[str, Any]): |
|
""" This method instantiates the flow from a configuration file |
|
|
|
:param config: The configuration of the flow. |
|
:type config: Dict[str, Any] |
|
:return: The instantiated flow. |
|
:rtype: VectorStoreFlow |
|
""" |
|
flow_config = deepcopy(config) |
|
|
|
kwargs = {"flow_config": flow_config} |
|
|
|
|
|
kwargs.update(cls._set_up_backend(flow_config)) |
|
api_information = kwargs["backend"].get_key() |
|
|
|
kwargs.update(cls._set_up_retriever(api_information,flow_config)) |
|
|
|
return cls(**kwargs) |
|
|
|
@staticmethod |
|
def package_documents(documents: List[str]) -> List[Document]: |
|
""" This method packages the documents in a list of Documents. |
|
|
|
:param documents: The documents to package. |
|
:type documents: List[str] |
|
:return: The packaged documents. |
|
:rtype: List[Document] |
|
""" |
|
|
|
return [Document(page_content=doc, metadata={"": ""}) for doc in documents] |
|
|
|
def run(self, input_message: FlowMessage): |
|
""" This method runs the flow. It either writes or reads memories from the database. |
|
|
|
:param input_message: The input data of the flow. |
|
:type input_message: FlowMessage |
|
""" |
|
response = {} |
|
input_data = input_message.data |
|
operation = input_data["operation"] |
|
assert operation in ["write", "read"], f"Operation '{operation}' not supported" |
|
|
|
content = input_data["content"] |
|
if operation == "read": |
|
assert isinstance(content, str), f"Content must be a string, got {type(content)}" |
|
query = content |
|
retrieved_documents = self.vector_db.get_relevant_documents(query) |
|
response["retrieved"] = [doc.page_content for doc in retrieved_documents] |
|
elif operation == "write": |
|
if isinstance(content, str): |
|
content = [content] |
|
assert isinstance(content, list), f"Content must be a list of strings, got {type(content)}" |
|
documents = content |
|
documents = self.package_documents(documents) |
|
self.vector_db.add_documents(documents) |
|
response["retrieved"] = "" |
|
|
|
reply = self._package_output_message( |
|
input_message = input_message, |
|
response = response |
|
) |
|
self.reply_to_message(reply = reply, to = input_message) |
|
|