44import string
55from collections import abc
66from collections .abc import Iterable
7- from typing import Any , Optional , Union , cast , overload
7+ from typing import Any , List , Optional , Union , cast , overload
88from unittest .mock import create_autospec
99
1010from sqlalchemy import (
@@ -58,7 +58,7 @@ class SQLAlchemyAsyncMockRepository(SQLAlchemyAsyncRepositoryProtocol[ModelT]):
5858 """Default execution options for the repository."""
5959 model_type : type [ModelT ]
6060 id_attribute : Any = "id"
61- match_fields : Optional [Union [list [str ], str ]] = None
61+ match_fields : Optional [Union [List [str ], str ]] = None
6262 uniquify : bool = False
6363 _exclude_kwargs : set [str ] = {
6464 "statement" ,
@@ -87,7 +87,7 @@ def __init__(
8787 auto_expunge : bool = False ,
8888 auto_refresh : bool = True ,
8989 auto_commit : bool = False ,
90- order_by : Optional [Union [list [OrderingPair ], OrderingPair ]] = None ,
90+ order_by : Optional [Union [List [OrderingPair ], OrderingPair ]] = None ,
9191 error_messages : Optional [Union [ErrorMessages , EmptyType ]] = Empty ,
9292 wrap_exceptions : bool = True ,
9393 load : Optional [LoadSpec ] = None ,
@@ -210,7 +210,7 @@ def _exclude_unused_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
210210 return {key : value for key , value in kwargs .items () if key not in self ._exclude_kwargs }
211211
212212 @staticmethod
213- def _apply_limit_offset_pagination (result : list [ModelT ], limit : int , offset : int ) -> list [ModelT ]:
213+ def _apply_limit_offset_pagination (result : List [ModelT ], limit : int , offset : int ) -> List [ModelT ]:
214214 return result [offset :limit ]
215215
216216 def _extract_field_name (self , field : "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]" ) -> str :
@@ -234,35 +234,35 @@ def _extract_field_name(self, field: "Union[str, ColumnElement[Any], Instrumente
234234
235235 def _filter_in_collection (
236236 self ,
237- result : list [ModelT ],
237+ result : List [ModelT ],
238238 field_name : "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]" ,
239239 values : abc .Collection [Any ],
240- ) -> list [ModelT ]:
240+ ) -> List [ModelT ]:
241241 field_str = self ._extract_field_name (field_name )
242242 return [item for item in result if getattr (item , field_str ) in values ]
243243
244244 def _filter_not_in_collection (
245245 self ,
246- result : list [ModelT ],
246+ result : List [ModelT ],
247247 field_name : "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]" ,
248248 values : abc .Collection [Any ],
249- ) -> list [ModelT ]:
249+ ) -> List [ModelT ]:
250250 if not values :
251251 return result
252252 field_str = self ._extract_field_name (field_name )
253253 return [item for item in result if getattr (item , field_str ) not in values ]
254254
255255 def _filter_on_datetime_field (
256256 self ,
257- result : list [ModelT ],
257+ result : List [ModelT ],
258258 field_name : "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]" ,
259259 before : Optional [datetime .datetime ] = None ,
260260 after : Optional [datetime .datetime ] = None ,
261261 on_or_before : Optional [datetime .datetime ] = None ,
262262 on_or_after : Optional [datetime .datetime ] = None ,
263- ) -> list [ModelT ]:
263+ ) -> List [ModelT ]:
264264 field_str = self ._extract_field_name (field_name )
265- result_ : list [ModelT ] = []
265+ result_ : List [ModelT ] = []
266266 for item in result :
267267 attr : datetime .datetime = getattr (item , field_str )
268268 if before is not None and attr < before :
@@ -277,14 +277,14 @@ def _filter_on_datetime_field(
277277
278278 @staticmethod
279279 def _filter_by_like (
280- result : list [ModelT ],
280+ result : List [ModelT ],
281281 field_name : Union [str , set [str ]],
282282 value : str ,
283283 ignore_case : bool ,
284- ) -> list [ModelT ]:
284+ ) -> List [ModelT ]:
285285 pattern = re .compile (rf".*{ value } .*" , re .IGNORECASE ) if ignore_case else re .compile (rf".*{ value } .*" )
286286 fields = {field_name } if isinstance (field_name , str ) else field_name
287- items : list [ModelT ] = []
287+ items : List [ModelT ] = []
288288 for field in fields :
289289 items .extend (
290290 [
@@ -297,14 +297,14 @@ def _filter_by_like(
297297
298298 @staticmethod
299299 def _filter_by_not_like (
300- result : list [ModelT ],
300+ result : List [ModelT ],
301301 field_name : Union [str , set [str ]],
302302 value : str ,
303303 ignore_case : bool ,
304- ) -> list [ModelT ]:
304+ ) -> List [ModelT ]:
305305 pattern = re .compile (rf".*{ value } .*" , re .IGNORECASE ) if ignore_case else re .compile (rf".*{ value } .*" )
306306 fields = {field_name } if isinstance (field_name , str ) else field_name
307- items : list [ModelT ] = []
307+ items : List [ModelT ] = []
308308 for field in fields :
309309 items .extend (
310310 [
@@ -320,7 +320,7 @@ def _filter_result_by_kwargs(
320320 result : Iterable [ModelT ],
321321 / ,
322322 kwargs : Union [dict [Any , Any ], Iterable [tuple [Any , Any ]]],
323- ) -> list [ModelT ]:
323+ ) -> List [ModelT ]:
324324 kwargs_ : dict [Any , Any ] = kwargs if isinstance (kwargs , dict ) else dict (* kwargs ) # pyright: ignore
325325 kwargs_ = self ._exclude_unused_kwargs (kwargs_ ) # pyright: ignore
326326 try :
@@ -330,18 +330,18 @@ def _filter_result_by_kwargs(
330330
331331 def _order_by (
332332 self ,
333- result : list [ModelT ],
333+ result : List [ModelT ],
334334 field_name : "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]" ,
335335 sort_desc : bool = False ,
336- ) -> list [ModelT ]:
336+ ) -> List [ModelT ]:
337337 return sorted (result , key = lambda item : getattr (item , self ._extract_field_name (field_name )), reverse = sort_desc )
338338
339339 def _apply_filters (
340340 self ,
341- result : list [ModelT ],
341+ result : List [ModelT ],
342342 * filters : Union [StatementFilter , ColumnElement [bool ]],
343343 apply_pagination : bool = True ,
344- ) -> list [ModelT ]:
344+ ) -> List [ModelT ]:
345345 for filter_ in filters :
346346 if isinstance (filter_ , LimitOffset ):
347347 if apply_pagination :
@@ -394,9 +394,9 @@ def _apply_filters(
394394
395395 def _get_match_fields (
396396 self ,
397- match_fields : Union [list [str ], str , None ],
397+ match_fields : Union [List [str ], str , None ],
398398 id_attribute : Optional [str ] = None ,
399- ) -> Optional [list [str ]]:
399+ ) -> Optional [List [str ]]:
400400 id_attribute = id_attribute or self .id_attribute
401401 match_fields = match_fields or self .match_fields
402402 if isinstance (match_fields , str ):
@@ -407,22 +407,22 @@ async def _list_and_count_basic(
407407 self ,
408408 * filters : Union [StatementFilter , ColumnElement [bool ]],
409409 ** kwargs : Any ,
410- ) -> tuple [list [ModelT ], int ]:
410+ ) -> tuple [List [ModelT ], int ]:
411411 result = await self .list (* filters , ** kwargs )
412412 return result , len (result )
413413
414414 async def _list_and_count_window (
415415 self ,
416416 * filters : Union [StatementFilter , ColumnElement [bool ]],
417417 ** kwargs : Any ,
418- ) -> tuple [list [ModelT ], int ]:
418+ ) -> tuple [List [ModelT ], int ]:
419419 return await self ._list_and_count_basic (* filters , ** kwargs )
420420
421421 def _find_or_raise_not_found (self , id_ : Any ) -> ModelT :
422422 return self .check_not_found (self .__collection__ ().get_or_none (id_ ))
423423
424424 @staticmethod
425- def _find_one_or_raise_error (result : list [ModelT ]) -> ModelT :
425+ def _find_one_or_raise_error (result : List [ModelT ]) -> ModelT :
426426 if not result :
427427 msg = "No item found when one was expected"
428428 raise IntegrityError (msg )
@@ -435,7 +435,7 @@ def _get_update_many_statement(
435435 self ,
436436 model_type : type [ModelT ],
437437 supports_returning : bool ,
438- loader_options : Optional [list [_AbstractLoad ]],
438+ loader_options : Optional [List [_AbstractLoad ]],
439439 execution_options : Optional [dict [str , Any ]],
440440 ) -> Union [Update , ReturningUpdate [tuple [ModelT ]]]:
441441 return self .statement # type: ignore[no-any-return] # pyright: ignore[reportReturnType]
@@ -496,7 +496,7 @@ async def get_one_or_none(
496496 async def get_or_upsert (
497497 self ,
498498 * filters : Union [StatementFilter , ColumnElement [bool ]],
499- match_fields : Union [list [str ], str , None ] = None ,
499+ match_fields : Union [List [str ], str , None ] = None ,
500500 upsert : bool = True ,
501501 attribute_names : Optional [Iterable [str ]] = None ,
502502 with_for_update : ForUpdateParameter = None ,
@@ -534,7 +534,7 @@ async def get_or_upsert(
534534 async def get_and_update (
535535 self ,
536536 * filters : Union [StatementFilter , ColumnElement [bool ]],
537- match_fields : Union [list [str ], str , None ] = None ,
537+ match_fields : Union [List [str ], str , None ] = None ,
538538 attribute_names : Optional [Iterable [str ]] = None ,
539539 with_for_update : ForUpdateParameter = None ,
540540 auto_commit : Optional [bool ] = None ,
@@ -606,13 +606,13 @@ async def add(
606606
607607 async def add_many (
608608 self ,
609- data : list [ModelT ],
609+ data : List [ModelT ],
610610 * ,
611611 auto_commit : Optional [bool ] = None ,
612612 auto_expunge : Optional [bool ] = None ,
613613 error_messages : Optional [Union [ErrorMessages , EmptyType ]] = Empty ,
614614 bind_group : Optional [str ] = None ,
615- ) -> list [ModelT ]:
615+ ) -> List [ModelT ]:
616616 for obj in data :
617617 await self .add (obj ) # pyright: ignore[reportCallIssue]
618618 return data
@@ -638,7 +638,7 @@ async def update(
638638
639639 async def update_many (
640640 self ,
641- data : list [ModelT ],
641+ data : List [ModelT ],
642642 * ,
643643 auto_commit : Optional [bool ] = None ,
644644 auto_expunge : Optional [bool ] = None ,
@@ -647,7 +647,7 @@ async def update_many(
647647 execution_options : Optional [dict [str , Any ]] = None ,
648648 uniquify : Optional [bool ] = None ,
649649 bind_group : Optional [str ] = None ,
650- ) -> list [ModelT ]:
650+ ) -> List [ModelT ]:
651651 return [self .__collection__ ().update (obj ) for obj in data if obj in self .__collection__ ()]
652652
653653 async def delete (
@@ -670,7 +670,7 @@ async def delete(
670670
671671 async def delete_many (
672672 self ,
673- item_ids : list [Any ],
673+ item_ids : List [Any ],
674674 * ,
675675 auto_commit : Optional [bool ] = None ,
676676 auto_expunge : Optional [bool ] = None ,
@@ -681,8 +681,8 @@ async def delete_many(
681681 execution_options : Optional [dict [str , Any ]] = None ,
682682 uniquify : Optional [bool ] = None ,
683683 bind_group : Optional [str ] = None ,
684- ) -> list [ModelT ]:
685- deleted : list [ModelT ] = []
684+ ) -> List [ModelT ]:
685+ deleted : List [ModelT ] = []
686686 for id_ in item_ids :
687687 if obj := self .__collection__ ().get_or_none (id_ ):
688688 deleted .append (obj )
@@ -701,7 +701,7 @@ async def delete_where(
701701 uniquify : Optional [bool ] = None ,
702702 bind_group : Optional [str ] = None ,
703703 ** kwargs : Any ,
704- ) -> list [ModelT ]:
704+ ) -> List [ModelT ]:
705705 result = self .__collection__ ().list ()
706706 result = self ._apply_filters (result , * filters )
707707 models = self ._filter_result_by_kwargs (result , kwargs )
@@ -717,7 +717,7 @@ async def upsert(
717717 auto_expunge : Optional [bool ] = None ,
718718 auto_commit : Optional [bool ] = None ,
719719 auto_refresh : Optional [bool ] = None ,
720- match_fields : Optional [Union [list [str ], str ]] = None ,
720+ match_fields : Optional [Union [List [str ], str ]] = None ,
721721 error_messages : Optional [Union [ErrorMessages , EmptyType ]] = Empty ,
722722 load : Optional [LoadSpec ] = None ,
723723 execution_options : Optional [dict [str , Any ]] = None ,
@@ -731,18 +731,18 @@ async def upsert(
731731
732732 async def upsert_many (
733733 self ,
734- data : list [ModelT ],
734+ data : List [ModelT ],
735735 * ,
736736 auto_expunge : Optional [bool ] = None ,
737737 auto_commit : Optional [bool ] = None ,
738738 no_merge : bool = False ,
739- match_fields : Optional [Union [list [str ], str ]] = None ,
739+ match_fields : Optional [Union [List [str ], str ]] = None ,
740740 error_messages : Optional [Union [ErrorMessages , EmptyType ]] = Empty ,
741741 load : Optional [LoadSpec ] = None ,
742742 execution_options : Optional [dict [str , Any ]] = None ,
743743 uniquify : Optional [bool ] = None ,
744744 bind_group : Optional [str ] = None ,
745- ) -> list [ModelT ]:
745+ ) -> List [ModelT ]:
746746 return [await self .upsert (item ) for item in data ]
747747
748748 async def list_and_count (
@@ -751,15 +751,15 @@ async def list_and_count(
751751 statement : Union [Select [tuple [ModelT ]], StatementLambdaElement , None ] = None ,
752752 auto_expunge : Optional [bool ] = None ,
753753 count_with_window_function : Optional [bool ] = None ,
754- order_by : Optional [Union [list [OrderingPair ], OrderingPair ]] = None ,
754+ order_by : Optional [Union [List [OrderingPair ], OrderingPair ]] = None ,
755755 error_messages : Optional [Union [ErrorMessages , EmptyType ]] = Empty ,
756756 load : Optional [LoadSpec ] = None ,
757757 execution_options : Optional [dict [str , Any ]] = None ,
758758 uniquify : Optional [bool ] = None ,
759759 use_cache : bool = True ,
760760 bind_group : Optional [str ] = None ,
761761 ** kwargs : Any ,
762- ) -> tuple [list [ModelT ], int ]:
762+ ) -> tuple [List [ModelT ], int ]:
763763 return await self ._list_and_count_basic (* filters , ** kwargs )
764764
765765 async def list (
@@ -769,7 +769,7 @@ async def list(
769769 use_cache : bool = True ,
770770 bind_group : Optional [str ] = None ,
771771 ** kwargs : Any ,
772- ) -> list [ModelT ]:
772+ ) -> List [ModelT ]:
773773 result = self .__collection__ ().list ()
774774 result = self ._apply_filters (result , * filters )
775775 return self ._filter_result_by_kwargs (result , kwargs )
0 commit comments