File size: 8,096 Bytes
4a51346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
from typing import Sequence
from typing_extensions import TypedDict
import os
import re
import hashlib
from chromadb.db.base import SqlDB, Cursor
from abc import abstractmethod
from chromadb.config import System, Settings


class MigrationFile(TypedDict):
    dir: str
    filename: str
    version: int
    scope: str


class Migration(MigrationFile):
    hash: str
    sql: str


class UninitializedMigrationsError(Exception):
    def __init__(self) -> None:
        super().__init__("Migrations have not been initialized")


class UnappliedMigrationsError(Exception):
    def __init__(self, dir: str, version: int):
        self.dir = dir
        self.version = version
        super().__init__(
            f"Unapplied migrations in {dir}, starting with version {version}"
        )


class InconsistentVersionError(Exception):
    def __init__(self, dir: str, db_version: int, source_version: int):
        super().__init__(
            f"Inconsistent migration versions in {dir}:"
            + f"db version was {db_version}, source version was {source_version}."
            + " Has the migration sequence been modified since being applied to the DB?"
        )


class InconsistentHashError(Exception):
    def __init__(self, path: str, db_hash: str, source_hash: str):
        super().__init__(
            f"Inconsistent MD5 hashes in {path}:"
            + f"db hash was {db_hash}, source has was {source_hash}."
            + " Was the migration file modified after being applied to the DB?"
        )


class InvalidMigrationFilename(Exception):
    pass


class MigratableDB(SqlDB):
    """Simple base class for databases which support basic migrations.

    Migrations are SQL files stored in a project-relative directory. All migrations in
    the same directory are assumed to be dependent on previous migrations in the same
    directory, where "previous" is defined on lexographical ordering of filenames.

    Migrations have a ascending numeric version number and a hash of the file contents.
    When migrations are applied, the hashes of previous migrations are checked to ensure
    that the database is consistent with the source repository. If they are not, an
    error is thrown and no migrations will be applied.

    Migration files must follow the naming convention:
    <version>.<description>.<scope>.sql, where <version> is a 5-digit zero-padded
    integer, <description> is a short textual description, and <scope> is a short string
    identifying the database implementation.
    """

    _settings: Settings

    def __init__(self, system: System) -> None:
        self._settings = system.settings
        super().__init__(system)

    @staticmethod
    @abstractmethod
    def migration_scope() -> str:
        """The database implementation to use for migrations (e.g, sqlite, pgsql)"""
        pass

    @abstractmethod
    def migration_dirs(self) -> Sequence[str]:
        """Directories containing the migration sequences that should be applied to this
        DB."""
        pass

    @abstractmethod
    def setup_migrations(self) -> None:
        """Idempotently creates the migrations table"""
        pass

    @abstractmethod
    def migrations_initialized(self) -> bool:
        """Return true if the migrations table exists"""
        pass

    @abstractmethod
    def db_migrations(self, dir: str) -> Sequence[Migration]:
        """Return a list of all migrations already applied to this database, from the
        given source directory, in ascending order."""
        pass

    @abstractmethod
    def apply_migration(self, cur: Cursor, migration: Migration) -> None:
        """Apply a single migration to the database"""
        pass

    def initialize_migrations(self) -> None:
        """Initialize migrations for this DB"""
        migrate = self._settings.require("migrations")

        if migrate == "validate":
            self.validate_migrations()

        if migrate == "apply":
            self.apply_migrations()

    def validate_migrations(self) -> None:
        """Validate all migrations and throw an exception if there are any unapplied
        migrations in the source repo."""
        if not self.migrations_initialized():
            raise UninitializedMigrationsError()
        for dir in self.migration_dirs():
            db_migrations = self.db_migrations(dir)
            source_migrations = find_migrations(dir, self.migration_scope())
            unapplied_migrations = verify_migration_sequence(
                db_migrations, source_migrations
            )
            if len(unapplied_migrations) > 0:
                version = unapplied_migrations[0]["version"]
                raise UnappliedMigrationsError(dir=dir, version=version)

    def apply_migrations(self) -> None:
        """Validate existing migrations, and apply all new ones."""
        self.setup_migrations()
        for dir in self.migration_dirs():
            db_migrations = self.db_migrations(dir)
            source_migrations = find_migrations(dir, self.migration_scope())
            unapplied_migrations = verify_migration_sequence(
                db_migrations, source_migrations
            )
            with self.tx() as cur:
                for migration in unapplied_migrations:
                    self.apply_migration(cur, migration)


# Format is <version>-<name>.<scope>.sql
# e.g, 00001-users.sqlite.sql
filename_regex = re.compile(r"(\d+)-(.+)\.(.+)\.sql")


def _parse_migration_filename(dir: str, filename: str) -> MigrationFile:
    """Parse a migration filename into a MigrationFile object"""
    match = filename_regex.match(filename)
    if match is None:
        raise InvalidMigrationFilename("Invalid migration filename: " + filename)
    version, _, scope = match.groups()
    return {
        "dir": dir,
        "filename": filename,
        "version": int(version),
        "scope": scope,
    }


def verify_migration_sequence(
    db_migrations: Sequence[Migration],
    source_migrations: Sequence[Migration],
) -> Sequence[Migration]:
    """Given a list of migrations already applied to a database, and a list of
    migrations from the source code, validate that the applied migrations are correct
    and match the expected migrations.

    Throws an exception if any migrations are missing, out of order, or if the source
    hash does not match.

    Returns a list of all unapplied migrations, or an empty list if all migrations are
    applied and the database is up to date."""

    for db_migration, source_migration in zip(db_migrations, source_migrations):
        if db_migration["version"] != source_migration["version"]:
            raise InconsistentVersionError(
                dir=db_migration["dir"],
                db_version=db_migration["version"],
                source_version=source_migration["version"],
            )

        if db_migration["hash"] != source_migration["hash"]:
            raise InconsistentHashError(
                path=db_migration["dir"] + "/" + db_migration["filename"],
                db_hash=db_migration["hash"],
                source_hash=source_migration["hash"],
            )

    return source_migrations[len(db_migrations) :]


def find_migrations(dir: str, scope: str) -> Sequence[Migration]:
    """Return a list of all migration present in the given directory, in ascending
    order. Filter by scope."""
    files = [
        _parse_migration_filename(dir, filename)
        for filename in os.listdir(dir)
        if filename.endswith(".sql")
    ]
    files = list(filter(lambda f: f["scope"] == scope, files))
    files = sorted(files, key=lambda f: f["version"])
    return [_read_migration_file(f) for f in files]


def _read_migration_file(file: MigrationFile) -> Migration:
    """Read a migration file"""
    sql = open(os.path.join(file["dir"], file["filename"])).read()
    hash = hashlib.md5(sql.encode("utf-8")).hexdigest()
    return {
        "hash": hash,
        "sql": sql,
        "dir": file["dir"],
        "filename": file["filename"],
        "version": file["version"],
        "scope": file["scope"],
    }