Skip to content

Commit 33599a6

Browse files
committed
chore: cleanup and split up tests
1 parent b003373 commit 33599a6

12 files changed

Lines changed: 938 additions & 815 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ repos:
2222
- id: unasyncd
2323
additional_dependencies: ["ruff"]
2424
- repo: https://github.com/charliermarsh/ruff-pre-commit
25-
rev: "v0.14.14"
25+
rev: "v0.15.0"
2626
hooks:
2727
# Run the linter.
2828
- id: ruff

advanced_alchemy/alembic/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing_extensions import TypeIs
77

88
from advanced_alchemy.exceptions import MissingDependencyError
9+
from advanced_alchemy.utils.sync_tools import async_
910

1011
if TYPE_CHECKING:
1112
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
@@ -133,8 +134,8 @@ async def _dump_table_async(session: "AbstractAsyncContextManager[AsyncSession]"
133134
)
134135
json_path.write_text(encode_json([row.to_dict() for row in await repo(session=_session).list()]))
135136

136-
dump_dir.mkdir(exist_ok=True)
137+
await async_(dump_dir.mkdir)(exist_ok=True)
137138

138139
if _is_sync(session):
139-
return _dump_table_sync(session)
140+
return await async_(_dump_table_sync)(session)
140141
return await _dump_table_async(session)

advanced_alchemy/repository/memory/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ class Child(Base):
313313
continue
314314
remote_mapper = mappers[next(iter(column.foreign_keys))._table_key()] # noqa: SLF001 # pyright: ignore[reportPrivateUsage]
315315
try:
316-
obj = self.store(remote_mapper.class_).get(new_attrs.get(column.key, None))
316+
obj = self.store(remote_mapper.class_).get(new_attrs.get(column.key))
317317
except KeyError:
318318
continue
319319

advanced_alchemy/types/file_object/backends/obstore.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def save_object(
142142
"max_concurrency": max_concurrency,
143143
}
144144
if not isinstance(self.fs, LocalStore):
145-
put_params["attributes"] = attributes if attributes else None
145+
put_params["attributes"] = attributes or None
146146

147147
_ = self.fs.put(file_object.path, data, **put_params)
148148
info = self.fs.head(file_object.path)
@@ -198,7 +198,7 @@ async def save_object_async(
198198
"max_concurrency": max_concurrency,
199199
}
200200
if not isinstance(self.fs, LocalStore):
201-
put_params["attributes"] = attributes if attributes else None
201+
put_params["attributes"] = attributes or None
202202

203203
_ = await self.fs.put_async(file_object.path, data, **put_params)
204204
info = await self.fs.head_async(file_object.path)

