Spaces:
Runtime error
Runtime error
File size: 8,096 Bytes
4a51346 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
from typing import Sequence
from typing_extensions import TypedDict
import os
import re
import hashlib
from chromadb.db.base import SqlDB, Cursor
from abc import abstractmethod
from chromadb.config import System, Settings
class MigrationFile(TypedDict):
dir: str
filename: str
version: int
scope: str
class Migration(MigrationFile):
hash: str
sql: str
class UninitializedMigrationsError(Exception):
def __init__(self) -> None:
super().__init__("Migrations have not been initialized")
class UnappliedMigrationsError(Exception):
def __init__(self, dir: str, version: int):
self.dir = dir
self.version = version
super().__init__(
f"Unapplied migrations in {dir}, starting with version {version}"
)
class InconsistentVersionError(Exception):
def __init__(self, dir: str, db_version: int, source_version: int):
super().__init__(
f"Inconsistent migration versions in {dir}:"
+ f"db version was {db_version}, source version was {source_version}."
+ " Has the migration sequence been modified since being applied to the DB?"
)
class InconsistentHashError(Exception):
def __init__(self, path: str, db_hash: str, source_hash: str):
super().__init__(
f"Inconsistent MD5 hashes in {path}:"
+ f"db hash was {db_hash}, source has was {source_hash}."
+ " Was the migration file modified after being applied to the DB?"
)
class InvalidMigrationFilename(Exception):
pass
class MigratableDB(SqlDB):
"""Simple base class for databases which support basic migrations.
Migrations are SQL files stored in a project-relative directory. All migrations in
the same directory are assumed to be dependent on previous migrations in the same
directory, where "previous" is defined on lexographical ordering of filenames.
Migrations have a ascending numeric version number and a hash of the file contents.
When migrations are applied, the hashes of previous migrations are checked to ensure
that the database is consistent with the source repository. If they are not, an
error is thrown and no migrations will be applied.
Migration files must follow the naming convention:
<version>.<description>.<scope>.sql, where <version> is a 5-digit zero-padded
integer, <description> is a short textual description, and <scope> is a short string
identifying the database implementation.
"""
_settings: Settings
def __init__(self, system: System) -> None:
self._settings = system.settings
super().__init__(system)
@staticmethod
@abstractmethod
def migration_scope() -> str:
"""The database implementation to use for migrations (e.g, sqlite, pgsql)"""
pass
@abstractmethod
def migration_dirs(self) -> Sequence[str]:
"""Directories containing the migration sequences that should be applied to this
DB."""
pass
@abstractmethod
def setup_migrations(self) -> None:
"""Idempotently creates the migrations table"""
pass
@abstractmethod
def migrations_initialized(self) -> bool:
"""Return true if the migrations table exists"""
pass
@abstractmethod
def db_migrations(self, dir: str) -> Sequence[Migration]:
"""Return a list of all migrations already applied to this database, from the
given source directory, in ascending order."""
pass
@abstractmethod
def apply_migration(self, cur: Cursor, migration: Migration) -> None:
"""Apply a single migration to the database"""
pass
def initialize_migrations(self) -> None:
"""Initialize migrations for this DB"""
migrate = self._settings.require("migrations")
if migrate == "validate":
self.validate_migrations()
if migrate == "apply":
self.apply_migrations()
def validate_migrations(self) -> None:
"""Validate all migrations and throw an exception if there are any unapplied
migrations in the source repo."""
if not self.migrations_initialized():
raise UninitializedMigrationsError()
for dir in self.migration_dirs():
db_migrations = self.db_migrations(dir)
source_migrations = find_migrations(dir, self.migration_scope())
unapplied_migrations = verify_migration_sequence(
db_migrations, source_migrations
)
if len(unapplied_migrations) > 0:
version = unapplied_migrations[0]["version"]
raise UnappliedMigrationsError(dir=dir, version=version)
def apply_migrations(self) -> None:
"""Validate existing migrations, and apply all new ones."""
self.setup_migrations()
for dir in self.migration_dirs():
db_migrations = self.db_migrations(dir)
source_migrations = find_migrations(dir, self.migration_scope())
unapplied_migrations = verify_migration_sequence(
db_migrations, source_migrations
)
with self.tx() as cur:
for migration in unapplied_migrations:
self.apply_migration(cur, migration)
# Format is <version>-<name>.<scope>.sql
# e.g, 00001-users.sqlite.sql
filename_regex = re.compile(r"(\d+)-(.+)\.(.+)\.sql")
def _parse_migration_filename(dir: str, filename: str) -> MigrationFile:
"""Parse a migration filename into a MigrationFile object"""
match = filename_regex.match(filename)
if match is None:
raise InvalidMigrationFilename("Invalid migration filename: " + filename)
version, _, scope = match.groups()
return {
"dir": dir,
"filename": filename,
"version": int(version),
"scope": scope,
}
def verify_migration_sequence(
db_migrations: Sequence[Migration],
source_migrations: Sequence[Migration],
) -> Sequence[Migration]:
"""Given a list of migrations already applied to a database, and a list of
migrations from the source code, validate that the applied migrations are correct
and match the expected migrations.
Throws an exception if any migrations are missing, out of order, or if the source
hash does not match.
Returns a list of all unapplied migrations, or an empty list if all migrations are
applied and the database is up to date."""
for db_migration, source_migration in zip(db_migrations, source_migrations):
if db_migration["version"] != source_migration["version"]:
raise InconsistentVersionError(
dir=db_migration["dir"],
db_version=db_migration["version"],
source_version=source_migration["version"],
)
if db_migration["hash"] != source_migration["hash"]:
raise InconsistentHashError(
path=db_migration["dir"] + "/" + db_migration["filename"],
db_hash=db_migration["hash"],
source_hash=source_migration["hash"],
)
return source_migrations[len(db_migrations) :]
def find_migrations(dir: str, scope: str) -> Sequence[Migration]:
"""Return a list of all migration present in the given directory, in ascending
order. Filter by scope."""
files = [
_parse_migration_filename(dir, filename)
for filename in os.listdir(dir)
if filename.endswith(".sql")
]
files = list(filter(lambda f: f["scope"] == scope, files))
files = sorted(files, key=lambda f: f["version"])
return [_read_migration_file(f) for f in files]
def _read_migration_file(file: MigrationFile) -> Migration:
"""Read a migration file"""
sql = open(os.path.join(file["dir"], file["filename"])).read()
hash = hashlib.md5(sql.encode("utf-8")).hexdigest()
return {
"hash": hash,
"sql": sql,
"dir": file["dir"],
"filename": file["filename"],
"version": file["version"],
"scope": file["scope"],
}
|