diff --git a/CHANGELOG.md b/CHANGELOG.md index 9fb6bb2e517..348c8bc49f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,12 @@ Changes can also be flagged with a GitHub label for tracking purposes. The URL o - https://github.com/ethyca/fides/labels/high-risk: to indicate that a change is a "high-risk" change that could potentially lead to unanticipated regressions or degradations - https://github.com/ethyca/fides/labels/db-migration: to indicate that a given change includes a DB migration -## [Unreleased](https://github.com/ethyca/fides/compare/2.86.1..main) +## [Unreleased](https://github.com/ethyca/fides/compare/2.86.2..main) + +## [2.86.2](https://github.com/ethyca/fides/compare/2.86.1..2.86.2) + +### Changed +- Reduced per-request overhead in logging and JWE token decryption to improve API performance [#8284](https://github.com/ethyca/fides/pull/8284) ## [2.86.1](https://github.com/ethyca/fides/compare/2.86.0..2.86.1) diff --git a/src/fides/api/asgi_middleware.py b/src/fides/api/asgi_middleware.py index 194e5ca912a..8daa52f6088 100644 --- a/src/fides/api/asgi_middleware.py +++ b/src/fides/api/asgi_middleware.py @@ -199,23 +199,26 @@ async def handle_http(self, scope: Scope, receive: Receive, send: Send) -> None: status_code = 500 - try: - await self.app(scope, receive, wrapped_send) - status_code = get_status() - except Exception as e: - logger.exception(f"Unhandled exception processing request: '{e}'") - await self.send_response(send_with_headers, 500, b"Internal Server Error") - status_code = 500 - - handler_time = round((perf_counter() - start_time) * 1000, 3) - - logger.bind( - method=method, - status_code=status_code, - handler_time=f"{handler_time}ms", - path=path, - fides_client=fides_client, - ).info("Request received") + with logger.contextualize(request_id=request_id): + try: + await self.app(scope, receive, wrapped_send) + status_code = get_status() + except Exception as e: + logger.exception(f"Unhandled exception processing request: '{e}'") + await self.send_response( + send_with_headers, 500, b"Internal Server Error" + ) + status_code = 500 + + handler_time = round((perf_counter() - start_time) * 1000, 3) + + logger.bind( + method=method, + status_code=status_code, + handler_time=f"{handler_time}ms", + path=path, + fides_client=fides_client, + ).info("Request received") class AuditLogMiddleware(BaseASGIMiddleware): diff --git a/src/fides/api/oauth/jwt.py b/src/fides/api/oauth/jwt.py index 1fb7eaace02..eab696a7eec 100644 --- a/src/fides/api/oauth/jwt.py +++ b/src/fides/api/oauth/jwt.py @@ -1,24 +1,41 @@ +from functools import lru_cache + from joserfc import jwe from joserfc.jwk import OctKey JWT_ENCRYPTION_ALGORITHM = "A256GCM" -def generate_jwe(payload: str, encryption_key: str, encoding: str = "UTF-8") -> str: - """Generates a JWE with the provided payload. +# Cache size is set to 3 somewhat arbitrarily. We currently don't allow +# encryption key rotation so we should effectively only have a single +# value in this cache. When we implement key rotation, there should only +# ever be 2 keys in use at a time, so 3 should be enough. +# This will need to be re-thought if we ever support multiple different +# encryption keys for encrypting tokens. +@lru_cache(maxsize=3) +def _get_oct_key(key_bytes: bytes) -> OctKey: + """Cache the OctKey so it's built once per distinct key value.""" + return OctKey.import_key(key_bytes) - Returns a string representation. - """ + +def _to_oct_key(encryption_key: str) -> OctKey: key_bytes = ( encryption_key.encode("utf-8") if isinstance(encryption_key, str) else encryption_key ) - key = OctKey.import_key(key_bytes) + return _get_oct_key(key_bytes) + + +def generate_jwe(payload: str, encryption_key: str, encoding: str = "UTF-8") -> str: + """Generates a JWE with the provided payload. + + Returns a string representation. + """ return jwe.encrypt_compact( {"alg": "dir", "enc": JWT_ENCRYPTION_ALGORITHM}, payload.encode(encoding), - key, + _to_oct_key(encryption_key), ) @@ -33,13 +50,7 @@ def decrypt_jwe(token: str, encryption_key: str, encoding: str = "UTF-8") -> str Returns: The decrypted payload as a string. """ - key_bytes = ( - encryption_key.encode("utf-8") - if isinstance(encryption_key, str) - else encryption_key - ) - key = OctKey.import_key(key_bytes) - result = jwe.decrypt_compact(token, key) + result = jwe.decrypt_compact(token, _to_oct_key(encryption_key)) if result.plaintext is None: raise ValueError("JWE decryption produced no plaintext") return result.plaintext.decode(encoding) diff --git a/src/fides/api/oauth/utils.py b/src/fides/api/oauth/utils.py index f290235af37..61f287b28b6 100644 --- a/src/fides/api/oauth/utils.py +++ b/src/fides/api/oauth/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json from datetime import datetime, timezone from functools import update_wrapper @@ -702,7 +703,11 @@ async def extract_token_and_load_client_async( raise AuthenticationError(detail="Authentication Failure") try: - token_data = json.loads(extract_payload(authorization, get_encryption_key())) + loop = asyncio.get_running_loop() + payload = await loop.run_in_executor( + None, extract_payload, authorization, get_encryption_key() + ) + token_data = await loop.run_in_executor(None, json.loads, payload) except (JoseError, ValueError) as exc: logger.debug("Unable to parse auth token.") raise AuthorizationError(detail="Not Authorized for this action") from exc diff --git a/src/fides/api/tasks/__init__.py b/src/fides/api/tasks/__init__.py index 9951b615874..db98f462fc1 100644 --- a/src/fides/api/tasks/__init__.py +++ b/src/fides/api/tasks/__init__.py @@ -1,3 +1,4 @@ +import contextvars from typing import Any, ContextManager, Dict, List, Optional import celery_redis_cluster_backend # type: ignore[import-untyped] # noqa: F401 - registers redis+cluster/rediss+cluster backends @@ -205,6 +206,11 @@ def _propagate_request_id(headers: Dict[str, Any], **kwargs: Any) -> None: headers["request_id"] = request_id +_task_log_context: contextvars.ContextVar = contextvars.ContextVar( + "_task_log_context", default=None +) + + @task_prerun.connect def _restore_request_id(task: Task, **kwargs: Any) -> None: """Restore request_id from the task headers into the worker's ContextVar. @@ -215,6 +221,9 @@ def _restore_request_id(task: Task, **kwargs: Any) -> None: request_id = getattr(task.request, "request_id", None) if request_id is not None: set_request_id(request_id) + ctx = logger.contextualize(request_id=request_id) + ctx.__enter__() + _task_log_context.set(ctx) @task_postrun.connect @@ -225,6 +234,10 @@ def _clear_request_id(**kwargs: Any) -> None: a request_id from Task A would leak into Task B if Task B was dispatched without a request_id header. """ + ctx = _task_log_context.get() + if ctx is not None: + ctx.__exit__(None, None, None) + _task_log_context.set(None) set_request_id(None) diff --git a/src/fides/api/util/logger.py b/src/fides/api/util/logger.py index 260e9e9c404..4db8004f798 100644 --- a/src/fides/api/util/logger.py +++ b/src/fides/api/util/logger.py @@ -16,7 +16,6 @@ from loguru import logger from loguru._handler import Message -from fides.api.request_context import get_request_context from fides.api.schemas.privacy_request import LogEntry, PrivacyRequestSource from fides.api.util.sqlalchemy_filter import SQLAlchemyGeneratedFilter from fides.config import CONFIG, FidesConfig @@ -26,11 +25,11 @@ MASKED = "MASKED" -# Keys injected by the Loguru patcher into every log record's extra dict. -# These are excluded when deciding whether a record has "custom" extra context -# (e.g. from logger.bind()) so that the log format only appends the extra -# block when the caller explicitly added context beyond the automatic fields. -_PATCHER_INJECTED_KEYS = {"request_id"} +# Keys injected into log records by logger.contextualize() in the ASGI +# middleware. Excluded when deciding whether a record has "custom" extra +# context so that the log format only appends the extra block when the +# caller explicitly added context via logger.bind(). +_CONTEXTUALIZED_KEYS = {"request_id"} def _safe_stdout_sink(message: str) -> None: @@ -194,7 +193,7 @@ def create_handler_dicts( def has_custom_extra(log_record: Dict) -> bool: """Check if log record has custom extra context beyond Loguru's defaults.""" extra = log_record.get("extra", {}) - return bool(extra.keys() - _PATCHER_INJECTED_KEYS) + return bool(extra.keys() - _CONTEXTUALIZED_KEYS) # Helper to filter logs without custom extra def filter_standard(log_record: Dict) -> bool: @@ -218,20 +217,6 @@ def filter_standard(log_record: Dict) -> bool: return [standard_dict, extra_dict] -def _inject_request_context(record: Dict[str, Any]) -> None: - """Loguru patcher that injects request-scoped context into every log record. - - Reads the current ``request_id`` from the ``RequestContext`` ContextVar and - adds it to ``record["extra"]``. Because the patcher runs on *every* log - record (including those from stdlib loggers intercepted by - ``InterceptHandler``), all logs emitted during a request are automatically - tagged without any changes to individual call-sites. - """ - ctx = get_request_context() - if ctx.request_id is not None: - record["extra"]["request_id"] = ctx.request_id - - def setup(config: FidesConfig) -> None: """ Configures logging with the appropriate sink based on configuration. @@ -278,7 +263,7 @@ def setup(config: FidesConfig) -> None: ) ) - logger.configure(handlers=handlers, patcher=_inject_request_context) # type: ignore[arg-type] + logger.configure(handlers=handlers) # type: ignore[arg-type] # Add InterceptHandler to root logger to capture standard library logs # This intercepts logs from SQLAlchemy, Alembic, Celery, etc. diff --git a/src/fides/api/v1/endpoints/health.py b/src/fides/api/v1/endpoints/health.py index 8f7cf1d0451..9bcc50730e1 100644 --- a/src/fides/api/v1/endpoints/health.py +++ b/src/fides/api/v1/endpoints/health.py @@ -154,20 +154,26 @@ async def database_health(db: Session = Depends(get_db)) -> Dict: pools: Dict[str, PoolStatus] = {} async_readonly_pool_prewarmed: Optional[bool] = None - migration_health, current_revision = get_db_health( - db_cred_provider.get_database_url(), db=db + loop = asyncio.get_running_loop() + + migration_health, current_revision = await loop.run_in_executor( + None, get_db_health, db_cred_provider.get_database_url(), db ) # Primary sync pool (already checked out by dependency-injected session). - pools["api_sync_primary"] = PoolStatus(health=_check_sync_session(db)) + pools["api_sync_primary"] = PoolStatus( + health=await loop.run_in_executor(None, _check_sync_session, db) + ) # Optional sync readonly pool. if CONFIG.database.sqlalchemy_readonly_database_uri: readonly_db: Optional[Session] = None try: - readonly_db = get_readonly_api_session() + readonly_db = await loop.run_in_executor(None, get_readonly_api_session) pools["api_sync_readonly"] = PoolStatus( - health=_check_sync_session(readonly_db) + health=await loop.run_in_executor( + None, _check_sync_session, readonly_db + ) ) except Exception as error: # pylint: disable=broad-except logger.error( diff --git a/tests/fides/api/middleware/test_request_id.py b/tests/fides/api/middleware/test_request_id.py index 3d9dca6c4ae..0ae294aef17 100644 --- a/tests/fides/api/middleware/test_request_id.py +++ b/tests/fides/api/middleware/test_request_id.py @@ -18,6 +18,7 @@ _clear_request_id, _propagate_request_id, _restore_request_id, + _task_log_context, ) from .conftest import ( @@ -29,9 +30,14 @@ @pytest.fixture(autouse=True) def _clean_request_context(): - """Reset request context before and after each test.""" + """Reset request context and Loguru log context before and after each test.""" reset_request_context() yield + # Clean up any leaked Loguru contextualize from Celery signal tests + ctx = _task_log_context.get() + if ctx is not None: + ctx.__exit__(None, None, None) + _task_log_context.set(None) reset_request_context() @@ -193,7 +199,7 @@ async def test_rejects_oversized_request_id(self, mock_asgi_app): assert UUID_PATTERN.match(request_id_headers[0]) async def test_request_id_in_log_output(self, mock_asgi_app, loguru_caplog): - """Request ID appears in log records via the patcher.""" + """Request ID appears in log records via logger.contextualize().""" app, _ = mock_asgi_app() middleware = LogRequestMiddleware(app) @@ -268,24 +274,32 @@ async def test_concurrent_requests_get_different_ids(self, mock_asgi_app): assert id1 != id2 -class TestLoggerPatcher: - """Tests for the Loguru patcher that injects request_id.""" +class TestLoggerContextualize: + """Tests for request_id injection via logger.contextualize().""" - def test_patcher_injects_request_id(self, loguru_caplog): - """When request_id is set, it appears in log records.""" - set_request_id("patcher-test-abc") - logger.info("test message") + def test_contextualize_injects_request_id(self, loguru_caplog): + """When logger.contextualize is active, request_id appears in log records.""" + with logger.contextualize(request_id="ctx-test-abc"): + logger.info("test message") record = loguru_caplog.records[-1] - assert record.extra.get("request_id") == "patcher-test-abc" + assert record.extra.get("request_id") == "ctx-test-abc" - def test_patcher_omits_request_id_when_none(self, loguru_caplog): - """When no request_id is set, it's not added to log records.""" + def test_no_request_id_outside_context(self, loguru_caplog): + """When no contextualize is active, request_id is not in log records.""" logger.info("test message without context") record = loguru_caplog.records[-1] assert "request_id" not in record.extra + def test_set_request_id_alone_does_not_inject_into_logs(self, loguru_caplog): + """set_request_id without contextualize does not inject into log records.""" + set_request_id("only-context-var") + logger.info("test message") + + record = loguru_caplog.records[-1] + assert "request_id" not in record.extra + class TestCelerySignals: """Tests for Celery request_id propagation signals.""" @@ -325,6 +339,23 @@ def test_restore_request_id_skips_when_absent(self): assert get_request_id() is None + def test_restore_injects_request_id_into_logs(self, loguru_caplog): + """task_prerun sets up logger.contextualize so logs get request_id.""" + mock_task = type( + "MockTask", + (), + {"request": type("MockRequest", (), {"request_id": "celery-log-789"})()}, + )() + + _restore_request_id(task=mock_task) + logger.info("task log message") + + record = loguru_caplog.records[-1] + assert record.extra.get("request_id") == "celery-log-789" + + # Cleanup + _clear_request_id() + def test_clear_request_id_after_task(self): """task_postrun clears only the request_id, not other context.""" set_request_context(request_id="should-be-cleared", user_id="keep-me") @@ -332,3 +363,18 @@ def test_clear_request_id_after_task(self): assert get_request_id() is None assert get_user_id() == "keep-me" + + def test_clear_removes_log_context(self, loguru_caplog): + """task_postrun cleans up logger.contextualize so logs no longer get request_id.""" + mock_task = type( + "MockTask", + (), + {"request": type("MockRequest", (), {"request_id": "celery-clear-test"})()}, + )() + + _restore_request_id(task=mock_task) + _clear_request_id() + + logger.info("after task") + record = loguru_caplog.records[-1] + assert "request_id" not in record.extra