tests/integration/test_alembic_commands.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,23 +223,27 @@ def tmp_project_dir(monkeypatch: MonkeyPatch, tmp_path: Path) -> Generator[Path,
223223

224224

225225
async def test_alembic_init(alembic_commands: commands.AlembicCommands, tmp_project_dir: Path) -> None:
226+
from advanced_alchemy.utils.sync_tools import async_
227+
226228
alembic_commands.init(directory=f"{tmp_project_dir}/migrations/")
227229
expected_dirs = [f"{tmp_project_dir}/migrations/", f"{tmp_project_dir}/migrations/versions"]
228230
expected_files = [f"{tmp_project_dir}/migrations/env.py", f"{tmp_project_dir}/migrations/script.py.mako"]
229231
for dir in expected_dirs:
230-
assert Path(dir).is_dir()
232+
assert await async_(Path(dir).is_dir)()
231233
for file in expected_files:
232-
assert Path(file).is_file()
234+
assert await async_(Path(file).is_file)()
233235

234236

235237
async def test_alembic_init_already(alembic_commands: commands.AlembicCommands, tmp_project_dir: Path) -> None:
238+
from advanced_alchemy.utils.sync_tools import async_
239+
236240
alembic_commands.init(directory=f"{tmp_project_dir}/migrations/")
237241
expected_dirs = [f"{tmp_project_dir}/migrations/", f"{tmp_project_dir}/migrations/versions"]
238242
expected_files = [f"{tmp_project_dir}/migrations/env.py", f"{tmp_project_dir}/migrations/script.py.mako"]
239243
for dir in expected_dirs:
240-
assert Path(dir).is_dir()
244+
assert await async_(Path(dir).is_dir)()
241245
for file in expected_files:
242-
assert Path(file).is_file()
246+
assert await async_(Path(file).is_file)()
243247
with pytest.raises(CommandError):
244248
alembic_commands.init(directory=f"{tmp_project_dir}/migrations/")
245249

tests/integration/test_file_object.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -959,7 +959,9 @@ async def test_file_object_save_with_different_data_types(storage_registry: Stor
959959
assert await obj2.get_content_async() == test_content
960960
assert obj2.get_content() == test_content
961961
# Cleanup
962-
temp_path.unlink()
962+
from advanced_alchemy.utils.sync_tools import async_
963+
964+
await async_(temp_path.unlink)()
963965

964966

965967
@pytest.mark.xdist_group("file_object")

tests/integration/test_operations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ async def test_table_async(
120120
yield cached_test_table
121121

122122
# Fast data-only cleanup between tests
123-
if dialect_name not in {"mock"} and "spanner" not in dialect_name:
123+
if dialect_name != "mock" and "spanner" not in dialect_name:
124124
async with async_engine.begin() as conn:
125125
await conn.execute(cached_test_table.delete())
126126
await conn.commit()
@@ -255,7 +255,7 @@ def store_model_sync(
255255
yield cached_store_model
256256

257257
# Fast data-only cleanup between tests
258-
if dialect_name not in {"mock"} and "spanner" not in dialect_name:
258+
if dialect_name != "mock" and "spanner" not in dialect_name:
259259
clean_tables(engine, cached_store_model.metadata)
260260

261261
# Drop table at session end (handled by teardown)
@@ -296,7 +296,7 @@ async def store_model_async(
296296
yield cached_store_model
297297

298298
# Fast data-only cleanup between tests
299-
if dialect_name not in {"mock"} and "spanner" not in dialect_name:
299+
if dialect_name != "mock" and "spanner" not in dialect_name:
300300
await async_clean_tables(async_engine, cached_store_model.metadata)
301301

302302
# Drop table at session end (handled by teardown)
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
"""Unit tests for cache-related listeners in advanced_alchemy._listeners."""
2+
3+
from unittest.mock import AsyncMock, MagicMock, patch
4+
5+
import pytest
6+
from sqlalchemy.orm import Session
7+
8+
from advanced_alchemy._listeners import (
9+
AsyncCacheListener,
10+
BaseCacheListener,
11+
CacheInvalidationListener,
12+
CacheInvalidationTracker,
13+
SyncCacheListener,
14+
get_cache_tracker,
15+
setup_cache_listeners,
16+
)
17+
18+
# --- CacheInvalidationTracker Tests ---
19+
20+
21+
def test_cache_invalidation_tracker_add_invalidation() -> None:
22+
mock_manager = MagicMock()
23+
tracker = CacheInvalidationTracker(mock_manager)
24+
tracker.add_invalidation("User", 1, "group1")
25+
26+
assert ("User", 1, "group1") in tracker._pending_invalidations
27+
assert "User" in tracker._pending_model_bumps
28+
29+
30+
def test_cache_invalidation_tracker_commit() -> None:
31+
mock_manager = MagicMock()
32+
tracker = CacheInvalidationTracker(mock_manager)
33+
tracker.add_invalidation("User", 1, "group1")
34+
35+
tracker.commit()
36+
37+
mock_manager.bump_model_version_sync.assert_called_with("User")
38+
mock_manager.invalidate_entity_sync.assert_called_with("User", 1, "group1")
39+
assert not tracker._pending_invalidations
40+
assert not tracker._pending_model_bumps
41+
42+
43+
def test_cache_invalidation_tracker_rollback() -> None:
44+
mock_manager = MagicMock()
45+
tracker = CacheInvalidationTracker(mock_manager)
46+
tracker.add_invalidation("User", 1, "group1")
47+
48+
tracker.rollback()
49+
50+
mock_manager.bump_model_version_sync.assert_not_called()
51+
assert not tracker._pending_invalidations
52+
assert not tracker._pending_model_bumps
53+
54+
55+
@pytest.mark.asyncio
56+
async def test_cache_invalidation_tracker_commit_async() -> None:
57+
mock_manager = AsyncMock()
58+
tracker = CacheInvalidationTracker(mock_manager)
59+
tracker.add_invalidation("User", 1, "group1")
60+
61+
await tracker.commit_async()
62+
63+
mock_manager.bump_model_version_async.assert_called_with("User")
64+
mock_manager.invalidate_entity_async.assert_called_with("User", 1, "group1")
65+
assert not tracker._pending_invalidations
66+
assert not tracker._pending_model_bumps
67+
68+
69+
# --- get_cache_tracker Tests ---
70+
71+
72+
def test_get_cache_tracker_existing() -> None:
73+
session = MagicMock(spec=Session)
74+
tracker = MagicMock()
75+
session.info = {"_aa_cache_tracker": tracker}
76+
77+
result = get_cache_tracker(session)
78+
assert result is tracker
79+
80+
81+
def test_get_cache_tracker_create() -> None:
82+
session = MagicMock(spec=Session)
83+
session.info = {}
84+
mock_manager = MagicMock()
85+
86+
result = get_cache_tracker(session, cache_manager=mock_manager, create=True)
87+
assert isinstance(result, CacheInvalidationTracker)
88+
assert session.info["_aa_cache_tracker"] is result
89+
90+
91+
def test_get_cache_tracker_no_create() -> None:
92+
session = MagicMock(spec=Session)
93+
session.info = {}
94+
95+
result = get_cache_tracker(session, create=False)
96+
assert result is None
97+
98+
99+
# --- SyncCacheListener Tests ---
100+
101+
102+
def test_sync_cache_listener_after_commit() -> None:
103+
session = MagicMock(spec=Session)
104+
tracker = MagicMock()
105+
session.info = {"_aa_cache_tracker": tracker, "enable_cache_listener": True}
106+
107+
SyncCacheListener.after_commit(session)
108+
109+
tracker.commit.assert_called_once()
110+
assert "_aa_cache_tracker" not in session.info
111+
112+
113+
def test_sync_cache_listener_after_rollback() -> None:
114+
session = MagicMock(spec=Session)
115+
tracker = MagicMock()
116+
session.info = {"_aa_cache_tracker": tracker, "enable_cache_listener": True}
117+
118+
SyncCacheListener.after_rollback(session)
119+
120+
tracker.rollback.assert_called_once()
121+
assert "_aa_cache_tracker" not in session.info
122+
123+
124+
def test_sync_cache_listener_disabled() -> None:
125+
session = MagicMock(spec=Session)
126+
tracker = MagicMock()
127+
session.info = {"_aa_cache_tracker": tracker, "enable_cache_listener": False}
128+
129+
SyncCacheListener.after_commit(session)
130+
131+
tracker.commit.assert_not_called()
132+
133+
134+
# --- AsyncCacheListener Tests ---
135+
136+
137+
@pytest.mark.asyncio
138+
async def test_async_cache_listener_after_commit() -> None:
139+
session = MagicMock(spec=Session)
140+
tracker = MagicMock()
141+
# Mock commit_async to verify it's scheduled
142+
tracker.commit_async = AsyncMock()
143+
session.info = {"_aa_cache_tracker": tracker, "enable_cache_listener": True}
144+
145+
AsyncCacheListener.after_commit(session)
146+
147+
# Since it creates a task, we verify it popped the tracker
148+
assert "_aa_cache_tracker" not in session.info
149+
150+
151+
def test_async_cache_listener_after_rollback() -> None:
152+
session = MagicMock(spec=Session)
153+
tracker = MagicMock()
154+
session.info = {"_aa_cache_tracker": tracker, "enable_cache_listener": True}
155+
156+
AsyncCacheListener.after_rollback(session)
157+
158+
tracker.rollback.assert_called_once()
159+
assert "_aa_cache_tracker" not in session.info
160+
161+
162+
# --- CacheInvalidationListener Tests ---
163+
164+
165+
def test_cache_invalidation_listener_after_commit_sync_context() -> None:
166+
session = MagicMock(spec=Session)
167+
tracker = MagicMock()
168+
session.info = {"_aa_cache_tracker": tracker, "enable_cache_listener": True}
169+
170+
try:
171+
CacheInvalidationListener.after_commit(session)
172+
except RuntimeError:
173+
# If no loop, it calls tracker.commit()
174+
pass
175+
176+
177+
# --- BaseCacheListener Tests ---
178+
179+
180+
class TestBaseCacheListener(BaseCacheListener):
181+
pass
182+
183+
184+
def test_base_cache_listener_is_listener_enabled() -> None:
185+
session = MagicMock(spec=Session)
186+
session.info = {}
187+
session.bind = None
188+
session.execution_options = None
189+
assert TestBaseCacheListener._is_listener_enabled(session) is True
190+
191+
session.info = {"enable_cache_listener": False}
192+
assert TestBaseCacheListener._is_listener_enabled(session) is False
193+
194+
195+
def test_base_cache_listener_execution_options() -> None:
196+
session = MagicMock(spec=Session)
197+
session.info = {}
198+
session.bind = None
199+
200+
# Test execution_options via session
201+
session.execution_options = {"enable_cache_listener": False}
202+
assert TestBaseCacheListener._is_listener_enabled(session) is False
203+
204+
session.execution_options = {"enable_cache_listener": True}
205+
assert TestBaseCacheListener._is_listener_enabled(session) is True
206+
207+
208+
# --- Setup Tests ---
209+
210+
211+
def test_setup_cache_listeners() -> None:
212+
with patch("sqlalchemy.event.listen") as mock_listen:
213+
setup_cache_listeners()
214+
assert mock_listen.called

0 commit comments

Comments
 (0)