Skip to content

Commit 8916358

Browse files
authored
feat: initial support for dogpile caching (#636)
Introduce support for dogpile caching, including a new caching configuration and a null region implementation for scenarios where caching is disabled or `dogpile.cache` is not installed. Add unit tests to validate the caching behavior and configuration options.
1 parent 298b4ce commit 8916358

24 files changed

Lines changed: 5820 additions & 134 deletions

advanced_alchemy/_listeners.py

Lines changed: 216 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@
1111
from sqlalchemy.inspection import inspect
1212

1313
if TYPE_CHECKING:
14-
from sqlalchemy.orm import Session, UOWTransaction
14+
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session
15+
from sqlalchemy.orm import Session, UOWTransaction, scoped_session
1516
from sqlalchemy.orm.state import InstanceState
1617

18+
from advanced_alchemy.cache import CacheManager
1719
from advanced_alchemy.types.file_object import FileObjectSessionTracker, StorageRegistry
1820

1921
_active_file_operations: set[asyncio.Task[Any]] = set()
2022
"""Stores active file operations to prevent them from being garbage collected."""
23+
_active_cache_operations: set[asyncio.Task[Any]] = set()
24+
"""Stores active cache invalidation operations to prevent them from being garbage collected."""
2125
# Context variable to hold the session tracker instance for the current session context
2226
_current_session_tracker: contextvars.ContextVar[Optional["FileObjectSessionTracker"]] = contextvars.ContextVar(
2327
"_current_session_tracker",
@@ -444,6 +448,217 @@ def setup_file_object_listeners(registry: Optional["StorageRegistry"] = None) ->
444448
set_async_context(False)
445449

446450

451+
# Cache invalidation support
452+
_CACHE_TRACKER_KEY = "_aa_cache_tracker"
453+
454+
455+
class CacheInvalidationTracker:
456+
"""Tracks pending cache invalidations for a session transaction.
457+
458+
This tracker collects entity invalidations during a transaction and
459+
processes them only after a successful commit. On rollback, the
460+
pending invalidations are discarded.
461+
462+
Note:
463+
Model version bumps are also deferred to commit to ensure rollbacks
464+
don't invalidate list caches when no DB change occurred.
465+
"""
466+
467+
__slots__ = ("_cache_manager", "_pending_invalidations", "_pending_model_bumps")
468+
469+
def __init__(self, cache_manager: "CacheManager") -> None:
470+
self._cache_manager = cache_manager
471+
self._pending_invalidations: list[tuple[str, Any, Optional[str]]] = []
472+
self._pending_model_bumps: set[str] = set()
473+
474+
def add_invalidation(self, model_name: str, entity_id: Any, bind_group: Optional[str] = None) -> None:
475+
"""Queue an entity for cache invalidation.
476+
477+
The actual invalidation and model version bump are deferred until
478+
commit() is called, ensuring rollbacks don't affect the cache.
479+
480+
Args:
481+
model_name: The model/table name.
482+
entity_id: The entity's primary key value.
483+
bind_group: Optional routing group for multi-master configurations.
484+
When provided, only the cache entry for that bind_group is
485+
invalidated.
486+
"""
487+
self._pending_invalidations.append((model_name, entity_id, bind_group))
488+
# Queue model version bump for list query invalidation (deferred to commit)
489+
self._pending_model_bumps.add(model_name)
490+
491+
def commit(self) -> None:
492+
"""Process all pending invalidations after successful commit."""
493+
# First bump model versions for list query invalidation
494+
for model_name in self._pending_model_bumps:
495+
self._cache_manager.bump_model_version_sync(model_name)
496+
self._pending_model_bumps.clear()
497+
498+
# Then invalidate individual entities
499+
for model_name, entity_id, bind_group in self._pending_invalidations:
500+
self._cache_manager.invalidate_entity_sync(model_name, entity_id, bind_group)
501+
self._pending_invalidations.clear()
502+
503+
def rollback(self) -> None:
504+
"""Discard pending invalidations on rollback."""
505+
self._pending_invalidations.clear()
506+
self._pending_model_bumps.clear()
507+
508+
async def commit_async(self) -> None:
509+
"""Process all pending invalidations after successful commit (async-safe).
510+
511+
This method performs cache I/O using the CacheManager async APIs so that
512+
dogpile backends (often sync network clients) never block the event loop.
513+
"""
514+
# First bump model versions for list query invalidation
515+
for model_name in self._pending_model_bumps:
516+
await self._cache_manager.bump_model_version_async(model_name)
517+
self._pending_model_bumps.clear()
518+
519+
# Then invalidate individual entities
520+
for model_name, entity_id, bind_group in self._pending_invalidations:
521+
await self._cache_manager.invalidate_entity_async(model_name, entity_id, bind_group)
522+
self._pending_invalidations.clear()
523+
524+
525+
def get_cache_tracker(
526+
session: "Union[Session, AsyncSession, scoped_session[Session], async_scoped_session[AsyncSession]]",
527+
cache_manager: Optional["CacheManager"] = None,
528+
create: bool = True,
529+
) -> Optional["CacheInvalidationTracker"]:
530+
"""Get or create a cache invalidation tracker for the session.
531+
532+
The tracker is stored on session.info to ensure proper scoping
533+
per session instance and avoid ContextVar collisions.
534+
535+
Args:
536+
session: The SQLAlchemy session instance (sync or async, including scoped sessions).
537+
cache_manager: The CacheManager instance (required if create=True).
538+
create: Whether to create a new tracker if one doesn't exist.
539+
540+
Returns:
541+
The cache tracker or None if not available.
542+
"""
543+
tracker: Optional[CacheInvalidationTracker] = session.info.get(_CACHE_TRACKER_KEY)
544+
if tracker is None and create and cache_manager is not None:
545+
tracker = CacheInvalidationTracker(cache_manager)
546+
session.info[_CACHE_TRACKER_KEY] = tracker
547+
return tracker
548+
549+
550+
class CacheInvalidationListener:
551+
"""Manages cache invalidation during SQLAlchemy Session transactions.
552+
553+
This listener hooks into the SQLAlchemy Session event lifecycle to
554+
handle cache invalidation in a transaction-safe manner.
555+
556+
How it Works:
557+
558+
1. **Event Registration (`setup_cache_listeners`):**
559+
Registers `after_commit` and `after_rollback` listeners globally
560+
on the Session class.
561+
562+
2. **Tracking Changes:**
563+
During mutations (add, update, delete), repositories call
564+
`get_cache_tracker()` and add invalidations via `add_invalidation()`.
565+
566+
3. **Processing (`after_commit`):**
567+
After successful commit, all pending invalidations are processed
568+
and the tracker is cleared.
569+
570+
4. **Discarding (`after_rollback`):**
571+
On rollback, pending invalidations are discarded without processing.
572+
"""
573+
574+
@classmethod
575+
def _is_listener_enabled(cls, session: "Session") -> bool:
576+
"""Check if cache listener is enabled for this session."""
577+
enable_listener = True
578+
579+
session_info = getattr(session, "info", {})
580+
if "enable_cache_listener" in session_info:
581+
return bool(session_info["enable_cache_listener"])
582+
583+
options_sources: list[Optional[Union[Callable[[], dict[str, Any]], dict[str, Any]]]] = []
584+
if session.bind:
585+
options_sources.append(getattr(session.bind, "execution_options", None))
586+
sync_engine = getattr(session.bind, "sync_engine", None)
587+
if sync_engine:
588+
options_sources.append(getattr(sync_engine, "execution_options", None))
589+
options_sources.append(getattr(session, "execution_options", None))
590+
591+
for options_source in options_sources:
592+
if options_source is None:
593+
continue
594+
595+
options: Optional[dict[str, Any]] = None
596+
if callable(options_source):
597+
try:
598+
result = options_source()
599+
if isinstance(result, dict): # pyright: ignore
600+
options = result
601+
except Exception as e:
602+
logger.debug("Error calling execution_options source: %s", e)
603+
else:
604+
options = options_source
605+
606+
if options is not None and "enable_cache_listener" in options:
607+
enable_listener = bool(options["enable_cache_listener"])
608+
break
609+
610+
return enable_listener
611+
612+
@classmethod
613+
def after_commit(cls, session: "Session") -> None:
614+
"""Process cache invalidations after a successful commit."""
615+
if not cls._is_listener_enabled(session):
616+
return
617+
618+
tracker = get_cache_tracker(session, create=False)
619+
if tracker:
620+
try:
621+
asyncio.get_running_loop()
622+
except RuntimeError:
623+
# No running loop: sync usage, perform invalidation inline.
624+
tracker.commit()
625+
else:
626+
# Running loop: schedule async invalidation so commit doesn't block.
627+
task = asyncio.create_task(tracker.commit_async())
628+
_active_cache_operations.add(task)
629+
task.add_done_callback(_active_cache_operations.discard)
630+
session.info.pop(_CACHE_TRACKER_KEY, None)
631+
632+
@classmethod
633+
def after_rollback(cls, session: "Session") -> None:
634+
"""Discard pending cache invalidations after a rollback."""
635+
tracker = get_cache_tracker(session, create=False)
636+
if tracker:
637+
tracker.rollback()
638+
session.info.pop(_CACHE_TRACKER_KEY, None)
639+
640+
641+
def setup_cache_listeners() -> None:
642+
"""Register cache invalidation event listeners globally.
643+
644+
This should be called once during application initialization to enable
645+
automatic cache invalidation for repositories using a CacheManager.
646+
"""
647+
from sqlalchemy.event import contains
648+
from sqlalchemy.orm import Session
649+
650+
listeners = {
651+
"after_commit": CacheInvalidationListener.after_commit,
652+
"after_rollback": CacheInvalidationListener.after_rollback,
653+
}
654+
655+
for event_name, listener_func in listeners.items():
656+
if not contains(Session, event_name, listener_func):
657+
event.listen(Session, event_name, listener_func)
658+
659+
logger.debug("Cache invalidation listeners registered")
660+
661+
447662
# Existing listener (keep it)
448663
def touch_updated_timestamp(session: "Session", *_: Any) -> None: # pragma: no cover
449664
"""Set timestamp on update.

advanced_alchemy/cache/__init__.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""Dogpile.cache integration for Advanced Alchemy.
2+
3+
This module provides optional caching support for SQLAlchemy repositories
4+
using dogpile.cache. It supports multiple cache backends (Redis, Memcached,
5+
file, memory) and provides automatic cache invalidation on model changes.
6+
7+
Features:
8+
- Multiple backend support (Redis, Memcached, file, memory, null)
9+
- Commit-aware cache invalidation via SQLAlchemy events
10+
- Version-based invalidation for list queries
11+
- JSON serialization for cached models (configurable)
12+
- Graceful degradation when dogpile.cache is not installed
13+
- Per-process singleflight to reduce stampedes on cache miss
14+
15+
Example:
16+
Using the config system (recommended)::
17+
18+
from advanced_alchemy.cache import CacheConfig
19+
from advanced_alchemy.config import SQLAlchemyAsyncConfig
20+
from advanced_alchemy.repository import (
21+
SQLAlchemyAsyncRepository,
22+
)
23+
24+
# Configure database with caching
25+
db_config = SQLAlchemyAsyncConfig(
26+
connection_string="sqlite+aiosqlite:///app.db",
27+
cache_config=CacheConfig(
28+
backend="dogpile.cache.memory",
29+
expiration_time=300,
30+
),
31+
)
32+
33+
# Cache listeners are auto-registered, cache_manager is stored
34+
# in session.info and auto-retrieved by repositories.
35+
36+
37+
class UserRepository(SQLAlchemyAsyncRepository[User]):
38+
model_type = User
39+
40+
41+
async with db_config.get_session() as session:
42+
repo = UserRepository(session=session)
43+
user = await repo.get(
44+
1
45+
) # First call hits DB and caches
46+
user = await repo.get(
47+
1
48+
) # Second call returns cached result
49+
50+
Redis configuration::
51+
52+
cache_config = CacheConfig(
53+
backend="dogpile.cache.redis",
54+
expiration_time=3600,
55+
arguments={
56+
"host": "localhost",
57+
"port": 6379,
58+
"db": 0,
59+
"distributed_lock": True,
60+
},
61+
)
62+
63+
Note:
64+
This module requires the optional ``dogpile.cache`` dependency.
65+
Install with: ``pip install advanced-alchemy[dogpile]``
66+
67+
Without dogpile.cache installed, the cache manager will use a
68+
NullRegion that provides the same interface but doesn't cache.
69+
70+
Manual Setup:
71+
If not using the config system, call ``setup_cache_listeners()``
72+
during application initialization and pass cache_manager explicitly::
73+
74+
from advanced_alchemy.cache import (
75+
CacheConfig,
76+
CacheManager,
77+
setup_cache_listeners,
78+
)
79+
80+
cache_manager = CacheManager(
81+
CacheConfig(backend="dogpile.cache.memory")
82+
)
83+
setup_cache_listeners()
84+
85+
repo = UserRepository(
86+
session=session, cache_manager=cache_manager
87+
)
88+
"""
89+
90+
from advanced_alchemy._listeners import setup_cache_listeners
91+
from advanced_alchemy.cache.config import CacheConfig
92+
from advanced_alchemy.cache.manager import DOGPILE_CACHE_INSTALLED, CacheManager
93+
from advanced_alchemy.cache.serializers import default_deserializer, default_serializer
94+
95+
__all__ = (
96+
"DOGPILE_CACHE_INSTALLED",
97+
"CacheConfig",
98+
"CacheManager",
99+
"default_deserializer",
100+
"default_serializer",
101+
"setup_cache_listeners",
102+
)

0 commit comments

Comments
 (0)