Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions superset/mcp_service/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,45 @@ def _cleanup_session_on_error() -> None:
logger.warning("Error cleaning up session after exception: %s", e)


def _remove_session_safe() -> None:
"""Remove the scoped SQLAlchemy session, tolerating SSL/connection errors.

Thread-pool workers reuse threads across requests. Before each tool call
the session is removed to prevent a prior request's thread-local session
from leaking into the next one. If the underlying DBAPI connection died
between requests (e.g. RDS SSL idle-timeout or max-connection-age), the
rollback implicit in ``session.close()`` raises ``OperationalError``.

When that happens:
1. Invalidate the dead connection so the pool discards it (rather than
returning a broken connection to the next caller).
2. Retry ``remove()`` to deregister the session from the scoped registry.

The tool call still proceeds because a fresh connection will be obtained
on the next DB access.
"""
from sqlalchemy.exc import OperationalError

from superset.extensions import db

try:
db.session.remove()
except OperationalError as exc:
logger.warning(
"Connection error during pre-call session cleanup "
"(likely SSL/idle timeout); invalidating connection and retrying: %s",
exc,
)
try:
db.session.invalidate()
except Exception as invalidate_exc:
logger.debug(
"Could not invalidate session after connection error: %s",
invalidate_exc,
)
db.session.remove()


def mcp_auth_hook(tool_func: F) -> F: # noqa: C901
"""
Authentication and authorization decorator for MCP tools.
Expand Down Expand Up @@ -638,9 +677,7 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
# still be bound to a different tenant's DB engine. Removing it here
# ensures the next DB access creates a fresh session bound to the
# correct engine for the current request.
from superset.extensions import db

db.session.remove()
_remove_session_safe()
user = _setup_user_context()

# No Flask context - this is a FastMCP internal operation
Expand Down
53 changes: 50 additions & 3 deletions tests/unit_tests/mcp_service/test_auth_user_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,9 @@ def dummy_tool():

def _assert_preserved_then_return():
"""Verify g.user was preserved (not cleared) before returning."""
assert hasattr(g, "user"), (
"g.user should be preserved in request context but was removed"
)
assert hasattr(
g, "user"
), "g.user should be preserved in request context but was removed"
assert g.user is middleware_user, (
"g.user should be preserved in request context but was changed; "
f"g.user={g.user}"
Expand Down Expand Up @@ -409,6 +409,53 @@ def _assert_remove_already_called() -> MagicMock:
assert result == "fresh"


def test_sync_wrapper_handles_ssl_error_on_pre_call_remove(app) -> None:
"""sync_wrapper tolerates OperationalError from db.session.remove() before the call.

If the underlying DBAPI connection died between requests (e.g. RDS SSL
idle-timeout), the rollback implicit in session.close() raises
OperationalError. _remove_session_safe() should:
- Log a warning
- Call session.invalidate() to mark the dead connection for pool discard
- Retry session.remove() so the registry is clean
- Allow the tool to run successfully
"""
from sqlalchemy.exc import OperationalError as SAOperationalError

fresh_user = _make_mock_user("fresh")

def dummy_tool() -> str:
"""Dummy sync tool."""
return g.user.username

wrapped = mcp_auth_hook(dummy_tool)

remove_call_count = 0

def _flaky_remove() -> None:
nonlocal remove_call_count
remove_call_count += 1
if remove_call_count == 1:
raise SAOperationalError(
"SSL connection has been closed unexpectedly", None, None
)

with app.test_request_context():
g.user = fresh_user
with patch("superset.extensions.db") as mock_db:
mock_db.session.remove.side_effect = _flaky_remove

with patch(
"superset.mcp_service.auth.get_user_from_request",
return_value=fresh_user,
):
result = wrapped()

assert result == "fresh"
assert mock_db.session.invalidate.called, "invalidate() must be called on SSL error"
assert remove_call_count == 2, "remove() must be retried after SSL error"


# -- default_user_resolver --


Expand Down
Loading