|
1 | | -import operator |
2 | 1 | from typing import TYPE_CHECKING, cast |
3 | 2 |
|
4 | | -from sqlalchemy import Column, Engine, engine_from_config, pool |
| 3 | +from sqlalchemy import Engine, engine_from_config, pool |
5 | 4 |
|
6 | 5 | from advanced_alchemy.base import metadata_registry |
7 | 6 | from alembic import context |
8 | 7 | from alembic.autogenerate import rewriter |
9 | | -from alembic.operations import ops |
10 | 8 |
|
11 | 9 | if TYPE_CHECKING: |
12 | 10 | from sqlalchemy.engine import Connection |
13 | 11 |
|
14 | 12 | from advanced_alchemy.alembic.commands import AlembicCommandConfig |
15 | | - from alembic.runtime.environment import EnvironmentContext |
16 | 13 |
|
17 | | -__all__ = ["do_run_migrations", "run_migrations_offline", "run_migrations_online"] |
| 14 | +__all__ = ("do_run_migrations", "run_migrations_offline", "run_migrations_online") |
18 | 15 |
|
19 | 16 |
|
20 | 17 | # this is the Alembic Config object, which provides |
|
23 | 20 | writer = rewriter.Rewriter() |
24 | 21 |
|
25 | 22 |
|
26 | | -@writer.rewrites(ops.CreateTableOp) |
27 | | -def order_columns( |
28 | | - context: "EnvironmentContext", # noqa: ARG001 |
29 | | - revision: tuple[str, ...], # noqa: ARG001 |
30 | | - op: ops.CreateTableOp, |
31 | | -) -> ops.CreateTableOp: |
32 | | - """Orders ID first and the audit columns at the end. |
33 | | -
|
34 | | - Args: |
35 | | - context: The context of the environment. |
36 | | - revision: The revision of the environment. |
37 | | - op: The operation to create the table. |
38 | | -
|
39 | | - Returns: |
40 | | - The operation to create the table. |
41 | | - """ |
42 | | - special_names = {"id": -100, "sa_orm_sentinel": 3001, "created_at": 3002, "updated_at": 3003} |
43 | | - cols_by_key = [ # pyright: ignore[reportUnknownVariableType] |
44 | | - ( |
45 | | - special_names.get(col.key, index) if isinstance(col, Column) else 2000, |
46 | | - col.copy(), # type: ignore[attr-defined] |
47 | | - ) |
48 | | - for index, col in enumerate(op.columns) |
49 | | - ] |
50 | | - columns = [col for _, col in sorted(cols_by_key, key=operator.itemgetter(0))] # pyright: ignore[reportUnknownVariableType,reportUnknownArgumentType,reportUnknownLambdaType] |
51 | | - return ops.CreateTableOp( |
52 | | - op.table_name, |
53 | | - columns, # pyright: ignore[reportUnknownArgumentType] |
54 | | - schema=op.schema, |
55 | | - # TODO: Remove when https://github.com/sqlalchemy/alembic/issues/1193 is fixed # noqa: FIX002 |
56 | | - _namespace_metadata=op._namespace_metadata, # noqa: SLF001 # noqa: SLF001 # pyright: ignore[reportPrivateUsage] |
57 | | - **op.kw, |
58 | | - ) |
59 | | - |
60 | | - |
61 | 23 | def run_migrations_offline() -> None: |
62 | 24 | """Run migrations in 'offline' mode. |
63 | 25 |
|
|
0 commit comments