Skip to content

Commit b433058

Browse files
authored
fix: InvalidRequestError when calling create_session_maker (#701)
## Summary - Fixes `sqlalchemy.exc.InvalidRequestError: No such event 'before_flush' for target 'async_sessionmaker(...)'` raised when calling `SQLAlchemyAsyncConfig.create_session_maker()` ## Root Cause SQLAlchemy's `async_sessionmaker` does not support session-level events (`before_flush`, `after_commit`, `after_rollback`) only the sync `sessionmaker` does. The code was registering these events directly on the `async_sessionmaker` instance, which SQLAlchemy rejects. ## Fix Instead of attaching events to the `async_sessionmaker` 1. Creates a sync `sessionmaker()` 2. Registers all session-level events (`before_flush`, `after_commit`, `after_rollback`) on it 3. Injects it into the `async_sessionmaker` via `session_maker.configure(sync_session_class=sync_maker)` This ensures every `AsyncSession` created by the factory uses this sync sessionmaker internally, and all event listeners fire correctly on the underlying sync `Session`.
1 parent 04c2898 commit b433058

11 files changed

Lines changed: 636 additions & 534 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ repos:
2222
- id: unasyncd
2323
additional_dependencies: ["ruff"]
2424
- repo: https://github.com/charliermarsh/ruff-pre-commit
25-
rev: "v0.15.7"
25+
rev: "v0.15.9"
2626
hooks:
2727
# Run the linter.
2828
- id: ruff

advanced_alchemy/config/asyncio.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import TYPE_CHECKING, Callable, Optional, Union, cast
55

66
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
7+
from sqlalchemy.orm import sessionmaker as sync_sessionmaker
78

89
from advanced_alchemy._listeners import set_async_context
910
from advanced_alchemy.config.common import (
@@ -162,14 +163,20 @@ def create_session_maker(self) -> "Callable[[], AsyncSession]":
162163
"async_sessionmaker[AsyncSession]",
163164
self.session_maker, # pyright: ignore[reportUnknownMemberType]
164165
)
166+
167+
# async_sessionmaker does not support Session-level events directly.
168+
# Create a sync sessionmaker, register events on it, and inject it
169+
# as sync_session_class so events fire on the underlying sync Session.
170+
sync_maker = sync_sessionmaker()
165171
if self.enable_file_object_listener:
166-
event.listen(session_maker, "before_flush", AsyncFileObjectListener.before_flush)
167-
event.listen(session_maker, "after_commit", AsyncFileObjectListener.after_commit)
168-
event.listen(session_maker, "after_rollback", AsyncFileObjectListener.after_rollback)
172+
event.listen(sync_maker, "before_flush", AsyncFileObjectListener.before_flush)
173+
event.listen(sync_maker, "after_commit", AsyncFileObjectListener.after_commit)
174+
event.listen(sync_maker, "after_rollback", AsyncFileObjectListener.after_rollback)
169175
if self.enable_touch_updated_timestamp_listener:
170-
event.listen(session_maker, "before_flush", touch_updated_timestamp)
171-
event.listen(session_maker, "after_commit", AsyncCacheListener.after_commit)
172-
event.listen(session_maker, "after_rollback", AsyncCacheListener.after_rollback)
176+
event.listen(sync_maker, "before_flush", touch_updated_timestamp)
177+
event.listen(sync_maker, "after_commit", AsyncCacheListener.after_commit)
178+
event.listen(sync_maker, "after_rollback", AsyncCacheListener.after_rollback)
179+
session_maker.configure(sync_session_class=sync_maker)
173180

174181
if self.session_maker is None: # pyright: ignore
175182
msg = "Session maker was not initialized." # type: ignore[unreachable]

advanced_alchemy/repository/_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ def _get_error_messages(
678678
if default_messages and isinstance(default_messages, dict):
679679
messages.update(default_messages)
680680
if error_messages:
681-
messages.update(cast("ErrorMessages", error_messages))
681+
messages.update(cast("ErrorMessages", error_messages)) # type: ignore[unused-ignore,redundant-cast]
682682
return messages
683683

684684
@classmethod

advanced_alchemy/repository/_sync.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ def _get_error_messages(
679679
if default_messages and isinstance(default_messages, dict):
680680
messages.update(default_messages)
681681
if error_messages:
682-
messages.update(cast("ErrorMessages", error_messages))
682+
messages.update(cast("ErrorMessages", error_messages)) # type: ignore[unused-ignore,redundant-cast]
683683
return messages
684684

685685
@classmethod

advanced_alchemy/repository/memory/_async.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,9 @@ def _get_error_messages(
228228
default_messages = None
229229
messages = cast("ErrorMessages", dict(DEFAULT_ERROR_MESSAGE_TEMPLATES))
230230
if default_messages:
231-
messages.update(cast("ErrorMessages", default_messages))
231+
messages.update(cast("ErrorMessages", default_messages)) # type: ignore[unused-ignore,redundant-cast]
232232
if error_messages:
233-
messages.update(cast("ErrorMessages", error_messages))
233+
messages.update(cast("ErrorMessages", error_messages)) # type: ignore[unused-ignore,redundant-cast]
234234
return messages
235235

236236
@classmethod

advanced_alchemy/repository/memory/_sync.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,9 @@ def _get_error_messages(
229229
default_messages = None
230230
messages = cast("ErrorMessages", dict(DEFAULT_ERROR_MESSAGE_TEMPLATES))
231231
if default_messages:
232-
messages.update(cast("ErrorMessages", default_messages))
232+
messages.update(cast("ErrorMessages", default_messages)) # type: ignore[unused-ignore,redundant-cast]
233233
if error_messages:
234-
messages.update(cast("ErrorMessages", error_messages))
234+
messages.update(cast("ErrorMessages", error_messages)) # type: ignore[unused-ignore,redundant-cast]
235235
return messages
236236

237237
@classmethod

advanced_alchemy/service/_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def to_schema(
215215
)
216216
if MSGSPEC_INSTALLED and issubclass(schema_type, Struct):
217217
if not isinstance(data, Sequence):
218-
return cast( # type: ignore[redundant-cast]
218+
return cast(
219219
"ModelDTOT",
220220
convert(
221221
obj=data,

pyproject.toml

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -462,23 +462,14 @@ warn_unused_ignores = false
462462
ignore_missing_imports = true
463463
module = [
464464
"asyncmy",
465-
"pyodbc",
466465
"greenlet",
467466
"google.auth.*",
468467
"google.cloud.*",
469-
"google.protobuf.*",
470468
"pyarrow.*",
471-
"pytest_docker.*",
472-
"googleapiclient",
473-
"googleapiclient.*",
474469
"uuid_utils",
475-
"uuid_utils.*",
476470
"fsspec",
477471
"fsspec.*",
478-
"gcsfs",
479-
"fastnanoid",
480472
"s3fs",
481-
"s3fs.*",
482473
"argon2",
483474
"argon2.*",
484475
]
@@ -491,12 +482,12 @@ module = ["dishka", "dishka.*"]
491482
[[tool.mypy.overrides]]
492483
follow_imports = "skip"
493484
ignore_missing_imports = true
494-
module = ["pytest", "pytest.*", "_pytest", "_pytest.*"]
485+
module = ["pytest", "_pytest.*"]
495486

496487
[[tool.mypy.overrides]]
497488
follow_imports = "skip"
498489
ignore_missing_imports = true
499-
module = ["sphinx", "sphinx.*"]
490+
module = ["sphinx.*"]
500491

501492
[[tool.mypy.overrides]]
502493
module = "advanced_alchemy._serialization"

tests/unit/test_config/test_async_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ def test_create_session_maker_standard_path() -> None:
3333
return_value=mock_session_maker,
3434
),
3535
patch("sqlalchemy.event.listen") as mock_listen,
36+
patch("advanced_alchemy.config.asyncio.sync_sessionmaker") as mock_sync_maker_factory,
3637
):
38+
mock_sync_maker = MagicMock()
39+
mock_sync_maker_factory.return_value = mock_sync_maker
3740
result = config.create_session_maker()
3841

3942
assert result is mock_session_maker
@@ -42,6 +45,14 @@ def test_create_session_maker_standard_path() -> None:
4245
# after_commit, after_rollback for cache
4346
assert mock_listen.call_count == 6
4447

48+
# Verify session_maker.configure was called with the sync_maker
49+
mock_session_maker.configure.assert_called_once_with(sync_session_class=mock_sync_maker)
50+
51+
# Verify listeners are attached to the sync_maker, not the async session_maker
52+
for call in mock_listen.call_args_list:
53+
assert call.args[0] is mock_sync_maker
54+
assert call.args[0] is not mock_session_maker
55+
4556
# Verify file object listeners
4657
listener_events = [c.args[1] for c in mock_listen.call_args_list]
4758
assert "before_flush" in listener_events

tests/unit/test_extensions/test_fastapi/test_providers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_add_get_dependencies_cache() -> None:
6060
assert cache.get_dependencies(key) == deps1 # type: ignore
6161

6262
# Test retrieving non-existent key
63-
assert cache.get_dependencies(hash("nonexistent")) is None
63+
assert cache.get_dependencies(hash("nonexistent")) is None # type: ignore[unreachable]
6464

6565

6666
def test_create_filter_dependencies_cache_hit() -> None:
@@ -84,7 +84,7 @@ def test_create_filter_dependencies_cache_hit() -> None:
8484
assert deps == mock_deps # type: ignore
8585

8686
# Verify aggregate function builder was NOT called
87-
mock_create.assert_not_called()
87+
mock_create.assert_not_called() # type: ignore[unreachable]
8888

8989
# Verify cache wasn't updated again
9090
mock_add.assert_not_called()

0 commit comments

Comments
 (0)