SungBeom's picture
Upload folder using huggingface_hub
4a51346
raw
history blame
22.3 kB
# type: ignore
from chromadb.api.types import (
Documents,
Embeddings,
IDs,
Metadatas,
Where,
WhereDocument,
)
from chromadb.db import DB
from chromadb.db.index.hnswlib import Hnswlib, delete_all_indexes
import uuid
import json
from typing import Optional, Sequence, List, Tuple, cast
import clickhouse_connect
from clickhouse_connect.driver.client import Client
from clickhouse_connect import common
import logging
from uuid import UUID
from chromadb.config import System
from overrides import override
import numpy.typing as npt
from chromadb.api.types import Metadata
logger = logging.getLogger(__name__)
COLLECTION_TABLE_SCHEMA = [{"uuid": "UUID"}, {"name": "String"}, {"metadata": "String"}]
EMBEDDING_TABLE_SCHEMA = [
{"collection_uuid": "UUID"},
{"uuid": "UUID"},
{"embedding": "Array(Float64)"},
{"document": "Nullable(String)"},
{"id": "Nullable(String)"},
{"metadata": "Nullable(String)"},
]
def db_array_schema_to_clickhouse_schema(table_schema):
return_str = ""
for element in table_schema:
for k, v in element.items():
return_str += f"{k} {v}, "
return return_str
def db_schema_to_keys() -> List[str]:
keys = []
for element in EMBEDDING_TABLE_SCHEMA:
keys.append(list(element.keys())[0])
return keys
class Clickhouse(DB):
#
# INIT METHODS
#
def __init__(self, system: System):
super().__init__(system)
self._conn = None
self._settings = system.settings
self._settings.require("clickhouse_host")
self._settings.require("clickhouse_port")
def _init_conn(self):
common.set_setting("autogenerate_session_id", False)
self._conn = clickhouse_connect.get_client(
host=self._settings.clickhouse_host,
port=int(self._settings.clickhouse_port),
)
self._create_table_collections(self._conn)
self._create_table_embeddings(self._conn)
def _get_conn(self) -> Client:
if self._conn is None:
self._init_conn()
return self._conn
def _create_table_collections(self, conn):
conn.command(
f"""CREATE TABLE IF NOT EXISTS collections (
{db_array_schema_to_clickhouse_schema(COLLECTION_TABLE_SCHEMA)}
) ENGINE = MergeTree() ORDER BY uuid"""
)
def _create_table_embeddings(self, conn):
conn.command(
f"""CREATE TABLE IF NOT EXISTS embeddings (
{db_array_schema_to_clickhouse_schema(EMBEDDING_TABLE_SCHEMA)}
) ENGINE = MergeTree() ORDER BY collection_uuid"""
)
index_cache = {}
def _index(self, collection_id):
"""Retrieve an HNSW index instance for the given collection"""
if collection_id not in self.index_cache:
coll = self.get_collection_by_id(collection_id)
collection_metadata = coll[2]
index = Hnswlib(
collection_id,
self._settings,
collection_metadata,
self.count(collection_id),
)
self.index_cache[collection_id] = index
return self.index_cache[collection_id]
def _delete_index(self, collection_id):
"""Delete an index from the cache"""
index = self._index(collection_id)
index.delete()
del self.index_cache[collection_id]
#
# UTILITY METHODS
#
@override
def persist(self):
raise NotImplementedError(
"Clickhouse is a persistent database, this method is not needed"
)
@override
def get_collection_uuid_from_name(self, collection_name: str) -> UUID:
res = self._get_conn().query(
f"""
SELECT uuid FROM collections WHERE name = '{collection_name}'
"""
)
return res.result_rows[0][0]
def _create_where_clause(
self,
collection_uuid: str,
ids: Optional[List[str]] = None,
where: Where = {},
where_document: WhereDocument = {},
):
where_clauses: List[str] = []
self._format_where(where, where_clauses)
if len(where_document) > 0:
where_document_clauses = []
self._format_where_document(where_document, where_document_clauses)
where_clauses.extend(where_document_clauses)
if ids is not None:
where_clauses.append(f" id IN {tuple(ids)}")
where_clauses.append(f"collection_uuid = '{collection_uuid}'")
where_str = " AND ".join(where_clauses)
where_str = f"WHERE {where_str}"
return where_str
#
# COLLECTION METHODS
#
@override
def create_collection(
self,
name: str,
metadata: Optional[Metadata] = None,
get_or_create: bool = False,
) -> Sequence:
# poor man's unique constraint
dupe_check = self.get_collection(name)
if len(dupe_check) > 0:
if get_or_create:
if dupe_check[0][2] != metadata:
self.update_collection(
dupe_check[0][0], new_name=name, new_metadata=metadata
)
dupe_check = self.get_collection(name)
logger.info(
f"collection with name {name} already exists, returning existing collection"
)
return dupe_check
else:
raise ValueError(f"Collection with name {name} already exists")
collection_uuid = uuid.uuid4()
data_to_insert = [[collection_uuid, name, json.dumps(metadata)]]
self._get_conn().insert(
"collections", data_to_insert, column_names=["uuid", "name", "metadata"]
)
return [[collection_uuid, name, metadata]]
@override
def get_collection(self, name: str) -> Sequence:
res = (
self._get_conn()
.query(
f"""
SELECT * FROM collections WHERE name = '{name}'
"""
)
.result_rows
)
# json.loads the metadata
return [[x[0], x[1], json.loads(x[2])] for x in res]
def get_collection_by_id(self, collection_uuid: str):
res = (
self._get_conn()
.query(
f"""
SELECT * FROM collections WHERE uuid = '{collection_uuid}'
"""
)
.result_rows
)
# json.loads the metadata
return [[x[0], x[1], json.loads(x[2])] for x in res][0]
@override
def list_collections(self) -> Sequence:
res = self._get_conn().query("SELECT * FROM collections").result_rows
return [[x[0], x[1], json.loads(x[2])] for x in res]
@override
def update_collection(
self,
id: UUID,
new_name: Optional[str] = None,
new_metadata: Optional[Metadata] = None,
):
if new_name is not None:
dupe_check = self.get_collection(new_name)
if len(dupe_check) > 0 and dupe_check[0][0] != id:
raise ValueError(f"Collection with name {new_name} already exists")
self._get_conn().command(
"ALTER TABLE collections UPDATE name = %(new_name)s WHERE uuid = %(uuid)s",
parameters={"new_name": new_name, "uuid": id},
)
if new_metadata is not None:
self._get_conn().command(
"ALTER TABLE collections UPDATE metadata = %(new_metadata)s WHERE uuid = %(uuid)s",
parameters={"new_metadata": json.dumps(new_metadata), "uuid": id},
)
@override
def delete_collection(self, name: str):
collection_uuid = self.get_collection_uuid_from_name(name)
self._get_conn().command(
f"""
DELETE FROM embeddings WHERE collection_uuid = '{collection_uuid}'
"""
)
self._delete_index(collection_uuid)
self._get_conn().command(
f"""
DELETE FROM collections WHERE name = '{name}'
"""
)
#
# ITEM METHODS
#
@override
def add(self, collection_uuid, embeddings, metadatas, documents, ids) -> List[UUID]:
data_to_insert = [
[
collection_uuid,
uuid.uuid4(),
embedding,
json.dumps(metadatas[i]) if metadatas else None,
documents[i] if documents else None,
ids[i],
]
for i, embedding in enumerate(embeddings)
]
column_names = [
"collection_uuid",
"uuid",
"embedding",
"metadata",
"document",
"id",
]
self._get_conn().insert("embeddings", data_to_insert, column_names=column_names)
return [x[1] for x in data_to_insert] # return uuids
def _update(
self,
collection_uuid,
ids: IDs,
embeddings: Optional[Embeddings],
metadatas: Optional[Metadatas],
documents: Optional[Documents],
):
updates = []
parameters = {}
for i in range(len(ids)):
update_fields = []
parameters[f"i{i}"] = ids[i]
if embeddings is not None:
update_fields.append(f"embedding = %(e{i})s")
parameters[f"e{i}"] = embeddings[i]
if metadatas is not None:
update_fields.append(f"metadata = %(m{i})s")
parameters[f"m{i}"] = json.dumps(metadatas[i])
if documents is not None:
update_fields.append(f"document = %(d{i})s")
parameters[f"d{i}"] = documents[i]
update_statement = f"""
UPDATE
{",".join(update_fields)}
WHERE
id = %(i{i})s AND
collection_uuid = '{collection_uuid}'{"" if i == len(ids) - 1 else ","}
"""
updates.append(update_statement)
update_clauses = ("").join(updates)
self._get_conn().command(
f"ALTER TABLE embeddings {update_clauses}", parameters=parameters
)
@override
def update(
self,
collection_uuid,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
) -> bool:
# Verify all IDs exist
existing_items = self.get(collection_uuid=collection_uuid, ids=ids)
if len(existing_items) != len(ids):
raise ValueError(
f"Could not find {len(ids) - len(existing_items)} items for update"
)
# Update the db
self._update(collection_uuid, ids, embeddings, metadatas, documents)
# Update the index
if embeddings is not None:
# `get` current returns items in arbitrary order.
# TODO if we fix `get`, we can remove this explicit mapping.
uuid_mapping = {r[4]: r[1] for r in existing_items}
update_uuids = [uuid_mapping[id] for id in ids]
index = self._index(collection_uuid)
index.add(update_uuids, embeddings, update=True)
def _get(self, where={}, columns: Optional[List] = None):
select_columns = db_schema_to_keys() if columns is None else columns
val = (
self._get_conn()
.query(f"""SELECT {",".join(select_columns)} FROM embeddings {where}""")
.result_rows
)
for i in range(len(val)):
# We know val has index abilities, so cast it for typechecker
val = cast(list, val)
val[i] = list(val[i])
# json.load the metadata
if "metadata" in select_columns:
metadata_column_index = select_columns.index("metadata")
db_metadata = val[i][metadata_column_index]
val[i][metadata_column_index] = (
json.loads(db_metadata) if db_metadata else None
)
return val
def _format_where(self, where, result):
for key, value in where.items():
def has_key_and(clause):
return f"(JSONHas(metadata,'{key}') = 1 AND {clause})"
# Shortcut for $eq
if type(value) == str:
result.append(
has_key_and(f" JSONExtractString(metadata,'{key}') = '{value}'")
)
elif type(value) == int:
result.append(
has_key_and(f" JSONExtractInt(metadata,'{key}') = {value}")
)
elif type(value) == float:
result.append(
has_key_and(f" JSONExtractFloat(metadata,'{key}') = {value}")
)
# Operator expression
elif type(value) == dict:
operator, operand = list(value.items())[0]
if operator == "$gt":
return result.append(
has_key_and(f" JSONExtractFloat(metadata,'{key}') > {operand}")
)
elif operator == "$lt":
return result.append(
has_key_and(f" JSONExtractFloat(metadata,'{key}') < {operand}")
)
elif operator == "$gte":
return result.append(
has_key_and(f" JSONExtractFloat(metadata,'{key}') >= {operand}")
)
elif operator == "$lte":
return result.append(
has_key_and(f" JSONExtractFloat(metadata,'{key}') <= {operand}")
)
elif operator == "$ne":
if type(operand) == str:
return result.append(
has_key_and(
f" JSONExtractString(metadata,'{key}') != '{operand}'"
)
)
return result.append(
has_key_and(f" JSONExtractFloat(metadata,'{key}') != {operand}")
)
elif operator == "$eq":
if type(operand) == str:
return result.append(
has_key_and(
f" JSONExtractString(metadata,'{key}') = '{operand}'"
)
)
return result.append(
has_key_and(f" JSONExtractFloat(metadata,'{key}') = {operand}")
)
else:
raise ValueError(
f"Expected one of $gt, $lt, $gte, $lte, $ne, $eq, got {operator}"
)
elif type(value) == list:
all_subresults = []
for subwhere in value:
subresults = []
self._format_where(subwhere, subresults)
all_subresults.append(subresults[0])
if key == "$or":
result.append(f"({' OR '.join(all_subresults)})")
elif key == "$and":
result.append(f"({' AND '.join(all_subresults)})")
else:
raise ValueError(f"Expected one of $or, $and, got {key}")
def _format_where_document(self, where_document, results):
operator = list(where_document.keys())[0]
if operator == "$contains":
results.append(f"position(document, '{where_document[operator]}') > 0")
elif operator == "$and" or operator == "$or":
all_subresults = []
for subwhere in where_document[operator]:
subresults = []
self._format_where_document(subwhere, subresults)
all_subresults.append(subresults[0])
if operator == "$or":
results.append(f"({' OR '.join(all_subresults)})")
if operator == "$and":
results.append(f"({' AND '.join(all_subresults)})")
else:
raise ValueError(f"Expected one of $contains, $and, $or, got {operator}")
@override
def get(
self,
where: Where = {},
collection_name: Optional[str] = None,
collection_uuid: Optional[UUID] = None,
ids: Optional[IDs] = None,
sort: Optional[str] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: WhereDocument = {},
columns: Optional[List[str]] = None,
) -> Sequence:
if collection_name is None and collection_uuid is None:
raise TypeError(
"Arguments collection_name and collection_uuid cannot both be None"
)
if collection_name is not None:
collection_uuid = self.get_collection_uuid_from_name(collection_name)
where_str = self._create_where_clause(
# collection_uuid must be defined at this point, cast it for typechecker
cast(str, collection_uuid),
ids=ids,
where=where,
where_document=where_document,
)
if sort is not None:
where_str += f" ORDER BY {sort}"
else:
where_str += " ORDER BY collection_uuid" # stable ordering
if limit is not None or isinstance(limit, int):
where_str += f" LIMIT {limit}"
if offset is not None or isinstance(offset, int):
where_str += f" OFFSET {offset}"
val = self._get(where=where_str, columns=columns)
return val
@override
def count(self, collection_id: UUID) -> int:
where_string = f"WHERE collection_uuid = '{collection_id}'"
return (
self._get_conn()
.query(f"SELECT COUNT() FROM embeddings {where_string}")
.result_rows[0][0]
)
def _delete(self, where_str: Optional[str] = None) -> List:
deleted_uuids = (
self._get_conn()
.query(f"""SELECT uuid FROM embeddings {where_str}""")
.result_rows
)
self._get_conn().command(
f"""
DELETE FROM
embeddings
{where_str}
"""
)
return [res[0] for res in deleted_uuids] if len(deleted_uuids) > 0 else []
@override
def delete(
self,
where: Where = {},
collection_uuid: Optional[UUID] = None,
ids: Optional[IDs] = None,
where_document: WhereDocument = {},
) -> List[str]:
where_str = self._create_where_clause(
# collection_uuid must be defined at this point, cast it for typechecker
cast(str, collection_uuid),
ids=ids,
where=where,
where_document=where_document,
)
deleted_uuids = self._delete(where_str)
index = self._index(collection_uuid)
index.delete_from_index(deleted_uuids)
return deleted_uuids
@override
def get_by_ids(
self, uuids: List[UUID], columns: Optional[List[str]] = None
) -> Sequence:
columns = columns + ["uuid"] if columns else ["uuid"]
select_columns = db_schema_to_keys() if columns is None else columns
response = (
self._get_conn()
.query(
f"""
SELECT {",".join(select_columns)} FROM embeddings WHERE uuid IN ({[id.hex for id in uuids]})
"""
)
.result_rows
)
# sort db results by the order of the uuids
response = sorted(response, key=lambda obj: uuids.index(obj[len(columns) - 1]))
return response
@override
def get_nearest_neighbors(
self,
collection_uuid: UUID,
where: Where = {},
embeddings: Optional[Embeddings] = None,
n_results: int = 10,
where_document: WhereDocument = {},
) -> Tuple[List[List[UUID]], npt.NDArray]:
# Either the collection name or the collection uuid must be provided
if collection_uuid is None:
raise TypeError("Argument collection_uuid cannot be None")
if len(where) != 0 or len(where_document) != 0:
results = self.get(
collection_uuid=collection_uuid,
where=where,
where_document=where_document,
)
if len(results) > 0:
ids = [x[1] for x in results]
else:
# No results found, return empty lists
return [[] for _ in range(len(embeddings))], [
[] for _ in range(len(embeddings))
]
else:
ids = None
index = self._index(collection_uuid)
uuids, distances = index.get_nearest_neighbors(embeddings, n_results, ids)
return uuids, distances
@override
def create_index(self, collection_uuid: UUID):
"""Create an index for a collection_uuid and optionally scoped to a dataset.
Args:
collection_uuid (str): The collection_uuid to create an index for
dataset (str, optional): The dataset to scope the index to. Defaults to None.
Returns:
None
"""
get = self.get(collection_uuid=collection_uuid)
uuids = [x[1] for x in get]
embeddings = [x[2] for x in get]
index = self._index(collection_uuid)
index.add(uuids, embeddings)
@override
def add_incremental(
self, collection_uuid: UUID, ids: List[UUID], embeddings: Embeddings
) -> None:
index = self._index(collection_uuid)
index.add(ids, embeddings)
def reset_indexes(self):
delete_all_indexes(self._settings)
self.index_cache = {}
@override
def reset(self):
conn = self._get_conn()
conn.command("DROP TABLE collections")
conn.command("DROP TABLE embeddings")
self._create_table_collections(conn)
self._create_table_embeddings(conn)
self.reset_indexes()
@override
def raw_sql(self, raw_sql):
return self._get_conn().query(raw_sql).result_rows