diff --git a/superset/mcp_service/chart/tool/get_chart_preview.py b/superset/mcp_service/chart/tool/get_chart_preview.py index 1fb3740f1164..7ab8b29473a2 100644 --- a/superset/mcp_service/chart/tool/get_chart_preview.py +++ b/superset/mcp_service/chart/tool/get_chart_preview.py @@ -28,7 +28,7 @@ from superset.commands.exceptions import CommandException from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException -from superset.extensions import event_logger +from superset.extensions import db, event_logger from superset.mcp_service.chart.ascii_charts import ( generate_ascii_chart, generate_ascii_table, @@ -1140,6 +1140,15 @@ def __init__(self, fd: Dict[str, Any]): ) chart = find_chart_by_identifier(request.identifier) + # Eagerly refresh all attributes while the session is still + # active. SQLAlchemy expires object attributes after any + # commit; if a downstream operation commits before the strategy + # classes access chart attributes, a DetachedInstanceError will + # be raised. Calling refresh() here ensures all column values + # are loaded into the object's __dict__ upfront. + if chart is not None: + db.session.refresh(chart) + # If not found and looks like a form_data_key, try transient if ( not chart @@ -1371,6 +1380,20 @@ def __init__(self, form_data: Dict[str, Any]): return _sanitize_chart_preview_for_llm_context(result) + except SQLAlchemyError as e: + # Catch DetachedInstanceError and other SQLAlchemy errors that can + # surface when the ORM session expires or commits mid-request. + await ctx.error( + "Chart preview failed due to database session error: " + "identifier=%s, error_type=%s, error=%s" + % (request.identifier, type(e).__name__, str(e)) + ) + logger.exception("SQLAlchemy error in get_chart_preview: %s", e) + return ChartError( + error="Database session error while generating chart preview. " + "Please retry the request.", + error_type="InternalError", + ) except ( CommandException, SupersetException, diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py index e451dd7c5ee7..e5fcf909f7fd 100644 --- a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py +++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py @@ -595,3 +595,191 @@ async def test_ascii_art_variations(self): """ # These demonstrate the expected ASCII formats for different chart types + + +class TestDetachedInstanceError: + """Tests that DetachedInstanceError is handled gracefully. + + When the SQLAlchemy session commits mid-request, ORM objects expire and + become detached. Accessing lazy attributes on a detached Slice raises + DetachedInstanceError. The tool must: + 1. Call db.session.refresh() immediately after loading the chart so all + column values are loaded upfront before any downstream operation. + 2. Catch SQLAlchemyError (the base class) and return a ChartError + instead of propagating the exception. + """ + + @pytest.mark.asyncio + async def test_session_refresh_called_after_chart_load(self): + """db.session.refresh() is invoked right after find_chart_by_identifier.""" + import importlib + from contextlib import nullcontext + from unittest.mock import MagicMock, patch + + from superset.mcp_service.chart.schemas import URLPreview + from superset.utils import json + + get_chart_preview_module = importlib.import_module( + "superset.mcp_service.chart.tool.get_chart_preview" + ) + + mock_chart = MagicMock() + mock_chart.id = 42 + mock_chart.slice_name = "Sales Chart" + mock_chart.viz_type = "table" + mock_chart.datasource_id = 1 + mock_chart.datasource_type = "table" + mock_chart.params = "{}" + + refresh_calls: list[object] = [] + + def _fake_refresh(obj: object) -> None: + refresh_calls.append(obj) + + url_preview = URLPreview( + preview_url="http://localhost/explore/?slice_id=42", + width=800, + height=600, + ) + + with ( + patch.object( + get_chart_preview_module, + "find_chart_by_identifier", + return_value=mock_chart, + ), + patch.object( + get_chart_preview_module.db, + "session", + **{"refresh.side_effect": _fake_refresh}, + ), + patch.object( + get_chart_preview_module, + "validate_chart_dataset", + return_value=MagicMock(is_valid=True, warnings=[]), + ), + patch.object( + get_chart_preview_module.event_logger, + "log_context", + return_value=nullcontext(), + ), + # Return a real URLPreview so Pydantic model validation succeeds + patch.object( + get_chart_preview_module.PreviewFormatGenerator, + "generate", + return_value=url_preview, + ), + patch( + "superset.mcp_service.utils.url_utils.get_superset_base_url", + return_value="http://localhost", + ), + ): + from fastmcp import Client + + from superset.mcp_service.app import mcp + from superset.mcp_service.chart.schemas import GetChartPreviewRequest + + with patch("superset.mcp_service.auth.get_user_from_request") as mu: + mu.return_value = MagicMock(id=1, username="admin") + with patch( + "superset.mcp_service.auth.check_tool_permission", return_value=True + ): + async with Client(mcp) as client: + response = await client.call_tool( + "get_chart_preview", + { + "request": GetChartPreviewRequest( + identifier=42, format="url" + ).model_dump() + }, + ) + + data = json.loads(response.content[0].text) + # The tool should succeed — not return a ChartError + assert "error_type" not in data, ( + f"Expected ChartPreview but got ChartError: {data.get('error')}" + ) + assert data.get("chart_id") == 42 + + assert len(refresh_calls) == 1, ( + "db.session.refresh() should be called once after loading the chart" + ) + assert refresh_calls[0] is mock_chart + + @pytest.mark.asyncio + async def test_detached_instance_error_returns_chart_error(self): + """DetachedInstanceError during preview generation returns ChartError.""" + import importlib + from contextlib import nullcontext + from unittest.mock import MagicMock, patch + + from sqlalchemy.orm.exc import DetachedInstanceError + + get_chart_preview_module = importlib.import_module( + "superset.mcp_service.chart.tool.get_chart_preview" + ) + + mock_chart = MagicMock() + mock_chart.id = 7 + mock_chart.slice_name = "Broken Chart" + mock_chart.viz_type = "bar" + mock_chart.datasource_id = 3 + mock_chart.datasource_type = "table" + mock_chart.params = "{}" + + with ( + patch.object( + get_chart_preview_module, + "find_chart_by_identifier", + return_value=mock_chart, + ), + patch.object( + get_chart_preview_module.db, + "session", + **{"refresh.return_value": None}, + ), + patch.object( + get_chart_preview_module, + "validate_chart_dataset", + return_value=MagicMock(is_valid=True, warnings=[]), + ), + patch.object( + get_chart_preview_module.event_logger, + "log_context", + return_value=nullcontext(), + ), + # Simulate the session expiring inside the strategy + patch.object( + get_chart_preview_module.PreviewFormatGenerator, + "generate", + side_effect=DetachedInstanceError(), + ), + patch( + "superset.mcp_service.utils.url_utils.get_superset_base_url", + return_value="http://localhost", + ), + ): + from fastmcp import Client + + from superset.mcp_service.app import mcp + from superset.mcp_service.chart.schemas import GetChartPreviewRequest + from superset.utils import json + + with patch("superset.mcp_service.auth.get_user_from_request") as mu: + mu.return_value = MagicMock(id=1, username="admin") + with patch( + "superset.mcp_service.auth.check_tool_permission", return_value=True + ): + async with Client(mcp) as client: + response = await client.call_tool( + "get_chart_preview", + { + "request": GetChartPreviewRequest( + identifier=7, format="ascii" + ).model_dump() + }, + ) + + data = json.loads(response.content[0].text) + assert data["error_type"] == "InternalError" + assert "session" in data["error"].lower() or "retry" in data["error"].lower()