Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions advanced_alchemy/extensions/litestar/dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
MappedColumn,
NotExtension,
QueryableAttribute,
Relationship,
RelationshipDirection,
RelationshipProperty,
WriteOnlyMapped,
Expand Down Expand Up @@ -145,28 +146,39 @@ def _(
model_type_hints: dict[str, FieldDefinition],
model_name: str,
) -> list[DTOFieldDefinition]:
if not isinstance(orm_descriptor, QueryableAttribute):
if not isinstance(orm_descriptor, QueryableAttribute): # pragma: no cover
msg = f"Unexpected descriptor type for '{extension_type}': '{orm_descriptor}'"
raise NotImplementedError(msg)

elem: ElementType
if isinstance(orm_descriptor.property, ColumnProperty): # pyright: ignore[reportUnknownMemberType]
if not isinstance(orm_descriptor.property.expression, (Column, ColumnClause, Label)): # pyright: ignore[reportUnknownMemberType]
if isinstance(
orm_descriptor.property, # pyright: ignore[reportUnknownMemberType]
ColumnProperty, # pragma: no cover
):
if not isinstance(
orm_descriptor.property.expression, # pyright: ignore[reportUnknownMemberType]
(Column, ColumnClause, Label),
):
msg = f"Expected 'Column', got: '{orm_descriptor.property.expression}, {type(orm_descriptor.property.expression)}'" # pyright: ignore[reportUnknownMemberType]
raise NotImplementedError(msg)
elem = orm_descriptor.property.expression # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
elif isinstance(orm_descriptor.property, (RelationshipProperty, CompositeProperty)): # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
elem = orm_descriptor.property # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
else:
else: # pragma: no cover
msg = f"Unhandled property type: '{orm_descriptor.property}'" # pyright: ignore[reportUnknownMemberType]
raise NotImplementedError(msg)

default, default_factory = _detect_defaults(elem)

try:
if (field_definition := model_type_hints[key]).origin in {Mapped, WriteOnlyMapped, DynamicMapped}:
if (field_definition := model_type_hints[key]).origin in {
Mapped,
WriteOnlyMapped,
DynamicMapped,
Relationship,
}:
(field_definition,) = field_definition.inner_types
else:
else: # pragma: no cover
msg = f"Expected 'Mapped' origin, got: '{field_definition.origin}'"
raise NotImplementedError(msg)
except KeyError:
Expand Down Expand Up @@ -201,13 +213,13 @@ def _(
model_type_hints: dict[str, FieldDefinition],
model_name: str,
) -> list[DTOFieldDefinition]:
if not isinstance(orm_descriptor, AssociationProxy):
if not isinstance(orm_descriptor, AssociationProxy): # pragma: no cover
msg = f"Unexpected descriptor type '{orm_descriptor}' for '{extension_type}'"
raise NotImplementedError(msg)

if (field_definition := model_type_hints[key]).origin is AssociationProxy:
(field_definition,) = field_definition.inner_types
else:
else: # pragma: no cover
msg = f"Expected 'AssociationProxy' origin, got: '{field_definition.origin}'"
raise NotImplementedError(msg)

Expand Down Expand Up @@ -299,7 +311,7 @@ def generate_field_definitions(cls, model_type: type[DeclarativeBase]) -> Genera
# for each method name it is bound to. We only need to see it once, so track views of it here.
seen_hybrid_descriptors: set[hybrid_property] = set() # pyright: ignore[reportUnknownVariableType,reportMissingTypeArgument]
skipped_descriptors: set[str] = set()
for composite_property in mapper.composites:
for composite_property in mapper.composites: # pragma: no cover
for attr in composite_property.attrs:
if isinstance(attr, (MappedColumn, Column)):
skipped_descriptors.add(attr.name)
Expand Down
45 changes: 45 additions & 0 deletions tests/unit/test_extensions/test_litestar/test_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,51 @@ class B(Base):
assert all(isinstance(val, module.A) for val in model.a)


async def test_to_mapped_model_with_relationship_type_hint(
base: type[DeclarativeBase],
create_module: Callable[[str], ModuleType],
asgi_connection: Request[Any, Any, Any],
) -> None:
"""Test building a DTO with collection relationship, and parsing data."""

module = create_module(
"""
from __future__ import annotations

from typing import Dict, List, Set, Tuple, Type, List

from sqlalchemy import ForeignKey, Integer
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, Relationship
from typing_extensions import Annotated

from advanced_alchemy.extensions.litestar.dto import SQLAlchemyDTO, SQLAlchemyDTOConfig

class Base(DeclarativeBase):
id: Mapped[int] = mapped_column(primary_key=True)

class A(Base):
__tablename__ = "a"
b_id: Mapped[int] = mapped_column(ForeignKey("b.id"))

class B(Base):
__tablename__ = "b"
a: Relationship[List[A]] = relationship("A")

dto_type = SQLAlchemyDTO[Annotated[B, SQLAlchemyDTOConfig()]]
""",
)

model = await get_model_from_dto(
module.dto_type,
module.B,
asgi_connection,
b'{"id": 1, "a": [{"id": 2, "b_id": 1}, {"id": 3, "b_id": 1}]}',
)
assert isinstance(model, module.B)
assert len(model.a) == 2
assert all(isinstance(val, module.A) for val in model.a)


async def test_to_mapped_model_with_scalar_relationship(
create_module: Callable[[str], ModuleType],
asgi_connection: Request[Any, Any, Any],
Expand Down