Skip to content

Commit 453f695

Browse files
authored
feat(litestar): use property in SQLAlchemyDTO with MappedAsDataclass (#447)
Allow better compatibility with `MappedAsDataclass` and the `SQLALchemyDTO`
1 parent a22a285 commit 453f695

3 files changed

Lines changed: 251 additions & 6 deletions

File tree

advanced_alchemy/extensions/litestar/dto.py

Lines changed: 91 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
# ruff: noqa: C901
2+
import inspect as stdlib_inspect
3+
import logging
14
from collections.abc import Collection, Generator
25
from collections.abc import Set as AbstractSet
36
from dataclasses import asdict, dataclass, field, replace
4-
from functools import singledispatchmethod
7+
from functools import cached_property, singledispatchmethod
58
from typing import (
69
Any,
710
ClassVar,
@@ -44,6 +47,8 @@
4447

4548
__all__ = ("SQLAlchemyDTO",)
4649

50+
logger = logging.getLogger(__name__)
51+
4752
T = TypeVar("T", bound="Union[DeclarativeBase, Collection[DeclarativeBase]]")
4853

4954
ElementType: TypeAlias = Union[
@@ -231,7 +236,9 @@ def _(
231236
default=Empty,
232237
),
233238
default_factory=None,
234-
dto_field=orm_descriptor.info.get(DTO_FIELD_META_KEY, DTOField(mark=Mark.READ_ONLY)),
239+
dto_field=orm_descriptor.info.get(
240+
DTO_FIELD_META_KEY, DTOField(mark=Mark.READ_ONLY)
241+
), # Mark as read-only
235242
model_name=model_name,
236243
),
237244
]
@@ -260,7 +267,9 @@ def _(
260267
default=Empty,
261268
),
262269
default_factory=None,
263-
dto_field=orm_descriptor.info.get(DTO_FIELD_META_KEY, DTOField(mark=Mark.READ_ONLY)),
270+
dto_field=orm_descriptor.info.get(
271+
DTO_FIELD_META_KEY, DTOField(mark=Mark.READ_ONLY)
272+
), # Mark as read-only
264273
model_name=model_name,
265274
),
266275
]
@@ -275,13 +284,61 @@ def _(
275284
default=Empty,
276285
),
277286
default_factory=None,
278-
dto_field=orm_descriptor.info.get(DTO_FIELD_META_KEY, DTOField(mark=Mark.WRITE_ONLY)),
287+
dto_field=orm_descriptor.info.get(
288+
DTO_FIELD_META_KEY, DTOField(mark=Mark.WRITE_ONLY)
289+
), # Mark as read-only
279290
model_name=model_name,
280291
),
281292
)
282293

283294
return field_defs
284295

