diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py
index 205720c3c8d7..9163ab7c8630 100644
--- a/superset/mcp_service/app.py
+++ b/superset/mcp_service/app.py
@@ -222,10 +222,12 @@ def get_default_instructions(branding: str = "Apache Superset") -> str:
- PT1H (hourly), P1D (daily), P1W (weekly), P1M (monthly), P1Y (yearly)
Chart Types in Existing Charts (viewable via list_charts/get_chart_info):
-- pie, big_number, big_number_total, funnel, gauge_chart
-- echarts_timeseries_line, echarts_timeseries_bar, echarts_timeseries_area
-- pivot_table_v2, heatmap_v2, sankey_v2, sunburst_v2, treemap_v2
-- word_cloud, world_map, box_plot, bubble, mixed_timeseries
+Each chart returned by list_charts / get_chart_info includes a
+chart_type_display_name field with a human-readable name when available.
+This field is populated only for the 7 chart types supported by generate_chart
+(xy, pie, table, pivot_table, big_number, mixed_timeseries, handlebars).
+For all other viz_types (Funnel, Gauge, Heatmap, etc.) it will be null —
+use the raw viz_type field instead when referring to those chart types.
Query Examples:
- List all tables:
@@ -503,6 +505,7 @@ def create_mcp_app(
# NOTE: Always add new prompt/resource imports here when creating new prompts/resources.
# Prompts use @mcp.prompt decorators and resources use @mcp.resource decorators.
# They register automatically on import, similar to tools.
+import superset.mcp_service.chart.plugins # noqa: F401, E402 — registers all chart type plugins
from superset.mcp_service.chart import ( # noqa: F401, E402
prompts as chart_prompts,
resources as chart_resources,
diff --git a/superset/mcp_service/chart/chart_utils.py b/superset/mcp_service/chart/chart_utils.py
index 5b865c1f0dd6..f2ac21493791 100644
--- a/superset/mcp_service/chart/chart_utils.py
+++ b/superset/mcp_service/chart/chart_utils.py
@@ -318,29 +318,35 @@ def map_config_to_form_data(
| BigNumberChartConfig,
dataset_id: int | str | None = None,
) -> Dict[str, Any]:
- """Map chart config to Superset form_data."""
- if isinstance(config, TableChartConfig):
- return map_table_config(config)
- elif isinstance(config, XYChartConfig):
- return map_xy_config(config, dataset_id=dataset_id)
- elif isinstance(config, PieChartConfig):
- return map_pie_config(config)
- elif isinstance(config, PivotTableChartConfig):
- return map_pivot_table_config(config)
- elif isinstance(config, MixedTimeseriesChartConfig):
- return map_mixed_timeseries_config(config, dataset_id=dataset_id)
- elif isinstance(config, HandlebarsChartConfig):
- return map_handlebars_config(config)
- elif isinstance(config, BigNumberChartConfig):
- if config.show_trendline and config.temporal_column:
- if not is_column_truly_temporal(config.temporal_column, dataset_id):
- raise ValueError(
- f"Big Number trendline requires a temporal SQL column; "
- f"'{config.temporal_column}' is not temporal."
- )
- return map_big_number_config(config)
- else:
- raise ValueError(f"Unsupported config type: {type(config)}")
+ """Map chart config to Superset form_data via the plugin registry.
+
+ The previous if/elif chain across all 7 chart types has been replaced by a
+ single registry lookup. Cross-field constraints (e.g. BigNumber trendline
+ temporal check) are now owned by each plugin's post_map_validate() method
+ rather than being baked into this dispatcher.
+ """
+ # Local import: plugins call map_*_config from their to_form_data() methods,
+ # so chart_utils is loaded before plugins finish registering. A top-level
+ # import of registry here would trigger plugin loading mid-import = cycle.
+ from superset.mcp_service.chart.registry import get_registry
+
+ chart_type = getattr(config, "chart_type", None)
+ plugin = get_registry().get(chart_type) if chart_type else None
+
+ if plugin is None:
+ raise ValueError(
+ f"Unsupported config type: {type(config)} (chart_type={chart_type!r})"
+ )
+
+ form_data = plugin.to_form_data(config, dataset_id=dataset_id)
+
+ # Run post-map validation (e.g. BigNumber trendline temporal type check).
+ # Raise ValueError to preserve backward-compatible error handling in callers.
+ error = plugin.post_map_validate(config, form_data, dataset_id=dataset_id)
+ if error is not None:
+ raise ValueError(error.message)
+
+ return form_data
def _add_adhoc_filters(
@@ -1129,87 +1135,32 @@ def _big_number_chart_what(config: BigNumberChartConfig) -> str:
def generate_chart_name(
- config: TableChartConfig
- | XYChartConfig
- | PieChartConfig
- | PivotTableChartConfig
- | MixedTimeseriesChartConfig
- | HandlebarsChartConfig
- | BigNumberChartConfig,
+ config: Any,
dataset_name: str | None = None,
) -> str:
"""Generate a descriptive chart name following a standard format.
- Format conventions (by chart type):
- Aggregated (bar/scatter with group_by): [Metric] by [Dimension]
- Time-series (line/area, no group_by): [Metric] Over Time
- Table (no aggregates): [Dataset] Records
- Table (with aggregates): [Metric] Summary
- Pie: [Dimension] by [Metric]
- Pivot Table: Pivot Table – [Row1, Row2]
- Mixed Timeseries: [Primary] + [Secondary]
- An en-dash followed by context (filters / time grain) is appended
+ Delegates to each plugin's ``generate_name()`` method.
+ See each plugin's ``generate_name`` for chart-type-specific format conventions.
+ An en-dash followed by context (filters / time grain) is appended by the plugin
when such information is available.
"""
- if isinstance(config, TableChartConfig):
- what = _table_chart_what(config, dataset_name)
- context = _summarize_filters(config.filters)
- elif isinstance(config, XYChartConfig):
- what = _xy_chart_what(config)
- context = _xy_chart_context(config)
- elif isinstance(config, PieChartConfig):
- what = _pie_chart_what(config)
- context = _summarize_filters(config.filters)
- elif isinstance(config, PivotTableChartConfig):
- what = _pivot_table_what(config)
- context = _summarize_filters(config.filters)
- elif isinstance(config, MixedTimeseriesChartConfig):
- what = _mixed_timeseries_what(config)
- context = _summarize_filters(config.filters)
- elif isinstance(config, HandlebarsChartConfig):
- what = _handlebars_chart_what(config)
- context = _summarize_filters(getattr(config, "filters", None))
- elif isinstance(config, BigNumberChartConfig):
- what = _big_number_chart_what(config)
- context = _summarize_filters(getattr(config, "filters", None))
- else:
- return "Chart"
+ from superset.mcp_service.chart.registry import get_registry
- name = what
- if context:
- name = f"{what} \u2013 {context}"
- return _truncate(name)
+ plugin = get_registry().get(getattr(config, "chart_type", ""))
+ if plugin is None:
+ return "Chart"
+ return _truncate(plugin.generate_name(config, dataset_name))
def _resolve_viz_type(config: Any) -> str:
"""Resolve the Superset viz_type from a chart config object."""
- chart_type = getattr(config, "chart_type", "unknown")
- if chart_type == "xy":
- kind = getattr(config, "kind", "line")
- viz_type_map = {
- "line": "echarts_timeseries_line",
- "bar": "echarts_timeseries_bar",
- "area": "echarts_area",
- "scatter": "echarts_timeseries_scatter",
- }
- return viz_type_map.get(kind, "echarts_timeseries_line")
- elif chart_type == "table":
- return getattr(config, "viz_type", "table")
- elif chart_type == "pie":
- return "pie"
- elif chart_type == "pivot_table":
- return "pivot_table_v2"
- elif chart_type == "mixed_timeseries":
- return "mixed_timeseries"
- elif chart_type == "handlebars":
- return "handlebars"
- elif chart_type == "big_number":
- show_trendline = getattr(config, "show_trendline", False)
- temporal_column = getattr(config, "temporal_column", None)
- return (
- "big_number" if show_trendline and temporal_column else "big_number_total"
- )
- return "unknown"
+ from superset.mcp_service.chart.registry import get_registry
+
+ plugin = get_registry().get(getattr(config, "chart_type", ""))
+ if plugin is None:
+ return "unknown"
+ return plugin.resolve_viz_type(config)
def analyze_chart_capabilities(chart: Any | None, config: Any) -> ChartCapabilities:
diff --git a/superset/mcp_service/chart/plugin.py b/superset/mcp_service/chart/plugin.py
new file mode 100755
index 000000000000..6d88c7194d96
--- /dev/null
+++ b/superset/mcp_service/chart/plugin.py
@@ -0,0 +1,255 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+ChartTypePlugin protocol and BaseChartPlugin base class.
+
+Each chart type owns its pre-validation, column extraction, form_data mapping,
+and post-map validation in a single plugin class. This eliminates the previous
+pattern of 4 separate dispatch points (schema_validator.py, dataset_validator.py,
+chart_utils.py, pipeline.py) that had to be updated in sync whenever a new chart
+type was added.
+"""
+
+from __future__ import annotations
+
+from typing import Any, Protocol, runtime_checkable
+
+from superset.mcp_service.chart.schemas import ColumnRef
+from superset.mcp_service.common.error_schemas import ChartGenerationError
+
+
+@runtime_checkable
+class ChartTypePlugin(Protocol):
+ """
+ Protocol that every chart-type plugin must satisfy.
+
+ Implementing all eight methods in a single class guarantees that adding a
+ new chart type requires only one new file — the plugin — rather than edits
+ across multiple separate files.
+ """
+
+ #: Discriminator value matching ChartConfig's chart_type field.
+ chart_type: str
+
+ #: Human-readable name shown to users (e.g. "Line / Bar / Area / Scatter").
+ display_name: str
+
+ #: Maps every Superset-internal viz_type this plugin can produce to a
+ #: user-facing display name, e.g. {"echarts_timeseries_line": "Line Chart"}.
+ #: Used by the registry to resolve display names for existing charts without
+ #: needing a separate JSON mapping file.
+ native_viz_types: dict[str, str]
+
+ def pre_validate(
+ self,
+ config: dict[str, Any],
+ ) -> ChartGenerationError | None:
+ """
+ Early validation of the raw config dict before Pydantic parsing.
+
+ Called by SchemaValidator before attempting to parse the request.
+ Should check that required top-level keys are present and well-typed.
+
+ Returns None if valid, ChartGenerationError if invalid.
+ """
+ ...
+
+ def extract_column_refs(
+ self,
+ config: Any,
+ ) -> list[ColumnRef]:
+ """
+ Extract all column references from a parsed chart config.
+
+ Called by DatasetValidator to validate that all referenced columns exist
+ in the dataset. Must cover every field that holds a column name,
+ including filters.
+
+ Returns a list of ColumnRef objects (may be empty).
+ """
+ ...
+
+ def to_form_data(
+ self,
+ config: Any,
+ dataset_id: int | str | None = None,
+ ) -> dict[str, Any]:
+ """
+ Map a parsed chart config to Superset's internal form_data dict.
+
+ Replaces the if/elif chain in chart_utils.map_config_to_form_data().
+
+ Returns a Superset form_data dict ready for caching and rendering.
+ """
+ ...
+
+ def post_map_validate(
+ self,
+ config: Any,
+ form_data: dict[str, Any],
+ dataset_id: int | str | None = None,
+ ) -> ChartGenerationError | None:
+ """
+ Validate the mapped form_data after to_form_data() runs.
+
+ Use this for cross-field constraints that can only be checked once
+ form_data is assembled (e.g. BigNumber trendline requires a temporal
+ column whose type must be verified against the dataset).
+
+ Returns None if valid, ChartGenerationError if invalid.
+ """
+ ...
+
+ def normalize_column_refs(
+ self,
+ config: Any,
+ dataset_context: Any,
+ ) -> Any:
+ """
+ Return a new config with column names normalized to canonical dataset casing.
+
+ Called by DatasetValidator.normalize_column_names(). The default
+ implementation (in BaseChartPlugin) returns the config unchanged; plugins
+ with column fields override this to fix case sensitivity mismatches.
+
+ Returns a new config object (or the original if no normalization needed).
+ """
+ ...
+
+ def get_runtime_warnings(
+ self,
+ config: Any,
+ dataset_id: int | str,
+ ) -> list[str]:
+ """
+ Return chart-type-specific runtime warnings (performance, compatibility).
+
+ Called by RuntimeValidator to collect per-type warnings. Warnings are
+ informational only — they never block chart generation. The default
+ implementation returns an empty list; plugins override this to emit
+ chart-type-specific warnings (e.g. XY cardinality checks).
+
+ Returns a list of warning message strings (may be empty).
+ """
+ ...
+
+ def generate_name(
+ self,
+ config: Any,
+ dataset_name: str | None = None,
+ ) -> str:
+ """
+ Return a descriptive chart name for the given config.
+
+ Called by chart_utils.generate_chart_name(). The name should follow
+ the standard format conventions documented in that function. Plugins
+ that do not override this return the generic fallback "Chart".
+ """
+ ...
+
+ def resolve_viz_type(self, config: Any) -> str:
+ """
+ Return the Superset-internal viz_type string for this config.
+
+ Called by chart_utils._resolve_viz_type(). The returned string must
+ match a registered Superset viz plugin (e.g. "echarts_timeseries_line").
+ Plugins that do not override this return "unknown".
+ """
+ ...
+
+ def schema_error_hint(self) -> "ChartGenerationError | None":
+ """
+ Return a user-friendly error for Pydantic discriminated-union parse failures.
+
+ Called by SchemaValidator when Pydantic cannot parse the config union and
+ the chart_type is known. Returning None falls back to the generic error.
+ """
+ ...
+
+
+class BaseChartPlugin:
+ """
+ Base class providing sensible defaults for all ChartTypePlugin methods.
+
+ Concrete plugins extend this and override only what they need.
+ """
+
+ chart_type: str = ""
+ display_name: str = ""
+ native_viz_types: dict[str, str] = {}
+
+ def pre_validate(
+ self,
+ config: dict[str, Any],
+ ) -> ChartGenerationError | None:
+ return None
+
+ def extract_column_refs(
+ self,
+ config: Any,
+ ) -> list[ColumnRef]:
+ return []
+
+ def to_form_data(
+ self,
+ config: Any,
+ dataset_id: int | str | None = None,
+ ) -> dict[str, Any]:
+ raise NotImplementedError(
+ f"{self.__class__.__name__}.to_form_data() is not implemented"
+ )
+
+ def post_map_validate(
+ self,
+ config: Any,
+ form_data: dict[str, Any],
+ dataset_id: int | str | None = None,
+ ) -> ChartGenerationError | None:
+ return None
+
+ def normalize_column_refs(
+ self,
+ config: Any,
+ dataset_context: Any,
+ ) -> Any:
+ return config
+
+ def get_runtime_warnings(
+ self,
+ config: Any,
+ dataset_id: int | str,
+ ) -> list[str]:
+ return []
+
+ def generate_name(
+ self,
+ config: Any,
+ dataset_name: str | None = None,
+ ) -> str:
+ return "Chart"
+
+ def resolve_viz_type(self, config: Any) -> str:
+ return "unknown"
+
+ def schema_error_hint(self) -> ChartGenerationError | None:
+ return None
+
+ @staticmethod
+ def _with_context(what: str, context: str | None) -> str:
+ """Combine a 'what' label and optional context with an en-dash."""
+ return f"{what} – {context}" if context else what
diff --git a/superset/mcp_service/chart/plugins/__init__.py b/superset/mcp_service/chart/plugins/__init__.py
new file mode 100644
index 000000000000..5527c43a55f0
--- /dev/null
+++ b/superset/mcp_service/chart/plugins/__init__.py
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Chart type plugins package.
+
+Importing this module registers all built-in chart type plugins in the global
+registry. This module is imported by app.py at startup.
+
+To add a new chart type:
+1. Create ``superset/mcp_service/chart/plugins/{chart_type}.py``
+2. Implement a class extending ``BaseChartPlugin``
+3. Import and register it here
+"""
+
+from superset.mcp_service.chart.plugins.big_number import BigNumberChartPlugin
+from superset.mcp_service.chart.plugins.handlebars import HandlebarsChartPlugin
+from superset.mcp_service.chart.plugins.mixed_timeseries import (
+ MixedTimeseriesChartPlugin,
+)
+from superset.mcp_service.chart.plugins.pie import PieChartPlugin
+from superset.mcp_service.chart.plugins.pivot_table import PivotTableChartPlugin
+from superset.mcp_service.chart.plugins.table import TableChartPlugin
+from superset.mcp_service.chart.plugins.xy import XYChartPlugin
+from superset.mcp_service.chart.registry import register
+
+# Register all built-in chart type plugins
+register(XYChartPlugin())
+register(TableChartPlugin())
+register(PieChartPlugin())
+register(PivotTableChartPlugin())
+register(MixedTimeseriesChartPlugin())
+register(HandlebarsChartPlugin())
+register(BigNumberChartPlugin())
+
+__all__ = [
+ "BigNumberChartPlugin",
+ "HandlebarsChartPlugin",
+ "MixedTimeseriesChartPlugin",
+ "PieChartPlugin",
+ "PivotTableChartPlugin",
+ "TableChartPlugin",
+ "XYChartPlugin",
+]
diff --git a/superset/mcp_service/chart/plugins/big_number.py b/superset/mcp_service/chart/plugins/big_number.py
new file mode 100755
index 000000000000..e542f8e75f01
--- /dev/null
+++ b/superset/mcp_service/chart/plugins/big_number.py
@@ -0,0 +1,220 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Big number chart type plugin."""
+
+from __future__ import annotations
+
+from typing import Any
+
+from superset.mcp_service.chart.chart_utils import (
+ _big_number_chart_what,
+ _summarize_filters,
+ is_column_truly_temporal,
+ map_big_number_config,
+)
+from superset.mcp_service.chart.plugin import BaseChartPlugin
+from superset.mcp_service.chart.schemas import BigNumberChartConfig, ColumnRef
+from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
+from superset.mcp_service.common.error_schemas import ChartGenerationError
+
+
+class BigNumberChartPlugin(BaseChartPlugin):
+ """Plugin for big_number chart type."""
+
+ chart_type = "big_number"
+ display_name = "Big Number"
+ native_viz_types = {
+ "big_number": "Big Number with Trendline",
+ "big_number_total": "Big Number",
+ }
+
+ def pre_validate(
+ self,
+ config: dict[str, Any],
+ ) -> ChartGenerationError | None:
+ if "metric" not in config:
+ return ChartGenerationError(
+ error_type="missing_metric",
+ message="Big Number chart missing required field: metric",
+ details=(
+ "Big Number charts require a 'metric' field "
+ "specifying the value to display"
+ ),
+ suggestions=[
+ "Add 'metric' with name and aggregate: "
+ "{'name': 'revenue', 'aggregate': 'SUM'}",
+ "The aggregate function is required (SUM, COUNT, AVG, MIN, MAX)",
+ "Example: {'chart_type': 'big_number', "
+ "'metric': {'name': 'sales', 'aggregate': 'SUM'}}",
+ ],
+ error_code="MISSING_BIG_NUMBER_METRIC",
+ )
+
+ metric = config.get("metric", {})
+ if not isinstance(metric, dict):
+ return ChartGenerationError(
+ error_type="invalid_metric_type",
+ message="Big Number metric must be a dict with 'name' and 'aggregate'",
+ details=(
+ f"The 'metric' field must be an object, got {type(metric).__name__}"
+ ),
+ suggestions=[
+ "Use a dict: {'name': 'col', 'aggregate': 'SUM'}",
+ "Valid aggregates: SUM, COUNT, AVG, MIN, MAX",
+ ],
+ error_code="INVALID_BIG_NUMBER_METRIC_TYPE",
+ )
+ if not metric.get("aggregate") and not metric.get("saved_metric"):
+ return ChartGenerationError(
+ error_type="missing_metric_aggregate",
+ message=(
+ "Big Number metric must include an aggregate function "
+ "or reference a saved metric"
+ ),
+ details=(
+ "The metric must have an 'aggregate' field or 'saved_metric': true"
+ ),
+ suggestions=[
+ "Add 'aggregate': {'name': 'col', 'aggregate': 'SUM'}",
+ "Or use a saved metric: {'name': 'metric', 'saved_metric': true}",
+ "Valid aggregates: SUM, COUNT, AVG, MIN, MAX",
+ ],
+ error_code="MISSING_BIG_NUMBER_AGGREGATE",
+ )
+
+ show_trendline = config.get("show_trendline", False)
+ temporal_column = config.get("temporal_column")
+ if show_trendline and not temporal_column:
+ return ChartGenerationError(
+ error_type="missing_temporal_column",
+ message="Trendline requires a temporal column",
+ details=(
+ "When 'show_trendline' is True, "
+ "a 'temporal_column' must be specified"
+ ),
+ suggestions=[
+ "Add 'temporal_column': 'date_column_name'",
+ "Or set 'show_trendline': false for number only",
+ "Use get_dataset_info to find temporal columns",
+ ],
+ error_code="MISSING_TEMPORAL_COLUMN",
+ )
+
+ return None
+
+ def extract_column_refs(self, config: Any) -> list[ColumnRef]:
+ if not isinstance(config, BigNumberChartConfig):
+ return []
+ refs: list[ColumnRef] = [config.metric]
+ # temporal_column is a str field, not a ColumnRef — validate it exists
+ if config.temporal_column:
+ refs.append(ColumnRef(name=config.temporal_column))
+ if config.filters:
+ for f in config.filters:
+ refs.append(ColumnRef(name=f.column))
+ return refs
+
+ def to_form_data(
+ self, config: Any, dataset_id: int | str | None = None
+ ) -> dict[str, Any]:
+ return map_big_number_config(config)
+
+ def post_map_validate(
+ self,
+ config: Any,
+ form_data: dict[str, Any],
+ dataset_id: int | str | None = None,
+ ) -> ChartGenerationError | None:
+ """Verify the trendline temporal column is a real temporal SQL type.
+
+ This check was previously baked into map_config_to_form_data() in
+ chart_utils.py as a special case. Moving it here keeps the dispatcher
+ clean and makes the constraint explicit and discoverable.
+ """
+ if not isinstance(config, BigNumberChartConfig):
+ return None
+ if not (config.show_trendline and config.temporal_column):
+ return None
+
+ if not is_column_truly_temporal(config.temporal_column, dataset_id):
+ return ChartGenerationError(
+ error_type="non_temporal_trendline_column",
+ message=(
+ f"Big Number trendline requires a temporal SQL column; "
+ f"'{config.temporal_column}' is not temporal."
+ ),
+ details=(
+ f"Column '{config.temporal_column}' does not have a temporal "
+ f"SQL type (DATE, DATETIME, TIMESTAMP). The trendline requires "
+ f"a true temporal column for DATE_TRUNC to work."
+ ),
+ suggestions=[
+ "Use get_dataset_info to find columns with temporal SQL types",
+ "Set 'show_trendline': false to use any column as the metric",
+ "If the column contains dates stored as integers, "
+ "consider casting it in a virtual dataset",
+ ],
+ error_code="NON_TEMPORAL_TRENDLINE_COLUMN",
+ )
+
+ return None
+
+ def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
+ what = _big_number_chart_what(config)
+ context = _summarize_filters(getattr(config, "filters", None))
+ return self._with_context(what, context)
+
+ def resolve_viz_type(self, config: Any) -> str:
+ show_trendline = getattr(config, "show_trendline", False)
+ temporal_column = getattr(config, "temporal_column", None)
+ if show_trendline and temporal_column:
+ return "big_number"
+ return "big_number_total"
+
+ def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
+ config_dict = config.model_dump()
+
+ if config_dict.get("metric") and not config_dict["metric"].get("saved_metric"):
+ config_dict["metric"]["name"] = DatasetValidator._get_canonical_column_name(
+ config_dict["metric"]["name"], dataset_context
+ )
+ if config_dict.get("temporal_column"):
+ config_dict["temporal_column"] = (
+ DatasetValidator._get_canonical_column_name(
+ config_dict["temporal_column"], dataset_context
+ )
+ )
+ DatasetValidator._normalize_filters(config_dict, dataset_context)
+ return BigNumberChartConfig.model_validate(config_dict)
+
+ def schema_error_hint(self) -> ChartGenerationError | None:
+ return ChartGenerationError(
+ error_type="big_number_validation_error",
+ message="Big Number chart configuration validation failed",
+ details=(
+ "The Big Number chart configuration is missing required "
+ "fields or has invalid structure"
+ ),
+ suggestions=[
+ "Ensure 'metric' field has 'name' and 'aggregate'",
+ "Example: 'metric': {'name': 'revenue', 'aggregate': 'SUM'}",
+ "For trendline: add show_trendline=true and temporal_column='col'",
+ "Without trendline: just provide the metric",
+ ],
+ error_code="BIG_NUMBER_VALIDATION_ERROR",
+ )
diff --git a/superset/mcp_service/chart/plugins/handlebars.py b/superset/mcp_service/chart/plugins/handlebars.py
new file mode 100755
index 000000000000..53d78cd8b824
--- /dev/null
+++ b/superset/mcp_service/chart/plugins/handlebars.py
@@ -0,0 +1,189 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Handlebars chart type plugin."""
+
+from __future__ import annotations
+
+from typing import Any
+
+from superset.mcp_service.chart.chart_utils import (
+ _handlebars_chart_what,
+ _summarize_filters,
+ map_handlebars_config,
+)
+from superset.mcp_service.chart.plugin import BaseChartPlugin
+from superset.mcp_service.chart.schemas import ColumnRef, HandlebarsChartConfig
+from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
+from superset.mcp_service.common.error_schemas import ChartGenerationError
+
+
+class HandlebarsChartPlugin(BaseChartPlugin):
+ """Plugin for handlebars chart type (custom HTML template charts)."""
+
+ chart_type = "handlebars"
+ display_name = "Handlebars (Custom Template)"
+ native_viz_types = {
+ "handlebars": "Custom Template Chart",
+ }
+
+ def pre_validate(
+ self,
+ config: dict[str, Any],
+ ) -> ChartGenerationError | None:
+ if "handlebars_template" not in config:
+ return ChartGenerationError(
+ error_type="missing_handlebars_template",
+ message="Handlebars chart missing required field: handlebars_template",
+ details=(
+ "Handlebars charts require a 'handlebars_template' string "
+ "containing Handlebars HTML template markup"
+ ),
+ suggestions=[
+ "Add 'handlebars_template' with a Handlebars HTML template",
+ "Data is available as {{data}} array in the template",
+ "Example: '
{{#each data}}- {{this.name}}: "
+ "{{this.value}}
{{/each}}
'",
+ ],
+ error_code="MISSING_HANDLEBARS_TEMPLATE",
+ )
+
+ template = config.get("handlebars_template")
+ if not isinstance(template, str) or not template.strip():
+ return ChartGenerationError(
+ error_type="invalid_handlebars_template",
+ message="Handlebars template must be a non-empty string",
+ details=(
+ "The 'handlebars_template' field must be a non-empty string "
+ "containing valid Handlebars HTML template markup"
+ ),
+ suggestions=[
+ "Ensure handlebars_template is a non-empty string",
+ "Example: '{{#each data}}- {{this.name}}
{{/each}}
'",
+ ],
+ error_code="INVALID_HANDLEBARS_TEMPLATE",
+ )
+
+ query_mode = config.get("query_mode", "aggregate")
+ if query_mode not in ("aggregate", "raw"):
+ return ChartGenerationError(
+ error_type="invalid_query_mode",
+ message="Invalid query_mode for handlebars chart",
+ details="query_mode must be either 'aggregate' or 'raw'",
+ suggestions=[
+ "Use 'aggregate' for aggregated data (default)",
+ "Use 'raw' for individual rows",
+ ],
+ error_code="INVALID_QUERY_MODE",
+ )
+
+ if query_mode == "raw" and not config.get("columns"):
+ return ChartGenerationError(
+ error_type="missing_raw_columns",
+ message="Handlebars chart in 'raw' mode requires 'columns'",
+ details=(
+ "When query_mode is 'raw', you must specify which columns "
+ "to include in the query results"
+ ),
+ suggestions=[
+ "Add 'columns': [{'name': 'column_name'}] for raw mode",
+ "Or use query_mode='aggregate' with 'metrics' and optional 'groupby'", # noqa: E501
+ ],
+ error_code="MISSING_RAW_COLUMNS",
+ )
+
+ if query_mode == "aggregate" and not config.get("metrics"):
+ return ChartGenerationError(
+ error_type="missing_aggregate_metrics",
+ message="Handlebars chart in 'aggregate' mode requires 'metrics'",
+ details=(
+ "When query_mode is 'aggregate' (default), you must specify "
+ "at least one metric with an aggregate function"
+ ),
+ suggestions=[
+ "Add 'metrics': [{'name': 'column', 'aggregate': 'SUM'}]",
+ "Or use query_mode='raw' with 'columns' for individual rows",
+ ],
+ error_code="MISSING_AGGREGATE_METRICS",
+ )
+
+ return None
+
+ def extract_column_refs(self, config: Any) -> list[ColumnRef]:
+ if not isinstance(config, HandlebarsChartConfig):
+ return []
+ refs: list[ColumnRef] = []
+ if config.columns:
+ refs.extend(config.columns)
+ if config.metrics:
+ refs.extend(config.metrics)
+ if config.groupby:
+ refs.extend(config.groupby)
+ if config.filters:
+ for f in config.filters:
+ refs.append(ColumnRef(name=f.column))
+ return refs
+
+ def to_form_data(
+ self, config: Any, dataset_id: int | str | None = None
+ ) -> dict[str, Any]:
+ return map_handlebars_config(config)
+
+ def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
+ what = _handlebars_chart_what(config)
+ context = _summarize_filters(getattr(config, "filters", None))
+ return self._with_context(what, context)
+
+ def resolve_viz_type(self, config: Any) -> str:
+ return "handlebars"
+
+ def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
+ config_dict = config.model_dump()
+
+ def _norm_list(key: str) -> None:
+ if config_dict.get(key):
+ for col in config_dict[key]:
+ if not col.get("saved_metric"):
+ col["name"] = DatasetValidator._get_canonical_column_name(
+ col["name"], dataset_context
+ )
+
+ _norm_list("columns")
+ _norm_list("metrics")
+ _norm_list("groupby")
+ DatasetValidator._normalize_filters(config_dict, dataset_context)
+ return HandlebarsChartConfig.model_validate(config_dict)
+
+ def schema_error_hint(self) -> ChartGenerationError | None:
+ return ChartGenerationError(
+ error_type="handlebars_validation_error",
+ message="Handlebars chart configuration validation failed",
+ details=(
+ "The handlebars chart configuration is missing "
+ "required fields or has invalid structure"
+ ),
+ suggestions=[
+ "Ensure 'handlebars_template' is a non-empty string",
+ "For aggregate mode: add 'metrics' with aggregate functions",
+ "For raw mode: set 'query_mode': 'raw' and add 'columns'",
+ "Example: {'chart_type': 'handlebars', "
+ "'handlebars_template': "
+ "'{{#each data}}- {{this.name}}
{{/each}}
', "
+ "'metrics': [{'name': 'sales', 'aggregate': 'SUM'}]}",
+ ],
+ error_code="HANDLEBARS_VALIDATION_ERROR",
+ )
diff --git a/superset/mcp_service/chart/plugins/mixed_timeseries.py b/superset/mcp_service/chart/plugins/mixed_timeseries.py
new file mode 100755
index 000000000000..0cf7b82e80eb
--- /dev/null
+++ b/superset/mcp_service/chart/plugins/mixed_timeseries.py
@@ -0,0 +1,165 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Mixed timeseries chart type plugin."""
+
+from __future__ import annotations
+
+from typing import Any
+
+from superset.mcp_service.chart.chart_utils import (
+ _mixed_timeseries_what,
+ _summarize_filters,
+ map_mixed_timeseries_config,
+)
+from superset.mcp_service.chart.plugin import BaseChartPlugin
+from superset.mcp_service.chart.schemas import ColumnRef, MixedTimeseriesChartConfig
+from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
+from superset.mcp_service.common.error_schemas import ChartGenerationError
+
+
+class MixedTimeseriesChartPlugin(BaseChartPlugin):
+ """Plugin for mixed_timeseries chart type."""
+
+ chart_type = "mixed_timeseries"
+ display_name = "Mixed Timeseries"
+ native_viz_types = {
+ "mixed_timeseries": "Mixed Timeseries Chart",
+ }
+
+ def pre_validate(
+ self,
+ config: dict[str, Any],
+ ) -> ChartGenerationError | None:
+ missing_fields = []
+
+ if "x" not in config:
+ missing_fields.append("'x' (X-axis temporal column)")
+ if "y" not in config:
+ missing_fields.append("'y' (primary Y-axis metrics)")
+ if "y_secondary" not in config:
+ missing_fields.append("'y_secondary' (secondary Y-axis metrics)")
+
+ if missing_fields:
+ return ChartGenerationError(
+ error_type="missing_mixed_timeseries_fields",
+ message=(
+ f"Mixed timeseries chart missing required fields: "
+ f"{', '.join(missing_fields)}"
+ ),
+ details=(
+ "Mixed timeseries charts require an x-axis, primary metrics, "
+ "and secondary metrics"
+ ),
+ suggestions=[
+ "Add 'x' field: {'name': 'date_column'}",
+ "Add 'y' field: [{'name': 'revenue', 'aggregate': 'SUM'}]",
+ "Add 'y_secondary': [{'name': 'orders', 'aggregate': 'COUNT'}]",
+ "Optional: 'primary_kind' and 'secondary_kind' for chart types",
+ ],
+ error_code="MISSING_MIXED_TIMESERIES_FIELDS",
+ )
+
+ for field_name in ["y", "y_secondary"]:
+ if not isinstance(config.get(field_name, []), list):
+ return ChartGenerationError(
+ error_type=f"invalid_{field_name}_format",
+ message=f"'{field_name}' must be a list of metrics",
+ details=(
+ f"The '{field_name}' field must be an array of metric "
+ "specifications"
+ ),
+ suggestions=[
+ f"Wrap in array: '{field_name}': "
+ "[{'name': 'col', 'aggregate': 'SUM'}]",
+ ],
+ error_code=f"INVALID_{field_name.upper()}_FORMAT",
+ )
+
+ return None
+
+ def extract_column_refs(self, config: Any) -> list[ColumnRef]:
+ if not isinstance(config, MixedTimeseriesChartConfig):
+ return []
+ refs: list[ColumnRef] = [config.x]
+ refs.extend(config.y)
+ refs.extend(config.y_secondary)
+ if config.group_by:
+ refs.extend(config.group_by)
+ if config.group_by_secondary:
+ refs.extend(config.group_by_secondary)
+ if config.filters:
+ for f in config.filters:
+ refs.append(ColumnRef(name=f.column))
+ return refs
+
+ def to_form_data(
+ self, config: Any, dataset_id: int | str | None = None
+ ) -> dict[str, Any]:
+ return map_mixed_timeseries_config(config, dataset_id=dataset_id)
+
+ def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
+ what = _mixed_timeseries_what(config)
+ context = _summarize_filters(config.filters)
+ return self._with_context(what, context)
+
+ def resolve_viz_type(self, config: Any) -> str:
+ return "mixed_timeseries"
+
+ def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
+ config_dict = config.model_dump()
+
+ def _norm_single(key: str) -> None:
+ if config_dict.get(key):
+ config_dict[key]["name"] = DatasetValidator._get_canonical_column_name(
+ config_dict[key]["name"], dataset_context
+ )
+
+ def _norm_list(key: str) -> None:
+ if config_dict.get(key):
+ for col in config_dict[key]:
+ col["name"] = DatasetValidator._get_canonical_column_name(
+ col["name"], dataset_context
+ )
+
+ _norm_single("x")
+ _norm_list("y")
+ _norm_list("y_secondary")
+ _norm_list("group_by")
+ _norm_list("group_by_secondary")
+ DatasetValidator._normalize_filters(config_dict, dataset_context)
+ return MixedTimeseriesChartConfig.model_validate(config_dict)
+
+ def schema_error_hint(self) -> ChartGenerationError | None:
+ return ChartGenerationError(
+ error_type="mixed_timeseries_validation_error",
+ message="Mixed timeseries chart configuration validation failed",
+ details=(
+ "The mixed timeseries configuration is missing "
+ "required fields or has invalid structure"
+ ),
+ suggestions=[
+ "Ensure 'x' field has 'name' for the time axis column",
+ "Ensure 'y' is an array of primary-axis metrics",
+ "Ensure 'y_secondary' is an array of secondary-axis metrics",
+ "Example: {'chart_type': 'mixed_timeseries', "
+ "'x': {'name': 'order_date'}, "
+ "'y': [{'name': 'revenue', 'aggregate': 'SUM'}], "
+ "'y_secondary': [{'name': 'orders', 'aggregate': 'COUNT'}]}",
+ ],
+ error_code="MIXED_TIMESERIES_VALIDATION_ERROR",
+ )
diff --git a/superset/mcp_service/chart/plugins/pie.py b/superset/mcp_service/chart/plugins/pie.py
new file mode 100755
index 000000000000..3d87fe7f05f2
--- /dev/null
+++ b/superset/mcp_service/chart/plugins/pie.py
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Pie chart type plugin."""
+
+from __future__ import annotations
+
+from typing import Any
+
+from superset.mcp_service.chart.chart_utils import (
+ _pie_chart_what,
+ _summarize_filters,
+ map_pie_config,
+)
+from superset.mcp_service.chart.plugin import BaseChartPlugin
+from superset.mcp_service.chart.schemas import ColumnRef, PieChartConfig
+from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
+from superset.mcp_service.common.error_schemas import ChartGenerationError
+
+
+class PieChartPlugin(BaseChartPlugin):
+ """Plugin for pie chart type."""
+
+ chart_type = "pie"
+ display_name = "Pie / Donut Chart"
+ native_viz_types = {
+ "pie": "Pie Chart",
+ }
+
+ def pre_validate(
+ self,
+ config: dict[str, Any],
+ ) -> ChartGenerationError | None:
+ missing_fields = []
+
+ if "dimension" not in config:
+ missing_fields.append("'dimension' (category column for slices)")
+ if "metric" not in config:
+ missing_fields.append("'metric' (value metric for slice sizes)")
+
+ if missing_fields:
+ return ChartGenerationError(
+ error_type="missing_pie_fields",
+ message=(
+ f"Pie chart missing required fields: {', '.join(missing_fields)}"
+ ),
+ details=(
+ "Pie charts require a dimension (categories) and a metric (values)"
+ ),
+ suggestions=[
+ "Add 'dimension' field: {'name': 'category_column'}",
+ "Add 'metric' field: {'name': 'value_column', 'aggregate': 'SUM'}",
+ "Example: {'chart_type': 'pie', 'dimension': {'name': 'product'}, "
+ "'metric': {'name': 'revenue', 'aggregate': 'SUM'}}",
+ ],
+ error_code="MISSING_PIE_FIELDS",
+ )
+
+ return None
+
+ def extract_column_refs(self, config: Any) -> list[ColumnRef]:
+ if not isinstance(config, PieChartConfig):
+ return []
+ refs: list[ColumnRef] = [config.dimension, config.metric]
+ if config.filters:
+ for f in config.filters:
+ refs.append(ColumnRef(name=f.column))
+ return refs
+
+ def to_form_data(
+ self, config: Any, dataset_id: int | str | None = None
+ ) -> dict[str, Any]:
+ return map_pie_config(config)
+
+ def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
+ what = _pie_chart_what(config)
+ context = _summarize_filters(config.filters)
+ return self._with_context(what, context)
+
+ def resolve_viz_type(self, config: Any) -> str:
+ return "pie"
+
+ def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
+ config_dict = config.model_dump()
+
+ if config_dict.get("dimension"):
+ config_dict["dimension"]["name"] = (
+ DatasetValidator._get_canonical_column_name(
+ config_dict["dimension"]["name"], dataset_context
+ )
+ )
+ if config_dict.get("metric") and not config_dict["metric"].get("saved_metric"):
+ config_dict["metric"]["name"] = DatasetValidator._get_canonical_column_name(
+ config_dict["metric"]["name"], dataset_context
+ )
+ DatasetValidator._normalize_filters(config_dict, dataset_context)
+ return PieChartConfig.model_validate(config_dict)
+
+ def schema_error_hint(self) -> ChartGenerationError | None:
+ return ChartGenerationError(
+ error_type="pie_validation_error",
+ message="Pie chart configuration validation failed",
+ details=(
+ "The pie chart configuration is missing required "
+ "fields or has invalid structure"
+ ),
+ suggestions=[
+ "Ensure 'dimension' field has 'name' for the slice label",
+ "Ensure 'metric' field has 'name' and 'aggregate'",
+ "Example: {'chart_type': 'pie', 'dimension': {'name': 'category'}, "
+ "'metric': {'name': 'revenue', 'aggregate': 'SUM'}}",
+ ],
+ error_code="PIE_VALIDATION_ERROR",
+ )
diff --git a/superset/mcp_service/chart/plugins/pivot_table.py b/superset/mcp_service/chart/plugins/pivot_table.py
new file mode 100755
index 000000000000..038f8c794161
--- /dev/null
+++ b/superset/mcp_service/chart/plugins/pivot_table.py
@@ -0,0 +1,153 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Pivot table chart type plugin."""
+
+from __future__ import annotations
+
+from typing import Any
+
+from superset.mcp_service.chart.chart_utils import (
+ _pivot_table_what,
+ _summarize_filters,
+ map_pivot_table_config,
+)
+from superset.mcp_service.chart.plugin import BaseChartPlugin
+from superset.mcp_service.chart.schemas import ColumnRef, PivotTableChartConfig
+from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
+from superset.mcp_service.common.error_schemas import ChartGenerationError
+
+
+class PivotTableChartPlugin(BaseChartPlugin):
+ """Plugin for pivot_table chart type."""
+
+ chart_type = "pivot_table"
+ display_name = "Pivot Table"
+ native_viz_types = {
+ "pivot_table_v2": "Pivot Table",
+ }
+
+ def pre_validate(
+ self,
+ config: dict[str, Any],
+ ) -> ChartGenerationError | None:
+ missing_fields = []
+
+ if "rows" not in config:
+ missing_fields.append("'rows' (row grouping columns)")
+ if "metrics" not in config:
+ missing_fields.append("'metrics' (aggregation metrics)")
+
+ if missing_fields:
+ return ChartGenerationError(
+ error_type="missing_pivot_fields",
+ message=(
+ f"Pivot table missing required fields: {', '.join(missing_fields)}"
+ ),
+ details="Pivot tables require row groupings and metrics",
+ suggestions=[
+ "Add 'rows' field: [{'name': 'category'}]",
+ "Add 'metrics' field: [{'name': 'sales', 'aggregate': 'SUM'}]",
+ "Optional 'columns' for cross-tabulation: [{'name': 'region'}]",
+ ],
+ error_code="MISSING_PIVOT_FIELDS",
+ )
+
+ if not isinstance(config.get("rows", []), list):
+ return ChartGenerationError(
+ error_type="invalid_rows_format",
+ message="Rows must be a list of columns",
+ details="The 'rows' field must be an array of column specifications",
+ suggestions=[
+ "Wrap row columns in array: 'rows': [{'name': 'category'}]",
+ ],
+ error_code="INVALID_ROWS_FORMAT",
+ )
+
+ if not isinstance(config.get("metrics", []), list):
+ return ChartGenerationError(
+ error_type="invalid_metrics_format",
+ message="Metrics must be a list",
+ details="The 'metrics' field must be an array of metric specifications",
+ suggestions=[
+ "Wrap metrics in array: 'metrics': [{'name': 'sales', "
+ "'aggregate': 'SUM'}]",
+ ],
+ error_code="INVALID_METRICS_FORMAT",
+ )
+
+ return None
+
+ def extract_column_refs(self, config: Any) -> list[ColumnRef]:
+ if not isinstance(config, PivotTableChartConfig):
+ return []
+ refs: list[ColumnRef] = list(config.rows)
+ refs.extend(config.metrics)
+ if config.columns:
+ refs.extend(config.columns)
+ if config.filters:
+ for f in config.filters:
+ refs.append(ColumnRef(name=f.column))
+ return refs
+
+ def to_form_data(
+ self, config: Any, dataset_id: int | str | None = None
+ ) -> dict[str, Any]:
+ return map_pivot_table_config(config)
+
+ def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
+ what = _pivot_table_what(config)
+ context = _summarize_filters(config.filters)
+ return self._with_context(what, context)
+
+ def resolve_viz_type(self, config: Any) -> str:
+ return "pivot_table_v2"
+
+ def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
+ config_dict = config.model_dump()
+
+ def _norm_col_list(key: str) -> None:
+ if config_dict.get(key):
+ for col in config_dict[key]:
+ col["name"] = DatasetValidator._get_canonical_column_name(
+ col["name"], dataset_context
+ )
+
+ _norm_col_list("rows")
+ _norm_col_list("metrics")
+ _norm_col_list("columns")
+ DatasetValidator._normalize_filters(config_dict, dataset_context)
+ return PivotTableChartConfig.model_validate(config_dict)
+
+ def schema_error_hint(self) -> ChartGenerationError | None:
+ return ChartGenerationError(
+ error_type="pivot_table_validation_error",
+ message="Pivot table configuration validation failed",
+ details=(
+ "The pivot table configuration is missing required "
+ "fields or has invalid structure"
+ ),
+ suggestions=[
+ "Ensure 'rows' field is an array of column specs",
+ "Ensure 'metrics' field is an array with aggregate funcs",
+ "Optional: add 'columns' for column grouping",
+ "Example: {'chart_type': 'pivot_table', "
+ "'rows': [{'name': 'region'}], "
+ "'metrics': [{'name': 'revenue', 'aggregate': 'SUM'}]}",
+ ],
+ error_code="PIVOT_TABLE_VALIDATION_ERROR",
+ )
diff --git a/superset/mcp_service/chart/plugins/table.py b/superset/mcp_service/chart/plugins/table.py
new file mode 100755
index 000000000000..86f5dcaead25
--- /dev/null
+++ b/superset/mcp_service/chart/plugins/table.py
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Table chart type plugin."""
+
+from __future__ import annotations
+
+from typing import Any
+
+from superset.mcp_service.chart.chart_utils import (
+ _summarize_filters,
+ _table_chart_what,
+ map_table_config,
+)
+from superset.mcp_service.chart.plugin import BaseChartPlugin
+from superset.mcp_service.chart.schemas import ColumnRef, TableChartConfig
+from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
+from superset.mcp_service.common.error_schemas import ChartGenerationError
+
+
+class TableChartPlugin(BaseChartPlugin):
+ """Plugin for table chart type."""
+
+ chart_type = "table"
+ display_name = "Table"
+ native_viz_types = {
+ "table": "Table",
+ "ag-grid-table": "Interactive Table",
+ }
+
+ def pre_validate(
+ self,
+ config: dict[str, Any],
+ ) -> ChartGenerationError | None:
+ if "columns" not in config:
+ return ChartGenerationError(
+ error_type="missing_columns",
+ message="Table chart missing required field: columns",
+ details=(
+ "Table charts require a 'columns' array to specify which "
+ "columns to display"
+ ),
+ suggestions=[
+ "Add 'columns' field with array of column specifications",
+ "Example: 'columns': [{'name': 'product'}, {'name': 'sales', "
+ "'aggregate': 'SUM'}]",
+ "Each column can have optional 'aggregate' for metrics",
+ ],
+ error_code="MISSING_COLUMNS",
+ )
+
+ if not isinstance(config.get("columns", []), list):
+ return ChartGenerationError(
+ error_type="invalid_columns_format",
+ message="Columns must be a list",
+ details="The 'columns' field must be an array of column specifications",
+ suggestions=[
+ "Ensure columns is an array: 'columns': [...]",
+ "Each column should be an object with 'name' field",
+ ],
+ error_code="INVALID_COLUMNS_FORMAT",
+ )
+
+ return None
+
+ def extract_column_refs(self, config: Any) -> list[ColumnRef]:
+ if not isinstance(config, TableChartConfig):
+ return []
+ refs: list[ColumnRef] = list(config.columns)
+ if config.filters:
+ for f in config.filters:
+ refs.append(ColumnRef(name=f.column))
+ return refs
+
+ def to_form_data(
+ self, config: Any, dataset_id: int | str | None = None
+ ) -> dict[str, Any]:
+ return map_table_config(config)
+
+ def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
+ what = _table_chart_what(config, dataset_name)
+ context = _summarize_filters(config.filters)
+ return self._with_context(what, context)
+
+ def resolve_viz_type(self, config: Any) -> str:
+ return getattr(config, "viz_type", "table")
+
+ def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
+ config_dict = config.model_dump()
+ get_canonical = DatasetValidator._get_canonical_column_name
+
+ for col in config_dict.get("columns") or []:
+ col["name"] = get_canonical(col["name"], dataset_context)
+
+ DatasetValidator._normalize_filters(config_dict, dataset_context)
+ return TableChartConfig.model_validate(config_dict)
+
+ def schema_error_hint(self) -> ChartGenerationError | None:
+ return ChartGenerationError(
+ error_type="table_validation_error",
+ message="Table chart configuration validation failed",
+ details=(
+ "The table chart configuration is missing required "
+ "fields or has invalid structure"
+ ),
+ suggestions=[
+ "Ensure 'columns' field is an array of column specifications",
+ "Each column needs {'name': 'column_name'}",
+ "Optional: add 'aggregate' for metrics",
+ "Example: 'columns': [{'name': 'product'}, "
+ "{'name': 'sales', 'aggregate': 'SUM'}]",
+ ],
+ error_code="TABLE_VALIDATION_ERROR",
+ )
diff --git a/superset/mcp_service/chart/plugins/xy.py b/superset/mcp_service/chart/plugins/xy.py
new file mode 100755
index 000000000000..076826f3f08e
--- /dev/null
+++ b/superset/mcp_service/chart/plugins/xy.py
@@ -0,0 +1,192 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""XY chart type plugin (line, bar, area, scatter)."""
+
+from __future__ import annotations
+
+import logging
+from typing import Any
+
+from superset.mcp_service.chart.chart_utils import (
+ _xy_chart_context,
+ _xy_chart_what,
+ map_xy_config,
+)
+from superset.mcp_service.chart.plugin import BaseChartPlugin
+from superset.mcp_service.chart.schemas import ColumnRef, XYChartConfig
+from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
+from superset.mcp_service.chart.validation.runtime.cardinality_validator import (
+ CardinalityValidator,
+)
+from superset.mcp_service.chart.validation.runtime.format_validator import (
+ FormatTypeValidator,
+)
+from superset.mcp_service.common.error_schemas import ChartGenerationError
+
+logger = logging.getLogger(__name__)
+
+
+class XYChartPlugin(BaseChartPlugin):
+ """Plugin for xy chart type (line, bar, area, scatter)."""
+
+ chart_type = "xy"
+ display_name = "Line / Bar / Area / Scatter Chart"
+ native_viz_types = {
+ "echarts_timeseries_line": "Line Chart",
+ "echarts_timeseries_bar": "Bar Chart",
+ "echarts_area": "Area Chart",
+ "echarts_timeseries_scatter": "Scatter Plot",
+ }
+
+ def pre_validate(
+ self,
+ config: dict[str, Any],
+ ) -> ChartGenerationError | None:
+ # x is optional — defaults to dataset's main_dttm_col in map_xy_config
+ if "y" not in config:
+ return ChartGenerationError(
+ error_type="missing_xy_fields",
+ message="XY chart missing required field: 'y' (Y-axis metrics)",
+ details=(
+ "XY charts require Y-axis (metrics) specifications. "
+ "X-axis is optional and defaults to the dataset's primary "
+ "datetime column when omitted."
+ ),
+ suggestions=[
+ "Add 'y' field: [{'name': 'metric_column', 'aggregate': 'SUM'}]",
+ "Example: {'chart_type': 'xy', 'x': {'name': 'date'}, "
+ "'y': [{'name': 'sales', 'aggregate': 'SUM'}]}",
+ ],
+ error_code="MISSING_XY_FIELDS",
+ )
+
+ if not isinstance(config.get("y", []), list):
+ return ChartGenerationError(
+ error_type="invalid_y_format",
+ message="Y-axis must be a list of metrics",
+ details="The 'y' field must be an array of metric specifications",
+ suggestions=[
+ "Wrap Y-axis metric in array: 'y': [{'name': 'column', "
+ "'aggregate': 'SUM'}]",
+ "Multiple metrics supported: 'y': [metric1, metric2, ...]",
+ ],
+ error_code="INVALID_Y_FORMAT",
+ )
+
+ return None
+
+ def extract_column_refs(self, config: Any) -> list[ColumnRef]:
+ if not isinstance(config, XYChartConfig):
+ return []
+ refs: list[ColumnRef] = []
+ if config.x is not None:
+ refs.append(config.x)
+ refs.extend(config.y)
+ if config.group_by:
+ refs.extend(config.group_by)
+ if config.filters:
+ for f in config.filters:
+ refs.append(ColumnRef(name=f.column))
+ return refs
+
+ def to_form_data(
+ self, config: Any, dataset_id: int | str | None = None
+ ) -> dict[str, Any]:
+ return map_xy_config(config, dataset_id=dataset_id)
+
+ def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
+ config_dict = config.model_dump()
+ get_canonical = DatasetValidator._get_canonical_column_name
+
+ if config_dict.get("x"):
+ config_dict["x"]["name"] = get_canonical(
+ config_dict["x"]["name"], dataset_context
+ )
+ for y_col in config_dict.get("y") or []:
+ y_col["name"] = get_canonical(y_col["name"], dataset_context)
+ for gb_col in config_dict.get("group_by") or []:
+ gb_col["name"] = get_canonical(gb_col["name"], dataset_context)
+
+ DatasetValidator._normalize_filters(config_dict, dataset_context)
+ return XYChartConfig.model_validate(config_dict)
+
+ def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
+ what = _xy_chart_what(config)
+ context = _xy_chart_context(config)
+ return self._with_context(what, context)
+
+ def resolve_viz_type(self, config: Any) -> str:
+ kind = getattr(config, "kind", "line")
+ return {
+ "line": "echarts_timeseries_line",
+ "bar": "echarts_timeseries_bar",
+ "area": "echarts_area",
+ "scatter": "echarts_timeseries_scatter",
+ }.get(kind, "echarts_timeseries_line")
+
+ def get_runtime_warnings(self, config: Any, dataset_id: int | str) -> list[str]:
+ """Return format-compatibility and cardinality warnings for XY charts."""
+ if not isinstance(config, XYChartConfig):
+ return []
+
+ warnings: list[str] = []
+
+ try:
+ _valid, format_warnings = FormatTypeValidator.validate_format_compatibility(
+ config
+ )
+ if format_warnings:
+ warnings.extend(format_warnings)
+ except Exception as exc:
+ logger.warning("XY format validation failed: %s", exc)
+
+ try:
+ chart_kind = config.kind
+ group_by_col = config.group_by[0].name if config.group_by else None
+ if config.x is not None:
+ _ok, card_info = CardinalityValidator.check_cardinality(
+ dataset_id=dataset_id,
+ x_column=config.x.name,
+ chart_type=chart_kind,
+ group_by_column=group_by_col,
+ )
+ if not _ok and card_info:
+ warnings.extend(card_info.get("warnings", []))
+ warnings.extend(card_info.get("suggestions", []))
+ except Exception as exc:
+ logger.warning("XY cardinality validation failed: %s", exc)
+
+ return warnings
+
+ def schema_error_hint(self) -> ChartGenerationError | None:
+ return ChartGenerationError(
+ error_type="xy_validation_error",
+ message="XY chart configuration validation failed",
+ details=(
+ "The XY chart configuration is missing required "
+ "fields or has invalid structure"
+ ),
+ suggestions=[
+ "Note: 'x' is optional and defaults to the dataset's primary "
+ "datetime column",
+ "Ensure 'y' is an array: [{'name': 'metric', 'aggregate': 'SUM'}]",
+ "Check that all column names are strings",
+ "Verify aggregate functions are valid: SUM, COUNT, AVG, MIN, MAX",
+ ],
+ error_code="XY_VALIDATION_ERROR",
+ )
diff --git a/superset/mcp_service/chart/registry.py b/superset/mcp_service/chart/registry.py
new file mode 100755
index 000000000000..920cfcc5592f
--- /dev/null
+++ b/superset/mcp_service/chart/registry.py
@@ -0,0 +1,141 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+ChartTypeRegistry — central registry mapping chart_type strings to plugins.
+
+Replaces the four previously-scattered dispatch locations:
+ - schema_validator.py: chart_type_validators dict
+ - dataset_validator.py: isinstance branches in _extract_column_references()
+ - chart_utils.py: if/elif chain in map_config_to_form_data()
+ - dataset_validator.py: isinstance branches in normalize_column_names()
+
+Usage::
+
+ from superset.mcp_service.chart.registry import get_registry
+
+ plugin = get_registry().get("xy")
+ if plugin is None:
+ raise ValueError("Unknown chart type: xy")
+ form_data = plugin.to_form_data(config, dataset_id)
+"""
+
+from __future__ import annotations
+
+import logging
+import threading
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from superset.mcp_service.chart.plugin import ChartTypePlugin
+
+logger = logging.getLogger(__name__)
+
+_REGISTRY: dict[str, "ChartTypePlugin"] = {}
+_plugins_loaded = False
+_plugins_lock = threading.Lock()
+
+
+def _ensure_plugins_loaded() -> None:
+ """Lazily import the plugins package to populate _REGISTRY.
+
+ Called before every registry lookup so the registry is always populated,
+ even when callers (tests, chart_utils, validators) import this module
+ directly without first importing app.py.
+ """
+ global _plugins_loaded
+ if _plugins_loaded:
+ return
+ with _plugins_lock:
+ if not _plugins_loaded:
+ try:
+ import superset.mcp_service.chart.plugins # noqa: F401
+
+ _plugins_loaded = True
+ except Exception:
+ logger.exception("Failed to load built-in chart type plugins")
+
+
+def register(plugin: "ChartTypePlugin") -> None:
+ """Register a chart type plugin in the global registry."""
+ if not plugin.chart_type:
+ raise ValueError(f"{type(plugin).__name__} must define a non-empty chart_type")
+ if plugin.chart_type in _REGISTRY:
+ logger.warning(
+ "Overwriting existing plugin for chart_type=%r", plugin.chart_type
+ )
+ _REGISTRY[plugin.chart_type] = plugin
+ logger.debug("Registered chart plugin: %r", plugin.chart_type)
+
+
+def get(chart_type: str) -> "ChartTypePlugin | None":
+ """Return the plugin for a given chart_type, or None if not registered."""
+ _ensure_plugins_loaded()
+ return _REGISTRY.get(chart_type)
+
+
+def all_types() -> list[str]:
+ """Return all registered chart type strings in insertion order."""
+ _ensure_plugins_loaded()
+ return list(_REGISTRY.keys())
+
+
+def is_registered(chart_type: str) -> bool:
+ """Return True if chart_type has a registered plugin."""
+ _ensure_plugins_loaded()
+ return chart_type in _REGISTRY
+
+
+def display_name_for_viz_type(viz_type: str) -> str | None:
+ """Return the user-facing display name for a Superset-internal viz_type.
+
+ Searches every registered plugin's ``native_viz_types`` mapping.
+ Returns None if no plugin recognises the viz_type.
+
+ Example::
+
+ display_name_for_viz_type("echarts_timeseries_line") # "Line Chart"
+ display_name_for_viz_type("pivot_table_v2") # "Pivot Table"
+ display_name_for_viz_type("unknown_type") # None
+ """
+ _ensure_plugins_loaded()
+ for plugin in _REGISTRY.values():
+ name = plugin.native_viz_types.get(viz_type)
+ if name is not None:
+ return name
+ return None
+
+
+def get_registry() -> "_RegistryProxy":
+ """Return a proxy object for registry access (convenience wrapper)."""
+ return _RegistryProxy()
+
+
+class _RegistryProxy:
+ """Thin proxy exposing registry functions as instance methods."""
+
+ def get(self, chart_type: str) -> "ChartTypePlugin | None":
+ return get(chart_type)
+
+ def all_types(self) -> list[str]:
+ return all_types()
+
+ def is_registered(self, chart_type: str) -> bool:
+ return is_registered(chart_type)
+
+ def display_name_for_viz_type(self, viz_type: str) -> str | None:
+ return display_name_for_viz_type(viz_type)
diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py
old mode 100644
new mode 100755
index 1fdb4f43ab6d..05aff20a8b65
--- a/superset/mcp_service/chart/schemas.py
+++ b/superset/mcp_service/chart/schemas.py
@@ -101,7 +101,14 @@ class ChartInfo(BaseModel):
id: int | None = Field(None, description="Chart ID")
slice_name: str | None = Field(None, description="Chart name")
- viz_type: str | None = Field(None, description="Visualization type")
+ viz_type: str | None = Field(None, description="Visualization type (internal ID)")
+ chart_type_display_name: str | None = Field(
+ None,
+ description=(
+ "User-friendly chart type name (e.g. 'Line Chart', 'Pivot Table'). "
+ "Use this field when referring to chart types — never expose viz_type."
+ ),
+ )
datasource_name: str | None = Field(None, description="Datasource name")
datasource_type: str | None = Field(None, description="Datasource type")
url: str | None = Field(None, description="Chart explore page URL")
@@ -488,11 +495,20 @@ def serialize_chart_object(chart: ChartLike | None) -> ChartInfo | None:
# Extract structured filter information
filters_info = extract_filters_from_form_data(chart_form_data)
+ _viz_type = getattr(chart, "viz_type", None)
+ try:
+ from superset.mcp_service.chart.registry import display_name_for_viz_type
+
+ _display_name = display_name_for_viz_type(_viz_type) if _viz_type else None
+ except Exception:
+ _display_name = None
+
return sanitize_chart_info_for_llm_context(
ChartInfo(
id=chart_id,
slice_name=getattr(chart, "slice_name", None),
- viz_type=getattr(chart, "viz_type", None),
+ viz_type=_viz_type,
+ chart_type_display_name=_display_name,
datasource_name=getattr(chart, "datasource_name", None),
datasource_type=getattr(chart, "datasource_type", None),
url=chart_url,
@@ -669,7 +685,6 @@ class ColumnRef(BaseModel):
...,
min_length=1,
max_length=255,
- pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
validation_alias=AliasChoices("name", "column_name"),
)
label: str | None = Field(None, max_length=500)
@@ -743,7 +758,6 @@ class FilterConfig(BaseModel):
...,
min_length=1,
max_length=255,
- pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
validation_alias=AliasChoices("column", "col"),
)
op: Literal[
@@ -1082,7 +1096,6 @@ class BigNumberChartConfig(UnknownFieldCheckMixin):
),
min_length=1,
max_length=255,
- pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
)
time_grain: TimeGrain | None = Field(
None,
diff --git a/superset/mcp_service/chart/tool/generate_chart.py b/superset/mcp_service/chart/tool/generate_chart.py
index 8ad907c1abe6..b7b171fcf9b9 100644
--- a/superset/mcp_service/chart/tool/generate_chart.py
+++ b/superset/mcp_service/chart/tool/generate_chart.py
@@ -100,18 +100,34 @@ async def generate_chart( # noqa: C901
- Set save_chart=True to permanently save the chart
- LLM clients MUST display returned chart URL to users
- Use numeric dataset ID or UUID (NOT schema.table_name format)
- - MUST include chart_type in config (either 'xy' or 'table')
+ - MUST include chart_type in config (one of: 'xy', 'table', 'pie',
+ 'pivot_table', 'mixed_timeseries', 'handlebars', 'big_number')
IMPORTANT: The 'chart_type' field in the config is a DISCRIMINATOR that determines
which chart configuration schema to use. It MUST be included and MUST match the
other fields in your configuration:
- Use chart_type='xy' for charts with x and y axes (line, bar, area, scatter)
- Required fields: x, y
+ Required fields: y (x is optional — defaults to dataset's primary datetime column)
- Use chart_type='table' for tabular visualizations
Required fields: columns
+ - Use chart_type='pie' for pie/donut charts
+ Required fields: dimension, metric
+
+ - Use chart_type='pivot_table' for pivot table visualizations
+ Required fields: rows, metrics
+
+ - Use chart_type='mixed_timeseries' for dual-axis time-series charts
+ Required fields: x, y, y_secondary
+
+ - Use chart_type='handlebars' for custom template-based visualizations
+ Required fields: handlebars_template
+
+ - Use chart_type='big_number' for single KPI metric displays
+ Required fields: metric
+
Example usage for XY chart:
```json
{
diff --git a/superset/mcp_service/chart/tool/update_chart.py b/superset/mcp_service/chart/tool/update_chart.py
old mode 100644
new mode 100755
index 3e3057bdd2ee..5c56a39a7f94
--- a/superset/mcp_service/chart/tool/update_chart.py
+++ b/superset/mcp_service/chart/tool/update_chart.py
@@ -195,6 +195,29 @@ def _validate_update_against_dataset(
}
)
+ # Column existence + fuzzy-match validation
+ # (mirrors generate_chart pipeline layer 2)
+ from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
+
+ is_col_valid, col_error = DatasetValidator.validate_against_dataset(
+ parsed_config, dataset.id
+ )
+ if not is_col_valid and col_error is not None:
+ logger.warning(
+ "update_chart column validation failed for chart %s: %s",
+ getattr(chart, "id", None),
+ col_error,
+ )
+ return GenerateChartResponse.model_validate(
+ {
+ "chart": None,
+ "error": col_error.model_dump(),
+ "success": False,
+ "schema_version": "2.0",
+ "api_version": "v1",
+ }
+ )
+
compile_result = validate_and_compile(
parsed_config, form_data, dataset, run_compile_check=True
)
@@ -388,6 +411,24 @@ async def update_chart( # noqa: C901
# config is already a typed ChartConfig | None (validated by Pydantic)
parsed_config = request.config
+ # Normalize column case to match dataset canonical names
+ # (mirrors generate_chart pipeline layer 4)
+ chart_datasource_id = getattr(chart, "datasource_id", None)
+ if parsed_config is not None and chart_datasource_id is not None:
+ from superset.mcp_service.chart.validation.dataset_validator import (
+ DatasetValidator,
+ NORMALIZATION_EXCEPTIONS,
+ )
+
+ try:
+ parsed_config = DatasetValidator.normalize_column_names(
+ parsed_config, chart.datasource_id
+ )
+ except NORMALIZATION_EXCEPTIONS as e:
+ logger.warning(
+ "Column normalization failed for chart %s: %s", chart.id, e
+ )
+
if not request.generate_preview:
from superset.commands.chart.update import UpdateChartCommand
diff --git a/superset/mcp_service/chart/validation/dataset_validator.py b/superset/mcp_service/chart/validation/dataset_validator.py
index 5602b7af1165..a27d18629b2c 100644
--- a/superset/mcp_service/chart/validation/dataset_validator.py
+++ b/superset/mcp_service/chart/validation/dataset_validator.py
@@ -22,17 +22,11 @@
import difflib
import logging
-from typing import Any, Dict, List, Tuple
+from typing import Any, Dict, List, Tuple, TypeVar
from superset.mcp_service.chart.schemas import (
- BigNumberChartConfig,
+ ChartConfig,
ColumnRef,
- HandlebarsChartConfig,
- MixedTimeseriesChartConfig,
- PieChartConfig,
- PivotTableChartConfig,
- TableChartConfig,
- XYChartConfig,
)
from superset.mcp_service.common.error_schemas import (
ChartGenerationError,
@@ -40,6 +34,8 @@
DatasetContext,
)
+_C = TypeVar("_C", bound=ChartConfig)
+
logger = logging.getLogger(__name__)
# Exceptions that can occur during column name normalization.
@@ -58,7 +54,7 @@ class DatasetValidator:
@staticmethod
def validate_against_dataset(
- config: Any,
+ config: ChartConfig,
dataset_id: int | str,
dataset_context: DatasetContext | None = None,
) -> Tuple[bool, ChartGenerationError | None]:
@@ -260,59 +256,31 @@ def _get_dataset_context(dataset_id: int | str) -> DatasetContext | None:
return None
@staticmethod
- def _extract_column_references(config: Any) -> List[ColumnRef]: # noqa: C901
- """Extract all column references from a chart configuration.
-
- Covers every supported ``ChartConfig`` variant so fast-path tools
- (``generate_explore_link``, ``update_chart_preview``) that only run
- Tier-1 validation still catch bad column refs in pie / pivot table /
- mixed timeseries / handlebars / big number charts — not just XY and
- table.
+ def _extract_column_references(
+ config: ChartConfig,
+ ) -> List[ColumnRef]:
+ """Extract all column references from configuration via the plugin registry.
+
+ Previously only handled TableChartConfig and XYChartConfig, causing
+ 5 of 7 chart types to silently skip column validation. Now delegates
+ to the plugin for each chart type so all types are covered.
"""
- refs: List[ColumnRef] = []
-
- if isinstance(config, TableChartConfig):
- refs.extend(config.columns)
- elif isinstance(config, XYChartConfig):
- if config.x is not None:
- refs.append(config.x)
- refs.extend(config.y)
- if config.group_by:
- refs.extend(config.group_by)
- elif isinstance(config, PieChartConfig):
- refs.append(config.dimension)
- refs.append(config.metric)
- elif isinstance(config, PivotTableChartConfig):
- refs.extend(config.rows)
- if config.columns:
- refs.extend(config.columns)
- refs.extend(config.metrics)
- elif isinstance(config, MixedTimeseriesChartConfig):
- refs.append(config.x)
- refs.extend(config.y)
- if config.group_by:
- refs.extend(config.group_by)
- refs.extend(config.y_secondary)
- if config.group_by_secondary:
- refs.extend(config.group_by_secondary)
- elif isinstance(config, HandlebarsChartConfig):
- if config.columns:
- refs.extend(config.columns)
- if config.groupby:
- refs.extend(config.groupby)
- if config.metrics:
- refs.extend(config.metrics)
- elif isinstance(config, BigNumberChartConfig):
- refs.append(config.metric)
- if config.temporal_column:
- refs.append(ColumnRef(name=config.temporal_column))
-
- # Filter columns (shared by every config type that defines ``filters``).
- if filters := getattr(config, "filters", None):
- for filter_config in filters:
- refs.append(ColumnRef(name=filter_config.column))
-
- return refs
+ # Local import: plugins call DatasetValidator helpers from
+ # normalize_column_refs().
+ # A top-level import of registry in dataset_validator would make loading this
+ # module implicitly trigger plugin registration, creating a circular dependency.
+ from superset.mcp_service.chart.registry import get_registry
+
+ chart_type = getattr(config, "chart_type", None)
+ if chart_type is None:
+ return []
+
+ plugin = get_registry().get(chart_type)
+ if plugin is None:
+ logger.warning("No plugin registered for chart_type=%r", chart_type)
+ return []
+
+ return plugin.extract_column_refs(config)
@staticmethod
def _column_exists(column_name: str, dataset_context: DatasetContext) -> bool:
@@ -365,42 +333,6 @@ def _get_canonical_column_name(
# Return original if not found (validation should catch this case)
return column_name
- @staticmethod
- def _normalize_xy_config(
- config_dict: Dict[str, Any], dataset_context: DatasetContext
- ) -> None:
- """Normalize column names in an XY chart config dict in place."""
- # Normalize x-axis column
- if "x" in config_dict and config_dict["x"]:
- config_dict["x"]["name"] = DatasetValidator._get_canonical_column_name(
- config_dict["x"]["name"], dataset_context
- )
-
- # Normalize y-axis columns
- if "y" in config_dict and config_dict["y"]:
- for y_col in config_dict["y"]:
- y_col["name"] = DatasetValidator._get_canonical_column_name(
- y_col["name"], dataset_context
- )
-
- # Normalize group_by columns
- if "group_by" in config_dict and config_dict["group_by"]:
- for gb_col in config_dict["group_by"]:
- gb_col["name"] = DatasetValidator._get_canonical_column_name(
- gb_col["name"], dataset_context
- )
-
- @staticmethod
- def _normalize_table_config(
- config_dict: Dict[str, Any], dataset_context: DatasetContext
- ) -> None:
- """Normalize column names in a table chart config dict in place."""
- if "columns" in config_dict and config_dict["columns"]:
- for col in config_dict["columns"]:
- col["name"] = DatasetValidator._get_canonical_column_name(
- col["name"], dataset_context
- )
-
@staticmethod
def _normalize_filters(
config_dict: Dict[str, Any], dataset_context: DatasetContext
@@ -417,10 +349,10 @@ def _normalize_filters(
@staticmethod
def normalize_column_names(
- config: TableChartConfig | XYChartConfig,
+ config: _C,
dataset_id: int | str,
dataset_context: DatasetContext | None = None,
- ) -> TableChartConfig | XYChartConfig:
+ ) -> _C:
"""
Normalize column names in config to match the canonical dataset column names.
@@ -429,6 +361,9 @@ def normalize_column_names(
(e.g., 'OrderDate'). The frontend performs case-sensitive comparisons,
so we need to ensure column names match exactly.
+ Previously only XYChartConfig and TableChartConfig were normalized; now
+ all 7 chart types are handled via the plugin registry.
+
Args:
config: Chart configuration with column references
dataset_id: Dataset ID to get canonical column names from
@@ -443,22 +378,24 @@ def normalize_column_names(
if not dataset_context:
return config
- # Create a mutable copy of the config
- config_dict = config.model_dump()
+ # Local import: plugins call DatasetValidator helpers from
+ # normalize_column_refs().
+ # A top-level import of registry in dataset_validator would make loading this
+ # module implicitly trigger plugin registration, creating a circular dependency.
+ from superset.mcp_service.chart.registry import get_registry
- # Normalize based on config type
- if isinstance(config, XYChartConfig):
- DatasetValidator._normalize_xy_config(config_dict, dataset_context)
- elif isinstance(config, TableChartConfig):
- DatasetValidator._normalize_table_config(config_dict, dataset_context)
+ chart_type = getattr(config, "chart_type", None)
+ if chart_type is None:
+ return config
- # Normalize filter columns (common to both config types)
- DatasetValidator._normalize_filters(config_dict, dataset_context)
+ plugin = get_registry().get(chart_type)
+ if plugin is None:
+ logger.warning(
+ "No plugin for chart_type=%r; skipping column normalization", chart_type
+ )
+ return config
- # Reconstruct the config with normalized names
- if isinstance(config, XYChartConfig):
- return XYChartConfig.model_validate(config_dict)
- return TableChartConfig.model_validate(config_dict)
+ return plugin.normalize_column_refs(config, dataset_context)
@staticmethod
def _get_column_suggestions(
diff --git a/superset/mcp_service/chart/validation/runtime/__init__.py b/superset/mcp_service/chart/validation/runtime/__init__.py
index 5e1c89d0a687..dce732ba9926 100644
--- a/superset/mcp_service/chart/validation/runtime/__init__.py
+++ b/superset/mcp_service/chart/validation/runtime/__init__.py
@@ -23,10 +23,7 @@
import logging
from typing import Any, Dict, List, Tuple
-from superset.mcp_service.chart.schemas import (
- ChartConfig,
- XYChartConfig,
-)
+from superset.mcp_service.chart.schemas import ChartConfig
logger = logging.getLogger(__name__)
@@ -56,20 +53,10 @@ def validate_runtime_issues(
warnings: List[str] = []
suggestions: List[str] = []
- # Only check XY charts for format and cardinality issues
- if isinstance(config, XYChartConfig):
- # Format-type compatibility validation
- format_warnings = RuntimeValidator._validate_format_compatibility(config)
- if format_warnings:
- warnings.extend(format_warnings)
-
- # Cardinality validation
- cardinality_warnings, cardinality_suggestions = (
- RuntimeValidator._validate_cardinality(config, dataset_id)
- )
- if cardinality_warnings:
- warnings.extend(cardinality_warnings)
- suggestions.extend(cardinality_suggestions)
+ # Per-plugin runtime warnings (format, cardinality, etc.)
+ plugin_warnings = RuntimeValidator._validate_plugin_runtime(config, dataset_id)
+ if plugin_warnings:
+ warnings.extend(plugin_warnings)
# Chart type appropriateness validation (for all chart types)
type_warnings, type_suggestions = RuntimeValidator._validate_chart_type(
@@ -98,61 +85,28 @@ def validate_runtime_issues(
return True, None
@staticmethod
- def _validate_format_compatibility(config: XYChartConfig) -> List[str]:
- """Validate format-type compatibility."""
- warnings: List[str] = []
-
- try:
- # Import here to avoid circular imports
- from .format_validator import FormatTypeValidator
-
- is_valid, format_warnings = (
- FormatTypeValidator.validate_format_compatibility(config)
- )
- if format_warnings:
- warnings.extend(format_warnings)
- except ImportError:
- logger.warning("Format validator not available")
- except Exception as e:
- logger.warning("Format validation failed: %s", e)
-
- return warnings
-
- @staticmethod
- def _validate_cardinality(
- config: XYChartConfig, dataset_id: int | str
- ) -> Tuple[List[str], List[str]]:
- """Validate cardinality issues."""
- warnings: List[str] = []
- suggestions: List[str] = []
+ def _validate_plugin_runtime(
+ config: ChartConfig, dataset_id: int | str
+ ) -> List[str]:
+ """Delegate per-chart-type runtime warnings to the plugin registry.
+ Each plugin's get_runtime_warnings() method returns chart-type-specific
+ warnings (e.g. format/cardinality for XY). The registry dispatch removes
+ the previous isinstance(config, XYChartConfig) hardcoding.
+ """
try:
- # Import here to avoid circular imports
- from .cardinality_validator import CardinalityValidator
-
- # Determine chart type for cardinality thresholds
- chart_type = config.kind if hasattr(config, "kind") else "default"
-
- # Check X-axis cardinality
- if config.x is None:
- return warnings, suggestions
- is_ok, cardinality_info = CardinalityValidator.check_cardinality(
- dataset_id=dataset_id,
- x_column=config.x.name,
- chart_type=chart_type,
- group_by_column=config.group_by[0].name if config.group_by else None,
- )
-
- if not is_ok and cardinality_info:
- warnings.extend(cardinality_info.get("warnings", []))
- suggestions.extend(cardinality_info.get("suggestions", []))
-
- except ImportError:
- logger.warning("Cardinality validator not available")
- except Exception as e:
- logger.warning("Cardinality validation failed: %s", e)
-
- return warnings, suggestions
+ from superset.mcp_service.chart.registry import get_registry
+
+ chart_type = getattr(config, "chart_type", None)
+ if chart_type is None:
+ return []
+ plugin = get_registry().get(chart_type)
+ if plugin is None:
+ return []
+ return plugin.get_runtime_warnings(config, dataset_id)
+ except Exception as exc:
+ logger.warning("Plugin runtime validation failed: %s", exc)
+ return []
@staticmethod
def _validate_chart_type(
diff --git a/superset/mcp_service/chart/validation/schema_validator.py b/superset/mcp_service/chart/validation/schema_validator.py
old mode 100644
new mode 100755
index 7cae450ff599..11f0f6ada8a4
--- a/superset/mcp_service/chart/validation/schema_validator.py
+++ b/superset/mcp_service/chart/validation/schema_validator.py
@@ -147,19 +147,13 @@ def _pre_validate_chart_type(
chart_type: str,
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
- """Validate chart type and dispatch to type-specific pre-validation."""
- chart_type_validators = {
- "xy": SchemaValidator._pre_validate_xy_config,
- "table": SchemaValidator._pre_validate_table_config,
- "pie": SchemaValidator._pre_validate_pie_config,
- "pivot_table": SchemaValidator._pre_validate_pivot_table_config,
- "mixed_timeseries": SchemaValidator._pre_validate_mixed_timeseries_config,
- "handlebars": SchemaValidator._pre_validate_handlebars_config,
- "big_number": SchemaValidator._pre_validate_big_number_config,
- }
+ """Validate chart type and dispatch to plugin pre-validation."""
+ from superset.mcp_service.chart.registry import get_registry
- if not isinstance(chart_type, str) or chart_type not in chart_type_validators:
- valid_types = ", ".join(chart_type_validators.keys())
+ registry = get_registry()
+
+ if not isinstance(chart_type, str) or not registry.is_registered(chart_type):
+ valid_types = ", ".join(registry.all_types())
return False, ChartGenerationError(
error_type="invalid_chart_type",
message=f"Invalid chart_type: '{chart_type}'",
@@ -178,351 +172,18 @@ def _pre_validate_chart_type(
error_code="INVALID_CHART_TYPE",
)
- return chart_type_validators[chart_type](config)
-
- @staticmethod
- def _pre_validate_xy_config(
- config: Dict[str, Any],
- ) -> Tuple[bool, ChartGenerationError | None]:
- """Pre-validate XY chart configuration."""
- # x is optional — defaults to dataset's main_dttm_col in map_xy_config
- if "y" not in config:
- return False, ChartGenerationError(
- error_type="missing_xy_fields",
- message="XY chart missing required field: 'y' (Y-axis metrics)",
- details="XY charts require Y-axis (metrics) specifications. "
- "X-axis is optional and defaults to the dataset's primary "
- "datetime column when omitted.",
- suggestions=[
- "Add 'y' field: [{'name': 'metric_column', 'aggregate': 'SUM'}] "
- "for Y-axis",
- "Example: {'chart_type': 'xy', 'x': {'name': 'date'}, "
- "'y': [{'name': 'sales', 'aggregate': 'SUM'}]}",
- ],
- error_code="MISSING_XY_FIELDS",
- )
-
- # Validate Y is a list
- if not isinstance(config.get("y", []), list):
- return False, ChartGenerationError(
- error_type="invalid_y_format",
- message="Y-axis must be a list of metrics",
- details="The 'y' field must be an array of metric specifications",
- suggestions=[
- "Wrap Y-axis metric in array: 'y': [{'name': 'column', "
- "'aggregate': 'SUM'}]",
- "Multiple metrics supported: 'y': [metric1, metric2, ...]",
- ],
- error_code="INVALID_Y_FORMAT",
- )
-
- return True, None
-
- @staticmethod
- def _pre_validate_table_config(
- config: Dict[str, Any],
- ) -> Tuple[bool, ChartGenerationError | None]:
- """Pre-validate table chart configuration."""
- if "columns" not in config:
- return False, ChartGenerationError(
- error_type="missing_columns",
- message="Table chart missing required field: columns",
- details="Table charts require a 'columns' array to specify which "
- "columns to display",
- suggestions=[
- "Add 'columns' field with array of column specifications",
- "Example: 'columns': [{'name': 'product'}, {'name': 'sales', "
- "'aggregate': 'SUM'}]",
- "Each column can have optional 'aggregate' for metrics",
- ],
- error_code="MISSING_COLUMNS",
- )
-
- if not isinstance(config.get("columns", []), list):
- return False, ChartGenerationError(
- error_type="invalid_columns_format",
- message="Columns must be a list",
- details="The 'columns' field must be an array of column specifications",
- suggestions=[
- "Ensure columns is an array: 'columns': [...]",
- "Each column should be an object with 'name' field",
- ],
- error_code="INVALID_COLUMNS_FORMAT",
- )
-
- return True, None
-
- @staticmethod
- def _pre_validate_pie_config(
- config: Dict[str, Any],
- ) -> Tuple[bool, ChartGenerationError | None]:
- """Pre-validate pie chart configuration."""
- missing_fields = []
-
- if "dimension" not in config:
- missing_fields.append("'dimension' (category column for slices)")
- if "metric" not in config:
- missing_fields.append("'metric' (value metric for slice sizes)")
-
- if missing_fields:
- return False, ChartGenerationError(
- error_type="missing_pie_fields",
- message=f"Pie chart missing required "
- f"fields: {', '.join(missing_fields)}",
- details="Pie charts require a dimension (categories) and a metric "
- "(values)",
- suggestions=[
- "Add 'dimension' field: {'name': 'category_column'}",
- "Add 'metric' field: {'name': 'value_column', 'aggregate': 'SUM'}",
- "Example: {'chart_type': 'pie', 'dimension': {'name': "
- "'product'}, 'metric': {'name': 'revenue', 'aggregate': 'SUM'}}",
- ],
- error_code="MISSING_PIE_FIELDS",
- )
-
- return True, None
-
- @staticmethod
- def _pre_validate_handlebars_config(
- config: Dict[str, Any],
- ) -> Tuple[bool, ChartGenerationError | None]:
- """Pre-validate handlebars chart configuration."""
- if "handlebars_template" not in config:
- return False, ChartGenerationError(
- error_type="missing_handlebars_template",
- message="Handlebars chart missing required field: handlebars_template",
- details="Handlebars charts require a 'handlebars_template' string "
- "containing Handlebars HTML template markup",
- suggestions=[
- "Add 'handlebars_template' with a Handlebars HTML template",
- "Data is available as {{data}} array in the template",
- "Example: '{{#each data}}- {{this.name}}: "
- "{{this.value}}
{{/each}}
'",
- ],
- error_code="MISSING_HANDLEBARS_TEMPLATE",
- )
-
- template = config.get("handlebars_template")
- if not isinstance(template, str) or not template.strip():
- return False, ChartGenerationError(
- error_type="invalid_handlebars_template",
- message="Handlebars template must be a non-empty string",
- details="The 'handlebars_template' field must be a non-empty string "
- "containing valid Handlebars HTML template markup",
- suggestions=[
- "Ensure handlebars_template is a non-empty string",
- "Example: '{{#each data}}- {{this.name}}
{{/each}}
'",
- ],
- error_code="INVALID_HANDLEBARS_TEMPLATE",
- )
-
- query_mode = config.get("query_mode", "aggregate")
- if query_mode not in ("aggregate", "raw"):
- return False, ChartGenerationError(
- error_type="invalid_query_mode",
- message="Invalid query_mode for handlebars chart",
- details="query_mode must be either 'aggregate' or 'raw'",
- suggestions=[
- "Use 'aggregate' for aggregated data (default)",
- "Use 'raw' for individual rows",
- ],
- error_code="INVALID_QUERY_MODE",
- )
-
- if query_mode == "raw" and not config.get("columns"):
- return False, ChartGenerationError(
- error_type="missing_raw_columns",
- message="Handlebars chart in 'raw' mode requires 'columns'",
- details="When query_mode is 'raw', you must specify which columns "
- "to include in the query results",
- suggestions=[
- "Add 'columns': [{'name': 'column_name'}] for raw mode",
- "Or use query_mode='aggregate' with 'metrics' "
- "and optional 'groupby'",
- ],
- error_code="MISSING_RAW_COLUMNS",
- )
-
- if query_mode == "aggregate" and not config.get("metrics"):
- return False, ChartGenerationError(
- error_type="missing_aggregate_metrics",
- message="Handlebars chart in 'aggregate' mode requires 'metrics'",
- details="When query_mode is 'aggregate' (default), you must specify "
- "at least one metric with an aggregate function",
- suggestions=[
- "Add 'metrics': [{'name': 'column', 'aggregate': 'SUM'}]",
- "Or use query_mode='raw' with 'columns' for individual rows",
- ],
- error_code="MISSING_AGGREGATE_METRICS",
- )
-
- return True, None
-
- @staticmethod
- def _pre_validate_big_number_config(
- config: Dict[str, Any],
- ) -> Tuple[bool, ChartGenerationError | None]:
- """Pre-validate big number chart configuration."""
- if "metric" not in config:
- return False, ChartGenerationError(
- error_type="missing_metric",
- message="Big Number chart missing required field: metric",
- details="Big Number charts require a 'metric' field "
- "specifying the value to display",
- suggestions=[
- "Add 'metric' with name and aggregate: "
- "{'name': 'revenue', 'aggregate': 'SUM'}",
- "The aggregate function is required (SUM, COUNT, AVG, MIN, MAX)",
- "Example: {'chart_type': 'big_number', "
- "'metric': {'name': 'sales', 'aggregate': 'SUM'}}",
- ],
- error_code="MISSING_BIG_NUMBER_METRIC",
- )
-
- metric = config.get("metric", {})
- if not isinstance(metric, dict):
+ plugin = registry.get(chart_type)
+ if plugin is None:
return False, ChartGenerationError(
- error_type="invalid_metric_type",
- message="Big Number metric must be a dict with 'name' and 'aggregate'",
- details="The 'metric' field must be an object, "
- f"got {type(metric).__name__}",
- suggestions=[
- "Use a dict: {'name': 'col', 'aggregate': 'SUM'}",
- "Valid aggregates: SUM, COUNT, AVG, MIN, MAX",
- ],
- error_code="INVALID_BIG_NUMBER_METRIC_TYPE",
- )
- if not metric.get("aggregate") and not metric.get("saved_metric"):
- return False, ChartGenerationError(
- error_type="missing_metric_aggregate",
- message="Big Number metric must include an aggregate function "
- "or reference a saved metric",
- details="The metric must have an 'aggregate' field "
- "or 'saved_metric': true",
- suggestions=[
- "Add 'aggregate' to your metric: "
- "{'name': 'col', 'aggregate': 'SUM'}",
- "Or use a saved metric: "
- "{'name': 'total_sales', 'saved_metric': true}",
- "Valid aggregates: SUM, COUNT, AVG, MIN, MAX",
- ],
- error_code="MISSING_BIG_NUMBER_AGGREGATE",
- )
-
- show_trendline = config.get("show_trendline", False)
- temporal_column = config.get("temporal_column")
- if show_trendline and not temporal_column:
- return False, ChartGenerationError(
- error_type="missing_temporal_column",
- message="Trendline requires a temporal column",
- details="When 'show_trendline' is True, a "
- "'temporal_column' must be specified",
- suggestions=[
- "Add 'temporal_column': 'date_column_name'",
- "Or set 'show_trendline': false for number only",
- "Use get_dataset_info to find temporal columns",
- ],
- error_code="MISSING_TEMPORAL_COLUMN",
- )
-
- return True, None
-
- @staticmethod
- def _pre_validate_pivot_table_config(
- config: Dict[str, Any],
- ) -> Tuple[bool, ChartGenerationError | None]:
- """Pre-validate pivot table configuration."""
- missing_fields = []
-
- if "rows" not in config:
- missing_fields.append("'rows' (row grouping columns)")
- if "metrics" not in config:
- missing_fields.append("'metrics' (aggregation metrics)")
-
- if missing_fields:
- return False, ChartGenerationError(
- error_type="missing_pivot_fields",
- message=f"Pivot table missing required "
- f"fields: {', '.join(missing_fields)}",
- details="Pivot tables require row groupings and metrics",
- suggestions=[
- "Add 'rows' field: [{'name': 'category'}]",
- "Add 'metrics' field: [{'name': 'sales', 'aggregate': 'SUM'}]",
- "Optional 'columns' for cross-tabulation: [{'name': 'region'}]",
- ],
- error_code="MISSING_PIVOT_FIELDS",
- )
-
- if not isinstance(config.get("rows", []), list):
- return False, ChartGenerationError(
- error_type="invalid_rows_format",
- message="Rows must be a list of columns",
- details="The 'rows' field must be an array of column specifications",
- suggestions=[
- "Wrap row columns in array: 'rows': [{'name': 'category'}]",
- ],
- error_code="INVALID_ROWS_FORMAT",
- )
-
- if not isinstance(config.get("metrics", []), list):
- return False, ChartGenerationError(
- error_type="invalid_metrics_format",
- message="Metrics must be a list",
- details="The 'metrics' field must be an array of metric specifications",
- suggestions=[
- "Wrap metrics in array: 'metrics': [{'name': 'sales', "
- "'aggregate': 'SUM'}]",
- ],
- error_code="INVALID_METRICS_FORMAT",
- )
-
- return True, None
-
- @staticmethod
- def _pre_validate_mixed_timeseries_config(
- config: Dict[str, Any],
- ) -> Tuple[bool, ChartGenerationError | None]:
- """Pre-validate mixed timeseries configuration."""
- missing_fields = []
-
- if "x" not in config:
- missing_fields.append("'x' (X-axis temporal column)")
- if "y" not in config:
- missing_fields.append("'y' (primary Y-axis metrics)")
- if "y_secondary" not in config:
- missing_fields.append("'y_secondary' (secondary Y-axis metrics)")
-
- if missing_fields:
- return False, ChartGenerationError(
- error_type="missing_mixed_timeseries_fields",
- message=f"Mixed timeseries chart missing required "
- f"fields: {', '.join(missing_fields)}",
- details="Mixed timeseries charts require an x-axis, primary metrics, "
- "and secondary metrics",
- suggestions=[
- "Add 'x' field: {'name': 'date_column'}",
- "Add 'y' field: [{'name': 'revenue', 'aggregate': 'SUM'}]",
- "Add 'y_secondary' field: [{'name': 'orders', "
- "'aggregate': 'COUNT'}]",
- "Optional: 'primary_kind' and 'secondary_kind' for chart types",
- ],
- error_code="MISSING_MIXED_TIMESERIES_FIELDS",
+ error_type="invalid_chart_type",
+ message=f"Chart type '{chart_type}' has no registered plugin",
+ details="Internal error: chart type is listed but has no plugin",
+ suggestions=["Use a supported chart_type"],
+ error_code="INVALID_CHART_TYPE",
)
- for field_name in ["y", "y_secondary"]:
- if not isinstance(config.get(field_name, []), list):
- return False, ChartGenerationError(
- error_type=f"invalid_{field_name}_format",
- message=f"'{field_name}' must be a list of metrics",
- details=f"The '{field_name}' field must be an array of metric "
- "specifications",
- suggestions=[
- f"Wrap in array: '{field_name}': "
- "[{'name': 'col', 'aggregate': 'SUM'}]",
- ],
- error_code=f"INVALID_{field_name.upper()}_FORMAT",
- )
-
+ if (error := plugin.pre_validate(config)) is not None:
+ return False, error
return True, None
@staticmethod
@@ -537,89 +198,26 @@ def _enhance_validation_error(
if err.get("type") == "union_tag_invalid" or "discriminator" in str(
err.get("ctx", {})
):
- # This is the generic union error - provide better message
- config = request_data.get("config", {})
- chart_type = config.get("chart_type", "unknown")
+ from superset.mcp_service.chart.registry import get_registry
- if chart_type == "xy":
- return ChartGenerationError(
- error_type="xy_validation_error",
- message="XY chart configuration validation failed",
- details="The XY chart configuration is missing required "
- "fields or has invalid structure",
- suggestions=[
- "Ensure 'x' field exists with {'name': 'column_name'}",
- "Ensure 'y' field is an array: [{'name': 'metric', "
- "'aggregate': 'SUM'}]",
- "Check that all column names are strings",
- "Verify aggregate functions are valid: SUM, COUNT, AVG, "
- "MIN, MAX",
- ],
- error_code="XY_VALIDATION_ERROR",
- )
- elif chart_type == "table":
- return ChartGenerationError(
- error_type="table_validation_error",
- message="Table chart configuration validation failed",
- details="The table chart configuration is missing required "
- "fields or has invalid structure",
- suggestions=[
- "Ensure 'columns' field is an array of column "
- "specifications",
- "Each column needs {'name': 'column_name'}",
- "Optional: add 'aggregate' for metrics",
- "Example: 'columns': [{'name': 'product'}, {'name': "
- "'sales', 'aggregate': 'SUM'}]",
- ],
- error_code="TABLE_VALIDATION_ERROR",
- )
- elif chart_type == "handlebars":
- return ChartGenerationError(
- error_type="handlebars_validation_error",
- message="Handlebars chart configuration validation failed",
- details="The handlebars chart configuration is missing "
- "required fields or has invalid structure",
- suggestions=[
- "Ensure 'handlebars_template' is a non-empty string",
- "For aggregate mode: add 'metrics' with aggregate "
- "functions",
- "For raw mode: set 'query_mode': 'raw' and add 'columns'",
- "Example: {'chart_type': 'handlebars', "
- "'handlebars_template': '{{#each data}}- "
- "{{this.name}}
{{/each}}
', "
- "'metrics': [{'name': 'sales', 'aggregate': 'SUM'}]}",
- ],
- error_code="HANDLEBARS_VALIDATION_ERROR",
- )
- elif chart_type == "big_number":
- return ChartGenerationError(
- error_type="big_number_validation_error",
- message="Big Number chart configuration validation failed",
- details="The Big Number chart configuration is "
- "missing required fields or has invalid "
- "structure",
- suggestions=[
- "Ensure 'metric' field has 'name' and 'aggregate'",
- "Example: 'metric': {'name': 'revenue', "
- "'aggregate': 'SUM'}",
- "For trendline: add 'show_trendline': true "
- "and 'temporal_column': 'date_col'",
- "Without trendline: just provide the metric",
- ],
- error_code="BIG_NUMBER_VALIDATION_ERROR",
- )
+ chart_type = request_data.get("config", {}).get("chart_type", "")
+ plugin = get_registry().get(chart_type)
+ if plugin is not None:
+ hint = plugin.schema_error_hint()
+ if hint is not None:
+ return hint
# Default enhanced error
error_details = []
for err in errors[:3]: # Show first 3 errors
loc = " -> ".join(str(location) for location in err.get("loc", []))
msg = err.get("msg", "Validation failed")
- error_details.append(f"{loc}: {msg}")
+ error_details.append(f"{loc}: {msg}" if loc else msg)
return ChartGenerationError(
error_type="validation_error",
message="Chart configuration validation failed",
- details="; ".join(error_details),
+ details="; ".join(error_details) or "Invalid chart configuration structure",
suggestions=[
"Check that all required fields are present",
"Ensure field types match the schema",
diff --git a/tests/unit_tests/mcp_service/chart/test_big_number_chart.py b/tests/unit_tests/mcp_service/chart/test_big_number_chart.py
index 59e142333bdb..c832d7793d15 100644
--- a/tests/unit_tests/mcp_service/chart/test_big_number_chart.py
+++ b/tests/unit_tests/mcp_service/chart/test_big_number_chart.py
@@ -90,7 +90,7 @@ def test_saved_metric_passes_pre_validation(self) -> None:
"chart_type": "big_number",
"metric": {"name": "total_sales", "saved_metric": True},
}
- is_valid, error = SchemaValidator._pre_validate_big_number_config(data)
+ is_valid, error = SchemaValidator._pre_validate_chart_type("big_number", data)
assert is_valid is True
assert error is None
diff --git a/tests/unit_tests/mcp_service/chart/test_registry.py b/tests/unit_tests/mcp_service/chart/test_registry.py
new file mode 100644
index 000000000000..0351b2d2bd33
--- /dev/null
+++ b/tests/unit_tests/mcp_service/chart/test_registry.py
@@ -0,0 +1,143 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Tests for the chart type plugin registry."""
+
+import pytest
+
+import superset.mcp_service.chart.registry as registry_module
+from superset.mcp_service.chart.plugin import BaseChartPlugin
+from superset.mcp_service.chart.registry import (
+ _RegistryProxy,
+ all_types,
+ display_name_for_viz_type,
+ get,
+ get_registry,
+ is_registered,
+ register,
+)
+
+
+@pytest.fixture(autouse=True)
+def _isolated_registry(monkeypatch):
+ """Run each test against a clean registry without touching the real one."""
+ monkeypatch.setattr(registry_module, "_REGISTRY", {})
+ monkeypatch.setattr(registry_module, "_plugins_loaded", True)
+
+
+class _FakePlugin(BaseChartPlugin):
+ chart_type = "fake"
+ display_name = "Fake Chart"
+ native_viz_types = {"fake_viz": "Fake Viz"}
+
+
+class _AnotherPlugin(BaseChartPlugin):
+ chart_type = "another"
+ display_name = "Another Chart"
+ native_viz_types = {"another_viz": "Another Viz"}
+
+
+def test_register_adds_plugin():
+ plugin = _FakePlugin()
+ register(plugin)
+ assert get("fake") is plugin
+
+
+def test_get_returns_none_for_unknown():
+ assert get("nonexistent") is None
+
+
+def test_all_types_returns_registered_keys():
+ register(_FakePlugin())
+ register(_AnotherPlugin())
+ types = all_types()
+ assert "fake" in types
+ assert "another" in types
+
+
+def test_all_types_insertion_order():
+ register(_FakePlugin())
+ register(_AnotherPlugin())
+ types = all_types()
+ assert types.index("fake") < types.index("another")
+
+
+def test_is_registered_true_for_known():
+ register(_FakePlugin())
+ assert is_registered("fake") is True
+
+
+def test_is_registered_false_for_unknown():
+ assert is_registered("nonexistent") is False
+
+
+def test_register_warns_on_duplicate(caplog):
+ register(_FakePlugin())
+ with caplog.at_level("WARNING"):
+ register(_FakePlugin())
+ assert "Overwriting" in caplog.text
+
+
+def test_register_raises_for_empty_chart_type():
+ class _BadPlugin(BaseChartPlugin):
+ chart_type = ""
+
+ with pytest.raises(ValueError, match="non-empty chart_type"):
+ register(_BadPlugin())
+
+
+def test_display_name_for_viz_type_found():
+ register(_FakePlugin())
+ assert display_name_for_viz_type("fake_viz") == "Fake Viz"
+
+
+def test_display_name_for_viz_type_not_found():
+ register(_FakePlugin())
+ assert display_name_for_viz_type("unknown_viz") is None
+
+
+def test_display_name_searches_all_plugins():
+ register(_FakePlugin())
+ register(_AnotherPlugin())
+ assert display_name_for_viz_type("another_viz") == "Another Viz"
+
+
+def test_get_registry_returns_proxy():
+ assert isinstance(get_registry(), _RegistryProxy)
+
+
+def test_registry_proxy_get():
+ plugin = _FakePlugin()
+ register(plugin)
+ assert get_registry().get("fake") is plugin
+
+
+def test_registry_proxy_all_types():
+ register(_FakePlugin())
+ assert "fake" in get_registry().all_types()
+
+
+def test_registry_proxy_is_registered():
+ register(_FakePlugin())
+ assert get_registry().is_registered("fake") is True
+ assert get_registry().is_registered("missing") is False
+
+
+def test_registry_proxy_display_name_for_viz_type():
+ register(_FakePlugin())
+ assert get_registry().display_name_for_viz_type("fake_viz") == "Fake Viz"
+ assert get_registry().display_name_for_viz_type("unknown") is None
diff --git a/tests/unit_tests/mcp_service/chart/tool/test_update_chart.py b/tests/unit_tests/mcp_service/chart/tool/test_update_chart.py
old mode 100644
new mode 100755
index c504d8bca598..21e2286fb6fb
--- a/tests/unit_tests/mcp_service/chart/tool/test_update_chart.py
+++ b/tests/unit_tests/mcp_service/chart/tool/test_update_chart.py
@@ -1175,6 +1175,11 @@ def _mock_chart_with_dataset(chart_id: int = 1) -> Mock:
)
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
@patch("superset.db.session")
+ @patch(
+ "superset.mcp_service.chart.validation.dataset_validator"
+ ".DatasetValidator.validate_against_dataset",
+ new=Mock(return_value=(True, None)),
+ )
@pytest.mark.asyncio
async def test_preview_path_validation_failure_skips_cache(
self,
@@ -1238,6 +1243,11 @@ async def test_preview_path_validation_failure_skips_cache(
)
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
@patch("superset.db.session")
+ @patch(
+ "superset.mcp_service.chart.validation.dataset_validator"
+ ".DatasetValidator.validate_against_dataset",
+ new=Mock(return_value=(True, None)),
+ )
@pytest.mark.asyncio
async def test_persist_path_validation_failure_skips_db_write(
self,
diff --git a/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py b/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py
index dbebe268b4b8..a81f0864f262 100644
--- a/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py
+++ b/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py
@@ -117,83 +117,6 @@ def test_unknown_column_returns_original(
assert result == "unknown_column"
-class TestNormalizeXYConfig:
- """Test _normalize_xy_config static method."""
-
- def test_normalize_x_axis_column(
- self, mock_dataset_context: DatasetContext
- ) -> None:
- """Test that x-axis column name is normalized."""
- config_dict: Dict[str, Any] = {
- "chart_type": "xy",
- "x": {"name": "orderdate"},
- "y": [{"name": "Sales", "aggregate": "SUM"}],
- "kind": "line",
- }
-
- DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context)
-
- assert config_dict["x"]["name"] == "OrderDate"
-
- def test_normalize_y_axis_columns(
- self, mock_dataset_context: DatasetContext
- ) -> None:
- """Test that y-axis column names are normalized."""
- config_dict: Dict[str, Any] = {
- "chart_type": "xy",
- "x": {"name": "OrderDate"},
- "y": [
- {"name": "sales", "aggregate": "SUM"},
- {"name": "QUANTITY_ORDERED", "aggregate": "COUNT"},
- ],
- "kind": "bar",
- }
-
- DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context)
-
- assert config_dict["y"][0]["name"] == "Sales"
- assert config_dict["y"][1]["name"] == "quantity_ordered"
-
- def test_normalize_group_by_column(
- self, mock_dataset_context: DatasetContext
- ) -> None:
- """Test that group_by column name is normalized."""
- config_dict: Dict[str, Any] = {
- "chart_type": "xy",
- "x": {"name": "OrderDate"},
- "y": [{"name": "Sales", "aggregate": "SUM"}],
- "kind": "line",
- "group_by": [{"name": "productline"}],
- }
-
- DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context)
-
- assert config_dict["group_by"][0]["name"] == "ProductLine"
-
-
-class TestNormalizeTableConfig:
- """Test _normalize_table_config static method."""
-
- def test_normalize_table_columns(
- self, mock_dataset_context: DatasetContext
- ) -> None:
- """Test that table column names are normalized."""
- config_dict: Dict[str, Any] = {
- "chart_type": "table",
- "columns": [
- {"name": "orderdate"},
- {"name": "PRODUCTLINE"},
- {"name": "sales", "aggregate": "SUM"},
- ],
- }
-
- DatasetValidator._normalize_table_config(config_dict, mock_dataset_context)
-
- assert config_dict["columns"][0]["name"] == "OrderDate"
- assert config_dict["columns"][1]["name"] == "ProductLine"
- assert config_dict["columns"][2]["name"] == "Sales"
-
-
class TestNormalizeFilters:
"""Test _normalize_filters static method."""
diff --git a/tests/unit_tests/mcp_service/chart/validation/test_runtime_validator.py b/tests/unit_tests/mcp_service/chart/validation/test_runtime_validator.py
index c49677cb99f2..6aed0b112698 100644
--- a/tests/unit_tests/mcp_service/chart/validation/test_runtime_validator.py
+++ b/tests/unit_tests/mcp_service/chart/validation/test_runtime_validator.py
@@ -58,12 +58,12 @@ def test_validate_runtime_issues_non_blocking_with_format_warnings(self):
x_axis=AxisConfig(format="$,.2f"), # Currency format for date - mismatch
)
- # Mock the format validator to return warnings
+ # Mock the plugin runtime dispatcher to return format warnings
with patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
- "_validate_format_compatibility"
- ) as mock_format:
- mock_format.return_value = [
+ "_validate_plugin_runtime"
+ ) as mock_plugin:
+ mock_plugin.return_value = [
"Currency format '$,.2f' may not display dates correctly"
]
@@ -87,15 +87,14 @@ def test_validate_runtime_issues_non_blocking_with_cardinality_warnings(self):
kind="bar",
)
- # Mock the cardinality validator to return warnings
+ # Mock the plugin runtime dispatcher to return cardinality warnings
with patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
- "_validate_cardinality"
- ) as mock_cardinality:
- mock_cardinality.return_value = (
- ["High cardinality detected: 10000+ unique values"],
- ["Consider using aggregation or filtering"],
- )
+ "_validate_plugin_runtime"
+ ) as mock_plugin:
+ mock_plugin.return_value = [
+ "High cardinality detected: 10000+ unique values"
+ ]
is_valid, warnings_metadata = RuntimeValidator.validate_runtime_issues(
config, 1
@@ -148,26 +147,21 @@ def test_validate_runtime_issues_non_blocking_with_multiple_warnings(self):
x_axis=AxisConfig(format="smart_date"), # Wrong format for user_id
)
- # Mock all validators to return warnings
+ # Mock plugin runtime and chart type validators to return warnings
with (
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
- "_validate_format_compatibility"
- ) as mock_format,
- patch(
- "superset.mcp_service.chart.validation.runtime.RuntimeValidator."
- "_validate_cardinality"
- ) as mock_cardinality,
+ "_validate_plugin_runtime"
+ ) as mock_plugin,
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_chart_type"
) as mock_type,
):
- mock_format.return_value = ["Format mismatch warning"]
- mock_cardinality.return_value = (
- ["High cardinality warning"],
- ["Cardinality suggestion"],
- )
+ mock_plugin.return_value = [
+ "Format mismatch warning",
+ "High cardinality warning",
+ ]
mock_type.return_value = (
["Chart type warning"],
["Chart type suggestion"],
@@ -197,13 +191,13 @@ def test_validate_runtime_issues_logs_warnings(self):
with (
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
- "_validate_format_compatibility"
- ) as mock_format,
+ "_validate_plugin_runtime"
+ ) as mock_plugin,
patch(
"superset.mcp_service.chart.validation.runtime.logger"
) as mock_logger,
):
- mock_format.return_value = ["Test warning message"]
+ mock_plugin.return_value = ["Test warning message"]
is_valid, warnings_metadata = RuntimeValidator.validate_runtime_issues(
config, 1
@@ -217,7 +211,7 @@ def test_validate_runtime_issues_logs_warnings(self):
assert "warnings" in warnings_metadata
def test_validate_table_chart_skips_xy_validations(self):
- """Test that table charts skip XY-specific validations."""
+ """Test that table charts produce no XY-specific runtime warnings."""
config = TableChartConfig(
chart_type="table",
columns=[
@@ -226,28 +220,15 @@ def test_validate_table_chart_skips_xy_validations(self):
],
)
- # These should not be called for table charts
- with (
- patch(
- "superset.mcp_service.chart.validation.runtime.RuntimeValidator."
- "_validate_format_compatibility"
- ) as mock_format,
- patch(
- "superset.mcp_service.chart.validation.runtime.RuntimeValidator."
- "_validate_cardinality"
- ) as mock_cardinality,
- patch(
- "superset.mcp_service.chart.validation.runtime.RuntimeValidator."
- "_validate_chart_type"
- ) as mock_chart_type,
- ):
- # Mock chart type validator to return no warnings
+ # Plugin runtime dispatches to TableChartPlugin which returns no warnings.
+ # Chart type suggester is also stubbed to return no warnings.
+ with patch(
+ "superset.mcp_service.chart.validation.runtime.RuntimeValidator."
+ "_validate_chart_type"
+ ) as mock_chart_type:
mock_chart_type.return_value = ([], [])
is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1)
- # Format and cardinality validation should not be called for table charts
- mock_format.assert_not_called()
- mock_cardinality.assert_not_called()
assert is_valid is True
assert error is None