Skip to content

Commit 17fa2a0

Browse files
authored
fix(dto): properly serialize Relationship type hints (#422)
Adds `sqlalchemy.orm.Relationship` to the supported type hints for the `SQLAlchemyDTO`
1 parent 47639a0 commit 17fa2a0

2 files changed

Lines changed: 66 additions & 9 deletions

File tree

  • advanced_alchemy/extensions/litestar
  • tests/unit/test_extensions/test_litestar

advanced_alchemy/extensions/litestar/dto.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
MappedColumn,
3333
NotExtension,
3434
QueryableAttribute,
35+
Relationship,
3536
RelationshipDirection,
3637
RelationshipProperty,
3738
WriteOnlyMapped,
@@ -145,28 +146,39 @@ def _(
145146
model_type_hints: dict[str, FieldDefinition],
146147
model_name: str,
147148
) -> list[DTOFieldDefinition]:
148-
if not isinstance(orm_descriptor, QueryableAttribute):
149+
if not isinstance(orm_descriptor, QueryableAttribute): # pragma: no cover
149150
msg = f"Unexpected descriptor type for '{extension_type}': '{orm_descriptor}'"
150151
raise NotImplementedError(msg)
151152

152153
elem: ElementType
153-
if isinstance(orm_descriptor.property, ColumnProperty): # pyright: ignore[reportUnknownMemberType]
154-
if not isinstance(orm_descriptor.property.expression, (Column, ColumnClause, Label)): # pyright: ignore[reportUnknownMemberType]
154+
if isinstance(
155+
orm_descriptor.property, # pyright: ignore[reportUnknownMemberType]
156+
ColumnProperty, # pragma: no cover
157+
):
158+
if not isinstance(
159+
orm_descriptor.property.expression, # pyright: ignore[reportUnknownMemberType]
160+
(Column, ColumnClause, Label),
161+
):
155162
msg = f"Expected 'Column', got: '{orm_descriptor.property.expression}, {type(orm_descriptor.property.expression)}'" # pyright: ignore[reportUnknownMemberType]
156163
raise NotImplementedError(msg)
157164
elem = orm_descriptor.property.expression # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
158165
elif isinstance(orm_descriptor.property, (RelationshipProperty, CompositeProperty)): # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
159166
elem = orm_descriptor.property # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
160-
else:
167+
else: # pragma: no cover
161168
msg = f"Unhandled property type: '{orm_descriptor.property}'" # pyright: ignore[reportUnknownMemberType]
162169
raise NotImplementedError(msg)
163170

164171
default, default_factory = _detect_defaults(elem)
165172

166173
try:
167-
if (field_definition := model_type_hints[key]).origin in {Mapped, WriteOnlyMapped, DynamicMapped}:
174+
if (field_definition := model_type_hints[key]).origin in {
175+
Mapped,
176+
WriteOnlyMapped,
177+
DynamicMapped,
178+
Relationship,
179+
}:
168180
(field_definition,) = field_definition.inner_types
169-
else:
181+
else: # pragma: no cover
170182
msg = f"Expected 'Mapped' origin, got: '{field_definition.origin}'"
171183
raise NotImplementedError(msg)
172184
except KeyError:
@@ -201,13 +213,13 @@ def _(
201213
model_type_hints: dict[str, FieldDefinition],
202214
model_name: str,
203215
) -> list[DTOFieldDefinition]:
204-
if not isinstance(orm_descriptor, AssociationProxy):
216+
if not isinstance(orm_descriptor, AssociationProxy): # pragma: no cover
205217
msg = f"Unexpected descriptor type '{orm_descriptor}' for '{extension_type}'"
206218
raise NotImplementedError(msg)
207219

208220
if (field_definition := model_type_hints[key]).origin is AssociationProxy:
209221
(field_definition,) = field_definition.inner_types
210-
else:
222+
else: # pragma: no cover
211223
msg = f"Expected 'AssociationProxy' origin, got: '{field_definition.origin}'"
212224
raise NotImplementedError(msg)
213225

@@ -299,7 +311,7 @@ def generate_field_definitions(cls, model_type: type[DeclarativeBase]) -> Genera
299311
# for each method name it is bound to. We only need to see it once, so track views of it here.
300312
seen_hybrid_descriptors: set[hybrid_property] = set() # pyright: ignore[reportUnknownVariableType,reportMissingTypeArgument]
301313
skipped_descriptors: set[str] = set()
302-
for composite_property in mapper.composites:
314+
for composite_property in mapper.composites: # pragma: no cover
303315
for attr in composite_property.attrs:
304316
if isinstance(attr, (MappedColumn, Column)):
305317
skipped_descriptors.add(attr.name)

tests/unit/test_extensions/test_litestar/test_dto.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,51 @@ class B(Base):
313313
assert all(isinstance(val, module.A) for val in model.a)
314314

315315

316+
async def test_to_mapped_model_with_relationship_type_hint(
317+
base: type[DeclarativeBase],
318+
create_module: Callable[[str], ModuleType],
319+
asgi_connection: Request[Any, Any, Any],
320+
) -> None:
321+
"""Test building a DTO with collection relationship, and parsing data."""
322+
323+
module = create_module(
324+
"""
325+
from __future__ import annotations
326+
327+
from typing import Dict, List, Set, Tuple, Type, List
328+
329+
from sqlalchemy import ForeignKey, Integer
330+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, Relationship
331+
from typing_extensions import Annotated
332+
333+
from advanced_alchemy.extensions.litestar.dto import SQLAlchemyDTO, SQLAlchemyDTOConfig
334+
335+
class Base(DeclarativeBase):
336+
id: Mapped[int] = mapped_column(primary_key=True)
337+
338+
class A(Base):
339+
__tablename__ = "a"
340+
b_id: Mapped[int] = mapped_column(ForeignKey("b.id"))
341+
342+
class B(Base):
343+
__tablename__ = "b"
344+
a: Relationship[List[A]] = relationship("A")
345+
346+
dto_type = SQLAlchemyDTO[Annotated[B, SQLAlchemyDTOConfig()]]
347+
""",
348+
)
349+
350+
model = await get_model_from_dto(
351+
module.dto_type,
352+
module.B,
353+
asgi_connection,
354+
b'{"id": 1, "a": [{"id": 2, "b_id": 1}, {"id": 3, "b_id": 1}]}',
355+
)
356+
assert isinstance(model, module.B)
357+
assert len(model.a) == 2
358+
assert all(isinstance(val, module.A) for val in model.a)
359+
360+
316361
async def test_to_mapped_model_with_scalar_relationship(
317362
create_module: Callable[[str], ModuleType],
318363
asgi_connection: Request[Any, Any, Any],

0 commit comments

Comments
 (0)