296+
@classmethod
297+
def get_property_fields(cls, model_type: "type[DeclarativeBase]") -> "dict[str, FieldDefinition]":
298+
"""Get fields defined as @property or @cached_property on the model.
299+
300+
Uses inspect.getmembers() to detect properties from the model class and mixins.
301+
Properties are marked read-only; setter support is not implemented.
302+
303+
Args:
304+
model_type: The SQLAlchemy model type to extract properties from.
305+
306+
Returns:
307+
A dictionary mapping property names to their field definitions.
308+
"""
309+
namespace = cls.get_model_namespace(model_type)
310+
sqla_internal_properties = {"awaitable_attrs", "registry", "metadata"}
311+
312+
properties: dict[str, FieldDefinition] = {}
313+
for name, member in stdlib_inspect.getmembers(
314+
model_type, predicate=lambda x: isinstance(x, (property, cached_property))
315+
):
316+
if name in sqla_internal_properties:
317+
continue
318+
319+
if isinstance(member, cached_property):
320+
func = member.func
321+
elif isinstance(member, property):
322+
if member.fget is None:
323+
continue
324+
func = member.fget
325+
else:
326+
continue
327+
328+
try:
329+
sig = ParsedSignature.from_fn(func, namespace)
330+
properties[name] = replace(sig.return_type, name=name)
331+
except (AttributeError, TypeError, ValueError) as e:
332+
logger.debug(
333+
"could not parse type hint for property %s.%s: %s, using Any type",
334+
model_type.__name__,
335+
name,
336+
e,
337+
)
338+
properties[name] = FieldDefinition.from_annotation(Any, name=name)
339+
340+
return properties
341+
285342
@classmethod
286343
def generate_field_definitions(cls, model_type: type[DeclarativeBase]) -> Generator[DTOFieldDefinition, None, None]:
287344
"""Generate DTO field definitions from a SQLAlchemy model.
@@ -315,6 +372,9 @@ def generate_field_definitions(cls, model_type: type[DeclarativeBase]) -> Genera
315372
skipped_descriptors.add(attr.name)
316373
elif isinstance(attr, str):
317374
skipped_descriptors.add(attr)
375+
376+
yielded_sqla_keys: set[str] = set() # Keep track of keys yielded by SQLAlchemy logic
377+
318378
for key, orm_descriptor in mapper.all_orm_descriptors.items():
319379
if is_hybrid_property := isinstance(orm_descriptor, hybrid_property):
320380
if orm_descriptor in seen_hybrid_descriptors:
@@ -328,7 +388,12 @@ def generate_field_definitions(cls, model_type: type[DeclarativeBase]) -> Genera
328388
should_skip_descriptor = False
329389
dto_field: Optional[DTOField] = None
330390
if hasattr(orm_descriptor, "property"): # pyright: ignore[reportUnknownArgumentType]
331-
dto_field = orm_descriptor.property.info.get(DTO_FIELD_META_KEY) # pyright: ignore
391+
# Access info safely, checking if property exists first
392+
prop = getattr(orm_descriptor, "property", None) # pyright: ignore[reportUnknownArgumentType]
393+
if prop and hasattr(prop, "info"):
394+
dto_field = prop.info.get(DTO_FIELD_META_KEY)
395+
elif hasattr(orm_descriptor, "info"): # pyright: ignore[reportUnknownArgumentType]
396+
dto_field = orm_descriptor.info.get(DTO_FIELD_META_KEY) # pyright: ignore[reportUnknownArgumentType,reportUnknownMemberType,reportAttributeAccessIssue,reportUnknownVariableType]
332397

333398
# Case 1
334399
is_field_marked_not_private = dto_field and dto_field.mark is not Mark.PRIVATE # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
@@ -353,13 +418,33 @@ def generate_field_definitions(cls, model_type: type[DeclarativeBase]) -> Genera
353418
if should_skip_descriptor:
354419
continue
355420

356-
yield from cls.handle_orm_descriptor(
421+
# Yield definitions from SQLAlchemy descriptor handling
422+
definitions = cls.handle_orm_descriptor(
357423
orm_descriptor.extension_type,
358424
key,
359425
orm_descriptor,
360426
model_type_hints,
361427
model_name,
362428
)
429+
for definition in definitions:
430+
yielded_sqla_keys.add(definition.name) # Track yielded key
431+
yield definition
432+
433+
property_fields = cls.get_property_fields(model_type)
434+
for key, property_field_definition in property_fields.items():
435+
if key.startswith("_") or key in yielded_sqla_keys:
436+
continue
437+
438+
yield DTOFieldDefinition.from_field_definition(
439+
field_definition=replace(
440+
property_field_definition,
441+
name=key,
442+
default=Empty,
443+
),
444+
model_name=model_name,
445+
default_factory=None,
446+
dto_field=DTOField(mark=Mark.READ_ONLY),
447+
)
363448

364449
@classmethod
365450
def detect_nested_field(cls, field_definition: FieldDefinition) -> bool:

docs/usage/frameworks/litestar.rst

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,47 @@ Define your SQLAlchemy models using Advanced Alchemy's enhanced base classes:
7070
author_id: Mapped[UUID] = mapped_column(ForeignKey("author.id"))
7171
author: Mapped[AuthorModel] = relationship(lazy="joined", innerjoin=True, viewonly=True)
7272
73+
Using Properties with DTOs
74+
---------------------------
75+
76+
SQLAlchemyDTO includes Python ``@property`` and ``@functools.cached_property`` decorated methods as read-only fields.
77+
78+
.. code-block:: python
79+
80+
from functools import cached_property
81+
from sqlalchemy.orm import Mapped, mapped_column, MappedAsDataclass
82+
from advanced_alchemy.extensions.litestar import base, SQLAlchemyDTO
83+
84+
class UserModel(base.UUIDAuditBase, MappedAsDataclass):
85+
__tablename__ = "user"
86+
87+
first_name: Mapped[str]
88+
last_name: Mapped[str]
89+
90+
@property
91+
def full_name(self) -> str:
92+
return f"{self.first_name} {self.last_name}"
93+
94+
@cached_property
95+
def name_length(self) -> int:
96+
return len(self.full_name)
97+
98+
# DTO includes: id, created_at, updated_at, first_name, last_name,
99+
# full_name (read-only), name_length (read-only)
100+
UserDTO = SQLAlchemyDTO[UserModel]
101+
102+
Property handling characteristics:
103+
104+
- Detected from model class and mixins
105+
- Marked as ``READ_ONLY`` (cannot be set via DTO)
106+
- Type inferred from return type annotations
107+
- Private properties (starting with ``_``) excluded
108+
- Skipped if already handled by SQLAlchemy descriptors (e.g., ``hybrid_property``)
109+
110+
.. note::
111+
112+
Properties with setters (``@property.setter``) are marked ``READ_ONLY``. Setter support is not implemented.
113+
73114
Pydantic Schemas
74115
----------------
75116

tests/unit/test_extensions/test_litestar/test_dto_integration.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass
4+
from functools import cached_property
45
from types import ModuleType
56
from typing import Annotated, Any, Callable, Optional
67
from uuid import UUID
@@ -230,6 +231,124 @@ async def get_handler() -> ModelWithFunc:
230231
assert response_callback
231232

232233

234+
def test_dto_includes_simple_property() -> None:
235+
"""Test that @property decorated methods appear in DTO as read-only fields."""
236+
237+
class ModelWithProperty(Base):
238+
__tablename__ = "model_with_property"
239+
240+
first_name: Mapped[str] = mapped_column()
241+
last_name: Mapped[str] = mapped_column()
242+
243+
@property
244+
def full_name(self) -> str:
245+
return f"{self.first_name} {self.last_name}"
246+
247+
config = SQLAlchemyDTOConfig()
248+
dto = SQLAlchemyDTO[Annotated[ModelWithProperty, config]]
249+
field_defs = list(dto.generate_field_definitions(ModelWithProperty))
250+
field_names = {f.name for f in field_defs}
251+
252+
assert "full_name" in field_names
253+
254+
full_name_field = next(f for f in field_defs if f.name == "full_name")
255+
assert full_name_field.dto_field.mark == Mark.READ_ONLY
256+
257+
258+
def test_dto_includes_cached_property() -> None:
259+
"""Test that @cached_property decorated methods appear in DTO as read-only fields."""
260+
261+
class ModelWithCachedProperty(Base):
262+
__tablename__ = "model_with_cached_property"
263+
264+
value: Mapped[int] = mapped_column()
265+
266+
@cached_property
267+
def expensive_calculation(self) -> int:
268+
return self.value * 2
269+
270+
config = SQLAlchemyDTOConfig()
271+
dto = SQLAlchemyDTO[Annotated[ModelWithCachedProperty, config]]
272+
field_defs = list(dto.generate_field_definitions(ModelWithCachedProperty))
273+
field_names = {f.name for f in field_defs}
274+
275+
assert "expensive_calculation" in field_names
276+
277+
field = next(f for f in field_defs if f.name == "expensive_calculation")
278+
assert field.dto_field.mark == Mark.READ_ONLY
279+
280+
281+
def test_dto_property_with_setter_is_read_only() -> None:
282+
"""Test that properties with setters are marked READ_ONLY (setter support not implemented)."""
283+
284+
class ModelWithPropertySetter(Base):
285+
__tablename__ = "model_with_property_setter"
286+
287+
_internal_value: Mapped[int] = mapped_column(default=0)
288+
289+
@property
290+
def value(self) -> int:
291+
return self._internal_value
292+
293+
@value.setter
294+
def value(self, new_value: int) -> None:
295+
self._internal_value = new_value
296+
297+
config = SQLAlchemyDTOConfig()
298+
dto = SQLAlchemyDTO[Annotated[ModelWithPropertySetter, config]]
299+
field_defs = list(dto.generate_field_definitions(ModelWithPropertySetter))
300+
field_names = {f.name for f in field_defs}
301+
302+
assert "value" in field_names
303+
304+
field = next(f for f in field_defs if f.name == "value")
305+
assert field.dto_field.mark == Mark.READ_ONLY
306+
307+
308+
def test_dto_skips_private_properties() -> None:
309+
"""Test that properties starting with _ are excluded from DTO."""
310+
311+
class ModelWithPrivateProperty(Base):
312+
__tablename__ = "model_with_private_property"
313+
314+
@property
315+
def public_prop(self) -> str:
316+
return "public"
317+
318+
@property
319+
def _private_prop(self) -> str:
320+
return "private"
321+
322+
config = SQLAlchemyDTOConfig()
323+
dto = SQLAlchemyDTO[Annotated[ModelWithPrivateProperty, config]]
324+
field_defs = list(dto.generate_field_definitions(ModelWithPrivateProperty))
325+
field_names = {f.name for f in field_defs}
326+
327+
assert "public_prop" in field_names
328+
assert "_private_prop" not in field_names
329+
330+
331+
def test_dto_handles_untyped_property() -> None:
332+
"""Test that properties without type hints are included with Any type."""
333+
334+
class ModelWithUntypedProperty(Base):
335+
__tablename__ = "model_with_untyped_property"
336+
337+
@property
338+
def untyped_prop(self): # type: ignore[no-untyped-def]
339+
return "value"
340+
341+
config = SQLAlchemyDTOConfig()
342+
dto = SQLAlchemyDTO[Annotated[ModelWithUntypedProperty, config]]
343+
field_defs = list(dto.generate_field_definitions(ModelWithUntypedProperty))
344+
field_names = {f.name for f in field_defs}
345+
346+
assert "untyped_prop" in field_names
347+
348+
field = next(f for f in field_defs if f.name == "untyped_prop")
349+
assert field is not None
350+
351+
233352
def test_dto_with_association_proxy(create_module: Callable[[str], ModuleType]) -> None:
234353
module = create_module(
235354
"""

0 commit comments

Comments
 (0)