SungBeom's picture
Upload folder using huggingface_hub
4a51346
raw
history blame
8.1 kB
from typing import Sequence
from typing_extensions import TypedDict
import os
import re
import hashlib
from chromadb.db.base import SqlDB, Cursor
from abc import abstractmethod
from chromadb.config import System, Settings
class MigrationFile(TypedDict):
dir: str
filename: str
version: int
scope: str
class Migration(MigrationFile):
hash: str
sql: str
class UninitializedMigrationsError(Exception):
def __init__(self) -> None:
super().__init__("Migrations have not been initialized")
class UnappliedMigrationsError(Exception):
def __init__(self, dir: str, version: int):
self.dir = dir
self.version = version
super().__init__(
f"Unapplied migrations in {dir}, starting with version {version}"
)
class InconsistentVersionError(Exception):
def __init__(self, dir: str, db_version: int, source_version: int):
super().__init__(
f"Inconsistent migration versions in {dir}:"
+ f"db version was {db_version}, source version was {source_version}."
+ " Has the migration sequence been modified since being applied to the DB?"
)
class InconsistentHashError(Exception):
def __init__(self, path: str, db_hash: str, source_hash: str):
super().__init__(
f"Inconsistent MD5 hashes in {path}:"
+ f"db hash was {db_hash}, source has was {source_hash}."
+ " Was the migration file modified after being applied to the DB?"
)
class InvalidMigrationFilename(Exception):
pass
class MigratableDB(SqlDB):
"""Simple base class for databases which support basic migrations.
Migrations are SQL files stored in a project-relative directory. All migrations in
the same directory are assumed to be dependent on previous migrations in the same
directory, where "previous" is defined on lexographical ordering of filenames.
Migrations have a ascending numeric version number and a hash of the file contents.
When migrations are applied, the hashes of previous migrations are checked to ensure
that the database is consistent with the source repository. If they are not, an
error is thrown and no migrations will be applied.
Migration files must follow the naming convention:
<version>.<description>.<scope>.sql, where <version> is a 5-digit zero-padded
integer, <description> is a short textual description, and <scope> is a short string
identifying the database implementation.
"""
_settings: Settings
def __init__(self, system: System) -> None:
self._settings = system.settings
super().__init__(system)
@staticmethod
@abstractmethod
def migration_scope() -> str:
"""The database implementation to use for migrations (e.g, sqlite, pgsql)"""
pass
@abstractmethod
def migration_dirs(self) -> Sequence[str]:
"""Directories containing the migration sequences that should be applied to this
DB."""
pass
@abstractmethod
def setup_migrations(self) -> None:
"""Idempotently creates the migrations table"""
pass
@abstractmethod
def migrations_initialized(self) -> bool:
"""Return true if the migrations table exists"""
pass
@abstractmethod
def db_migrations(self, dir: str) -> Sequence[Migration]:
"""Return a list of all migrations already applied to this database, from the
given source directory, in ascending order."""
pass
@abstractmethod
def apply_migration(self, cur: Cursor, migration: Migration) -> None:
"""Apply a single migration to the database"""
pass
def initialize_migrations(self) -> None:
"""Initialize migrations for this DB"""
migrate = self._settings.require("migrations")
if migrate == "validate":
self.validate_migrations()
if migrate == "apply":
self.apply_migrations()
def validate_migrations(self) -> None:
"""Validate all migrations and throw an exception if there are any unapplied
migrations in the source repo."""
if not self.migrations_initialized():
raise UninitializedMigrationsError()
for dir in self.migration_dirs():
db_migrations = self.db_migrations(dir)
source_migrations = find_migrations(dir, self.migration_scope())
unapplied_migrations = verify_migration_sequence(
db_migrations, source_migrations
)
if len(unapplied_migrations) > 0:
version = unapplied_migrations[0]["version"]
raise UnappliedMigrationsError(dir=dir, version=version)
def apply_migrations(self) -> None:
"""Validate existing migrations, and apply all new ones."""
self.setup_migrations()
for dir in self.migration_dirs():
db_migrations = self.db_migrations(dir)
source_migrations = find_migrations(dir, self.migration_scope())
unapplied_migrations = verify_migration_sequence(
db_migrations, source_migrations
)
with self.tx() as cur:
for migration in unapplied_migrations:
self.apply_migration(cur, migration)
# Format is <version>-<name>.<scope>.sql
# e.g, 00001-users.sqlite.sql
filename_regex = re.compile(r"(\d+)-(.+)\.(.+)\.sql")
def _parse_migration_filename(dir: str, filename: str) -> MigrationFile:
"""Parse a migration filename into a MigrationFile object"""
match = filename_regex.match(filename)
if match is None:
raise InvalidMigrationFilename("Invalid migration filename: " + filename)
version, _, scope = match.groups()
return {
"dir": dir,
"filename": filename,
"version": int(version),
"scope": scope,
}
def verify_migration_sequence(
db_migrations: Sequence[Migration],
source_migrations: Sequence[Migration],
) -> Sequence[Migration]:
"""Given a list of migrations already applied to a database, and a list of
migrations from the source code, validate that the applied migrations are correct
and match the expected migrations.
Throws an exception if any migrations are missing, out of order, or if the source
hash does not match.
Returns a list of all unapplied migrations, or an empty list if all migrations are
applied and the database is up to date."""
for db_migration, source_migration in zip(db_migrations, source_migrations):
if db_migration["version"] != source_migration["version"]:
raise InconsistentVersionError(
dir=db_migration["dir"],
db_version=db_migration["version"],
source_version=source_migration["version"],
)
if db_migration["hash"] != source_migration["hash"]:
raise InconsistentHashError(
path=db_migration["dir"] + "/" + db_migration["filename"],
db_hash=db_migration["hash"],
source_hash=source_migration["hash"],
)
return source_migrations[len(db_migrations) :]
def find_migrations(dir: str, scope: str) -> Sequence[Migration]:
"""Return a list of all migration present in the given directory, in ascending
order. Filter by scope."""
files = [
_parse_migration_filename(dir, filename)
for filename in os.listdir(dir)
if filename.endswith(".sql")
]
files = list(filter(lambda f: f["scope"] == scope, files))
files = sorted(files, key=lambda f: f["version"])
return [_read_migration_file(f) for f in files]
def _read_migration_file(file: MigrationFile) -> Migration:
"""Read a migration file"""
sql = open(os.path.join(file["dir"], file["filename"])).read()
hash = hashlib.md5(sql.encode("utf-8")).hexdigest()
return {
"hash": hash,
"sql": sql,
"dir": file["dir"],
"filename": file["filename"],
"version": file["version"],
"scope": file["scope"],
}