from typing import Sequence from typing_extensions import TypedDict import os import re import hashlib from chromadb.db.base import SqlDB, Cursor from abc import abstractmethod from chromadb.config import System, Settings class MigrationFile(TypedDict): dir: str filename: str version: int scope: str class Migration(MigrationFile): hash: str sql: str class UninitializedMigrationsError(Exception): def __init__(self) -> None: super().__init__("Migrations have not been initialized") class UnappliedMigrationsError(Exception): def __init__(self, dir: str, version: int): self.dir = dir self.version = version super().__init__( f"Unapplied migrations in {dir}, starting with version {version}" ) class InconsistentVersionError(Exception): def __init__(self, dir: str, db_version: int, source_version: int): super().__init__( f"Inconsistent migration versions in {dir}:" + f"db version was {db_version}, source version was {source_version}." + " Has the migration sequence been modified since being applied to the DB?" ) class InconsistentHashError(Exception): def __init__(self, path: str, db_hash: str, source_hash: str): super().__init__( f"Inconsistent MD5 hashes in {path}:" + f"db hash was {db_hash}, source has was {source_hash}." + " Was the migration file modified after being applied to the DB?" ) class InvalidMigrationFilename(Exception): pass class MigratableDB(SqlDB): """Simple base class for databases which support basic migrations. Migrations are SQL files stored in a project-relative directory. All migrations in the same directory are assumed to be dependent on previous migrations in the same directory, where "previous" is defined on lexographical ordering of filenames. Migrations have a ascending numeric version number and a hash of the file contents. When migrations are applied, the hashes of previous migrations are checked to ensure that the database is consistent with the source repository. If they are not, an error is thrown and no migrations will be applied. Migration files must follow the naming convention: ...sql, where is a 5-digit zero-padded integer, is a short textual description, and is a short string identifying the database implementation. """ _settings: Settings def __init__(self, system: System) -> None: self._settings = system.settings super().__init__(system) @staticmethod @abstractmethod def migration_scope() -> str: """The database implementation to use for migrations (e.g, sqlite, pgsql)""" pass @abstractmethod def migration_dirs(self) -> Sequence[str]: """Directories containing the migration sequences that should be applied to this DB.""" pass @abstractmethod def setup_migrations(self) -> None: """Idempotently creates the migrations table""" pass @abstractmethod def migrations_initialized(self) -> bool: """Return true if the migrations table exists""" pass @abstractmethod def db_migrations(self, dir: str) -> Sequence[Migration]: """Return a list of all migrations already applied to this database, from the given source directory, in ascending order.""" pass @abstractmethod def apply_migration(self, cur: Cursor, migration: Migration) -> None: """Apply a single migration to the database""" pass def initialize_migrations(self) -> None: """Initialize migrations for this DB""" migrate = self._settings.require("migrations") if migrate == "validate": self.validate_migrations() if migrate == "apply": self.apply_migrations() def validate_migrations(self) -> None: """Validate all migrations and throw an exception if there are any unapplied migrations in the source repo.""" if not self.migrations_initialized(): raise UninitializedMigrationsError() for dir in self.migration_dirs(): db_migrations = self.db_migrations(dir) source_migrations = find_migrations(dir, self.migration_scope()) unapplied_migrations = verify_migration_sequence( db_migrations, source_migrations ) if len(unapplied_migrations) > 0: version = unapplied_migrations[0]["version"] raise UnappliedMigrationsError(dir=dir, version=version) def apply_migrations(self) -> None: """Validate existing migrations, and apply all new ones.""" self.setup_migrations() for dir in self.migration_dirs(): db_migrations = self.db_migrations(dir) source_migrations = find_migrations(dir, self.migration_scope()) unapplied_migrations = verify_migration_sequence( db_migrations, source_migrations ) with self.tx() as cur: for migration in unapplied_migrations: self.apply_migration(cur, migration) # Format is -..sql # e.g, 00001-users.sqlite.sql filename_regex = re.compile(r"(\d+)-(.+)\.(.+)\.sql") def _parse_migration_filename(dir: str, filename: str) -> MigrationFile: """Parse a migration filename into a MigrationFile object""" match = filename_regex.match(filename) if match is None: raise InvalidMigrationFilename("Invalid migration filename: " + filename) version, _, scope = match.groups() return { "dir": dir, "filename": filename, "version": int(version), "scope": scope, } def verify_migration_sequence( db_migrations: Sequence[Migration], source_migrations: Sequence[Migration], ) -> Sequence[Migration]: """Given a list of migrations already applied to a database, and a list of migrations from the source code, validate that the applied migrations are correct and match the expected migrations. Throws an exception if any migrations are missing, out of order, or if the source hash does not match. Returns a list of all unapplied migrations, or an empty list if all migrations are applied and the database is up to date.""" for db_migration, source_migration in zip(db_migrations, source_migrations): if db_migration["version"] != source_migration["version"]: raise InconsistentVersionError( dir=db_migration["dir"], db_version=db_migration["version"], source_version=source_migration["version"], ) if db_migration["hash"] != source_migration["hash"]: raise InconsistentHashError( path=db_migration["dir"] + "/" + db_migration["filename"], db_hash=db_migration["hash"], source_hash=source_migration["hash"], ) return source_migrations[len(db_migrations) :] def find_migrations(dir: str, scope: str) -> Sequence[Migration]: """Return a list of all migration present in the given directory, in ascending order. Filter by scope.""" files = [ _parse_migration_filename(dir, filename) for filename in os.listdir(dir) if filename.endswith(".sql") ] files = list(filter(lambda f: f["scope"] == scope, files)) files = sorted(files, key=lambda f: f["version"]) return [_read_migration_file(f) for f in files] def _read_migration_file(file: MigrationFile) -> Migration: """Read a migration file""" sql = open(os.path.join(file["dir"], file["filename"])).read() hash = hashlib.md5(sql.encode("utf-8")).hexdigest() return { "hash": hash, "sql": sql, "dir": file["dir"], "filename": file["filename"], "version": file["version"], "scope": file["scope"], }