Skip to content

Commit 43586d8

Browse files
authored
feat: allows positional args for session (#455)
This change allows for arguments to also be matched when generating a service provider closure.
1 parent 4b8e1b0 commit 43586d8

2 files changed

Lines changed: 34 additions & 11 deletions

File tree

advanced_alchemy/extensions/litestar/providers.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -203,15 +203,14 @@ def create_service_provider(
203203
A dependency provider function suitable for Litestar's DI system.
204204
"""
205205

206-
# Determine if the service is async or sync
207206
session_dependency_key = config.session_dependency_key if config else "db_session"
208207

209208
if issubclass(service_class, SQLAlchemyAsyncRepositoryService) or service_class is SQLAlchemyAsyncRepositoryService: # type: ignore[comparison-overlap]
210209
session_type_annotation = "Optional[AsyncSession]"
211210
return_type_annotation = AsyncGenerator[service_class, None] # type: ignore[valid-type]
212211

213-
async def provide_service_async(**kwargs: Any) -> "AsyncGenerator[AsyncServiceT_co, None]":
214-
db_session = cast("Optional[AsyncSession]", kwargs.get(session_dependency_key))
212+
async def provide_service_async(*args: Any, **kwargs: Any) -> "AsyncGenerator[AsyncServiceT_co, None]":
213+
db_session = cast("Optional[AsyncSession]", args[0] if args else kwargs.get(session_dependency_key))
215214
async with service_class.new( # type: ignore[union-attr]
216215
session=db_session, # type: ignore[arg-type]
217216
statement=statement,
@@ -231,7 +230,6 @@ async def provide_service_async(**kwargs: Any) -> "AsyncGenerator[AsyncServiceT_
231230
annotation=session_type_annotation,
232231
)
233232

234-
# Create the full signature for the provider function
235233
provider_signature = inspect.Signature(
236234
parameters=[session_param],
237235
return_annotation=return_type_annotation,
@@ -245,10 +243,8 @@ async def provide_service_async(**kwargs: Any) -> "AsyncGenerator[AsyncServiceT_
245243
session_type_annotation = "Optional[Session]"
246244
return_type_annotation = Generator[service_class, None, None] # type: ignore[misc,assignment,valid-type]
247245

248-
def provide_service_sync(**kwargs: Any) -> "Generator[SyncServiceT_co, None, None]":
249-
# Extract the session using the dynamic key
250-
db_session = cast("Optional[Session]", kwargs.get(session_dependency_key))
251-
# Instantiate and yield the service
246+
def provide_service_sync(*args: Any, **kwargs: Any) -> "Generator[SyncServiceT_co, None, None]":
247+
db_session = cast("Optional[Session]", args[0] if args else kwargs.get(session_dependency_key))
252248
with service_class.new(
253249
session=db_session,
254250
statement=statement,
@@ -261,16 +257,13 @@ def provide_service_sync(**kwargs: Any) -> "Generator[SyncServiceT_co, None, Non
261257
) as service:
262258
yield service
263259

264-
# Create the signature parameter for the session dependency
265-
266260
session_param = inspect.Parameter(
267261
name=session_dependency_key,
268262
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
269263
default=Dependency(skip_validation=True),
270264
annotation=session_type_annotation,
271265
)
272266

273-
# Create the full signature for the provider function
274267
provider_signature = inspect.Signature(
275268
parameters=[session_param],
276269
return_annotation=return_type_annotation,

tests/unit/test_extensions/test_litestar/test_providers.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,16 @@ def test_create_sync_service_provider_custom() -> None:
184184
assert isinstance(svc, TestSyncService)
185185

186186

187+
def test_create_sync_service_provider_positional() -> None:
188+
"""Test creating an async service provider."""
189+
provider = create_service_provider(TestSyncService, config=MagicMock(session_dependency_key="testing_session"))
190+
191+
# Ensure the provider is callable
192+
assert callable(provider)
193+
svc = next(provider(MagicMock()))
194+
assert isinstance(svc, TestSyncService)
195+
196+
187197
async def test_create_async_service_provider() -> None:
188198
"""Test creating an async service provider."""
189199
provider = create_service_provider(TestAsyncService)
@@ -194,6 +204,26 @@ async def test_create_async_service_provider() -> None:
194204
assert isinstance(svc, TestAsyncService)
195205

196206

207+
async def test_create_async_service_provider_custom() -> None:
208+
"""Test creating an async service provider."""
209+
provider = create_service_provider(TestAsyncService, config=MagicMock(session_dependency_key="testing_session"))
210+
211+
# Ensure the provider is callable
212+
assert callable(provider)
213+
svc = await anext_(provider(testing_session=MagicMock()))
214+
assert isinstance(svc, TestAsyncService)
215+
216+
217+
async def test_create_async_service_provider_positional() -> None:
218+
"""Test creating an async service provider."""
219+
provider = create_service_provider(TestAsyncService, config=MagicMock(session_dependency_key="testing_session"))
220+
221+
# Ensure the provider is callable
222+
assert callable(provider)
223+
svc = await anext_(provider(MagicMock()))
224+
assert isinstance(svc, TestAsyncService)
225+
226+
197227
def test_create_async_service_dependencies() -> None:
198228
"""Test creating async service dependencies."""
199229
with patch("advanced_alchemy.extensions.litestar.providers.create_service_provider") as mock_create_provider:

0 commit comments

Comments
 (0)