@@ -203,15 +203,14 @@ def create_service_provider(
203203 A dependency provider function suitable for Litestar's DI system.
204204 """
205205
206- # Determine if the service is async or sync
207206 session_dependency_key = config .session_dependency_key if config else "db_session"
208207
209208 if issubclass (service_class , SQLAlchemyAsyncRepositoryService ) or service_class is SQLAlchemyAsyncRepositoryService : # type: ignore[comparison-overlap]
210209 session_type_annotation = "Optional[AsyncSession]"
211210 return_type_annotation = AsyncGenerator [service_class , None ] # type: ignore[valid-type]
212211
213- async def provide_service_async (** kwargs : Any ) -> "AsyncGenerator[AsyncServiceT_co, None]" :
214- db_session = cast ("Optional[AsyncSession]" , kwargs .get (session_dependency_key ))
212+ async def provide_service_async (* args : Any , * *kwargs : Any ) -> "AsyncGenerator[AsyncServiceT_co, None]" :
213+ db_session = cast ("Optional[AsyncSession]" , args [ 0 ] if args else kwargs .get (session_dependency_key ))
215214 async with service_class .new ( # type: ignore[union-attr]
216215 session = db_session , # type: ignore[arg-type]
217216 statement = statement ,
@@ -231,7 +230,6 @@ async def provide_service_async(**kwargs: Any) -> "AsyncGenerator[AsyncServiceT_
231230 annotation = session_type_annotation ,
232231 )
233232
234- # Create the full signature for the provider function
235233 provider_signature = inspect .Signature (
236234 parameters = [session_param ],
237235 return_annotation = return_type_annotation ,
@@ -245,10 +243,8 @@ async def provide_service_async(**kwargs: Any) -> "AsyncGenerator[AsyncServiceT_
245243 session_type_annotation = "Optional[Session]"
246244 return_type_annotation = Generator [service_class , None , None ] # type: ignore[misc,assignment,valid-type]
247245
248- def provide_service_sync (** kwargs : Any ) -> "Generator[SyncServiceT_co, None, None]" :
249- # Extract the session using the dynamic key
250- db_session = cast ("Optional[Session]" , kwargs .get (session_dependency_key ))
251- # Instantiate and yield the service
246+ def provide_service_sync (* args : Any , ** kwargs : Any ) -> "Generator[SyncServiceT_co, None, None]" :
247+ db_session = cast ("Optional[Session]" , args [0 ] if args else kwargs .get (session_dependency_key ))
252248 with service_class .new (
253249 session = db_session ,
254250 statement = statement ,
@@ -261,16 +257,13 @@ def provide_service_sync(**kwargs: Any) -> "Generator[SyncServiceT_co, None, Non
261257 ) as service :
262258 yield service
263259
264- # Create the signature parameter for the session dependency
265-
266260 session_param = inspect .Parameter (
267261 name = session_dependency_key ,
268262 kind = inspect .Parameter .POSITIONAL_OR_KEYWORD ,
269263 default = Dependency (skip_validation = True ),
270264 annotation = session_type_annotation ,
271265 )
272266
273- # Create the full signature for the provider function
274267 provider_signature = inspect .Signature (
275268 parameters = [session_param ],
276269 return_annotation = return_type_annotation ,
0 commit comments