|
11 | 11 | from sqlalchemy.inspection import inspect |
12 | 12 |
|
13 | 13 | if TYPE_CHECKING: |
14 | | - from sqlalchemy.orm import Session, UOWTransaction |
| 14 | + from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session |
| 15 | + from sqlalchemy.orm import Session, UOWTransaction, scoped_session |
15 | 16 | from sqlalchemy.orm.state import InstanceState |
16 | 17 |
|
| 18 | + from advanced_alchemy.cache import CacheManager |
17 | 19 | from advanced_alchemy.types.file_object import FileObjectSessionTracker, StorageRegistry |
18 | 20 |
|
19 | 21 | _active_file_operations: set[asyncio.Task[Any]] = set() |
20 | 22 | """Stores active file operations to prevent them from being garbage collected.""" |
| 23 | +_active_cache_operations: set[asyncio.Task[Any]] = set() |
| 24 | +"""Stores active cache invalidation operations to prevent them from being garbage collected.""" |
21 | 25 | # Context variable to hold the session tracker instance for the current session context |
22 | 26 | _current_session_tracker: contextvars.ContextVar[Optional["FileObjectSessionTracker"]] = contextvars.ContextVar( |
23 | 27 | "_current_session_tracker", |
@@ -444,6 +448,217 @@ def setup_file_object_listeners(registry: Optional["StorageRegistry"] = None) -> |
444 | 448 | set_async_context(False) |
445 | 449 |
|
446 | 450 |
|
| 451 | +# Cache invalidation support |
| 452 | +_CACHE_TRACKER_KEY = "_aa_cache_tracker" |
| 453 | + |
| 454 | + |
| 455 | +class CacheInvalidationTracker: |
| 456 | + """Tracks pending cache invalidations for a session transaction. |
| 457 | +
|
| 458 | + This tracker collects entity invalidations during a transaction and |
| 459 | + processes them only after a successful commit. On rollback, the |
| 460 | + pending invalidations are discarded. |
| 461 | +
|
| 462 | + Note: |
| 463 | + Model version bumps are also deferred to commit to ensure rollbacks |
| 464 | + don't invalidate list caches when no DB change occurred. |
| 465 | + """ |
| 466 | + |
| 467 | + __slots__ = ("_cache_manager", "_pending_invalidations", "_pending_model_bumps") |
| 468 | + |
| 469 | + def __init__(self, cache_manager: "CacheManager") -> None: |
| 470 | + self._cache_manager = cache_manager |
| 471 | + self._pending_invalidations: list[tuple[str, Any, Optional[str]]] = [] |
| 472 | + self._pending_model_bumps: set[str] = set() |
| 473 | + |
| 474 | + def add_invalidation(self, model_name: str, entity_id: Any, bind_group: Optional[str] = None) -> None: |
| 475 | + """Queue an entity for cache invalidation. |
| 476 | +
|
| 477 | + The actual invalidation and model version bump are deferred until |
| 478 | + commit() is called, ensuring rollbacks don't affect the cache. |
| 479 | +
|
| 480 | + Args: |
| 481 | + model_name: The model/table name. |
| 482 | + entity_id: The entity's primary key value. |
| 483 | + bind_group: Optional routing group for multi-master configurations. |
| 484 | + When provided, only the cache entry for that bind_group is |
| 485 | + invalidated. |
| 486 | + """ |
| 487 | + self._pending_invalidations.append((model_name, entity_id, bind_group)) |
| 488 | + # Queue model version bump for list query invalidation (deferred to commit) |
| 489 | + self._pending_model_bumps.add(model_name) |
| 490 | + |
| 491 | + def commit(self) -> None: |
| 492 | + """Process all pending invalidations after successful commit.""" |
| 493 | + # First bump model versions for list query invalidation |
| 494 | + for model_name in self._pending_model_bumps: |
| 495 | + self._cache_manager.bump_model_version_sync(model_name) |
| 496 | + self._pending_model_bumps.clear() |
| 497 | + |
| 498 | + # Then invalidate individual entities |
| 499 | + for model_name, entity_id, bind_group in self._pending_invalidations: |
| 500 | + self._cache_manager.invalidate_entity_sync(model_name, entity_id, bind_group) |
| 501 | + self._pending_invalidations.clear() |
| 502 | + |
| 503 | + def rollback(self) -> None: |
| 504 | + """Discard pending invalidations on rollback.""" |
| 505 | + self._pending_invalidations.clear() |
| 506 | + self._pending_model_bumps.clear() |
| 507 | + |
| 508 | + async def commit_async(self) -> None: |
| 509 | + """Process all pending invalidations after successful commit (async-safe). |
| 510 | +
|
| 511 | + This method performs cache I/O using the CacheManager async APIs so that |
| 512 | + dogpile backends (often sync network clients) never block the event loop. |
| 513 | + """ |
| 514 | + # First bump model versions for list query invalidation |
| 515 | + for model_name in self._pending_model_bumps: |
| 516 | + await self._cache_manager.bump_model_version_async(model_name) |
| 517 | + self._pending_model_bumps.clear() |
| 518 | + |
| 519 | + # Then invalidate individual entities |
| 520 | + for model_name, entity_id, bind_group in self._pending_invalidations: |
| 521 | + await self._cache_manager.invalidate_entity_async(model_name, entity_id, bind_group) |
| 522 | + self._pending_invalidations.clear() |
| 523 | + |
| 524 | + |
| 525 | +def get_cache_tracker( |
| 526 | + session: "Union[Session, AsyncSession, scoped_session[Session], async_scoped_session[AsyncSession]]", |
| 527 | + cache_manager: Optional["CacheManager"] = None, |
| 528 | + create: bool = True, |
| 529 | +) -> Optional["CacheInvalidationTracker"]: |
| 530 | + """Get or create a cache invalidation tracker for the session. |
| 531 | +
|
| 532 | + The tracker is stored on session.info to ensure proper scoping |
| 533 | + per session instance and avoid ContextVar collisions. |
| 534 | +
|
| 535 | + Args: |
| 536 | + session: The SQLAlchemy session instance (sync or async, including scoped sessions). |
| 537 | + cache_manager: The CacheManager instance (required if create=True). |
| 538 | + create: Whether to create a new tracker if one doesn't exist. |
| 539 | +
|
| 540 | + Returns: |
| 541 | + The cache tracker or None if not available. |
| 542 | + """ |
| 543 | + tracker: Optional[CacheInvalidationTracker] = session.info.get(_CACHE_TRACKER_KEY) |
| 544 | + if tracker is None and create and cache_manager is not None: |
| 545 | + tracker = CacheInvalidationTracker(cache_manager) |
| 546 | + session.info[_CACHE_TRACKER_KEY] = tracker |
| 547 | + return tracker |
| 548 | + |
| 549 | + |
| 550 | +class CacheInvalidationListener: |
| 551 | + """Manages cache invalidation during SQLAlchemy Session transactions. |
| 552 | +
|
| 553 | + This listener hooks into the SQLAlchemy Session event lifecycle to |
| 554 | + handle cache invalidation in a transaction-safe manner. |
| 555 | +
|
| 556 | + How it Works: |
| 557 | +
|
| 558 | + 1. **Event Registration (`setup_cache_listeners`):** |
| 559 | + Registers `after_commit` and `after_rollback` listeners globally |
| 560 | + on the Session class. |
| 561 | +
|
| 562 | + 2. **Tracking Changes:** |
| 563 | + During mutations (add, update, delete), repositories call |
| 564 | + `get_cache_tracker()` and add invalidations via `add_invalidation()`. |
| 565 | +
|
| 566 | + 3. **Processing (`after_commit`):** |
| 567 | + After successful commit, all pending invalidations are processed |
| 568 | + and the tracker is cleared. |
| 569 | +
|
| 570 | + 4. **Discarding (`after_rollback`):** |
| 571 | + On rollback, pending invalidations are discarded without processing. |
| 572 | + """ |
| 573 | + |
| 574 | + @classmethod |
| 575 | + def _is_listener_enabled(cls, session: "Session") -> bool: |
| 576 | + """Check if cache listener is enabled for this session.""" |
| 577 | + enable_listener = True |
| 578 | + |
| 579 | + session_info = getattr(session, "info", {}) |
| 580 | + if "enable_cache_listener" in session_info: |
| 581 | + return bool(session_info["enable_cache_listener"]) |
| 582 | + |
| 583 | + options_sources: list[Optional[Union[Callable[[], dict[str, Any]], dict[str, Any]]]] = [] |
| 584 | + if session.bind: |
| 585 | + options_sources.append(getattr(session.bind, "execution_options", None)) |
| 586 | + sync_engine = getattr(session.bind, "sync_engine", None) |
| 587 | + if sync_engine: |
| 588 | + options_sources.append(getattr(sync_engine, "execution_options", None)) |
| 589 | + options_sources.append(getattr(session, "execution_options", None)) |
| 590 | + |
| 591 | + for options_source in options_sources: |
| 592 | + if options_source is None: |
| 593 | + continue |
| 594 | + |
| 595 | + options: Optional[dict[str, Any]] = None |
| 596 | + if callable(options_source): |
| 597 | + try: |
| 598 | + result = options_source() |
| 599 | + if isinstance(result, dict): # pyright: ignore |
| 600 | + options = result |
| 601 | + except Exception as e: |
| 602 | + logger.debug("Error calling execution_options source: %s", e) |
| 603 | + else: |
| 604 | + options = options_source |
| 605 | + |
| 606 | + if options is not None and "enable_cache_listener" in options: |
| 607 | + enable_listener = bool(options["enable_cache_listener"]) |
| 608 | + break |
| 609 | + |
| 610 | + return enable_listener |
| 611 | + |
| 612 | + @classmethod |
| 613 | + def after_commit(cls, session: "Session") -> None: |
| 614 | + """Process cache invalidations after a successful commit.""" |
| 615 | + if not cls._is_listener_enabled(session): |
| 616 | + return |
| 617 | + |
| 618 | + tracker = get_cache_tracker(session, create=False) |
| 619 | + if tracker: |
| 620 | + try: |
| 621 | + asyncio.get_running_loop() |
| 622 | + except RuntimeError: |
| 623 | + # No running loop: sync usage, perform invalidation inline. |
| 624 | + tracker.commit() |
| 625 | + else: |
| 626 | + # Running loop: schedule async invalidation so commit doesn't block. |
| 627 | + task = asyncio.create_task(tracker.commit_async()) |
| 628 | + _active_cache_operations.add(task) |
| 629 | + task.add_done_callback(_active_cache_operations.discard) |
| 630 | + session.info.pop(_CACHE_TRACKER_KEY, None) |
| 631 | + |
| 632 | + @classmethod |
| 633 | + def after_rollback(cls, session: "Session") -> None: |
| 634 | + """Discard pending cache invalidations after a rollback.""" |
| 635 | + tracker = get_cache_tracker(session, create=False) |
| 636 | + if tracker: |
| 637 | + tracker.rollback() |
| 638 | + session.info.pop(_CACHE_TRACKER_KEY, None) |
| 639 | + |
| 640 | + |
| 641 | +def setup_cache_listeners() -> None: |
| 642 | + """Register cache invalidation event listeners globally. |
| 643 | +
|
| 644 | + This should be called once during application initialization to enable |
| 645 | + automatic cache invalidation for repositories using a CacheManager. |
| 646 | + """ |
| 647 | + from sqlalchemy.event import contains |
| 648 | + from sqlalchemy.orm import Session |
| 649 | + |
| 650 | + listeners = { |
| 651 | + "after_commit": CacheInvalidationListener.after_commit, |
| 652 | + "after_rollback": CacheInvalidationListener.after_rollback, |
| 653 | + } |
| 654 | + |
| 655 | + for event_name, listener_func in listeners.items(): |
| 656 | + if not contains(Session, event_name, listener_func): |
| 657 | + event.listen(Session, event_name, listener_func) |
| 658 | + |
| 659 | + logger.debug("Cache invalidation listeners registered") |
| 660 | + |
| 661 | + |
447 | 662 | # Existing listener (keep it) |
448 | 663 | def touch_updated_timestamp(session: "Session", *_: Any) -> None: # pragma: no cover |
449 | 664 | """Set timestamp on update. |
|
0 commit comments