Skip to content

Commit e53484d

Browse files
authored
feat: enable standard order by (#438)
Enables the standard `UnaryOperator` order by support in addition to the existing `OrderingPair`
1 parent ac58063 commit e53484d

7 files changed

Lines changed: 98 additions & 85 deletions

File tree

advanced_alchemy/repository/_async.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,7 +1600,7 @@ async def _list_and_count_window(
16001600
execution_options=execution_options,
16011601
)
16021602
if order_by is None:
1603-
order_by = self.order_by or []
1603+
order_by = self.order_by if self.order_by is not None else []
16041604
statement = self._apply_order_by(statement=statement, order_by=order_by)
16051605
statement = self._apply_filters(*filters, statement=statement)
16061606
statement = self._filter_select_by_kwargs(statement, kwargs)
@@ -1657,7 +1657,7 @@ async def _list_and_count_basic(
16571657
execution_options=execution_options,
16581658
)
16591659
if order_by is None:
1660-
order_by = self.order_by or []
1660+
order_by = self.order_by if self.order_by is not None else []
16611661
statement = self._apply_order_by(statement=statement, order_by=order_by)
16621662
statement = self._apply_filters(*filters, statement=statement)
16631663
statement = self._filter_select_by_kwargs(statement, kwargs)
@@ -1945,7 +1945,7 @@ async def list(
19451945
execution_options=execution_options,
19461946
)
19471947
if order_by is None:
1948-
order_by = self.order_by or []
1948+
order_by = self.order_by if self.order_by is not None else []
19491949
statement = self._apply_order_by(statement=statement, order_by=order_by)
19501950
statement = self._apply_filters(*filters, statement=statement)
19511951
statement = self._filter_select_by_kwargs(statement, kwargs)

advanced_alchemy/repository/_sync.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,7 +1601,7 @@ def _list_and_count_window(
16011601
execution_options=execution_options,
16021602
)
16031603
if order_by is None:
1604-
order_by = self.order_by or []
1604+
order_by = self.order_by if self.order_by is not None else []
16051605
statement = self._apply_order_by(statement=statement, order_by=order_by)
16061606
statement = self._apply_filters(*filters, statement=statement)
16071607
statement = self._filter_select_by_kwargs(statement, kwargs)
@@ -1656,7 +1656,7 @@ def _list_and_count_basic(
16561656
execution_options=execution_options,
16571657
)
16581658
if order_by is None:
1659-
order_by = self.order_by or []
1659+
order_by = self.order_by if self.order_by is not None else []
16601660
statement = self._apply_order_by(statement=statement, order_by=order_by)
16611661
statement = self._apply_filters(*filters, statement=statement)
16621662
statement = self._filter_select_by_kwargs(statement, kwargs)
@@ -1944,7 +1944,7 @@ def list(
19441944
execution_options=execution_options,
19451945
)
19461946
if order_by is None:
1947-
order_by = self.order_by or []
1947+
order_by = self.order_by if self.order_by is not None else []
19481948
statement = self._apply_order_by(statement=statement, order_by=order_by)
19491949
statement = self._apply_filters(*filters, statement=statement)
19501950
statement = self._filter_select_by_kwargs(statement, kwargs)

advanced_alchemy/repository/_util.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Delete,
66
Dialect,
77
Select,
8+
UnaryExpression,
89
Update,
910
)
1011
from sqlalchemy.orm import (
@@ -294,8 +295,8 @@ def _apply_order_by(
294295
self,
295296
statement: StatementTypeT,
296297
order_by: Union[
297-
list[tuple[Union[str, InstrumentedAttribute[Any]], bool]],
298-
tuple[Union[str, InstrumentedAttribute[Any]], bool],
298+
OrderingPair,
299+
list[OrderingPair],
299300
],
300301
) -> StatementTypeT:
301302
"""Apply ordering to a SQL statement.
@@ -311,13 +312,16 @@ def _apply_order_by(
311312
"""
312313
if not isinstance(order_by, list):
313314
order_by = [order_by]
314-
for order_field, is_desc in order_by:
315-
field = get_instrumented_attr(self.model_type, order_field)
316-
statement = self._order_by_attribute(statement, field, is_desc)
315+
for order_field in order_by:
316+
if isinstance(order_field, UnaryExpression):
317+
statement = statement.order_by(order_field) # type: ignore
318+
else:
319+
field = get_instrumented_attr(self.model_type, order_field[0])
320+
statement = self._order_by_attribute(statement, field, order_field[1])
317321
return statement
318322

323+
@staticmethod
319324
def _order_by_attribute(
320-
self,
321325
statement: StatementTypeT,
322326
field: InstrumentedAttribute[Any],
323327
is_desc: bool,

advanced_alchemy/repository/memory/_async.py

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def __init__(
8585
auto_expunge: bool = False,
8686
auto_refresh: bool = True,
8787
auto_commit: bool = False,
88-
order_by: Union[list[OrderingPair], OrderingPair, None] = None,
89-
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
88+
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
89+
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
9090
wrap_exceptions: bool = True,
9191
load: Optional[LoadSpec] = None,
9292
execution_options: Optional[dict[str, Any]] = None,
@@ -115,8 +115,8 @@ def __init_subclass__(cls) -> None:
115115

116116
@staticmethod
117117
def _get_error_messages(
118-
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
119-
default_messages: Union[ErrorMessages, None, EmptyType] = Empty,
118+
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
119+
default_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
120120
) -> Optional[ErrorMessages]:
121121
if error_messages == Empty:
122122
error_messages = None
@@ -204,19 +204,20 @@ def set_id_attribute_value(
204204
def _exclude_unused_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
205205
return {key: value for key, value in kwargs.items() if key not in self._exclude_kwargs}
206206

207-
def _apply_limit_offset_pagination(self, result: list[ModelT], limit: int, offset: int) -> list[ModelT]:
207+
@staticmethod
208+
def _apply_limit_offset_pagination(result: list[ModelT], limit: int, offset: int) -> list[ModelT]:
208209
return result[offset:limit]
209210

211+
@staticmethod
210212
def _filter_in_collection(
211-
self,
212213
result: list[ModelT],
213214
field_name: str,
214215
values: abc.Collection[Any],
215216
) -> list[ModelT]:
216217
return [item for item in result if getattr(item, field_name) in values]
217218

219+
@staticmethod
218220
def _filter_not_in_collection(
219-
self,
220221
result: list[ModelT],
221222
field_name: str,
222223
values: abc.Collection[Any],
@@ -225,8 +226,8 @@ def _filter_not_in_collection(
225226
return result
226227
return [item for item in result if getattr(item, field_name) not in values]
227228

229+
@staticmethod
228230
def _filter_on_datetime_field(
229-
self,
230231
result: list[ModelT],
231232
field_name: str,
232233
before: Optional[datetime.datetime] = None,
@@ -247,8 +248,8 @@ def _filter_on_datetime_field(
247248
result_.append(item)
248249
return result_
249250

251+
@staticmethod
250252
def _filter_by_like(
251-
self,
252253
result: list[ModelT],
253254
field_name: Union[str, set[str]],
254255
value: str,
@@ -267,8 +268,8 @@ def _filter_by_like(
267268
)
268269
return list(set(items))
269270

271+
@staticmethod
270272
def _filter_by_not_like(
271-
self,
272273
result: list[ModelT],
273274
field_name: Union[str, set[str]],
274275
value: str,
@@ -294,13 +295,14 @@ def _filter_result_by_kwargs(
294295
kwargs: Union[dict[Any, Any], Iterable[tuple[Any, Any]]],
295296
) -> list[ModelT]:
296297
kwargs_: dict[Any, Any] = kwargs if isinstance(kwargs, dict) else dict(*kwargs)
297-
kwargs_ = self._exclude_unused_kwargs(kwargs_)
298+
kwargs_ = self._exclude_unused_kwargs(kwargs_) # pyright: ignore
298299
try:
299-
return [item for item in result if all(getattr(item, field) == value for field, value in kwargs_.items())]
300+
return [item for item in result if all(getattr(item, field) == value for field, value in kwargs_.items())] # pyright: ignore
300301
except AttributeError as error:
301302
raise RepositoryError from error
302303

303-
def _order_by(self, result: list[ModelT], field_name: str, sort_desc: bool = False) -> list[ModelT]:
304+
@staticmethod
305+
def _order_by(result: list[ModelT], field_name: str, sort_desc: bool = False) -> list[ModelT]:
304306
return sorted(result, key=lambda item: getattr(item, field_name), reverse=sort_desc)
305307

306308
def _apply_filters(
@@ -418,7 +420,7 @@ async def get(
418420
auto_expunge: Optional[bool] = None,
419421
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
420422
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
421-
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
423+
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
422424
load: Optional[LoadSpec] = None,
423425
execution_options: Optional[dict[str, Any]] = None,
424426
uniquify: Optional[bool] = None,
@@ -430,7 +432,7 @@ async def get_one(
430432
*filters: Union[StatementFilter, ColumnElement[bool]],
431433
auto_expunge: Optional[bool] = None,
432434
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
433-
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
435+
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
434436
load: Optional[LoadSpec] = None,
435437
execution_options: Optional[dict[str, Any]] = None,
436438
uniquify: Optional[bool] = None,
@@ -443,7 +445,7 @@ async def get_one_or_none(
443445
*filters: Union[StatementFilter, ColumnElement[bool]],
444446
auto_expunge: Optional[bool] = None,
445447
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
446-
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
448+
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
447449
load: Optional[LoadSpec] = None,
448450
execution_options: Optional[dict[str, Any]] = None,
449451
uniquify: Optional[bool] = None,
@@ -465,7 +467,7 @@ async def get_or_upsert(
465467
auto_commit: Optional[bool] = None,
466468
auto_expunge: Optional[bool] = None,
467469
auto_refresh: Optional[bool] = None,
468-
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
470+
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
469471
load: Optional[LoadSpec] = None,
470472
execution_options: Optional[dict[str, Any]] = None,
471473
uniquify: Optional[bool] = None,
@@ -501,7 +503,7 @@ async def get_and_update(
501503
auto_commit: Optional[bool] = None,
502504
auto_expunge: Optional[bool] = None,
503505
auto_refresh: Optional[bool] = None,
504-
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
506+
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
505507
load: Optional[LoadSpec] = None,
506508
execution_options: Optional[dict[str, Any]] = None,
507509
uniquify: Optional[bool] = None,
@@ -552,7 +554,7 @@ async def add(
552554
auto_commit: Optional[bool] = None,
553555
auto_expunge: Optional[bool] = None,
554556
auto_refresh: Optional[bool] = None,
555-
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
557+
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
556558
) -> ModelT:
557559
try:
558560
self.__database__.add(self.model_type, data)
@@ -567,7 +569,7 @@ async def add_many(
567569
*,
568570
auto_commit: Optional[bool] = None,
569571
auto_expunge: Optional[bool] = None,
570-
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
572+
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
571573
) -> list[ModelT]:
572574
for obj in data:
573575
await self.add(obj) # pyright: ignore[reportCallIssue]
@@ -582,8 +584,8 @@ async def update(
582584
auto_commit: Optional[bool] = None,
583585
auto_expunge: Optional[bool] = None,
584586
auto_refresh: Optional[bool] = None,
585-
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
586-
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
587+
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
588+
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
587589
load: Optional[LoadSpec] = None,
588590
execution_options: Optional[dict[str, Any]] = None,
589591
uniquify: Optional[bool] = None,
@@ -597,7 +599,7 @@ async def update_many(
597599
*,
598600
auto_commit: Optional[bool] = None,
599601
auto_expunge: Optional[bool] = None,
600-
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
602+
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
601603
load: Optional[LoadSpec] = None,
602604
execution_options: Optional[dict[str, Any]] = None,
603605
uniquify: Optional[bool] = None,
@@ -610,8 +612,8 @@ async def delete(
610612
*,
611613
auto_commit: Optional[bool] = None,
612614
auto_expunge: Optional[bool] = None,
613-
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
614-
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
615+
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
616+
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
615617
load: Optional[LoadSpec] = None,
616618
execution_options: Optional[dict[str, Any]] = None,
617619
uniquify: Optional[bool] = None,
@@ -627,9 +629,9 @@ async def delete_many(
627629
*,
628630
auto_commit: Optional[bool] = None,
629631
auto_expunge: Optional[bool] = None,
630-
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
632+
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
631633
chunk_size: Optional[int] = None,
632-
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
634+
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
633635
load: Optional[LoadSpec] = None,
634636
execution_options: Optional[dict[str, Any]] = None,
635637
uniquify: Optional[bool] = None,
@@ -646,7 +648,7 @@ async def delete_where(
646648
*filters: Union[StatementFilter, ColumnElement[bool]],
647649
auto_commit: Optional[bool] = None,
648650
auto_expunge: Optional[bool] = None,
649-
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
651+
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
650652
sanity_check: bool = True,
651653
load: Optional[LoadSpec] = None,
652654
execution_options: Optional[dict[str, Any]] = None,
@@ -668,8 +670,8 @@ async def upsert(
668670
auto_expunge: Optional[bool] = None,
669671
auto_commit: Optional[bool] = None,
670672
auto_refresh: Optional[bool] = None,
671-
match_fields: Union[list[str], str, None] = None,
672-
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
673+
match_fields: Optional[Union[list[str], str]] = None,
674+
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
673675
load: Optional[LoadSpec] = None,
674676
execution_options: Optional[dict[str, Any]] = None,
675677
uniquify: Optional[bool] = None,
@@ -686,8 +688,8 @@ async def upsert_many(
686688
auto_expunge: Optional[bool] = None,
687689
auto_commit: Optional[bool] = None,
688690
no_merge: bool = False,
689-
match_fields: Union[list[str], str, None] = None,
690-
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
691+
match_fields: Optional[Union[list[str], str]] = None,
692+
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
691693
load: Optional[LoadSpec] = None,
692694
execution_options: Optional[dict[str, Any]] = None,
693695
uniquify: Optional[bool] = None,
@@ -700,8 +702,8 @@ async def list_and_count(
700702
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
701703
auto_expunge: Optional[bool] = None,
702704
count_with_window_function: Optional[bool] = None,
703-
order_by: Union[list[OrderingPair], OrderingPair, None] = None,
704-
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
705+
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
706+
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
705707
load: Optional[LoadSpec] = None,
706708
execution_options: Optional[dict[str, Any]] = None,
707709
uniquify: Optional[bool] = None,
@@ -727,13 +729,12 @@ class SQLAlchemyAsyncMockSlugRepository(
727729
async def get_by_slug(
728730
self,
729731
slug: str,
730-
error_messages: Union[ErrorMessages, None, EmptyType] = Empty,
732+
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
731733
load: Optional[LoadSpec] = None,
732734
execution_options: Optional[dict[str, Any]] = None,
733735
uniquify: Optional[bool] = None,
734736
**kwargs: Any,
735737
) -> Union[ModelT, None]:
736-
"""Select record by slug value."""
737738
return await self.get_one_or_none(slug=slug)
738739

739740
async def get_available_slug(

0 commit comments

Comments
 (0)