Spaces:
Runtime error
Runtime error
from chromadb.segment import ( | |
SegmentImplementation, | |
SegmentManager, | |
MetadataReader, | |
VectorReader, | |
) | |
from chromadb.config import System, get_class | |
from chromadb.db.system import SysDB | |
from overrides import override | |
from enum import Enum | |
from chromadb.types import Collection, Segment, SegmentScope | |
from typing import Dict, Set, Type, TypeVar | |
from uuid import UUID, uuid4 | |
from collections import defaultdict | |
import re | |
class SegmentType(Enum): | |
SQLITE = "urn:chroma:segment/metadata/sqlite" | |
HNSW_LOCAL_MEMORY = "urn:chroma:segment/vector/hnsw-local-memory" | |
SEGMENT_TYPE_IMPLS = { | |
SegmentType.SQLITE: "chromadb.segment.impl.sqlite.SqliteMetadataReader", | |
SegmentType.HNSW_LOCAL_MEMORY: "chromadb.segment.impl.vector.local_hnsw.LocalHnswSegment", | |
} | |
PROPAGATE_METADATA = { | |
SegmentType.HNSW_LOCAL_MEMORY: [r"^hnsw:.*"], | |
} | |
class LocalSegmentManager(SegmentManager): | |
_sysdb: SysDB | |
_system: System | |
_instances: Dict[UUID, SegmentImplementation] | |
_segment_cache: Dict[UUID, Dict[SegmentScope, Segment]] | |
def __init__(self, system: System): | |
self._sysdb = self.require(SysDB) | |
self._system = system | |
self._instances = {} | |
self._segment_cache = defaultdict(dict) | |
super().__init__(system) | |
def start(self) -> None: | |
super().start() | |
def stop(self) -> None: | |
super().stop() | |
def reset(self) -> None: | |
self._instances = {} | |
self._segment_cache = defaultdict(dict) | |
super().reset() | |
def create_segments(self, collection: Collection) -> Set[Segment]: | |
vector_segment = _segment( | |
SegmentType.HNSW_LOCAL_MEMORY, SegmentScope.VECTOR, collection | |
) | |
metadata_segment = _segment( | |
SegmentType.SQLITE, SegmentScope.METADATA, collection | |
) | |
self._sysdb.create_segment(vector_segment) | |
self._sysdb.create_segment(metadata_segment) | |
return {vector_segment, metadata_segment} | |
def delete_segments(self, collection_id: UUID) -> None: | |
segments = self._sysdb.get_segments(collection=collection_id) | |
for segment in segments: | |
self._sysdb.delete_segment(segment["id"]) | |
del self._instances[segment["id"]] | |
del self._segment_cache[collection_id][segment["scope"]] | |
del self._segment_cache[collection_id] | |
T = TypeVar("T", bound="SegmentImplementation") | |
def get_segment(self, collection_id: UUID, type: Type[T]) -> SegmentImplementation: | |
if type == Type[MetadataReader]: | |
scope = SegmentScope.METADATA | |
elif type == Type[VectorReader]: | |
scope = SegmentScope.VECTOR | |
else: | |
raise ValueError(f"Invalid segment type: {type}") | |
if scope not in self._segment_cache[collection_id]: | |
segments = self._sysdb.get_segments(collection=collection_id, scope=scope) | |
known_types = set([k.value for k in SEGMENT_TYPE_IMPLS.keys()]) | |
# Get the first segment of a known type | |
segment = next(filter(lambda s: s["type"] in known_types, segments)) | |
self._segment_cache[collection_id][scope] = segment | |
return self._instance(self._segment_cache[collection_id][scope]) | |
def _instance(self, segment: Segment) -> SegmentImplementation: | |
if segment["id"] not in self._instances: | |
classname = SEGMENT_TYPE_IMPLS[SegmentType(segment["type"])] | |
cls = get_class(classname, SegmentImplementation) | |
self._instances[segment["id"]] = cls(self._system, segment) | |
return self._instances[segment["id"]] | |
def _segment(type: SegmentType, scope: SegmentScope, collection: Collection) -> Segment: | |
"""Create a metadata dict, propagating metadata correctly for the given segment type.""" | |
metadata = {} | |
regexes = PROPAGATE_METADATA.get(type, []) | |
if collection["metadata"]: | |
for key, value in collection["metadata"].items(): | |
for regex in regexes: | |
if re.match(regex, key): | |
metadata[key] = value | |
break | |
return Segment( | |
id=uuid4(), | |
type=type.value, | |
scope=scope, | |
topic=collection["topic"], | |
collection=collection["id"], | |
metadata=metadata, | |
) | |