Spaces:
Runtime error
Runtime error
from chromadb.db.migrations import MigratableDB, Migration | |
from chromadb.config import System, Settings | |
import chromadb.db.base as base | |
from chromadb.db.mixins.embeddings_queue import SqlEmbeddingsQueue | |
from chromadb.db.mixins.sysdb import SqlSysDB | |
import sqlite3 | |
from overrides import override | |
import pypika | |
from typing import Sequence, cast, Optional, Type, Any | |
from typing_extensions import Literal | |
from types import TracebackType | |
import os | |
from uuid import UUID | |
from threading import local | |
class TxWrapper(base.TxWrapper): | |
def __init__(self, conn: sqlite3.Connection, stack: local) -> None: | |
self._tx_stack = stack | |
self._conn = conn | |
def __enter__(self) -> base.Cursor: | |
if len(self._tx_stack.stack) == 0: | |
self._conn.execute("BEGIN;") | |
self._tx_stack.stack.append(self) | |
return self._conn.cursor() # type: ignore | |
def __exit__( | |
self, | |
exc_type: Optional[Type[BaseException]], | |
exc_value: Optional[BaseException], | |
traceback: Optional[TracebackType], | |
) -> Literal[False]: | |
self._tx_stack.stack.pop() | |
if len(self._tx_stack.stack) == 0: | |
if exc_type is None: | |
self._conn.commit() | |
else: | |
self._conn.rollback() | |
return False | |
class SqliteDB(MigratableDB, SqlEmbeddingsQueue, SqlSysDB): | |
_conn: sqlite3.Connection | |
_settings: Settings | |
_migration_dirs: Sequence[str] | |
_db_file: str | |
_tx_stack: local | |
def __init__(self, system: System): | |
self._settings = system.settings | |
self._migration_dirs = [ | |
"migrations/embeddings_queue", | |
"migrations/sysdb", | |
"migrations/metadb", | |
] | |
self._db_file = self._settings.require("sqlite_database") | |
self._tx_stack = local() | |
super().__init__(system) | |
def start(self) -> None: | |
super().start() | |
self._conn = sqlite3.connect(self._db_file) | |
self._conn.isolation_level = None # Handle commits explicitly | |
with self.tx() as cur: | |
cur.execute("PRAGMA foreign_keys = ON") | |
self.initialize_migrations() | |
def stop(self) -> None: | |
super().stop() | |
self._conn.close() | |
def querybuilder() -> Type[pypika.Query]: | |
return pypika.Query # type: ignore | |
def parameter_format() -> str: | |
return "?" | |
def migration_scope() -> str: | |
return "sqlite" | |
def migration_dirs(self) -> Sequence[str]: | |
return self._migration_dirs | |
def tx(self) -> TxWrapper: | |
if not hasattr(self._tx_stack, "stack"): | |
self._tx_stack.stack = [] | |
return TxWrapper(self._conn, stack=self._tx_stack) | |
def reset(self) -> None: | |
if not self._settings.require("allow_reset"): | |
raise ValueError( | |
"Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted." | |
) | |
self._conn.close() | |
db_file = self._settings.require("sqlite_database") | |
if db_file != ":memory:": | |
os.remove(db_file) | |
self.start() | |
super().reset() | |
def setup_migrations(self) -> None: | |
with self.tx() as cur: | |
cur.execute( | |
""" | |
CREATE TABLE IF NOT EXISTS migrations ( | |
dir TEXT NOT NULL, | |
version INTEGER NOT NULL, | |
filename TEXT NOT NULL, | |
sql TEXT NOT NULL, | |
hash TEXT NOT NULL, | |
PRIMARY KEY (dir, version) | |
) | |
""" | |
) | |
def migrations_initialized(self) -> bool: | |
with self.tx() as cur: | |
cur.execute( | |
"""SELECT count(*) FROM sqlite_master | |
WHERE type='table' AND name='migrations'""" | |
) | |
if cur.fetchone()[0] == 0: | |
return False | |
else: | |
return True | |
def db_migrations(self, dir: str) -> Sequence[Migration]: | |
with self.tx() as cur: | |
cur.execute( | |
""" | |
SELECT dir, version, filename, sql, hash | |
FROM migrations | |
WHERE dir = ? | |
ORDER BY version ASC | |
""", | |
(dir,), | |
) | |
migrations = [] | |
for row in cur.fetchall(): | |
dir = cast(str, row[0]) | |
version = cast(int, row[1]) | |
filename = cast(str, row[2]) | |
sql = cast(str, row[3]) | |
hash = cast(str, row[4]) | |
migrations.append( | |
Migration( | |
dir=dir, | |
version=version, | |
filename=filename, | |
sql=sql, | |
hash=hash, | |
scope=self.migration_scope(), | |
) | |
) | |
return migrations | |
def apply_migration(self, cur: base.Cursor, migration: Migration) -> None: | |
cur.executescript(migration["sql"]) | |
cur.execute( | |
""" | |
INSERT INTO migrations (dir, version, filename, sql, hash) | |
VALUES (?, ?, ?, ?, ?) | |
""", | |
( | |
migration["dir"], | |
migration["version"], | |
migration["filename"], | |
migration["sql"], | |
migration["hash"], | |
), | |
) | |
def uuid_from_db(value: Optional[Any]) -> Optional[UUID]: | |
return UUID(value) if value is not None else None | |
def uuid_to_db(uuid: Optional[UUID]) -> Optional[Any]: | |
return str(uuid) if uuid is not None else None | |
def unique_constraint_error() -> Type[BaseException]: | |
return sqlite3.IntegrityError | |