Spaces:
Runtime error
Runtime error
from typing import Optional, Sequence, Any, Tuple, cast, Dict, Union, Set | |
from uuid import UUID | |
from overrides import override | |
from pypika import Table, Column | |
from itertools import groupby | |
from chromadb.config import System | |
from chromadb.db.base import ( | |
Cursor, | |
SqlDB, | |
ParameterValue, | |
get_sql, | |
NotFoundError, | |
UniqueConstraintError, | |
) | |
from chromadb.db.system import SysDB | |
from chromadb.types import ( | |
OptionalArgument, | |
Segment, | |
Metadata, | |
Collection, | |
SegmentScope, | |
Unspecified, | |
UpdateMetadata, | |
) | |
class SqlSysDB(SqlDB, SysDB): | |
def __init__(self, system: System): | |
super().__init__(system) | |
def create_segment(self, segment: Segment) -> None: | |
with self.tx() as cur: | |
segments = Table("segments") | |
insert_segment = ( | |
self.querybuilder() | |
.into(segments) | |
.columns( | |
segments.id, | |
segments.type, | |
segments.scope, | |
segments.topic, | |
segments.collection, | |
) | |
.insert( | |
ParameterValue(self.uuid_to_db(segment["id"])), | |
ParameterValue(segment["type"]), | |
ParameterValue(segment["scope"].value), | |
ParameterValue(segment["topic"]), | |
ParameterValue(self.uuid_to_db(segment["collection"])), | |
) | |
) | |
sql, params = get_sql(insert_segment, self.parameter_format()) | |
try: | |
cur.execute(sql, params) | |
except self.unique_constraint_error() as e: | |
raise UniqueConstraintError( | |
f"Segment {segment['id']} already exists" | |
) from e | |
metadata_t = Table("segment_metadata") | |
if segment["metadata"]: | |
self._insert_metadata( | |
cur, | |
metadata_t, | |
metadata_t.segment_id, | |
segment["id"], | |
segment["metadata"], | |
) | |
def create_collection(self, collection: Collection) -> None: | |
"""Create a new collection""" | |
with self.tx() as cur: | |
collections = Table("collections") | |
insert_collection = ( | |
self.querybuilder() | |
.into(collections) | |
.columns(collections.id, collections.topic, collections.name) | |
.insert( | |
ParameterValue(self.uuid_to_db(collection["id"])), | |
ParameterValue(collection["topic"]), | |
ParameterValue(collection["name"]), | |
) | |
) | |
sql, params = get_sql(insert_collection, self.parameter_format()) | |
try: | |
cur.execute(sql, params) | |
except self.unique_constraint_error() as e: | |
raise UniqueConstraintError( | |
f"Collection {collection['id']} already exists" | |
) from e | |
metadata_t = Table("collection_metadata") | |
if collection["metadata"]: | |
self._insert_metadata( | |
cur, | |
metadata_t, | |
metadata_t.collection_id, | |
collection["id"], | |
collection["metadata"], | |
) | |
def get_segments( | |
self, | |
id: Optional[UUID] = None, | |
type: Optional[str] = None, | |
scope: Optional[SegmentScope] = None, | |
topic: Optional[str] = None, | |
collection: Optional[UUID] = None, | |
) -> Sequence[Segment]: | |
segments_t = Table("segments") | |
metadata_t = Table("segment_metadata") | |
q = ( | |
self.querybuilder() | |
.from_(segments_t) | |
.select( | |
segments_t.id, | |
segments_t.type, | |
segments_t.scope, | |
segments_t.topic, | |
segments_t.collection, | |
metadata_t.key, | |
metadata_t.str_value, | |
metadata_t.int_value, | |
metadata_t.float_value, | |
) | |
.left_join(metadata_t) | |
.on(segments_t.id == metadata_t.segment_id) | |
.orderby(segments_t.id) | |
) | |
if id: | |
q = q.where(segments_t.id == ParameterValue(self.uuid_to_db(id))) | |
if type: | |
q = q.where(segments_t.type == ParameterValue(type)) | |
if scope: | |
q = q.where(segments_t.scope == ParameterValue(scope.value)) | |
if topic: | |
q = q.where(segments_t.topic == ParameterValue(topic)) | |
if collection: | |
q = q.where( | |
segments_t.collection == ParameterValue(self.uuid_to_db(collection)) | |
) | |
with self.tx() as cur: | |
sql, params = get_sql(q, self.parameter_format()) | |
rows = cur.execute(sql, params).fetchall() | |
by_segment = groupby(rows, lambda r: cast(object, r[0])) | |
segments = [] | |
for segment_id, segment_rows in by_segment: | |
id = self.uuid_from_db(str(segment_id)) | |
rows = list(segment_rows) | |
type = str(rows[0][1]) | |
scope = SegmentScope(str(rows[0][2])) | |
topic = str(rows[0][3]) if rows[0][3] else None | |
collection = self.uuid_from_db(rows[0][4]) if rows[0][4] else None | |
metadata = self._metadata_from_rows(rows) | |
segments.append( | |
Segment( | |
id=cast(UUID, id), | |
type=type, | |
scope=scope, | |
topic=topic, | |
collection=collection, | |
metadata=metadata, | |
) | |
) | |
return segments | |
def get_collections( | |
self, | |
id: Optional[UUID] = None, | |
topic: Optional[str] = None, | |
name: Optional[str] = None, | |
) -> Sequence[Collection]: | |
"""Get collections by name, embedding function and/or metadata""" | |
collections_t = Table("collections") | |
metadata_t = Table("collection_metadata") | |
q = ( | |
self.querybuilder() | |
.from_(collections_t) | |
.select( | |
collections_t.id, | |
collections_t.name, | |
collections_t.topic, | |
metadata_t.key, | |
metadata_t.str_value, | |
metadata_t.int_value, | |
metadata_t.float_value, | |
) | |
.left_join(metadata_t) | |
.on(collections_t.id == metadata_t.collection_id) | |
.orderby(collections_t.id) | |
) | |
if id: | |
q = q.where(collections_t.id == ParameterValue(self.uuid_to_db(id))) | |
if topic: | |
q = q.where(collections_t.topic == ParameterValue(topic)) | |
if name: | |
q = q.where(collections_t.name == ParameterValue(name)) | |
with self.tx() as cur: | |
sql, params = get_sql(q, self.parameter_format()) | |
rows = cur.execute(sql, params).fetchall() | |
by_collection = groupby(rows, lambda r: cast(object, r[0])) | |
collections = [] | |
for collection_id, collection_rows in by_collection: | |
id = self.uuid_from_db(str(collection_id)) | |
rows = list(collection_rows) | |
name = str(rows[0][1]) | |
topic = str(rows[0][2]) | |
metadata = self._metadata_from_rows(rows) | |
collections.append( | |
Collection( | |
id=cast(UUID, id), | |
topic=topic, | |
name=name, | |
metadata=metadata, | |
) | |
) | |
return collections | |
def delete_segment(self, id: UUID) -> None: | |
"""Delete a segment from the SysDB""" | |
t = Table("segments") | |
q = ( | |
self.querybuilder() | |
.from_(t) | |
.where(t.id == ParameterValue(self.uuid_to_db(id))) | |
.delete() | |
) | |
with self.tx() as cur: | |
# no need for explicit del from metadata table because of ON DELETE CASCADE | |
sql, params = get_sql(q, self.parameter_format()) | |
sql = sql + " RETURNING id" | |
result = cur.execute(sql, params).fetchone() | |
if not result: | |
raise NotFoundError(f"Segment {id} not found") | |
def delete_collection(self, id: UUID) -> None: | |
"""Delete a topic and all associated segments from the SysDB""" | |
t = Table("collections") | |
q = ( | |
self.querybuilder() | |
.from_(t) | |
.where(t.id == ParameterValue(self.uuid_to_db(id))) | |
.delete() | |
) | |
with self.tx() as cur: | |
# no need for explicit del from metadata table because of ON DELETE CASCADE | |
sql, params = get_sql(q, self.parameter_format()) | |
sql = sql + " RETURNING id" | |
result = cur.execute(sql, params).fetchone() | |
if not result: | |
raise NotFoundError(f"Collection {id} not found") | |
def update_segment( | |
self, | |
id: UUID, | |
topic: OptionalArgument[Optional[str]] = Unspecified(), | |
collection: OptionalArgument[Optional[UUID]] = Unspecified(), | |
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(), | |
) -> None: | |
segments_t = Table("segments") | |
metadata_t = Table("segment_metadata") | |
q = ( | |
self.querybuilder() | |
.update(segments_t) | |
.where(segments_t.id == ParameterValue(self.uuid_to_db(id))) | |
) | |
if not topic == Unspecified(): | |
q = q.set(segments_t.topic, ParameterValue(topic)) | |
if not collection == Unspecified(): | |
collection = cast(Optional[UUID], collection) | |
q = q.set( | |
segments_t.collection, ParameterValue(self.uuid_to_db(collection)) | |
) | |
with self.tx() as cur: | |
sql, params = get_sql(q, self.parameter_format()) | |
if sql: # pypika emits a blank string if nothing to do | |
cur.execute(sql, params) | |
if metadata is None: | |
q = ( | |
self.querybuilder() | |
.from_(metadata_t) | |
.where(metadata_t.segment_id == ParameterValue(self.uuid_to_db(id))) | |
.delete() | |
) | |
sql, params = get_sql(q, self.parameter_format()) | |
cur.execute(sql, params) | |
elif metadata != Unspecified(): | |
metadata = cast(UpdateMetadata, metadata) | |
metadata = cast(UpdateMetadata, metadata) | |
self._insert_metadata( | |
cur, | |
metadata_t, | |
metadata_t.segment_id, | |
id, | |
metadata, | |
set(metadata.keys()), | |
) | |
def update_collection( | |
self, | |
id: UUID, | |
topic: OptionalArgument[Optional[str]] = Unspecified(), | |
name: OptionalArgument[str] = Unspecified(), | |
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(), | |
) -> None: | |
collections_t = Table("collections") | |
metadata_t = Table("collection_metadata") | |
q = ( | |
self.querybuilder() | |
.update(collections_t) | |
.where(collections_t.id == ParameterValue(self.uuid_to_db(id))) | |
) | |
if not topic == Unspecified(): | |
q = q.set(collections_t.topic, ParameterValue(topic)) | |
if not name == Unspecified(): | |
q = q.set(collections_t.name, ParameterValue(name)) | |
with self.tx() as cur: | |
sql, params = get_sql(q, self.parameter_format()) | |
if sql: # pypika emits a blank string if nothing to do | |
cur.execute(sql, params) | |
if metadata is None: | |
q = ( | |
self.querybuilder() | |
.from_(metadata_t) | |
.where( | |
metadata_t.collection_id == ParameterValue(self.uuid_to_db(id)) | |
) | |
.delete() | |
) | |
sql, params = get_sql(q, self.parameter_format()) | |
cur.execute(sql, params) | |
elif metadata != Unspecified(): | |
metadata = cast(UpdateMetadata, metadata) | |
self._insert_metadata( | |
cur, | |
metadata_t, | |
metadata_t.collection_id, | |
id, | |
metadata, | |
set(metadata.keys()), | |
) | |
def _metadata_from_rows( | |
self, rows: Sequence[Tuple[Any, ...]] | |
) -> Optional[Metadata]: | |
"""Given SQL rows, return a metadata map (assuming that the last four columns | |
are the key, str_value, int_value & float_value)""" | |
metadata: Dict[str, Union[str, int, float]] = {} | |
for row in rows: | |
key = str(row[-4]) | |
if row[-3]: | |
metadata[key] = str(row[-3]) | |
elif row[-2]: | |
metadata[key] = int(row[-2]) | |
elif row[-1]: | |
metadata[key] = float(row[-1]) | |
return metadata or None | |
def _insert_metadata( | |
self, | |
cur: Cursor, | |
table: Table, | |
id_col: Column, | |
id: UUID, | |
metadata: UpdateMetadata, | |
clear_keys: Optional[Set[str]] = None, | |
) -> None: | |
# It would be cleaner to use something like ON CONFLICT UPDATE here But that is | |
# very difficult to do in a portable way (e.g sqlite and postgres have | |
# completely different sytnax) | |
if clear_keys: | |
q = ( | |
self.querybuilder() | |
.from_(table) | |
.where(id_col == ParameterValue(self.uuid_to_db(id))) | |
.where(table.key.isin([ParameterValue(k) for k in clear_keys])) | |
.delete() | |
) | |
sql, params = get_sql(q, self.parameter_format()) | |
cur.execute(sql, params) | |
q = ( | |
self.querybuilder() | |
.into(table) | |
.columns( | |
id_col, table.key, table.str_value, table.int_value, table.float_value | |
) | |
) | |
sql_id = self.uuid_to_db(id) | |
for k, v in metadata.items(): | |
if isinstance(v, str): | |
q = q.insert( | |
ParameterValue(sql_id), | |
ParameterValue(k), | |
ParameterValue(v), | |
None, | |
None, | |
) | |
elif isinstance(v, int): | |
q = q.insert( | |
ParameterValue(sql_id), | |
ParameterValue(k), | |
None, | |
ParameterValue(v), | |
None, | |
) | |
elif isinstance(v, float): | |
q = q.insert( | |
ParameterValue(sql_id), | |
ParameterValue(k), | |
None, | |
None, | |
ParameterValue(v), | |
) | |
elif v is None: | |
continue | |
sql, params = get_sql(q, self.parameter_format()) | |
if sql: | |
cur.execute(sql, params) | |