Skip to content

Commit 3153b49

Browse files
authored
feat: add SQLModel compatibility (#686)
- **SQLModel `table=True` models** now work seamlessly with Advanced Alchemy repositories and services without requiring AA base classes - `ModelProtocol` no longer requires `to_dict()` — models only need `__mapper__`, `__table__`, and `__name__` (which all SQLAlchemy-mapped models have) - New `model_to_dict()` utility converts any mapped model to a dict using mapper column introspection - `is_schema()` family of functions now correctly **exclude** SQLModel table models (they are ORM models, not transfer schemas) - `schema_dump()` returns SQLModel table instances as-is instead of calling `model_dump()` - New `is_sqlmodel_table_model()` detection function
1 parent f5e02bb commit 3153b49

16 files changed

Lines changed: 1667 additions & 575 deletions

File tree

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ repos:
2222
- id: unasyncd
2323
additional_dependencies: ["ruff"]
2424
- repo: https://github.com/charliermarsh/ruff-pre-commit
25-
rev: "v0.15.2"
25+
rev: "v0.15.5"
2626
hooks:
2727
# Run the linter.
2828
- id: ruff
@@ -32,7 +32,7 @@ repos:
3232
- id: ruff-format
3333
types_or: [python, pyi]
3434
- repo: https://github.com/codespell-project/codespell
35-
rev: v2.4.1
35+
rev: v2.4.2
3636
hooks:
3737
- id: codespell
3838
additional_dependencies: [tomli]

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ offering:
3636
- Integration with major web frameworks including Litestar, Starlette, FastAPI, Sanic
3737
- Custom-built alembic configuration and CLI with optional framework integration
3838
- Utility base classes with audit columns, primary keys and utility functions
39+
- [SQLModel](https://sqlmodel.tiangolo.com/) compatibility — use `SQLModel` `table=True` models directly with repositories and services
40+
- Composite primary key support — work with multi-column primary keys across repositories, services, and bulk operations
41+
- Read/write replica routing with automatic query routing, round-robin/random replica selection, and sticky-primary mode
42+
- Dogpile caching integration for query result caching
3943
- Built in `File Object` data type for storing objects:
4044
- Unified interface for various storage backends ([`fsspec`](https://filesystem-spec.readthedocs.io/en/latest/) and [`obstore`](https://developmentseed.org/obstore/latest/))
4145
- Optional lifecycle event hooks integrated with SQLAlchemy's event system to automatically save and delete files as records are inserted, updated, or deleted.
@@ -49,7 +53,7 @@ offering:
4953
- Synchronous and asynchronous repositories featuring:
5054
- Common CRUD operations for SQLAlchemy models
5155
- Bulk inserts, updates, upserts, and deletes with dialect-specific enhancements
52-
- Integrated counts, pagination, sorting, filtering with `LIKE`, `IN`, and dates before and/or after.
56+
- Integrated counts, pagination, sorting, filtering with `LIKE`, `IN`, `IS NULL`/`IS NOT NULL`, and dates before and/or after.
5357
- Tested support for multiple database backends including:
5458
- SQLite via [aiosqlite](https://aiosqlite.omnilib.dev/en/stable/) or [sqlite](https://docs.python.org/3/library/sqlite3.html)
5559
- Postgres via [asyncpg](https://magicstack.github.io/asyncpg/current/) or [psycopg3 (async or sync)](https://www.psycopg.org/psycopg3/)

advanced_alchemy/_typing.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# ruff: noqa: RUF100
2+
"""Foundational type shims for optional dependencies.
3+
4+
Provides stub types used across the package when optional libraries
5+
(e.g. SQLModel) are not installed. This module is intentionally
6+
kept minimal and free of internal imports so that low-level modules
7+
like ``base`` can use it without reaching into higher-level packages.
8+
"""
9+
10+
from typing import TYPE_CHECKING, Any, ClassVar
11+
12+
if TYPE_CHECKING:
13+
from sqlalchemy.orm import Mapper
14+
from sqlalchemy.sql import FromClause
15+
16+
17+
class SQLModelBaseLike:
18+
"""Placeholder for sqlmodel.SQLModel when the package is not installed.
19+
20+
Declares the same structural attributes as :class:`ModelProtocol`
21+
so that type checkers can see SQLModel ``table=True`` models as
22+
protocol-compatible without requiring the real SQLModel package.
23+
"""
24+
25+
if TYPE_CHECKING:
26+
__table__: "FromClause"
27+
__mapper__: "Mapper[Any]"
28+
__name__: str
29+
30+
model_fields: ClassVar[dict[str, Any]] = {}
31+
32+
33+
try:
34+
from sqlmodel import SQLModel as SQLModelBase
35+
36+
SQLMODEL_INSTALLED: bool = True # pyright: ignore[reportConstantRedefinition]
37+
except ImportError:
38+
SQLModelBase = SQLModelBaseLike # type: ignore[assignment,misc]
39+
SQLMODEL_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
40+
41+
__all__ = (
42+
"SQLMODEL_INSTALLED",
43+
"SQLModelBase",
44+
"SQLModelBaseLike",
45+
)

advanced_alchemy/alembic/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sqlalchemy import Column, Engine, MetaData, String, Table
66
from typing_extensions import TypeIs
77

8+
from advanced_alchemy.base import model_to_dict
89
from advanced_alchemy.exceptions import MissingDependencyError
910
from advanced_alchemy.utils.sync_tools import async_
1011

@@ -114,7 +115,7 @@ def _dump_table_sync(session: "AbstractContextManager[Session]") -> None:
114115
(SQLAlchemySyncRepository,),
115116
exec_body=lambda ns, model=model: ns.setdefault("model_type", model), # type: ignore[misc]
116117
)
117-
json_path.write_text(encode_json([row.to_dict() for row in repo(session=_session).list()]))
118+
json_path.write_text(encode_json([model_to_dict(row) for row in repo(session=_session).list()]))
118119

119120
async def _dump_table_async(session: "AbstractAsyncContextManager[AsyncSession]") -> None:
120121
from advanced_alchemy.repository import SQLAlchemyAsyncRepository
@@ -132,7 +133,7 @@ async def _dump_table_async(session: "AbstractAsyncContextManager[AsyncSession]"
132133
(SQLAlchemyAsyncRepository,),
133134
exec_body=lambda ns, model=model: ns.setdefault("model_type", model), # type: ignore[misc]
134135
)
135-
json_path.write_text(encode_json([row.to_dict() for row in await repo(session=_session).list()]))
136+
json_path.write_text(encode_json([model_to_dict(row) for row in await repo(session=_session).list()]))
136137

137138
await async_(dump_dir.mkdir)(exist_ok=True)
138139

advanced_alchemy/base.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from sqlalchemy.orm import (
1313
DeclarativeBase,
1414
Mapper,
15+
class_mapper,
1516
declared_attr,
1617
)
1718
from sqlalchemy.orm import (
@@ -72,6 +73,7 @@
7273
"create_registry",
7374
"merge_table_arguments",
7475
"metadata_registry",
76+
"model_to_dict",
7577
"orm_registry",
7678
"table_name_regexp",
7779
)
@@ -143,6 +145,9 @@ def merge_table_arguments(cls: type[DeclarativeBase], table_args: Optional[Table
143145
class ModelProtocol(Protocol):
144146
"""The base SQLAlchemy model protocol.
145147
148+
Defines the minimal contract for a SQLAlchemy-mapped model. Any class with
149+
a mapper and table (including SQLModel ``table=True`` models) satisfies this protocol.
150+
146151
Attributes:
147152
__table__ (:class:`sqlalchemy.sql.FromClause`): The table associated with the model.
148153
__mapper__ (:class:`sqlalchemy.orm.Mapper`): The mapper for the model.
@@ -154,13 +159,37 @@ class ModelProtocol(Protocol):
154159
__mapper__: Mapper[Any]
155160
__name__: str
156161

157-
def to_dict(self, exclude: Optional[set[str]] = None) -> dict[str, Any]:
158-
"""Convert model to dictionary.
159162

160-
Returns:
161-
Dict[str, Any]: A dict representation of the model
162-
"""
163-
...
163+
def model_to_dict(instance: "ModelProtocol", exclude: Optional[set[str]] = None) -> dict[str, Any]:
164+
"""Convert a mapped model instance to a dictionary.
165+
166+
Works with any SQLAlchemy-mapped model, including:
167+
- Advanced Alchemy models (delegates to ``to_dict()``)
168+
- SQLModel ``table=True`` models (uses mapper-based column iteration)
169+
- Any other mapped class
170+
171+
Args:
172+
instance: A SQLAlchemy-mapped model instance.
173+
exclude: Optional set of field names to exclude from the output.
174+
175+
Returns:
176+
A dictionary of column names to values.
177+
"""
178+
to_dict_fn = getattr(instance, "to_dict", None)
179+
if to_dict_fn is not None and callable(to_dict_fn):
180+
return to_dict_fn(exclude=exclude) # type: ignore[no-any-return]
181+
182+
exclude_fields: set[str] = {"sa_orm_sentinel", "_sentinel"}
183+
with contextlib.suppress(AttributeError):
184+
exclude_fields = exclude_fields.union(cast("set[str]", instance._sa_instance_state.unloaded)) # type: ignore[attr-defined,union-attr] # noqa: SLF001
185+
if exclude:
186+
exclude_fields = exclude_fields.union(exclude)
187+
mapper = class_mapper(type(instance))
188+
return {
189+
field: getattr(instance, field)
190+
for field in mapper.columns.keys() # noqa: SIM118
191+
if field not in exclude_fields
192+
}
164193

165194

166195
class BasicAttributes:

advanced_alchemy/repository/_async.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from sqlalchemy.sql.dml import ReturningDelete, ReturningUpdate
4646
from sqlalchemy.sql.selectable import ForUpdateArg, ForUpdateParameter
4747

48+
from advanced_alchemy.base import model_to_dict
4849
from advanced_alchemy.exceptions import ErrorMessages, NotFoundError, RepositoryError, wrap_sqlalchemy_exception
4950
from advanced_alchemy.filters import StatementFilter, StatementTypeT
5051
from advanced_alchemy.repository._util import (
@@ -2351,10 +2352,7 @@ async def update_many(
23512352
supports_updated_at = hasattr(self.model_type, "updated_at")
23522353
data_to_update: List[dict[str, Any]] = []
23532354
for v in data:
2354-
if isinstance(v, self.model_type) or (hasattr(v, "to_dict") and callable(v.to_dict)):
2355-
update_payload = v.to_dict()
2356-
else:
2357-
update_payload = cast("dict[str, Any]", schema_dump(v))
2355+
update_payload = model_to_dict(v) if hasattr(v, "__mapper__") else schema_dump(cast("dict[str, Any]", v))
23582356

23592357
if supports_updated_at and (update_payload.get("updated_at") is None):
23602358
update_payload["updated_at"] = datetime.datetime.now(datetime.timezone.utc)
@@ -2824,7 +2822,7 @@ async def upsert(
28242822
else:
28252823
# Exclude all PK columns when matching by non-PK fields
28262824
exclude_cols = set(self._pk_attr_names) if self.has_composite_pk else {self.id_attribute}
2827-
match_filter = data.to_dict(exclude=exclude_cols)
2825+
match_filter = model_to_dict(data, exclude=exclude_cols)
28282826
existing = await self.get_one_or_none(
28292827
load=load, execution_options=execution_options, bind_group=bind_group, **match_filter
28302828
)
@@ -2841,7 +2839,7 @@ async def upsert(
28412839
):
28422840
# Exclude all PK columns when copying field values
28432841
exclude_cols = set(self._pk_attr_names) if self.has_composite_pk else {self.id_attribute}
2844-
for field_name, new_field_value in data.to_dict(exclude=exclude_cols).items():
2842+
for field_name, new_field_value in model_to_dict(data, exclude=exclude_cols).items():
28452843
field = getattr(existing, field_name, MISSING)
28462844
if field is not MISSING and not compare_values(field, new_field_value): # pragma: no cover
28472845
setattr(existing, field_name, new_field_value)

advanced_alchemy/repository/_sync.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from sqlalchemy.sql.dml import ReturningDelete, ReturningUpdate
4747
from sqlalchemy.sql.selectable import ForUpdateArg, ForUpdateParameter
4848

49+
from advanced_alchemy.base import model_to_dict
4950
from advanced_alchemy.exceptions import ErrorMessages, NotFoundError, RepositoryError, wrap_sqlalchemy_exception
5051
from advanced_alchemy.filters import StatementFilter, StatementTypeT
5152
from advanced_alchemy.repository._util import (
@@ -2346,10 +2347,7 @@ def update_many(
23462347
supports_updated_at = hasattr(self.model_type, "updated_at")
23472348
data_to_update: List[dict[str, Any]] = []
23482349
for v in data:
2349-
if isinstance(v, self.model_type) or (hasattr(v, "to_dict") and callable(v.to_dict)):
2350-
update_payload = v.to_dict()
2351-
else:
2352-
update_payload = cast("dict[str, Any]", schema_dump(v))
2350+
update_payload = model_to_dict(v) if hasattr(v, "__mapper__") else schema_dump(cast("dict[str, Any]", v))
23532351

23542352
if supports_updated_at and (update_payload.get("updated_at") is None):
23552353
update_payload["updated_at"] = datetime.datetime.now(datetime.timezone.utc)
@@ -2817,7 +2815,7 @@ def upsert(
28172815
else:
28182816
# Exclude all PK columns when matching by non-PK fields
28192817
exclude_cols = set(self._pk_attr_names) if self.has_composite_pk else {self.id_attribute}
2820-
match_filter = data.to_dict(exclude=exclude_cols)
2818+
match_filter = model_to_dict(data, exclude=exclude_cols)
28212819
existing = self.get_one_or_none(
28222820
load=load, execution_options=execution_options, bind_group=bind_group, **match_filter
28232821
)
@@ -2834,7 +2832,7 @@ def upsert(
28342832
):
28352833
# Exclude all PK columns when copying field values
28362834
exclude_cols = set(self._pk_attr_names) if self.has_composite_pk else {self.id_attribute}
2837-
for field_name, new_field_value in data.to_dict(exclude=exclude_cols).items():
2835+
for field_name, new_field_value in model_to_dict(data, exclude=exclude_cols).items():
28382836
field = getattr(existing, field_name, MISSING)
28392837
if field is not MISSING and not compare_values(field, new_field_value): # pragma: no cover
28402838
setattr(existing, field_name, new_field_value)

advanced_alchemy/service/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
is_schema_or_dict_without_field,
5151
is_schema_with_field,
5252
is_schema_without_field,
53+
is_sqlmodel_table_model,
5354
schema_dump,
5455
)
5556

@@ -99,6 +100,7 @@
99100
"is_schema_or_dict_without_field",
100101
"is_schema_with_field",
101102
"is_schema_without_field",
103+
"is_sqlmodel_table_model",
102104
"model_from_dict",
103105
"schema_dump",
104106
)

advanced_alchemy/service/_async.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from sqlalchemy.sql.selectable import ForUpdateParameter
1919
from typing_extensions import Self
2020

21+
from advanced_alchemy.base import ModelProtocol, model_to_dict
2122
from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig
2223
from advanced_alchemy.exceptions import AdvancedAlchemyError, ErrorMessages, ImproperConfigurationError, RepositoryError
2324
from advanced_alchemy.filters import StatementFilter
@@ -39,6 +40,7 @@
3940
is_dto_data,
4041
is_msgspec_struct,
4142
is_pydantic_model,
43+
is_sqlmodel_table_model,
4244
)
4345
from advanced_alchemy.utils.dataclass import Empty, EmptyType
4446

@@ -477,8 +479,12 @@ async def to_model(
477479
}
478480
if operation and (op := operation_map.get(operation)):
479481
data = await op(data)
482+
if isinstance(data, self.model_type):
483+
return data
480484
if is_dict(data):
481485
return model_from_dict(self.model_type, **data)
486+
if is_sqlmodel_table_model(data):
487+
return model_from_dict(self.model_type, **model_to_dict(cast("ModelProtocol", data)))
482488
if is_pydantic_model(data):
483489
return model_from_dict(
484490
self.model_type,
@@ -1092,7 +1098,7 @@ async def get_or_upsert(
10921098
execution_options=execution_options,
10931099
uniquify=self._get_uniquify(uniquify),
10941100
bind_group=bind_group,
1095-
**validated_model.to_dict(),
1101+
**model_to_dict(validated_model),
10961102
),
10971103
)
10981104

@@ -1154,7 +1160,7 @@ async def get_and_update(
11541160
execution_options=execution_options,
11551161
uniquify=self._get_uniquify(uniquify),
11561162
bind_group=bind_group,
1157-
**validated_model.to_dict(),
1163+
**model_to_dict(validated_model),
11581164
),
11591165
)
11601166

advanced_alchemy/service/_sync.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sqlalchemy.sql.selectable import ForUpdateParameter
2020
from typing_extensions import Self
2121

22+
from advanced_alchemy.base import ModelProtocol, model_to_dict
2223
from advanced_alchemy.config.sync import SQLAlchemySyncConfig
2324
from advanced_alchemy.exceptions import AdvancedAlchemyError, ErrorMessages, ImproperConfigurationError, RepositoryError
2425
from advanced_alchemy.filters import StatementFilter
@@ -38,6 +39,7 @@
3839
is_dto_data,
3940
is_msgspec_struct,
4041
is_pydantic_model,
42+
is_sqlmodel_table_model,
4143
)
4244
from advanced_alchemy.utils.dataclass import Empty, EmptyType
4345

@@ -476,8 +478,12 @@ def to_model(
476478
}
477479
if operation and (op := operation_map.get(operation)):
478480
data = op(data)
481+
if isinstance(data, self.model_type):
482+
return data
479483
if is_dict(data):
480484
return model_from_dict(self.model_type, **data)
485+
if is_sqlmodel_table_model(data):
486+
return model_from_dict(self.model_type, **model_to_dict(cast("ModelProtocol", data)))
481487
if is_pydantic_model(data):
482488
return model_from_dict(
483489
self.model_type,
@@ -1091,7 +1097,7 @@ def get_or_upsert(
10911097
execution_options=execution_options,
10921098
uniquify=self._get_uniquify(uniquify),
10931099
bind_group=bind_group,
1094-
**validated_model.to_dict(),
1100+
**model_to_dict(validated_model),
10951101
),
10961102
)
10971103

@@ -1153,7 +1159,7 @@ def get_and_update(
11531159
execution_options=execution_options,
11541160
uniquify=self._get_uniquify(uniquify),
11551161
bind_group=bind_group,
1156-
**validated_model.to_dict(),
1162+
**model_to_dict(validated_model),
11571163
),
11581164
)
11591165

0 commit comments

Comments
 (0)