Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@ Changes can also be flagged with a GitHub label for tracking purposes. The URL o
- https://github.com/ethyca/fides/labels/high-risk: to indicate that a change is a "high-risk" change that could potentially lead to unanticipated regressions or degradations
- https://github.com/ethyca/fides/labels/db-migration: to indicate that a given change includes a DB migration

## [Unreleased](https://github.com/ethyca/fides/compare/2.86.1..main)
## [Unreleased](https://github.com/ethyca/fides/compare/2.86.2..main)

## [2.86.2](https://github.com/ethyca/fides/compare/2.86.1..2.86.2)

### Changed
- Reduced per-request overhead in logging and JWE token decryption to improve API performance [#8284](https://github.com/ethyca/fides/pull/8284)

## [2.86.1](https://github.com/ethyca/fides/compare/2.86.0..2.86.1)

Expand Down
37 changes: 20 additions & 17 deletions src/fides/api/asgi_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,23 +199,26 @@ async def handle_http(self, scope: Scope, receive: Receive, send: Send) -> None:

status_code = 500

try:
await self.app(scope, receive, wrapped_send)
status_code = get_status()
except Exception as e:
logger.exception(f"Unhandled exception processing request: '{e}'")
await self.send_response(send_with_headers, 500, b"Internal Server Error")
status_code = 500

handler_time = round((perf_counter() - start_time) * 1000, 3)

logger.bind(
method=method,
status_code=status_code,
handler_time=f"{handler_time}ms",
path=path,
fides_client=fides_client,
).info("Request received")
with logger.contextualize(request_id=request_id):
try:
await self.app(scope, receive, wrapped_send)
status_code = get_status()
except Exception as e:
logger.exception(f"Unhandled exception processing request: '{e}'")
await self.send_response(
send_with_headers, 500, b"Internal Server Error"
)
status_code = 500

handler_time = round((perf_counter() - start_time) * 1000, 3)

logger.bind(
method=method,
status_code=status_code,
handler_time=f"{handler_time}ms",
path=path,
fides_client=fides_client,
).info("Request received")


class AuditLogMiddleware(BaseASGIMiddleware):
Expand Down
37 changes: 24 additions & 13 deletions src/fides/api/oauth/jwt.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,41 @@
from functools import lru_cache

from joserfc import jwe
from joserfc.jwk import OctKey

JWT_ENCRYPTION_ALGORITHM = "A256GCM"


def generate_jwe(payload: str, encryption_key: str, encoding: str = "UTF-8") -> str:
"""Generates a JWE with the provided payload.
# Cache size is set to 3 somewhat arbitrarily. We currently don't allow
# encryption key rotation so we should effectively only have a single
# value in this cache. When we implement key rotation, there should only
# ever be 2 keys in use at a time, so 3 should be enough.
# This will need to be re-thought if we ever support multiple different
# encryption keys for encrypting tokens.
@lru_cache(maxsize=3)
def _get_oct_key(key_bytes: bytes) -> OctKey:
"""Cache the OctKey so it's built once per distinct key value."""
return OctKey.import_key(key_bytes)

Returns a string representation.
"""

def _to_oct_key(encryption_key: str) -> OctKey:
key_bytes = (
encryption_key.encode("utf-8")
if isinstance(encryption_key, str)
else encryption_key
)
key = OctKey.import_key(key_bytes)
return _get_oct_key(key_bytes)


def generate_jwe(payload: str, encryption_key: str, encoding: str = "UTF-8") -> str:
"""Generates a JWE with the provided payload.

Returns a string representation.
"""
return jwe.encrypt_compact(
{"alg": "dir", "enc": JWT_ENCRYPTION_ALGORITHM},
payload.encode(encoding),
key,
_to_oct_key(encryption_key),
)


Expand All @@ -33,13 +50,7 @@ def decrypt_jwe(token: str, encryption_key: str, encoding: str = "UTF-8") -> str
Returns:
The decrypted payload as a string.
"""
key_bytes = (
encryption_key.encode("utf-8")
if isinstance(encryption_key, str)
else encryption_key
)
key = OctKey.import_key(key_bytes)
result = jwe.decrypt_compact(token, key)
result = jwe.decrypt_compact(token, _to_oct_key(encryption_key))
if result.plaintext is None:
raise ValueError("JWE decryption produced no plaintext")
return result.plaintext.decode(encoding)
7 changes: 6 additions & 1 deletion src/fides/api/oauth/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import json
from datetime import datetime, timezone
from functools import update_wrapper
Expand Down Expand Up @@ -702,7 +703,11 @@ async def extract_token_and_load_client_async(
raise AuthenticationError(detail="Authentication Failure")

try:
token_data = json.loads(extract_payload(authorization, get_encryption_key()))
loop = asyncio.get_running_loop()
payload = await loop.run_in_executor(
None, extract_payload, authorization, get_encryption_key()
)
token_data = await loop.run_in_executor(None, json.loads, payload)
except (JoseError, ValueError) as exc:
logger.debug("Unable to parse auth token.")
raise AuthorizationError(detail="Not Authorized for this action") from exc
Expand Down
13 changes: 13 additions & 0 deletions src/fides/api/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextvars
from typing import Any, ContextManager, Dict, List, Optional

import celery_redis_cluster_backend # type: ignore[import-untyped] # noqa: F401 - registers redis+cluster/rediss+cluster backends
Expand Down Expand Up @@ -205,6 +206,11 @@ def _propagate_request_id(headers: Dict[str, Any], **kwargs: Any) -> None:
headers["request_id"] = request_id


_task_log_context: contextvars.ContextVar = contextvars.ContextVar(
"_task_log_context", default=None
)


@task_prerun.connect
def _restore_request_id(task: Task, **kwargs: Any) -> None:
"""Restore request_id from the task headers into the worker's ContextVar.
Expand All @@ -215,6 +221,9 @@ def _restore_request_id(task: Task, **kwargs: Any) -> None:
request_id = getattr(task.request, "request_id", None)
if request_id is not None:
set_request_id(request_id)
ctx = logger.contextualize(request_id=request_id)
ctx.__enter__()
_task_log_context.set(ctx)


@task_postrun.connect
Expand All @@ -225,6 +234,10 @@ def _clear_request_id(**kwargs: Any) -> None:
a request_id from Task A would leak into Task B if Task B was dispatched
without a request_id header.
"""
ctx = _task_log_context.get()
if ctx is not None:
ctx.__exit__(None, None, None)
_task_log_context.set(None)
set_request_id(None)


Expand Down
29 changes: 7 additions & 22 deletions src/fides/api/util/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from loguru import logger
from loguru._handler import Message

from fides.api.request_context import get_request_context
from fides.api.schemas.privacy_request import LogEntry, PrivacyRequestSource
from fides.api.util.sqlalchemy_filter import SQLAlchemyGeneratedFilter
from fides.config import CONFIG, FidesConfig
Expand All @@ -26,11 +25,11 @@

MASKED = "MASKED"

# Keys injected by the Loguru patcher into every log record's extra dict.
# These are excluded when deciding whether a record has "custom" extra context
# (e.g. from logger.bind()) so that the log format only appends the extra
# block when the caller explicitly added context beyond the automatic fields.
_PATCHER_INJECTED_KEYS = {"request_id"}
# Keys injected into log records by logger.contextualize() in the ASGI
# middleware. Excluded when deciding whether a record has "custom" extra
# context so that the log format only appends the extra block when the
# caller explicitly added context via logger.bind().
_CONTEXTUALIZED_KEYS = {"request_id"}


def _safe_stdout_sink(message: str) -> None:
Expand Down Expand Up @@ -194,7 +193,7 @@ def create_handler_dicts(
def has_custom_extra(log_record: Dict) -> bool:
"""Check if log record has custom extra context beyond Loguru's defaults."""
extra = log_record.get("extra", {})
return bool(extra.keys() - _PATCHER_INJECTED_KEYS)
return bool(extra.keys() - _CONTEXTUALIZED_KEYS)

# Helper to filter logs without custom extra
def filter_standard(log_record: Dict) -> bool:
Expand All @@ -218,20 +217,6 @@ def filter_standard(log_record: Dict) -> bool:
return [standard_dict, extra_dict]


def _inject_request_context(record: Dict[str, Any]) -> None:
"""Loguru patcher that injects request-scoped context into every log record.

Reads the current ``request_id`` from the ``RequestContext`` ContextVar and
adds it to ``record["extra"]``. Because the patcher runs on *every* log
record (including those from stdlib loggers intercepted by
``InterceptHandler``), all logs emitted during a request are automatically
tagged without any changes to individual call-sites.
"""
ctx = get_request_context()
if ctx.request_id is not None:
record["extra"]["request_id"] = ctx.request_id


def setup(config: FidesConfig) -> None:
"""
Configures logging with the appropriate sink based on configuration.
Expand Down Expand Up @@ -278,7 +263,7 @@ def setup(config: FidesConfig) -> None:
)
)

logger.configure(handlers=handlers, patcher=_inject_request_context) # type: ignore[arg-type]
logger.configure(handlers=handlers) # type: ignore[arg-type]

# Add InterceptHandler to root logger to capture standard library logs
# This intercepts logs from SQLAlchemy, Alembic, Celery, etc.
Expand Down
16 changes: 11 additions & 5 deletions src/fides/api/v1/endpoints/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,26 @@ async def database_health(db: Session = Depends(get_db)) -> Dict:
pools: Dict[str, PoolStatus] = {}
async_readonly_pool_prewarmed: Optional[bool] = None

migration_health, current_revision = get_db_health(
db_cred_provider.get_database_url(), db=db
loop = asyncio.get_running_loop()

migration_health, current_revision = await loop.run_in_executor(
None, get_db_health, db_cred_provider.get_database_url(), db
)

# Primary sync pool (already checked out by dependency-injected session).
pools["api_sync_primary"] = PoolStatus(health=_check_sync_session(db))
pools["api_sync_primary"] = PoolStatus(
health=await loop.run_in_executor(None, _check_sync_session, db)
)

# Optional sync readonly pool.
if CONFIG.database.sqlalchemy_readonly_database_uri:
readonly_db: Optional[Session] = None
try:
readonly_db = get_readonly_api_session()
readonly_db = await loop.run_in_executor(None, get_readonly_api_session)
pools["api_sync_readonly"] = PoolStatus(
health=_check_sync_session(readonly_db)
health=await loop.run_in_executor(
None, _check_sync_session, readonly_db
)
)
except Exception as error: # pylint: disable=broad-except
logger.error(
Expand Down
68 changes: 57 additions & 11 deletions tests/fides/api/middleware/test_request_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
_clear_request_id,
_propagate_request_id,
_restore_request_id,
_task_log_context,
)

from .conftest import (
Expand All @@ -29,9 +30,14 @@

@pytest.fixture(autouse=True)
def _clean_request_context():
"""Reset request context before and after each test."""
"""Reset request context and Loguru log context before and after each test."""
reset_request_context()
yield
# Clean up any leaked Loguru contextualize from Celery signal tests
ctx = _task_log_context.get()
if ctx is not None:
ctx.__exit__(None, None, None)
_task_log_context.set(None)
reset_request_context()


Expand Down Expand Up @@ -193,7 +199,7 @@ async def test_rejects_oversized_request_id(self, mock_asgi_app):
assert UUID_PATTERN.match(request_id_headers[0])

async def test_request_id_in_log_output(self, mock_asgi_app, loguru_caplog):
"""Request ID appears in log records via the patcher."""
"""Request ID appears in log records via logger.contextualize()."""
app, _ = mock_asgi_app()
middleware = LogRequestMiddleware(app)

Expand Down Expand Up @@ -268,24 +274,32 @@ async def test_concurrent_requests_get_different_ids(self, mock_asgi_app):
assert id1 != id2


class TestLoggerPatcher:
"""Tests for the Loguru patcher that injects request_id."""
class TestLoggerContextualize:
"""Tests for request_id injection via logger.contextualize()."""

def test_patcher_injects_request_id(self, loguru_caplog):
"""When request_id is set, it appears in log records."""
set_request_id("patcher-test-abc")
logger.info("test message")
def test_contextualize_injects_request_id(self, loguru_caplog):
"""When logger.contextualize is active, request_id appears in log records."""
with logger.contextualize(request_id="ctx-test-abc"):
logger.info("test message")

record = loguru_caplog.records[-1]
assert record.extra.get("request_id") == "patcher-test-abc"
assert record.extra.get("request_id") == "ctx-test-abc"

def test_patcher_omits_request_id_when_none(self, loguru_caplog):
"""When no request_id is set, it's not added to log records."""
def test_no_request_id_outside_context(self, loguru_caplog):
"""When no contextualize is active, request_id is not in log records."""
logger.info("test message without context")

record = loguru_caplog.records[-1]
assert "request_id" not in record.extra

def test_set_request_id_alone_does_not_inject_into_logs(self, loguru_caplog):
"""set_request_id without contextualize does not inject into log records."""
set_request_id("only-context-var")
logger.info("test message")

record = loguru_caplog.records[-1]
assert "request_id" not in record.extra


class TestCelerySignals:
"""Tests for Celery request_id propagation signals."""
Expand Down Expand Up @@ -325,10 +339,42 @@ def test_restore_request_id_skips_when_absent(self):

assert get_request_id() is None

def test_restore_injects_request_id_into_logs(self, loguru_caplog):
"""task_prerun sets up logger.contextualize so logs get request_id."""
mock_task = type(
"MockTask",
(),
{"request": type("MockRequest", (), {"request_id": "celery-log-789"})()},
)()

_restore_request_id(task=mock_task)
logger.info("task log message")

record = loguru_caplog.records[-1]
assert record.extra.get("request_id") == "celery-log-789"

# Cleanup
_clear_request_id()

def test_clear_request_id_after_task(self):
"""task_postrun clears only the request_id, not other context."""
set_request_context(request_id="should-be-cleared", user_id="keep-me")
_clear_request_id()

assert get_request_id() is None
assert get_user_id() == "keep-me"

def test_clear_removes_log_context(self, loguru_caplog):
"""task_postrun cleans up logger.contextualize so logs no longer get request_id."""
mock_task = type(
"MockTask",
(),
{"request": type("MockRequest", (), {"request_id": "celery-clear-test"})()},
)()

_restore_request_id(task=mock_task)
_clear_request_id()

logger.info("after task")
record = loguru_caplog.records[-1]
assert "request_id" not in record.extra
Loading