Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion superset/mcp_service/chart/tool/get_chart_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from superset.commands.exceptions import CommandException
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
from superset.extensions import event_logger
from superset.extensions import db, event_logger
from superset.mcp_service.chart.ascii_charts import (
generate_ascii_chart,
generate_ascii_table,
Expand Down Expand Up @@ -1140,6 +1140,15 @@ def __init__(self, fd: Dict[str, Any]):
)
chart = find_chart_by_identifier(request.identifier)

# Eagerly refresh all attributes while the session is still
# active. SQLAlchemy expires object attributes after any
# commit; if a downstream operation commits before the strategy
# classes access chart attributes, a DetachedInstanceError will
# be raised. Calling refresh() here ensures all column values
# are loaded into the object's __dict__ upfront.
if chart is not None:
db.session.refresh(chart)

# If not found and looks like a form_data_key, try transient
if (
not chart
Expand Down Expand Up @@ -1371,6 +1380,20 @@ def __init__(self, form_data: Dict[str, Any]):

return _sanitize_chart_preview_for_llm_context(result)

except SQLAlchemyError as e:
# Catch DetachedInstanceError and other SQLAlchemy errors that can
# surface when the ORM session expires or commits mid-request.
await ctx.error(
"Chart preview failed due to database session error: "
"identifier=%s, error_type=%s, error=%s"
% (request.identifier, type(e).__name__, str(e))
)
logger.exception("SQLAlchemy error in get_chart_preview: %s", e)
return ChartError(
error="Database session error while generating chart preview. "
"Please retry the request.",
error_type="InternalError",
)
except (
CommandException,
SupersetException,
Expand Down
188 changes: 188 additions & 0 deletions tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,3 +595,191 @@ async def test_ascii_art_variations(self):
"""

# These demonstrate the expected ASCII formats for different chart types


class TestDetachedInstanceError:
"""Tests that DetachedInstanceError is handled gracefully.

When the SQLAlchemy session commits mid-request, ORM objects expire and
become detached. Accessing lazy attributes on a detached Slice raises
DetachedInstanceError. The tool must:
1. Call db.session.refresh() immediately after loading the chart so all
column values are loaded upfront before any downstream operation.
2. Catch SQLAlchemyError (the base class) and return a ChartError
instead of propagating the exception.
"""

@pytest.mark.asyncio
async def test_session_refresh_called_after_chart_load(self):
"""db.session.refresh() is invoked right after find_chart_by_identifier."""
import importlib
from contextlib import nullcontext
from unittest.mock import MagicMock, patch

from superset.mcp_service.chart.schemas import URLPreview
from superset.utils import json

get_chart_preview_module = importlib.import_module(
"superset.mcp_service.chart.tool.get_chart_preview"
)

mock_chart = MagicMock()
mock_chart.id = 42
mock_chart.slice_name = "Sales Chart"
mock_chart.viz_type = "table"
mock_chart.datasource_id = 1
mock_chart.datasource_type = "table"
mock_chart.params = "{}"

refresh_calls: list[object] = []

def _fake_refresh(obj: object) -> None:
refresh_calls.append(obj)

url_preview = URLPreview(
preview_url="http://localhost/explore/?slice_id=42",
width=800,
height=600,
)

with (
patch.object(
get_chart_preview_module,
"find_chart_by_identifier",
return_value=mock_chart,
),
patch.object(
get_chart_preview_module.db,
"session",
**{"refresh.side_effect": _fake_refresh},
),
patch.object(
get_chart_preview_module,
"validate_chart_dataset",
return_value=MagicMock(is_valid=True, warnings=[]),
),
patch.object(
get_chart_preview_module.event_logger,
"log_context",
return_value=nullcontext(),
),
# Return a real URLPreview so Pydantic model validation succeeds
patch.object(
get_chart_preview_module.PreviewFormatGenerator,
"generate",
return_value=url_preview,
),
patch(
"superset.mcp_service.utils.url_utils.get_superset_base_url",
return_value="http://localhost",
),
):
from fastmcp import Client

from superset.mcp_service.app import mcp
from superset.mcp_service.chart.schemas import GetChartPreviewRequest

with patch("superset.mcp_service.auth.get_user_from_request") as mu:
mu.return_value = MagicMock(id=1, username="admin")
with patch(
"superset.mcp_service.auth.check_tool_permission", return_value=True
):
async with Client(mcp) as client:
response = await client.call_tool(
"get_chart_preview",
{
"request": GetChartPreviewRequest(
identifier=42, format="url"
).model_dump()
},
)

data = json.loads(response.content[0].text)
# The tool should succeed — not return a ChartError
assert "error_type" not in data, (
f"Expected ChartPreview but got ChartError: {data.get('error')}"
)
assert data.get("chart_id") == 42

assert len(refresh_calls) == 1, (
"db.session.refresh() should be called once after loading the chart"
)
assert refresh_calls[0] is mock_chart

@pytest.mark.asyncio
async def test_detached_instance_error_returns_chart_error(self):
"""DetachedInstanceError during preview generation returns ChartError."""
import importlib
from contextlib import nullcontext
from unittest.mock import MagicMock, patch

from sqlalchemy.orm.exc import DetachedInstanceError

get_chart_preview_module = importlib.import_module(
"superset.mcp_service.chart.tool.get_chart_preview"
)

mock_chart = MagicMock()
mock_chart.id = 7
mock_chart.slice_name = "Broken Chart"
mock_chart.viz_type = "bar"
mock_chart.datasource_id = 3
mock_chart.datasource_type = "table"
mock_chart.params = "{}"

with (
patch.object(
get_chart_preview_module,
"find_chart_by_identifier",
return_value=mock_chart,
),
patch.object(
get_chart_preview_module.db,
"session",
**{"refresh.return_value": None},
),
patch.object(
get_chart_preview_module,
"validate_chart_dataset",
return_value=MagicMock(is_valid=True, warnings=[]),
),
patch.object(
get_chart_preview_module.event_logger,
"log_context",
return_value=nullcontext(),
),
# Simulate the session expiring inside the strategy
patch.object(
get_chart_preview_module.PreviewFormatGenerator,
"generate",
side_effect=DetachedInstanceError(),
),
patch(
"superset.mcp_service.utils.url_utils.get_superset_base_url",
return_value="http://localhost",
),
):
from fastmcp import Client

from superset.mcp_service.app import mcp
from superset.mcp_service.chart.schemas import GetChartPreviewRequest
from superset.utils import json

with patch("superset.mcp_service.auth.get_user_from_request") as mu:
mu.return_value = MagicMock(id=1, username="admin")
with patch(
"superset.mcp_service.auth.check_tool_permission", return_value=True
):
async with Client(mcp) as client:
response = await client.call_tool(
"get_chart_preview",
{
"request": GetChartPreviewRequest(
identifier=7, format="ascii"
).model_dump()
},
)

data = json.loads(response.content[0].text)
assert data["error_type"] == "InternalError"
assert "session" in data["error"].lower() or "retry" in data["error"].lower()
Loading