diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 205720c3c8d7..9163ab7c8630 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -222,10 +222,12 @@ def get_default_instructions(branding: str = "Apache Superset") -> str: - PT1H (hourly), P1D (daily), P1W (weekly), P1M (monthly), P1Y (yearly) Chart Types in Existing Charts (viewable via list_charts/get_chart_info): -- pie, big_number, big_number_total, funnel, gauge_chart -- echarts_timeseries_line, echarts_timeseries_bar, echarts_timeseries_area -- pivot_table_v2, heatmap_v2, sankey_v2, sunburst_v2, treemap_v2 -- word_cloud, world_map, box_plot, bubble, mixed_timeseries +Each chart returned by list_charts / get_chart_info includes a +chart_type_display_name field with a human-readable name when available. +This field is populated only for the 7 chart types supported by generate_chart +(xy, pie, table, pivot_table, big_number, mixed_timeseries, handlebars). +For all other viz_types (Funnel, Gauge, Heatmap, etc.) it will be null — +use the raw viz_type field instead when referring to those chart types. Query Examples: - List all tables: @@ -503,6 +505,7 @@ def create_mcp_app( # NOTE: Always add new prompt/resource imports here when creating new prompts/resources. # Prompts use @mcp.prompt decorators and resources use @mcp.resource decorators. # They register automatically on import, similar to tools. +import superset.mcp_service.chart.plugins # noqa: F401, E402 — registers all chart type plugins from superset.mcp_service.chart import ( # noqa: F401, E402 prompts as chart_prompts, resources as chart_resources, diff --git a/superset/mcp_service/chart/chart_utils.py b/superset/mcp_service/chart/chart_utils.py index 5b865c1f0dd6..f2ac21493791 100644 --- a/superset/mcp_service/chart/chart_utils.py +++ b/superset/mcp_service/chart/chart_utils.py @@ -318,29 +318,35 @@ def map_config_to_form_data( | BigNumberChartConfig, dataset_id: int | str | None = None, ) -> Dict[str, Any]: - """Map chart config to Superset form_data.""" - if isinstance(config, TableChartConfig): - return map_table_config(config) - elif isinstance(config, XYChartConfig): - return map_xy_config(config, dataset_id=dataset_id) - elif isinstance(config, PieChartConfig): - return map_pie_config(config) - elif isinstance(config, PivotTableChartConfig): - return map_pivot_table_config(config) - elif isinstance(config, MixedTimeseriesChartConfig): - return map_mixed_timeseries_config(config, dataset_id=dataset_id) - elif isinstance(config, HandlebarsChartConfig): - return map_handlebars_config(config) - elif isinstance(config, BigNumberChartConfig): - if config.show_trendline and config.temporal_column: - if not is_column_truly_temporal(config.temporal_column, dataset_id): - raise ValueError( - f"Big Number trendline requires a temporal SQL column; " - f"'{config.temporal_column}' is not temporal." - ) - return map_big_number_config(config) - else: - raise ValueError(f"Unsupported config type: {type(config)}") + """Map chart config to Superset form_data via the plugin registry. + + The previous if/elif chain across all 7 chart types has been replaced by a + single registry lookup. Cross-field constraints (e.g. BigNumber trendline + temporal check) are now owned by each plugin's post_map_validate() method + rather than being baked into this dispatcher. + """ + # Local import: plugins call map_*_config from their to_form_data() methods, + # so chart_utils is loaded before plugins finish registering. A top-level + # import of registry here would trigger plugin loading mid-import = cycle. + from superset.mcp_service.chart.registry import get_registry + + chart_type = getattr(config, "chart_type", None) + plugin = get_registry().get(chart_type) if chart_type else None + + if plugin is None: + raise ValueError( + f"Unsupported config type: {type(config)} (chart_type={chart_type!r})" + ) + + form_data = plugin.to_form_data(config, dataset_id=dataset_id) + + # Run post-map validation (e.g. BigNumber trendline temporal type check). + # Raise ValueError to preserve backward-compatible error handling in callers. + error = plugin.post_map_validate(config, form_data, dataset_id=dataset_id) + if error is not None: + raise ValueError(error.message) + + return form_data def _add_adhoc_filters( @@ -1129,87 +1135,32 @@ def _big_number_chart_what(config: BigNumberChartConfig) -> str: def generate_chart_name( - config: TableChartConfig - | XYChartConfig - | PieChartConfig - | PivotTableChartConfig - | MixedTimeseriesChartConfig - | HandlebarsChartConfig - | BigNumberChartConfig, + config: Any, dataset_name: str | None = None, ) -> str: """Generate a descriptive chart name following a standard format. - Format conventions (by chart type): - Aggregated (bar/scatter with group_by): [Metric] by [Dimension] - Time-series (line/area, no group_by): [Metric] Over Time - Table (no aggregates): [Dataset] Records - Table (with aggregates): [Metric] Summary - Pie: [Dimension] by [Metric] - Pivot Table: Pivot Table – [Row1, Row2] - Mixed Timeseries: [Primary] + [Secondary] - An en-dash followed by context (filters / time grain) is appended + Delegates to each plugin's ``generate_name()`` method. + See each plugin's ``generate_name`` for chart-type-specific format conventions. + An en-dash followed by context (filters / time grain) is appended by the plugin when such information is available. """ - if isinstance(config, TableChartConfig): - what = _table_chart_what(config, dataset_name) - context = _summarize_filters(config.filters) - elif isinstance(config, XYChartConfig): - what = _xy_chart_what(config) - context = _xy_chart_context(config) - elif isinstance(config, PieChartConfig): - what = _pie_chart_what(config) - context = _summarize_filters(config.filters) - elif isinstance(config, PivotTableChartConfig): - what = _pivot_table_what(config) - context = _summarize_filters(config.filters) - elif isinstance(config, MixedTimeseriesChartConfig): - what = _mixed_timeseries_what(config) - context = _summarize_filters(config.filters) - elif isinstance(config, HandlebarsChartConfig): - what = _handlebars_chart_what(config) - context = _summarize_filters(getattr(config, "filters", None)) - elif isinstance(config, BigNumberChartConfig): - what = _big_number_chart_what(config) - context = _summarize_filters(getattr(config, "filters", None)) - else: - return "Chart" + from superset.mcp_service.chart.registry import get_registry - name = what - if context: - name = f"{what} \u2013 {context}" - return _truncate(name) + plugin = get_registry().get(getattr(config, "chart_type", "")) + if plugin is None: + return "Chart" + return _truncate(plugin.generate_name(config, dataset_name)) def _resolve_viz_type(config: Any) -> str: """Resolve the Superset viz_type from a chart config object.""" - chart_type = getattr(config, "chart_type", "unknown") - if chart_type == "xy": - kind = getattr(config, "kind", "line") - viz_type_map = { - "line": "echarts_timeseries_line", - "bar": "echarts_timeseries_bar", - "area": "echarts_area", - "scatter": "echarts_timeseries_scatter", - } - return viz_type_map.get(kind, "echarts_timeseries_line") - elif chart_type == "table": - return getattr(config, "viz_type", "table") - elif chart_type == "pie": - return "pie" - elif chart_type == "pivot_table": - return "pivot_table_v2" - elif chart_type == "mixed_timeseries": - return "mixed_timeseries" - elif chart_type == "handlebars": - return "handlebars" - elif chart_type == "big_number": - show_trendline = getattr(config, "show_trendline", False) - temporal_column = getattr(config, "temporal_column", None) - return ( - "big_number" if show_trendline and temporal_column else "big_number_total" - ) - return "unknown" + from superset.mcp_service.chart.registry import get_registry + + plugin = get_registry().get(getattr(config, "chart_type", "")) + if plugin is None: + return "unknown" + return plugin.resolve_viz_type(config) def analyze_chart_capabilities(chart: Any | None, config: Any) -> ChartCapabilities: diff --git a/superset/mcp_service/chart/plugin.py b/superset/mcp_service/chart/plugin.py new file mode 100755 index 000000000000..6d88c7194d96 --- /dev/null +++ b/superset/mcp_service/chart/plugin.py @@ -0,0 +1,255 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +ChartTypePlugin protocol and BaseChartPlugin base class. + +Each chart type owns its pre-validation, column extraction, form_data mapping, +and post-map validation in a single plugin class. This eliminates the previous +pattern of 4 separate dispatch points (schema_validator.py, dataset_validator.py, +chart_utils.py, pipeline.py) that had to be updated in sync whenever a new chart +type was added. +""" + +from __future__ import annotations + +from typing import Any, Protocol, runtime_checkable + +from superset.mcp_service.chart.schemas import ColumnRef +from superset.mcp_service.common.error_schemas import ChartGenerationError + + +@runtime_checkable +class ChartTypePlugin(Protocol): + """ + Protocol that every chart-type plugin must satisfy. + + Implementing all eight methods in a single class guarantees that adding a + new chart type requires only one new file — the plugin — rather than edits + across multiple separate files. + """ + + #: Discriminator value matching ChartConfig's chart_type field. + chart_type: str + + #: Human-readable name shown to users (e.g. "Line / Bar / Area / Scatter"). + display_name: str + + #: Maps every Superset-internal viz_type this plugin can produce to a + #: user-facing display name, e.g. {"echarts_timeseries_line": "Line Chart"}. + #: Used by the registry to resolve display names for existing charts without + #: needing a separate JSON mapping file. + native_viz_types: dict[str, str] + + def pre_validate( + self, + config: dict[str, Any], + ) -> ChartGenerationError | None: + """ + Early validation of the raw config dict before Pydantic parsing. + + Called by SchemaValidator before attempting to parse the request. + Should check that required top-level keys are present and well-typed. + + Returns None if valid, ChartGenerationError if invalid. + """ + ... + + def extract_column_refs( + self, + config: Any, + ) -> list[ColumnRef]: + """ + Extract all column references from a parsed chart config. + + Called by DatasetValidator to validate that all referenced columns exist + in the dataset. Must cover every field that holds a column name, + including filters. + + Returns a list of ColumnRef objects (may be empty). + """ + ... + + def to_form_data( + self, + config: Any, + dataset_id: int | str | None = None, + ) -> dict[str, Any]: + """ + Map a parsed chart config to Superset's internal form_data dict. + + Replaces the if/elif chain in chart_utils.map_config_to_form_data(). + + Returns a Superset form_data dict ready for caching and rendering. + """ + ... + + def post_map_validate( + self, + config: Any, + form_data: dict[str, Any], + dataset_id: int | str | None = None, + ) -> ChartGenerationError | None: + """ + Validate the mapped form_data after to_form_data() runs. + + Use this for cross-field constraints that can only be checked once + form_data is assembled (e.g. BigNumber trendline requires a temporal + column whose type must be verified against the dataset). + + Returns None if valid, ChartGenerationError if invalid. + """ + ... + + def normalize_column_refs( + self, + config: Any, + dataset_context: Any, + ) -> Any: + """ + Return a new config with column names normalized to canonical dataset casing. + + Called by DatasetValidator.normalize_column_names(). The default + implementation (in BaseChartPlugin) returns the config unchanged; plugins + with column fields override this to fix case sensitivity mismatches. + + Returns a new config object (or the original if no normalization needed). + """ + ... + + def get_runtime_warnings( + self, + config: Any, + dataset_id: int | str, + ) -> list[str]: + """ + Return chart-type-specific runtime warnings (performance, compatibility). + + Called by RuntimeValidator to collect per-type warnings. Warnings are + informational only — they never block chart generation. The default + implementation returns an empty list; plugins override this to emit + chart-type-specific warnings (e.g. XY cardinality checks). + + Returns a list of warning message strings (may be empty). + """ + ... + + def generate_name( + self, + config: Any, + dataset_name: str | None = None, + ) -> str: + """ + Return a descriptive chart name for the given config. + + Called by chart_utils.generate_chart_name(). The name should follow + the standard format conventions documented in that function. Plugins + that do not override this return the generic fallback "Chart". + """ + ... + + def resolve_viz_type(self, config: Any) -> str: + """ + Return the Superset-internal viz_type string for this config. + + Called by chart_utils._resolve_viz_type(). The returned string must + match a registered Superset viz plugin (e.g. "echarts_timeseries_line"). + Plugins that do not override this return "unknown". + """ + ... + + def schema_error_hint(self) -> "ChartGenerationError | None": + """ + Return a user-friendly error for Pydantic discriminated-union parse failures. + + Called by SchemaValidator when Pydantic cannot parse the config union and + the chart_type is known. Returning None falls back to the generic error. + """ + ... + + +class BaseChartPlugin: + """ + Base class providing sensible defaults for all ChartTypePlugin methods. + + Concrete plugins extend this and override only what they need. + """ + + chart_type: str = "" + display_name: str = "" + native_viz_types: dict[str, str] = {} + + def pre_validate( + self, + config: dict[str, Any], + ) -> ChartGenerationError | None: + return None + + def extract_column_refs( + self, + config: Any, + ) -> list[ColumnRef]: + return [] + + def to_form_data( + self, + config: Any, + dataset_id: int | str | None = None, + ) -> dict[str, Any]: + raise NotImplementedError( + f"{self.__class__.__name__}.to_form_data() is not implemented" + ) + + def post_map_validate( + self, + config: Any, + form_data: dict[str, Any], + dataset_id: int | str | None = None, + ) -> ChartGenerationError | None: + return None + + def normalize_column_refs( + self, + config: Any, + dataset_context: Any, + ) -> Any: + return config + + def get_runtime_warnings( + self, + config: Any, + dataset_id: int | str, + ) -> list[str]: + return [] + + def generate_name( + self, + config: Any, + dataset_name: str | None = None, + ) -> str: + return "Chart" + + def resolve_viz_type(self, config: Any) -> str: + return "unknown" + + def schema_error_hint(self) -> ChartGenerationError | None: + return None + + @staticmethod + def _with_context(what: str, context: str | None) -> str: + """Combine a 'what' label and optional context with an en-dash.""" + return f"{what} – {context}" if context else what diff --git a/superset/mcp_service/chart/plugins/__init__.py b/superset/mcp_service/chart/plugins/__init__.py new file mode 100644 index 000000000000..5527c43a55f0 --- /dev/null +++ b/superset/mcp_service/chart/plugins/__init__.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Chart type plugins package. + +Importing this module registers all built-in chart type plugins in the global +registry. This module is imported by app.py at startup. + +To add a new chart type: +1. Create ``superset/mcp_service/chart/plugins/{chart_type}.py`` +2. Implement a class extending ``BaseChartPlugin`` +3. Import and register it here +""" + +from superset.mcp_service.chart.plugins.big_number import BigNumberChartPlugin +from superset.mcp_service.chart.plugins.handlebars import HandlebarsChartPlugin +from superset.mcp_service.chart.plugins.mixed_timeseries import ( + MixedTimeseriesChartPlugin, +) +from superset.mcp_service.chart.plugins.pie import PieChartPlugin +from superset.mcp_service.chart.plugins.pivot_table import PivotTableChartPlugin +from superset.mcp_service.chart.plugins.table import TableChartPlugin +from superset.mcp_service.chart.plugins.xy import XYChartPlugin +from superset.mcp_service.chart.registry import register + +# Register all built-in chart type plugins +register(XYChartPlugin()) +register(TableChartPlugin()) +register(PieChartPlugin()) +register(PivotTableChartPlugin()) +register(MixedTimeseriesChartPlugin()) +register(HandlebarsChartPlugin()) +register(BigNumberChartPlugin()) + +__all__ = [ + "BigNumberChartPlugin", + "HandlebarsChartPlugin", + "MixedTimeseriesChartPlugin", + "PieChartPlugin", + "PivotTableChartPlugin", + "TableChartPlugin", + "XYChartPlugin", +] diff --git a/superset/mcp_service/chart/plugins/big_number.py b/superset/mcp_service/chart/plugins/big_number.py new file mode 100755 index 000000000000..e542f8e75f01 --- /dev/null +++ b/superset/mcp_service/chart/plugins/big_number.py @@ -0,0 +1,220 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Big number chart type plugin.""" + +from __future__ import annotations + +from typing import Any + +from superset.mcp_service.chart.chart_utils import ( + _big_number_chart_what, + _summarize_filters, + is_column_truly_temporal, + map_big_number_config, +) +from superset.mcp_service.chart.plugin import BaseChartPlugin +from superset.mcp_service.chart.schemas import BigNumberChartConfig, ColumnRef +from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator +from superset.mcp_service.common.error_schemas import ChartGenerationError + + +class BigNumberChartPlugin(BaseChartPlugin): + """Plugin for big_number chart type.""" + + chart_type = "big_number" + display_name = "Big Number" + native_viz_types = { + "big_number": "Big Number with Trendline", + "big_number_total": "Big Number", + } + + def pre_validate( + self, + config: dict[str, Any], + ) -> ChartGenerationError | None: + if "metric" not in config: + return ChartGenerationError( + error_type="missing_metric", + message="Big Number chart missing required field: metric", + details=( + "Big Number charts require a 'metric' field " + "specifying the value to display" + ), + suggestions=[ + "Add 'metric' with name and aggregate: " + "{'name': 'revenue', 'aggregate': 'SUM'}", + "The aggregate function is required (SUM, COUNT, AVG, MIN, MAX)", + "Example: {'chart_type': 'big_number', " + "'metric': {'name': 'sales', 'aggregate': 'SUM'}}", + ], + error_code="MISSING_BIG_NUMBER_METRIC", + ) + + metric = config.get("metric", {}) + if not isinstance(metric, dict): + return ChartGenerationError( + error_type="invalid_metric_type", + message="Big Number metric must be a dict with 'name' and 'aggregate'", + details=( + f"The 'metric' field must be an object, got {type(metric).__name__}" + ), + suggestions=[ + "Use a dict: {'name': 'col', 'aggregate': 'SUM'}", + "Valid aggregates: SUM, COUNT, AVG, MIN, MAX", + ], + error_code="INVALID_BIG_NUMBER_METRIC_TYPE", + ) + if not metric.get("aggregate") and not metric.get("saved_metric"): + return ChartGenerationError( + error_type="missing_metric_aggregate", + message=( + "Big Number metric must include an aggregate function " + "or reference a saved metric" + ), + details=( + "The metric must have an 'aggregate' field or 'saved_metric': true" + ), + suggestions=[ + "Add 'aggregate': {'name': 'col', 'aggregate': 'SUM'}", + "Or use a saved metric: {'name': 'metric', 'saved_metric': true}", + "Valid aggregates: SUM, COUNT, AVG, MIN, MAX", + ], + error_code="MISSING_BIG_NUMBER_AGGREGATE", + ) + + show_trendline = config.get("show_trendline", False) + temporal_column = config.get("temporal_column") + if show_trendline and not temporal_column: + return ChartGenerationError( + error_type="missing_temporal_column", + message="Trendline requires a temporal column", + details=( + "When 'show_trendline' is True, " + "a 'temporal_column' must be specified" + ), + suggestions=[ + "Add 'temporal_column': 'date_column_name'", + "Or set 'show_trendline': false for number only", + "Use get_dataset_info to find temporal columns", + ], + error_code="MISSING_TEMPORAL_COLUMN", + ) + + return None + + def extract_column_refs(self, config: Any) -> list[ColumnRef]: + if not isinstance(config, BigNumberChartConfig): + return [] + refs: list[ColumnRef] = [config.metric] + # temporal_column is a str field, not a ColumnRef — validate it exists + if config.temporal_column: + refs.append(ColumnRef(name=config.temporal_column)) + if config.filters: + for f in config.filters: + refs.append(ColumnRef(name=f.column)) + return refs + + def to_form_data( + self, config: Any, dataset_id: int | str | None = None + ) -> dict[str, Any]: + return map_big_number_config(config) + + def post_map_validate( + self, + config: Any, + form_data: dict[str, Any], + dataset_id: int | str | None = None, + ) -> ChartGenerationError | None: + """Verify the trendline temporal column is a real temporal SQL type. + + This check was previously baked into map_config_to_form_data() in + chart_utils.py as a special case. Moving it here keeps the dispatcher + clean and makes the constraint explicit and discoverable. + """ + if not isinstance(config, BigNumberChartConfig): + return None + if not (config.show_trendline and config.temporal_column): + return None + + if not is_column_truly_temporal(config.temporal_column, dataset_id): + return ChartGenerationError( + error_type="non_temporal_trendline_column", + message=( + f"Big Number trendline requires a temporal SQL column; " + f"'{config.temporal_column}' is not temporal." + ), + details=( + f"Column '{config.temporal_column}' does not have a temporal " + f"SQL type (DATE, DATETIME, TIMESTAMP). The trendline requires " + f"a true temporal column for DATE_TRUNC to work." + ), + suggestions=[ + "Use get_dataset_info to find columns with temporal SQL types", + "Set 'show_trendline': false to use any column as the metric", + "If the column contains dates stored as integers, " + "consider casting it in a virtual dataset", + ], + error_code="NON_TEMPORAL_TRENDLINE_COLUMN", + ) + + return None + + def generate_name(self, config: Any, dataset_name: str | None = None) -> str: + what = _big_number_chart_what(config) + context = _summarize_filters(getattr(config, "filters", None)) + return self._with_context(what, context) + + def resolve_viz_type(self, config: Any) -> str: + show_trendline = getattr(config, "show_trendline", False) + temporal_column = getattr(config, "temporal_column", None) + if show_trendline and temporal_column: + return "big_number" + return "big_number_total" + + def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any: + config_dict = config.model_dump() + + if config_dict.get("metric") and not config_dict["metric"].get("saved_metric"): + config_dict["metric"]["name"] = DatasetValidator._get_canonical_column_name( + config_dict["metric"]["name"], dataset_context + ) + if config_dict.get("temporal_column"): + config_dict["temporal_column"] = ( + DatasetValidator._get_canonical_column_name( + config_dict["temporal_column"], dataset_context + ) + ) + DatasetValidator._normalize_filters(config_dict, dataset_context) + return BigNumberChartConfig.model_validate(config_dict) + + def schema_error_hint(self) -> ChartGenerationError | None: + return ChartGenerationError( + error_type="big_number_validation_error", + message="Big Number chart configuration validation failed", + details=( + "The Big Number chart configuration is missing required " + "fields or has invalid structure" + ), + suggestions=[ + "Ensure 'metric' field has 'name' and 'aggregate'", + "Example: 'metric': {'name': 'revenue', 'aggregate': 'SUM'}", + "For trendline: add show_trendline=true and temporal_column='col'", + "Without trendline: just provide the metric", + ], + error_code="BIG_NUMBER_VALIDATION_ERROR", + ) diff --git a/superset/mcp_service/chart/plugins/handlebars.py b/superset/mcp_service/chart/plugins/handlebars.py new file mode 100755 index 000000000000..53d78cd8b824 --- /dev/null +++ b/superset/mcp_service/chart/plugins/handlebars.py @@ -0,0 +1,189 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Handlebars chart type plugin.""" + +from __future__ import annotations + +from typing import Any + +from superset.mcp_service.chart.chart_utils import ( + _handlebars_chart_what, + _summarize_filters, + map_handlebars_config, +) +from superset.mcp_service.chart.plugin import BaseChartPlugin +from superset.mcp_service.chart.schemas import ColumnRef, HandlebarsChartConfig +from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator +from superset.mcp_service.common.error_schemas import ChartGenerationError + + +class HandlebarsChartPlugin(BaseChartPlugin): + """Plugin for handlebars chart type (custom HTML template charts).""" + + chart_type = "handlebars" + display_name = "Handlebars (Custom Template)" + native_viz_types = { + "handlebars": "Custom Template Chart", + } + + def pre_validate( + self, + config: dict[str, Any], + ) -> ChartGenerationError | None: + if "handlebars_template" not in config: + return ChartGenerationError( + error_type="missing_handlebars_template", + message="Handlebars chart missing required field: handlebars_template", + details=( + "Handlebars charts require a 'handlebars_template' string " + "containing Handlebars HTML template markup" + ), + suggestions=[ + "Add 'handlebars_template' with a Handlebars HTML template", + "Data is available as {{data}} array in the template", + "Example: ''", + ], + error_code="MISSING_HANDLEBARS_TEMPLATE", + ) + + template = config.get("handlebars_template") + if not isinstance(template, str) or not template.strip(): + return ChartGenerationError( + error_type="invalid_handlebars_template", + message="Handlebars template must be a non-empty string", + details=( + "The 'handlebars_template' field must be a non-empty string " + "containing valid Handlebars HTML template markup" + ), + suggestions=[ + "Ensure handlebars_template is a non-empty string", + "Example: ''", + ], + error_code="INVALID_HANDLEBARS_TEMPLATE", + ) + + query_mode = config.get("query_mode", "aggregate") + if query_mode not in ("aggregate", "raw"): + return ChartGenerationError( + error_type="invalid_query_mode", + message="Invalid query_mode for handlebars chart", + details="query_mode must be either 'aggregate' or 'raw'", + suggestions=[ + "Use 'aggregate' for aggregated data (default)", + "Use 'raw' for individual rows", + ], + error_code="INVALID_QUERY_MODE", + ) + + if query_mode == "raw" and not config.get("columns"): + return ChartGenerationError( + error_type="missing_raw_columns", + message="Handlebars chart in 'raw' mode requires 'columns'", + details=( + "When query_mode is 'raw', you must specify which columns " + "to include in the query results" + ), + suggestions=[ + "Add 'columns': [{'name': 'column_name'}] for raw mode", + "Or use query_mode='aggregate' with 'metrics' and optional 'groupby'", # noqa: E501 + ], + error_code="MISSING_RAW_COLUMNS", + ) + + if query_mode == "aggregate" and not config.get("metrics"): + return ChartGenerationError( + error_type="missing_aggregate_metrics", + message="Handlebars chart in 'aggregate' mode requires 'metrics'", + details=( + "When query_mode is 'aggregate' (default), you must specify " + "at least one metric with an aggregate function" + ), + suggestions=[ + "Add 'metrics': [{'name': 'column', 'aggregate': 'SUM'}]", + "Or use query_mode='raw' with 'columns' for individual rows", + ], + error_code="MISSING_AGGREGATE_METRICS", + ) + + return None + + def extract_column_refs(self, config: Any) -> list[ColumnRef]: + if not isinstance(config, HandlebarsChartConfig): + return [] + refs: list[ColumnRef] = [] + if config.columns: + refs.extend(config.columns) + if config.metrics: + refs.extend(config.metrics) + if config.groupby: + refs.extend(config.groupby) + if config.filters: + for f in config.filters: + refs.append(ColumnRef(name=f.column)) + return refs + + def to_form_data( + self, config: Any, dataset_id: int | str | None = None + ) -> dict[str, Any]: + return map_handlebars_config(config) + + def generate_name(self, config: Any, dataset_name: str | None = None) -> str: + what = _handlebars_chart_what(config) + context = _summarize_filters(getattr(config, "filters", None)) + return self._with_context(what, context) + + def resolve_viz_type(self, config: Any) -> str: + return "handlebars" + + def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any: + config_dict = config.model_dump() + + def _norm_list(key: str) -> None: + if config_dict.get(key): + for col in config_dict[key]: + if not col.get("saved_metric"): + col["name"] = DatasetValidator._get_canonical_column_name( + col["name"], dataset_context + ) + + _norm_list("columns") + _norm_list("metrics") + _norm_list("groupby") + DatasetValidator._normalize_filters(config_dict, dataset_context) + return HandlebarsChartConfig.model_validate(config_dict) + + def schema_error_hint(self) -> ChartGenerationError | None: + return ChartGenerationError( + error_type="handlebars_validation_error", + message="Handlebars chart configuration validation failed", + details=( + "The handlebars chart configuration is missing " + "required fields or has invalid structure" + ), + suggestions=[ + "Ensure 'handlebars_template' is a non-empty string", + "For aggregate mode: add 'metrics' with aggregate functions", + "For raw mode: set 'query_mode': 'raw' and add 'columns'", + "Example: {'chart_type': 'handlebars', " + "'handlebars_template': " + "'', " + "'metrics': [{'name': 'sales', 'aggregate': 'SUM'}]}", + ], + error_code="HANDLEBARS_VALIDATION_ERROR", + ) diff --git a/superset/mcp_service/chart/plugins/mixed_timeseries.py b/superset/mcp_service/chart/plugins/mixed_timeseries.py new file mode 100755 index 000000000000..0cf7b82e80eb --- /dev/null +++ b/superset/mcp_service/chart/plugins/mixed_timeseries.py @@ -0,0 +1,165 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Mixed timeseries chart type plugin.""" + +from __future__ import annotations + +from typing import Any + +from superset.mcp_service.chart.chart_utils import ( + _mixed_timeseries_what, + _summarize_filters, + map_mixed_timeseries_config, +) +from superset.mcp_service.chart.plugin import BaseChartPlugin +from superset.mcp_service.chart.schemas import ColumnRef, MixedTimeseriesChartConfig +from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator +from superset.mcp_service.common.error_schemas import ChartGenerationError + + +class MixedTimeseriesChartPlugin(BaseChartPlugin): + """Plugin for mixed_timeseries chart type.""" + + chart_type = "mixed_timeseries" + display_name = "Mixed Timeseries" + native_viz_types = { + "mixed_timeseries": "Mixed Timeseries Chart", + } + + def pre_validate( + self, + config: dict[str, Any], + ) -> ChartGenerationError | None: + missing_fields = [] + + if "x" not in config: + missing_fields.append("'x' (X-axis temporal column)") + if "y" not in config: + missing_fields.append("'y' (primary Y-axis metrics)") + if "y_secondary" not in config: + missing_fields.append("'y_secondary' (secondary Y-axis metrics)") + + if missing_fields: + return ChartGenerationError( + error_type="missing_mixed_timeseries_fields", + message=( + f"Mixed timeseries chart missing required fields: " + f"{', '.join(missing_fields)}" + ), + details=( + "Mixed timeseries charts require an x-axis, primary metrics, " + "and secondary metrics" + ), + suggestions=[ + "Add 'x' field: {'name': 'date_column'}", + "Add 'y' field: [{'name': 'revenue', 'aggregate': 'SUM'}]", + "Add 'y_secondary': [{'name': 'orders', 'aggregate': 'COUNT'}]", + "Optional: 'primary_kind' and 'secondary_kind' for chart types", + ], + error_code="MISSING_MIXED_TIMESERIES_FIELDS", + ) + + for field_name in ["y", "y_secondary"]: + if not isinstance(config.get(field_name, []), list): + return ChartGenerationError( + error_type=f"invalid_{field_name}_format", + message=f"'{field_name}' must be a list of metrics", + details=( + f"The '{field_name}' field must be an array of metric " + "specifications" + ), + suggestions=[ + f"Wrap in array: '{field_name}': " + "[{'name': 'col', 'aggregate': 'SUM'}]", + ], + error_code=f"INVALID_{field_name.upper()}_FORMAT", + ) + + return None + + def extract_column_refs(self, config: Any) -> list[ColumnRef]: + if not isinstance(config, MixedTimeseriesChartConfig): + return [] + refs: list[ColumnRef] = [config.x] + refs.extend(config.y) + refs.extend(config.y_secondary) + if config.group_by: + refs.extend(config.group_by) + if config.group_by_secondary: + refs.extend(config.group_by_secondary) + if config.filters: + for f in config.filters: + refs.append(ColumnRef(name=f.column)) + return refs + + def to_form_data( + self, config: Any, dataset_id: int | str | None = None + ) -> dict[str, Any]: + return map_mixed_timeseries_config(config, dataset_id=dataset_id) + + def generate_name(self, config: Any, dataset_name: str | None = None) -> str: + what = _mixed_timeseries_what(config) + context = _summarize_filters(config.filters) + return self._with_context(what, context) + + def resolve_viz_type(self, config: Any) -> str: + return "mixed_timeseries" + + def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any: + config_dict = config.model_dump() + + def _norm_single(key: str) -> None: + if config_dict.get(key): + config_dict[key]["name"] = DatasetValidator._get_canonical_column_name( + config_dict[key]["name"], dataset_context + ) + + def _norm_list(key: str) -> None: + if config_dict.get(key): + for col in config_dict[key]: + col["name"] = DatasetValidator._get_canonical_column_name( + col["name"], dataset_context + ) + + _norm_single("x") + _norm_list("y") + _norm_list("y_secondary") + _norm_list("group_by") + _norm_list("group_by_secondary") + DatasetValidator._normalize_filters(config_dict, dataset_context) + return MixedTimeseriesChartConfig.model_validate(config_dict) + + def schema_error_hint(self) -> ChartGenerationError | None: + return ChartGenerationError( + error_type="mixed_timeseries_validation_error", + message="Mixed timeseries chart configuration validation failed", + details=( + "The mixed timeseries configuration is missing " + "required fields or has invalid structure" + ), + suggestions=[ + "Ensure 'x' field has 'name' for the time axis column", + "Ensure 'y' is an array of primary-axis metrics", + "Ensure 'y_secondary' is an array of secondary-axis metrics", + "Example: {'chart_type': 'mixed_timeseries', " + "'x': {'name': 'order_date'}, " + "'y': [{'name': 'revenue', 'aggregate': 'SUM'}], " + "'y_secondary': [{'name': 'orders', 'aggregate': 'COUNT'}]}", + ], + error_code="MIXED_TIMESERIES_VALIDATION_ERROR", + ) diff --git a/superset/mcp_service/chart/plugins/pie.py b/superset/mcp_service/chart/plugins/pie.py new file mode 100755 index 000000000000..3d87fe7f05f2 --- /dev/null +++ b/superset/mcp_service/chart/plugins/pie.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pie chart type plugin.""" + +from __future__ import annotations + +from typing import Any + +from superset.mcp_service.chart.chart_utils import ( + _pie_chart_what, + _summarize_filters, + map_pie_config, +) +from superset.mcp_service.chart.plugin import BaseChartPlugin +from superset.mcp_service.chart.schemas import ColumnRef, PieChartConfig +from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator +from superset.mcp_service.common.error_schemas import ChartGenerationError + + +class PieChartPlugin(BaseChartPlugin): + """Plugin for pie chart type.""" + + chart_type = "pie" + display_name = "Pie / Donut Chart" + native_viz_types = { + "pie": "Pie Chart", + } + + def pre_validate( + self, + config: dict[str, Any], + ) -> ChartGenerationError | None: + missing_fields = [] + + if "dimension" not in config: + missing_fields.append("'dimension' (category column for slices)") + if "metric" not in config: + missing_fields.append("'metric' (value metric for slice sizes)") + + if missing_fields: + return ChartGenerationError( + error_type="missing_pie_fields", + message=( + f"Pie chart missing required fields: {', '.join(missing_fields)}" + ), + details=( + "Pie charts require a dimension (categories) and a metric (values)" + ), + suggestions=[ + "Add 'dimension' field: {'name': 'category_column'}", + "Add 'metric' field: {'name': 'value_column', 'aggregate': 'SUM'}", + "Example: {'chart_type': 'pie', 'dimension': {'name': 'product'}, " + "'metric': {'name': 'revenue', 'aggregate': 'SUM'}}", + ], + error_code="MISSING_PIE_FIELDS", + ) + + return None + + def extract_column_refs(self, config: Any) -> list[ColumnRef]: + if not isinstance(config, PieChartConfig): + return [] + refs: list[ColumnRef] = [config.dimension, config.metric] + if config.filters: + for f in config.filters: + refs.append(ColumnRef(name=f.column)) + return refs + + def to_form_data( + self, config: Any, dataset_id: int | str | None = None + ) -> dict[str, Any]: + return map_pie_config(config) + + def generate_name(self, config: Any, dataset_name: str | None = None) -> str: + what = _pie_chart_what(config) + context = _summarize_filters(config.filters) + return self._with_context(what, context) + + def resolve_viz_type(self, config: Any) -> str: + return "pie" + + def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any: + config_dict = config.model_dump() + + if config_dict.get("dimension"): + config_dict["dimension"]["name"] = ( + DatasetValidator._get_canonical_column_name( + config_dict["dimension"]["name"], dataset_context + ) + ) + if config_dict.get("metric") and not config_dict["metric"].get("saved_metric"): + config_dict["metric"]["name"] = DatasetValidator._get_canonical_column_name( + config_dict["metric"]["name"], dataset_context + ) + DatasetValidator._normalize_filters(config_dict, dataset_context) + return PieChartConfig.model_validate(config_dict) + + def schema_error_hint(self) -> ChartGenerationError | None: + return ChartGenerationError( + error_type="pie_validation_error", + message="Pie chart configuration validation failed", + details=( + "The pie chart configuration is missing required " + "fields or has invalid structure" + ), + suggestions=[ + "Ensure 'dimension' field has 'name' for the slice label", + "Ensure 'metric' field has 'name' and 'aggregate'", + "Example: {'chart_type': 'pie', 'dimension': {'name': 'category'}, " + "'metric': {'name': 'revenue', 'aggregate': 'SUM'}}", + ], + error_code="PIE_VALIDATION_ERROR", + ) diff --git a/superset/mcp_service/chart/plugins/pivot_table.py b/superset/mcp_service/chart/plugins/pivot_table.py new file mode 100755 index 000000000000..038f8c794161 --- /dev/null +++ b/superset/mcp_service/chart/plugins/pivot_table.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pivot table chart type plugin.""" + +from __future__ import annotations + +from typing import Any + +from superset.mcp_service.chart.chart_utils import ( + _pivot_table_what, + _summarize_filters, + map_pivot_table_config, +) +from superset.mcp_service.chart.plugin import BaseChartPlugin +from superset.mcp_service.chart.schemas import ColumnRef, PivotTableChartConfig +from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator +from superset.mcp_service.common.error_schemas import ChartGenerationError + + +class PivotTableChartPlugin(BaseChartPlugin): + """Plugin for pivot_table chart type.""" + + chart_type = "pivot_table" + display_name = "Pivot Table" + native_viz_types = { + "pivot_table_v2": "Pivot Table", + } + + def pre_validate( + self, + config: dict[str, Any], + ) -> ChartGenerationError | None: + missing_fields = [] + + if "rows" not in config: + missing_fields.append("'rows' (row grouping columns)") + if "metrics" not in config: + missing_fields.append("'metrics' (aggregation metrics)") + + if missing_fields: + return ChartGenerationError( + error_type="missing_pivot_fields", + message=( + f"Pivot table missing required fields: {', '.join(missing_fields)}" + ), + details="Pivot tables require row groupings and metrics", + suggestions=[ + "Add 'rows' field: [{'name': 'category'}]", + "Add 'metrics' field: [{'name': 'sales', 'aggregate': 'SUM'}]", + "Optional 'columns' for cross-tabulation: [{'name': 'region'}]", + ], + error_code="MISSING_PIVOT_FIELDS", + ) + + if not isinstance(config.get("rows", []), list): + return ChartGenerationError( + error_type="invalid_rows_format", + message="Rows must be a list of columns", + details="The 'rows' field must be an array of column specifications", + suggestions=[ + "Wrap row columns in array: 'rows': [{'name': 'category'}]", + ], + error_code="INVALID_ROWS_FORMAT", + ) + + if not isinstance(config.get("metrics", []), list): + return ChartGenerationError( + error_type="invalid_metrics_format", + message="Metrics must be a list", + details="The 'metrics' field must be an array of metric specifications", + suggestions=[ + "Wrap metrics in array: 'metrics': [{'name': 'sales', " + "'aggregate': 'SUM'}]", + ], + error_code="INVALID_METRICS_FORMAT", + ) + + return None + + def extract_column_refs(self, config: Any) -> list[ColumnRef]: + if not isinstance(config, PivotTableChartConfig): + return [] + refs: list[ColumnRef] = list(config.rows) + refs.extend(config.metrics) + if config.columns: + refs.extend(config.columns) + if config.filters: + for f in config.filters: + refs.append(ColumnRef(name=f.column)) + return refs + + def to_form_data( + self, config: Any, dataset_id: int | str | None = None + ) -> dict[str, Any]: + return map_pivot_table_config(config) + + def generate_name(self, config: Any, dataset_name: str | None = None) -> str: + what = _pivot_table_what(config) + context = _summarize_filters(config.filters) + return self._with_context(what, context) + + def resolve_viz_type(self, config: Any) -> str: + return "pivot_table_v2" + + def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any: + config_dict = config.model_dump() + + def _norm_col_list(key: str) -> None: + if config_dict.get(key): + for col in config_dict[key]: + col["name"] = DatasetValidator._get_canonical_column_name( + col["name"], dataset_context + ) + + _norm_col_list("rows") + _norm_col_list("metrics") + _norm_col_list("columns") + DatasetValidator._normalize_filters(config_dict, dataset_context) + return PivotTableChartConfig.model_validate(config_dict) + + def schema_error_hint(self) -> ChartGenerationError | None: + return ChartGenerationError( + error_type="pivot_table_validation_error", + message="Pivot table configuration validation failed", + details=( + "The pivot table configuration is missing required " + "fields or has invalid structure" + ), + suggestions=[ + "Ensure 'rows' field is an array of column specs", + "Ensure 'metrics' field is an array with aggregate funcs", + "Optional: add 'columns' for column grouping", + "Example: {'chart_type': 'pivot_table', " + "'rows': [{'name': 'region'}], " + "'metrics': [{'name': 'revenue', 'aggregate': 'SUM'}]}", + ], + error_code="PIVOT_TABLE_VALIDATION_ERROR", + ) diff --git a/superset/mcp_service/chart/plugins/table.py b/superset/mcp_service/chart/plugins/table.py new file mode 100755 index 000000000000..86f5dcaead25 --- /dev/null +++ b/superset/mcp_service/chart/plugins/table.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Table chart type plugin.""" + +from __future__ import annotations + +from typing import Any + +from superset.mcp_service.chart.chart_utils import ( + _summarize_filters, + _table_chart_what, + map_table_config, +) +from superset.mcp_service.chart.plugin import BaseChartPlugin +from superset.mcp_service.chart.schemas import ColumnRef, TableChartConfig +from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator +from superset.mcp_service.common.error_schemas import ChartGenerationError + + +class TableChartPlugin(BaseChartPlugin): + """Plugin for table chart type.""" + + chart_type = "table" + display_name = "Table" + native_viz_types = { + "table": "Table", + "ag-grid-table": "Interactive Table", + } + + def pre_validate( + self, + config: dict[str, Any], + ) -> ChartGenerationError | None: + if "columns" not in config: + return ChartGenerationError( + error_type="missing_columns", + message="Table chart missing required field: columns", + details=( + "Table charts require a 'columns' array to specify which " + "columns to display" + ), + suggestions=[ + "Add 'columns' field with array of column specifications", + "Example: 'columns': [{'name': 'product'}, {'name': 'sales', " + "'aggregate': 'SUM'}]", + "Each column can have optional 'aggregate' for metrics", + ], + error_code="MISSING_COLUMNS", + ) + + if not isinstance(config.get("columns", []), list): + return ChartGenerationError( + error_type="invalid_columns_format", + message="Columns must be a list", + details="The 'columns' field must be an array of column specifications", + suggestions=[ + "Ensure columns is an array: 'columns': [...]", + "Each column should be an object with 'name' field", + ], + error_code="INVALID_COLUMNS_FORMAT", + ) + + return None + + def extract_column_refs(self, config: Any) -> list[ColumnRef]: + if not isinstance(config, TableChartConfig): + return [] + refs: list[ColumnRef] = list(config.columns) + if config.filters: + for f in config.filters: + refs.append(ColumnRef(name=f.column)) + return refs + + def to_form_data( + self, config: Any, dataset_id: int | str | None = None + ) -> dict[str, Any]: + return map_table_config(config) + + def generate_name(self, config: Any, dataset_name: str | None = None) -> str: + what = _table_chart_what(config, dataset_name) + context = _summarize_filters(config.filters) + return self._with_context(what, context) + + def resolve_viz_type(self, config: Any) -> str: + return getattr(config, "viz_type", "table") + + def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any: + config_dict = config.model_dump() + get_canonical = DatasetValidator._get_canonical_column_name + + for col in config_dict.get("columns") or []: + col["name"] = get_canonical(col["name"], dataset_context) + + DatasetValidator._normalize_filters(config_dict, dataset_context) + return TableChartConfig.model_validate(config_dict) + + def schema_error_hint(self) -> ChartGenerationError | None: + return ChartGenerationError( + error_type="table_validation_error", + message="Table chart configuration validation failed", + details=( + "The table chart configuration is missing required " + "fields or has invalid structure" + ), + suggestions=[ + "Ensure 'columns' field is an array of column specifications", + "Each column needs {'name': 'column_name'}", + "Optional: add 'aggregate' for metrics", + "Example: 'columns': [{'name': 'product'}, " + "{'name': 'sales', 'aggregate': 'SUM'}]", + ], + error_code="TABLE_VALIDATION_ERROR", + ) diff --git a/superset/mcp_service/chart/plugins/xy.py b/superset/mcp_service/chart/plugins/xy.py new file mode 100755 index 000000000000..076826f3f08e --- /dev/null +++ b/superset/mcp_service/chart/plugins/xy.py @@ -0,0 +1,192 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""XY chart type plugin (line, bar, area, scatter).""" + +from __future__ import annotations + +import logging +from typing import Any + +from superset.mcp_service.chart.chart_utils import ( + _xy_chart_context, + _xy_chart_what, + map_xy_config, +) +from superset.mcp_service.chart.plugin import BaseChartPlugin +from superset.mcp_service.chart.schemas import ColumnRef, XYChartConfig +from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator +from superset.mcp_service.chart.validation.runtime.cardinality_validator import ( + CardinalityValidator, +) +from superset.mcp_service.chart.validation.runtime.format_validator import ( + FormatTypeValidator, +) +from superset.mcp_service.common.error_schemas import ChartGenerationError + +logger = logging.getLogger(__name__) + + +class XYChartPlugin(BaseChartPlugin): + """Plugin for xy chart type (line, bar, area, scatter).""" + + chart_type = "xy" + display_name = "Line / Bar / Area / Scatter Chart" + native_viz_types = { + "echarts_timeseries_line": "Line Chart", + "echarts_timeseries_bar": "Bar Chart", + "echarts_area": "Area Chart", + "echarts_timeseries_scatter": "Scatter Plot", + } + + def pre_validate( + self, + config: dict[str, Any], + ) -> ChartGenerationError | None: + # x is optional — defaults to dataset's main_dttm_col in map_xy_config + if "y" not in config: + return ChartGenerationError( + error_type="missing_xy_fields", + message="XY chart missing required field: 'y' (Y-axis metrics)", + details=( + "XY charts require Y-axis (metrics) specifications. " + "X-axis is optional and defaults to the dataset's primary " + "datetime column when omitted." + ), + suggestions=[ + "Add 'y' field: [{'name': 'metric_column', 'aggregate': 'SUM'}]", + "Example: {'chart_type': 'xy', 'x': {'name': 'date'}, " + "'y': [{'name': 'sales', 'aggregate': 'SUM'}]}", + ], + error_code="MISSING_XY_FIELDS", + ) + + if not isinstance(config.get("y", []), list): + return ChartGenerationError( + error_type="invalid_y_format", + message="Y-axis must be a list of metrics", + details="The 'y' field must be an array of metric specifications", + suggestions=[ + "Wrap Y-axis metric in array: 'y': [{'name': 'column', " + "'aggregate': 'SUM'}]", + "Multiple metrics supported: 'y': [metric1, metric2, ...]", + ], + error_code="INVALID_Y_FORMAT", + ) + + return None + + def extract_column_refs(self, config: Any) -> list[ColumnRef]: + if not isinstance(config, XYChartConfig): + return [] + refs: list[ColumnRef] = [] + if config.x is not None: + refs.append(config.x) + refs.extend(config.y) + if config.group_by: + refs.extend(config.group_by) + if config.filters: + for f in config.filters: + refs.append(ColumnRef(name=f.column)) + return refs + + def to_form_data( + self, config: Any, dataset_id: int | str | None = None + ) -> dict[str, Any]: + return map_xy_config(config, dataset_id=dataset_id) + + def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any: + config_dict = config.model_dump() + get_canonical = DatasetValidator._get_canonical_column_name + + if config_dict.get("x"): + config_dict["x"]["name"] = get_canonical( + config_dict["x"]["name"], dataset_context + ) + for y_col in config_dict.get("y") or []: + y_col["name"] = get_canonical(y_col["name"], dataset_context) + for gb_col in config_dict.get("group_by") or []: + gb_col["name"] = get_canonical(gb_col["name"], dataset_context) + + DatasetValidator._normalize_filters(config_dict, dataset_context) + return XYChartConfig.model_validate(config_dict) + + def generate_name(self, config: Any, dataset_name: str | None = None) -> str: + what = _xy_chart_what(config) + context = _xy_chart_context(config) + return self._with_context(what, context) + + def resolve_viz_type(self, config: Any) -> str: + kind = getattr(config, "kind", "line") + return { + "line": "echarts_timeseries_line", + "bar": "echarts_timeseries_bar", + "area": "echarts_area", + "scatter": "echarts_timeseries_scatter", + }.get(kind, "echarts_timeseries_line") + + def get_runtime_warnings(self, config: Any, dataset_id: int | str) -> list[str]: + """Return format-compatibility and cardinality warnings for XY charts.""" + if not isinstance(config, XYChartConfig): + return [] + + warnings: list[str] = [] + + try: + _valid, format_warnings = FormatTypeValidator.validate_format_compatibility( + config + ) + if format_warnings: + warnings.extend(format_warnings) + except Exception as exc: + logger.warning("XY format validation failed: %s", exc) + + try: + chart_kind = config.kind + group_by_col = config.group_by[0].name if config.group_by else None + if config.x is not None: + _ok, card_info = CardinalityValidator.check_cardinality( + dataset_id=dataset_id, + x_column=config.x.name, + chart_type=chart_kind, + group_by_column=group_by_col, + ) + if not _ok and card_info: + warnings.extend(card_info.get("warnings", [])) + warnings.extend(card_info.get("suggestions", [])) + except Exception as exc: + logger.warning("XY cardinality validation failed: %s", exc) + + return warnings + + def schema_error_hint(self) -> ChartGenerationError | None: + return ChartGenerationError( + error_type="xy_validation_error", + message="XY chart configuration validation failed", + details=( + "The XY chart configuration is missing required " + "fields or has invalid structure" + ), + suggestions=[ + "Note: 'x' is optional and defaults to the dataset's primary " + "datetime column", + "Ensure 'y' is an array: [{'name': 'metric', 'aggregate': 'SUM'}]", + "Check that all column names are strings", + "Verify aggregate functions are valid: SUM, COUNT, AVG, MIN, MAX", + ], + error_code="XY_VALIDATION_ERROR", + ) diff --git a/superset/mcp_service/chart/registry.py b/superset/mcp_service/chart/registry.py new file mode 100755 index 000000000000..920cfcc5592f --- /dev/null +++ b/superset/mcp_service/chart/registry.py @@ -0,0 +1,141 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +ChartTypeRegistry — central registry mapping chart_type strings to plugins. + +Replaces the four previously-scattered dispatch locations: + - schema_validator.py: chart_type_validators dict + - dataset_validator.py: isinstance branches in _extract_column_references() + - chart_utils.py: if/elif chain in map_config_to_form_data() + - dataset_validator.py: isinstance branches in normalize_column_names() + +Usage:: + + from superset.mcp_service.chart.registry import get_registry + + plugin = get_registry().get("xy") + if plugin is None: + raise ValueError("Unknown chart type: xy") + form_data = plugin.to_form_data(config, dataset_id) +""" + +from __future__ import annotations + +import logging +import threading +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from superset.mcp_service.chart.plugin import ChartTypePlugin + +logger = logging.getLogger(__name__) + +_REGISTRY: dict[str, "ChartTypePlugin"] = {} +_plugins_loaded = False +_plugins_lock = threading.Lock() + + +def _ensure_plugins_loaded() -> None: + """Lazily import the plugins package to populate _REGISTRY. + + Called before every registry lookup so the registry is always populated, + even when callers (tests, chart_utils, validators) import this module + directly without first importing app.py. + """ + global _plugins_loaded + if _plugins_loaded: + return + with _plugins_lock: + if not _plugins_loaded: + try: + import superset.mcp_service.chart.plugins # noqa: F401 + + _plugins_loaded = True + except Exception: + logger.exception("Failed to load built-in chart type plugins") + + +def register(plugin: "ChartTypePlugin") -> None: + """Register a chart type plugin in the global registry.""" + if not plugin.chart_type: + raise ValueError(f"{type(plugin).__name__} must define a non-empty chart_type") + if plugin.chart_type in _REGISTRY: + logger.warning( + "Overwriting existing plugin for chart_type=%r", plugin.chart_type + ) + _REGISTRY[plugin.chart_type] = plugin + logger.debug("Registered chart plugin: %r", plugin.chart_type) + + +def get(chart_type: str) -> "ChartTypePlugin | None": + """Return the plugin for a given chart_type, or None if not registered.""" + _ensure_plugins_loaded() + return _REGISTRY.get(chart_type) + + +def all_types() -> list[str]: + """Return all registered chart type strings in insertion order.""" + _ensure_plugins_loaded() + return list(_REGISTRY.keys()) + + +def is_registered(chart_type: str) -> bool: + """Return True if chart_type has a registered plugin.""" + _ensure_plugins_loaded() + return chart_type in _REGISTRY + + +def display_name_for_viz_type(viz_type: str) -> str | None: + """Return the user-facing display name for a Superset-internal viz_type. + + Searches every registered plugin's ``native_viz_types`` mapping. + Returns None if no plugin recognises the viz_type. + + Example:: + + display_name_for_viz_type("echarts_timeseries_line") # "Line Chart" + display_name_for_viz_type("pivot_table_v2") # "Pivot Table" + display_name_for_viz_type("unknown_type") # None + """ + _ensure_plugins_loaded() + for plugin in _REGISTRY.values(): + name = plugin.native_viz_types.get(viz_type) + if name is not None: + return name + return None + + +def get_registry() -> "_RegistryProxy": + """Return a proxy object for registry access (convenience wrapper).""" + return _RegistryProxy() + + +class _RegistryProxy: + """Thin proxy exposing registry functions as instance methods.""" + + def get(self, chart_type: str) -> "ChartTypePlugin | None": + return get(chart_type) + + def all_types(self) -> list[str]: + return all_types() + + def is_registered(self, chart_type: str) -> bool: + return is_registered(chart_type) + + def display_name_for_viz_type(self, viz_type: str) -> str | None: + return display_name_for_viz_type(viz_type) diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py old mode 100644 new mode 100755 index 1fdb4f43ab6d..05aff20a8b65 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -101,7 +101,14 @@ class ChartInfo(BaseModel): id: int | None = Field(None, description="Chart ID") slice_name: str | None = Field(None, description="Chart name") - viz_type: str | None = Field(None, description="Visualization type") + viz_type: str | None = Field(None, description="Visualization type (internal ID)") + chart_type_display_name: str | None = Field( + None, + description=( + "User-friendly chart type name (e.g. 'Line Chart', 'Pivot Table'). " + "Use this field when referring to chart types — never expose viz_type." + ), + ) datasource_name: str | None = Field(None, description="Datasource name") datasource_type: str | None = Field(None, description="Datasource type") url: str | None = Field(None, description="Chart explore page URL") @@ -488,11 +495,20 @@ def serialize_chart_object(chart: ChartLike | None) -> ChartInfo | None: # Extract structured filter information filters_info = extract_filters_from_form_data(chart_form_data) + _viz_type = getattr(chart, "viz_type", None) + try: + from superset.mcp_service.chart.registry import display_name_for_viz_type + + _display_name = display_name_for_viz_type(_viz_type) if _viz_type else None + except Exception: + _display_name = None + return sanitize_chart_info_for_llm_context( ChartInfo( id=chart_id, slice_name=getattr(chart, "slice_name", None), - viz_type=getattr(chart, "viz_type", None), + viz_type=_viz_type, + chart_type_display_name=_display_name, datasource_name=getattr(chart, "datasource_name", None), datasource_type=getattr(chart, "datasource_type", None), url=chart_url, @@ -669,7 +685,6 @@ class ColumnRef(BaseModel): ..., min_length=1, max_length=255, - pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$", validation_alias=AliasChoices("name", "column_name"), ) label: str | None = Field(None, max_length=500) @@ -743,7 +758,6 @@ class FilterConfig(BaseModel): ..., min_length=1, max_length=255, - pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$", validation_alias=AliasChoices("column", "col"), ) op: Literal[ @@ -1082,7 +1096,6 @@ class BigNumberChartConfig(UnknownFieldCheckMixin): ), min_length=1, max_length=255, - pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$", ) time_grain: TimeGrain | None = Field( None, diff --git a/superset/mcp_service/chart/tool/generate_chart.py b/superset/mcp_service/chart/tool/generate_chart.py index 8ad907c1abe6..b7b171fcf9b9 100644 --- a/superset/mcp_service/chart/tool/generate_chart.py +++ b/superset/mcp_service/chart/tool/generate_chart.py @@ -100,18 +100,34 @@ async def generate_chart( # noqa: C901 - Set save_chart=True to permanently save the chart - LLM clients MUST display returned chart URL to users - Use numeric dataset ID or UUID (NOT schema.table_name format) - - MUST include chart_type in config (either 'xy' or 'table') + - MUST include chart_type in config (one of: 'xy', 'table', 'pie', + 'pivot_table', 'mixed_timeseries', 'handlebars', 'big_number') IMPORTANT: The 'chart_type' field in the config is a DISCRIMINATOR that determines which chart configuration schema to use. It MUST be included and MUST match the other fields in your configuration: - Use chart_type='xy' for charts with x and y axes (line, bar, area, scatter) - Required fields: x, y + Required fields: y (x is optional — defaults to dataset's primary datetime column) - Use chart_type='table' for tabular visualizations Required fields: columns + - Use chart_type='pie' for pie/donut charts + Required fields: dimension, metric + + - Use chart_type='pivot_table' for pivot table visualizations + Required fields: rows, metrics + + - Use chart_type='mixed_timeseries' for dual-axis time-series charts + Required fields: x, y, y_secondary + + - Use chart_type='handlebars' for custom template-based visualizations + Required fields: handlebars_template + + - Use chart_type='big_number' for single KPI metric displays + Required fields: metric + Example usage for XY chart: ```json { diff --git a/superset/mcp_service/chart/tool/update_chart.py b/superset/mcp_service/chart/tool/update_chart.py old mode 100644 new mode 100755 index 3e3057bdd2ee..5c56a39a7f94 --- a/superset/mcp_service/chart/tool/update_chart.py +++ b/superset/mcp_service/chart/tool/update_chart.py @@ -195,6 +195,29 @@ def _validate_update_against_dataset( } ) + # Column existence + fuzzy-match validation + # (mirrors generate_chart pipeline layer 2) + from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator + + is_col_valid, col_error = DatasetValidator.validate_against_dataset( + parsed_config, dataset.id + ) + if not is_col_valid and col_error is not None: + logger.warning( + "update_chart column validation failed for chart %s: %s", + getattr(chart, "id", None), + col_error, + ) + return GenerateChartResponse.model_validate( + { + "chart": None, + "error": col_error.model_dump(), + "success": False, + "schema_version": "2.0", + "api_version": "v1", + } + ) + compile_result = validate_and_compile( parsed_config, form_data, dataset, run_compile_check=True ) @@ -388,6 +411,24 @@ async def update_chart( # noqa: C901 # config is already a typed ChartConfig | None (validated by Pydantic) parsed_config = request.config + # Normalize column case to match dataset canonical names + # (mirrors generate_chart pipeline layer 4) + chart_datasource_id = getattr(chart, "datasource_id", None) + if parsed_config is not None and chart_datasource_id is not None: + from superset.mcp_service.chart.validation.dataset_validator import ( + DatasetValidator, + NORMALIZATION_EXCEPTIONS, + ) + + try: + parsed_config = DatasetValidator.normalize_column_names( + parsed_config, chart.datasource_id + ) + except NORMALIZATION_EXCEPTIONS as e: + logger.warning( + "Column normalization failed for chart %s: %s", chart.id, e + ) + if not request.generate_preview: from superset.commands.chart.update import UpdateChartCommand diff --git a/superset/mcp_service/chart/validation/dataset_validator.py b/superset/mcp_service/chart/validation/dataset_validator.py index 5602b7af1165..a27d18629b2c 100644 --- a/superset/mcp_service/chart/validation/dataset_validator.py +++ b/superset/mcp_service/chart/validation/dataset_validator.py @@ -22,17 +22,11 @@ import difflib import logging -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, TypeVar from superset.mcp_service.chart.schemas import ( - BigNumberChartConfig, + ChartConfig, ColumnRef, - HandlebarsChartConfig, - MixedTimeseriesChartConfig, - PieChartConfig, - PivotTableChartConfig, - TableChartConfig, - XYChartConfig, ) from superset.mcp_service.common.error_schemas import ( ChartGenerationError, @@ -40,6 +34,8 @@ DatasetContext, ) +_C = TypeVar("_C", bound=ChartConfig) + logger = logging.getLogger(__name__) # Exceptions that can occur during column name normalization. @@ -58,7 +54,7 @@ class DatasetValidator: @staticmethod def validate_against_dataset( - config: Any, + config: ChartConfig, dataset_id: int | str, dataset_context: DatasetContext | None = None, ) -> Tuple[bool, ChartGenerationError | None]: @@ -260,59 +256,31 @@ def _get_dataset_context(dataset_id: int | str) -> DatasetContext | None: return None @staticmethod - def _extract_column_references(config: Any) -> List[ColumnRef]: # noqa: C901 - """Extract all column references from a chart configuration. - - Covers every supported ``ChartConfig`` variant so fast-path tools - (``generate_explore_link``, ``update_chart_preview``) that only run - Tier-1 validation still catch bad column refs in pie / pivot table / - mixed timeseries / handlebars / big number charts — not just XY and - table. + def _extract_column_references( + config: ChartConfig, + ) -> List[ColumnRef]: + """Extract all column references from configuration via the plugin registry. + + Previously only handled TableChartConfig and XYChartConfig, causing + 5 of 7 chart types to silently skip column validation. Now delegates + to the plugin for each chart type so all types are covered. """ - refs: List[ColumnRef] = [] - - if isinstance(config, TableChartConfig): - refs.extend(config.columns) - elif isinstance(config, XYChartConfig): - if config.x is not None: - refs.append(config.x) - refs.extend(config.y) - if config.group_by: - refs.extend(config.group_by) - elif isinstance(config, PieChartConfig): - refs.append(config.dimension) - refs.append(config.metric) - elif isinstance(config, PivotTableChartConfig): - refs.extend(config.rows) - if config.columns: - refs.extend(config.columns) - refs.extend(config.metrics) - elif isinstance(config, MixedTimeseriesChartConfig): - refs.append(config.x) - refs.extend(config.y) - if config.group_by: - refs.extend(config.group_by) - refs.extend(config.y_secondary) - if config.group_by_secondary: - refs.extend(config.group_by_secondary) - elif isinstance(config, HandlebarsChartConfig): - if config.columns: - refs.extend(config.columns) - if config.groupby: - refs.extend(config.groupby) - if config.metrics: - refs.extend(config.metrics) - elif isinstance(config, BigNumberChartConfig): - refs.append(config.metric) - if config.temporal_column: - refs.append(ColumnRef(name=config.temporal_column)) - - # Filter columns (shared by every config type that defines ``filters``). - if filters := getattr(config, "filters", None): - for filter_config in filters: - refs.append(ColumnRef(name=filter_config.column)) - - return refs + # Local import: plugins call DatasetValidator helpers from + # normalize_column_refs(). + # A top-level import of registry in dataset_validator would make loading this + # module implicitly trigger plugin registration, creating a circular dependency. + from superset.mcp_service.chart.registry import get_registry + + chart_type = getattr(config, "chart_type", None) + if chart_type is None: + return [] + + plugin = get_registry().get(chart_type) + if plugin is None: + logger.warning("No plugin registered for chart_type=%r", chart_type) + return [] + + return plugin.extract_column_refs(config) @staticmethod def _column_exists(column_name: str, dataset_context: DatasetContext) -> bool: @@ -365,42 +333,6 @@ def _get_canonical_column_name( # Return original if not found (validation should catch this case) return column_name - @staticmethod - def _normalize_xy_config( - config_dict: Dict[str, Any], dataset_context: DatasetContext - ) -> None: - """Normalize column names in an XY chart config dict in place.""" - # Normalize x-axis column - if "x" in config_dict and config_dict["x"]: - config_dict["x"]["name"] = DatasetValidator._get_canonical_column_name( - config_dict["x"]["name"], dataset_context - ) - - # Normalize y-axis columns - if "y" in config_dict and config_dict["y"]: - for y_col in config_dict["y"]: - y_col["name"] = DatasetValidator._get_canonical_column_name( - y_col["name"], dataset_context - ) - - # Normalize group_by columns - if "group_by" in config_dict and config_dict["group_by"]: - for gb_col in config_dict["group_by"]: - gb_col["name"] = DatasetValidator._get_canonical_column_name( - gb_col["name"], dataset_context - ) - - @staticmethod - def _normalize_table_config( - config_dict: Dict[str, Any], dataset_context: DatasetContext - ) -> None: - """Normalize column names in a table chart config dict in place.""" - if "columns" in config_dict and config_dict["columns"]: - for col in config_dict["columns"]: - col["name"] = DatasetValidator._get_canonical_column_name( - col["name"], dataset_context - ) - @staticmethod def _normalize_filters( config_dict: Dict[str, Any], dataset_context: DatasetContext @@ -417,10 +349,10 @@ def _normalize_filters( @staticmethod def normalize_column_names( - config: TableChartConfig | XYChartConfig, + config: _C, dataset_id: int | str, dataset_context: DatasetContext | None = None, - ) -> TableChartConfig | XYChartConfig: + ) -> _C: """ Normalize column names in config to match the canonical dataset column names. @@ -429,6 +361,9 @@ def normalize_column_names( (e.g., 'OrderDate'). The frontend performs case-sensitive comparisons, so we need to ensure column names match exactly. + Previously only XYChartConfig and TableChartConfig were normalized; now + all 7 chart types are handled via the plugin registry. + Args: config: Chart configuration with column references dataset_id: Dataset ID to get canonical column names from @@ -443,22 +378,24 @@ def normalize_column_names( if not dataset_context: return config - # Create a mutable copy of the config - config_dict = config.model_dump() + # Local import: plugins call DatasetValidator helpers from + # normalize_column_refs(). + # A top-level import of registry in dataset_validator would make loading this + # module implicitly trigger plugin registration, creating a circular dependency. + from superset.mcp_service.chart.registry import get_registry - # Normalize based on config type - if isinstance(config, XYChartConfig): - DatasetValidator._normalize_xy_config(config_dict, dataset_context) - elif isinstance(config, TableChartConfig): - DatasetValidator._normalize_table_config(config_dict, dataset_context) + chart_type = getattr(config, "chart_type", None) + if chart_type is None: + return config - # Normalize filter columns (common to both config types) - DatasetValidator._normalize_filters(config_dict, dataset_context) + plugin = get_registry().get(chart_type) + if plugin is None: + logger.warning( + "No plugin for chart_type=%r; skipping column normalization", chart_type + ) + return config - # Reconstruct the config with normalized names - if isinstance(config, XYChartConfig): - return XYChartConfig.model_validate(config_dict) - return TableChartConfig.model_validate(config_dict) + return plugin.normalize_column_refs(config, dataset_context) @staticmethod def _get_column_suggestions( diff --git a/superset/mcp_service/chart/validation/runtime/__init__.py b/superset/mcp_service/chart/validation/runtime/__init__.py index 5e1c89d0a687..dce732ba9926 100644 --- a/superset/mcp_service/chart/validation/runtime/__init__.py +++ b/superset/mcp_service/chart/validation/runtime/__init__.py @@ -23,10 +23,7 @@ import logging from typing import Any, Dict, List, Tuple -from superset.mcp_service.chart.schemas import ( - ChartConfig, - XYChartConfig, -) +from superset.mcp_service.chart.schemas import ChartConfig logger = logging.getLogger(__name__) @@ -56,20 +53,10 @@ def validate_runtime_issues( warnings: List[str] = [] suggestions: List[str] = [] - # Only check XY charts for format and cardinality issues - if isinstance(config, XYChartConfig): - # Format-type compatibility validation - format_warnings = RuntimeValidator._validate_format_compatibility(config) - if format_warnings: - warnings.extend(format_warnings) - - # Cardinality validation - cardinality_warnings, cardinality_suggestions = ( - RuntimeValidator._validate_cardinality(config, dataset_id) - ) - if cardinality_warnings: - warnings.extend(cardinality_warnings) - suggestions.extend(cardinality_suggestions) + # Per-plugin runtime warnings (format, cardinality, etc.) + plugin_warnings = RuntimeValidator._validate_plugin_runtime(config, dataset_id) + if plugin_warnings: + warnings.extend(plugin_warnings) # Chart type appropriateness validation (for all chart types) type_warnings, type_suggestions = RuntimeValidator._validate_chart_type( @@ -98,61 +85,28 @@ def validate_runtime_issues( return True, None @staticmethod - def _validate_format_compatibility(config: XYChartConfig) -> List[str]: - """Validate format-type compatibility.""" - warnings: List[str] = [] - - try: - # Import here to avoid circular imports - from .format_validator import FormatTypeValidator - - is_valid, format_warnings = ( - FormatTypeValidator.validate_format_compatibility(config) - ) - if format_warnings: - warnings.extend(format_warnings) - except ImportError: - logger.warning("Format validator not available") - except Exception as e: - logger.warning("Format validation failed: %s", e) - - return warnings - - @staticmethod - def _validate_cardinality( - config: XYChartConfig, dataset_id: int | str - ) -> Tuple[List[str], List[str]]: - """Validate cardinality issues.""" - warnings: List[str] = [] - suggestions: List[str] = [] + def _validate_plugin_runtime( + config: ChartConfig, dataset_id: int | str + ) -> List[str]: + """Delegate per-chart-type runtime warnings to the plugin registry. + Each plugin's get_runtime_warnings() method returns chart-type-specific + warnings (e.g. format/cardinality for XY). The registry dispatch removes + the previous isinstance(config, XYChartConfig) hardcoding. + """ try: - # Import here to avoid circular imports - from .cardinality_validator import CardinalityValidator - - # Determine chart type for cardinality thresholds - chart_type = config.kind if hasattr(config, "kind") else "default" - - # Check X-axis cardinality - if config.x is None: - return warnings, suggestions - is_ok, cardinality_info = CardinalityValidator.check_cardinality( - dataset_id=dataset_id, - x_column=config.x.name, - chart_type=chart_type, - group_by_column=config.group_by[0].name if config.group_by else None, - ) - - if not is_ok and cardinality_info: - warnings.extend(cardinality_info.get("warnings", [])) - suggestions.extend(cardinality_info.get("suggestions", [])) - - except ImportError: - logger.warning("Cardinality validator not available") - except Exception as e: - logger.warning("Cardinality validation failed: %s", e) - - return warnings, suggestions + from superset.mcp_service.chart.registry import get_registry + + chart_type = getattr(config, "chart_type", None) + if chart_type is None: + return [] + plugin = get_registry().get(chart_type) + if plugin is None: + return [] + return plugin.get_runtime_warnings(config, dataset_id) + except Exception as exc: + logger.warning("Plugin runtime validation failed: %s", exc) + return [] @staticmethod def _validate_chart_type( diff --git a/superset/mcp_service/chart/validation/schema_validator.py b/superset/mcp_service/chart/validation/schema_validator.py old mode 100644 new mode 100755 index 7cae450ff599..11f0f6ada8a4 --- a/superset/mcp_service/chart/validation/schema_validator.py +++ b/superset/mcp_service/chart/validation/schema_validator.py @@ -147,19 +147,13 @@ def _pre_validate_chart_type( chart_type: str, config: Dict[str, Any], ) -> Tuple[bool, ChartGenerationError | None]: - """Validate chart type and dispatch to type-specific pre-validation.""" - chart_type_validators = { - "xy": SchemaValidator._pre_validate_xy_config, - "table": SchemaValidator._pre_validate_table_config, - "pie": SchemaValidator._pre_validate_pie_config, - "pivot_table": SchemaValidator._pre_validate_pivot_table_config, - "mixed_timeseries": SchemaValidator._pre_validate_mixed_timeseries_config, - "handlebars": SchemaValidator._pre_validate_handlebars_config, - "big_number": SchemaValidator._pre_validate_big_number_config, - } + """Validate chart type and dispatch to plugin pre-validation.""" + from superset.mcp_service.chart.registry import get_registry - if not isinstance(chart_type, str) or chart_type not in chart_type_validators: - valid_types = ", ".join(chart_type_validators.keys()) + registry = get_registry() + + if not isinstance(chart_type, str) or not registry.is_registered(chart_type): + valid_types = ", ".join(registry.all_types()) return False, ChartGenerationError( error_type="invalid_chart_type", message=f"Invalid chart_type: '{chart_type}'", @@ -178,351 +172,18 @@ def _pre_validate_chart_type( error_code="INVALID_CHART_TYPE", ) - return chart_type_validators[chart_type](config) - - @staticmethod - def _pre_validate_xy_config( - config: Dict[str, Any], - ) -> Tuple[bool, ChartGenerationError | None]: - """Pre-validate XY chart configuration.""" - # x is optional — defaults to dataset's main_dttm_col in map_xy_config - if "y" not in config: - return False, ChartGenerationError( - error_type="missing_xy_fields", - message="XY chart missing required field: 'y' (Y-axis metrics)", - details="XY charts require Y-axis (metrics) specifications. " - "X-axis is optional and defaults to the dataset's primary " - "datetime column when omitted.", - suggestions=[ - "Add 'y' field: [{'name': 'metric_column', 'aggregate': 'SUM'}] " - "for Y-axis", - "Example: {'chart_type': 'xy', 'x': {'name': 'date'}, " - "'y': [{'name': 'sales', 'aggregate': 'SUM'}]}", - ], - error_code="MISSING_XY_FIELDS", - ) - - # Validate Y is a list - if not isinstance(config.get("y", []), list): - return False, ChartGenerationError( - error_type="invalid_y_format", - message="Y-axis must be a list of metrics", - details="The 'y' field must be an array of metric specifications", - suggestions=[ - "Wrap Y-axis metric in array: 'y': [{'name': 'column', " - "'aggregate': 'SUM'}]", - "Multiple metrics supported: 'y': [metric1, metric2, ...]", - ], - error_code="INVALID_Y_FORMAT", - ) - - return True, None - - @staticmethod - def _pre_validate_table_config( - config: Dict[str, Any], - ) -> Tuple[bool, ChartGenerationError | None]: - """Pre-validate table chart configuration.""" - if "columns" not in config: - return False, ChartGenerationError( - error_type="missing_columns", - message="Table chart missing required field: columns", - details="Table charts require a 'columns' array to specify which " - "columns to display", - suggestions=[ - "Add 'columns' field with array of column specifications", - "Example: 'columns': [{'name': 'product'}, {'name': 'sales', " - "'aggregate': 'SUM'}]", - "Each column can have optional 'aggregate' for metrics", - ], - error_code="MISSING_COLUMNS", - ) - - if not isinstance(config.get("columns", []), list): - return False, ChartGenerationError( - error_type="invalid_columns_format", - message="Columns must be a list", - details="The 'columns' field must be an array of column specifications", - suggestions=[ - "Ensure columns is an array: 'columns': [...]", - "Each column should be an object with 'name' field", - ], - error_code="INVALID_COLUMNS_FORMAT", - ) - - return True, None - - @staticmethod - def _pre_validate_pie_config( - config: Dict[str, Any], - ) -> Tuple[bool, ChartGenerationError | None]: - """Pre-validate pie chart configuration.""" - missing_fields = [] - - if "dimension" not in config: - missing_fields.append("'dimension' (category column for slices)") - if "metric" not in config: - missing_fields.append("'metric' (value metric for slice sizes)") - - if missing_fields: - return False, ChartGenerationError( - error_type="missing_pie_fields", - message=f"Pie chart missing required " - f"fields: {', '.join(missing_fields)}", - details="Pie charts require a dimension (categories) and a metric " - "(values)", - suggestions=[ - "Add 'dimension' field: {'name': 'category_column'}", - "Add 'metric' field: {'name': 'value_column', 'aggregate': 'SUM'}", - "Example: {'chart_type': 'pie', 'dimension': {'name': " - "'product'}, 'metric': {'name': 'revenue', 'aggregate': 'SUM'}}", - ], - error_code="MISSING_PIE_FIELDS", - ) - - return True, None - - @staticmethod - def _pre_validate_handlebars_config( - config: Dict[str, Any], - ) -> Tuple[bool, ChartGenerationError | None]: - """Pre-validate handlebars chart configuration.""" - if "handlebars_template" not in config: - return False, ChartGenerationError( - error_type="missing_handlebars_template", - message="Handlebars chart missing required field: handlebars_template", - details="Handlebars charts require a 'handlebars_template' string " - "containing Handlebars HTML template markup", - suggestions=[ - "Add 'handlebars_template' with a Handlebars HTML template", - "Data is available as {{data}} array in the template", - "Example: ''", - ], - error_code="MISSING_HANDLEBARS_TEMPLATE", - ) - - template = config.get("handlebars_template") - if not isinstance(template, str) or not template.strip(): - return False, ChartGenerationError( - error_type="invalid_handlebars_template", - message="Handlebars template must be a non-empty string", - details="The 'handlebars_template' field must be a non-empty string " - "containing valid Handlebars HTML template markup", - suggestions=[ - "Ensure handlebars_template is a non-empty string", - "Example: ''", - ], - error_code="INVALID_HANDLEBARS_TEMPLATE", - ) - - query_mode = config.get("query_mode", "aggregate") - if query_mode not in ("aggregate", "raw"): - return False, ChartGenerationError( - error_type="invalid_query_mode", - message="Invalid query_mode for handlebars chart", - details="query_mode must be either 'aggregate' or 'raw'", - suggestions=[ - "Use 'aggregate' for aggregated data (default)", - "Use 'raw' for individual rows", - ], - error_code="INVALID_QUERY_MODE", - ) - - if query_mode == "raw" and not config.get("columns"): - return False, ChartGenerationError( - error_type="missing_raw_columns", - message="Handlebars chart in 'raw' mode requires 'columns'", - details="When query_mode is 'raw', you must specify which columns " - "to include in the query results", - suggestions=[ - "Add 'columns': [{'name': 'column_name'}] for raw mode", - "Or use query_mode='aggregate' with 'metrics' " - "and optional 'groupby'", - ], - error_code="MISSING_RAW_COLUMNS", - ) - - if query_mode == "aggregate" and not config.get("metrics"): - return False, ChartGenerationError( - error_type="missing_aggregate_metrics", - message="Handlebars chart in 'aggregate' mode requires 'metrics'", - details="When query_mode is 'aggregate' (default), you must specify " - "at least one metric with an aggregate function", - suggestions=[ - "Add 'metrics': [{'name': 'column', 'aggregate': 'SUM'}]", - "Or use query_mode='raw' with 'columns' for individual rows", - ], - error_code="MISSING_AGGREGATE_METRICS", - ) - - return True, None - - @staticmethod - def _pre_validate_big_number_config( - config: Dict[str, Any], - ) -> Tuple[bool, ChartGenerationError | None]: - """Pre-validate big number chart configuration.""" - if "metric" not in config: - return False, ChartGenerationError( - error_type="missing_metric", - message="Big Number chart missing required field: metric", - details="Big Number charts require a 'metric' field " - "specifying the value to display", - suggestions=[ - "Add 'metric' with name and aggregate: " - "{'name': 'revenue', 'aggregate': 'SUM'}", - "The aggregate function is required (SUM, COUNT, AVG, MIN, MAX)", - "Example: {'chart_type': 'big_number', " - "'metric': {'name': 'sales', 'aggregate': 'SUM'}}", - ], - error_code="MISSING_BIG_NUMBER_METRIC", - ) - - metric = config.get("metric", {}) - if not isinstance(metric, dict): + plugin = registry.get(chart_type) + if plugin is None: return False, ChartGenerationError( - error_type="invalid_metric_type", - message="Big Number metric must be a dict with 'name' and 'aggregate'", - details="The 'metric' field must be an object, " - f"got {type(metric).__name__}", - suggestions=[ - "Use a dict: {'name': 'col', 'aggregate': 'SUM'}", - "Valid aggregates: SUM, COUNT, AVG, MIN, MAX", - ], - error_code="INVALID_BIG_NUMBER_METRIC_TYPE", - ) - if not metric.get("aggregate") and not metric.get("saved_metric"): - return False, ChartGenerationError( - error_type="missing_metric_aggregate", - message="Big Number metric must include an aggregate function " - "or reference a saved metric", - details="The metric must have an 'aggregate' field " - "or 'saved_metric': true", - suggestions=[ - "Add 'aggregate' to your metric: " - "{'name': 'col', 'aggregate': 'SUM'}", - "Or use a saved metric: " - "{'name': 'total_sales', 'saved_metric': true}", - "Valid aggregates: SUM, COUNT, AVG, MIN, MAX", - ], - error_code="MISSING_BIG_NUMBER_AGGREGATE", - ) - - show_trendline = config.get("show_trendline", False) - temporal_column = config.get("temporal_column") - if show_trendline and not temporal_column: - return False, ChartGenerationError( - error_type="missing_temporal_column", - message="Trendline requires a temporal column", - details="When 'show_trendline' is True, a " - "'temporal_column' must be specified", - suggestions=[ - "Add 'temporal_column': 'date_column_name'", - "Or set 'show_trendline': false for number only", - "Use get_dataset_info to find temporal columns", - ], - error_code="MISSING_TEMPORAL_COLUMN", - ) - - return True, None - - @staticmethod - def _pre_validate_pivot_table_config( - config: Dict[str, Any], - ) -> Tuple[bool, ChartGenerationError | None]: - """Pre-validate pivot table configuration.""" - missing_fields = [] - - if "rows" not in config: - missing_fields.append("'rows' (row grouping columns)") - if "metrics" not in config: - missing_fields.append("'metrics' (aggregation metrics)") - - if missing_fields: - return False, ChartGenerationError( - error_type="missing_pivot_fields", - message=f"Pivot table missing required " - f"fields: {', '.join(missing_fields)}", - details="Pivot tables require row groupings and metrics", - suggestions=[ - "Add 'rows' field: [{'name': 'category'}]", - "Add 'metrics' field: [{'name': 'sales', 'aggregate': 'SUM'}]", - "Optional 'columns' for cross-tabulation: [{'name': 'region'}]", - ], - error_code="MISSING_PIVOT_FIELDS", - ) - - if not isinstance(config.get("rows", []), list): - return False, ChartGenerationError( - error_type="invalid_rows_format", - message="Rows must be a list of columns", - details="The 'rows' field must be an array of column specifications", - suggestions=[ - "Wrap row columns in array: 'rows': [{'name': 'category'}]", - ], - error_code="INVALID_ROWS_FORMAT", - ) - - if not isinstance(config.get("metrics", []), list): - return False, ChartGenerationError( - error_type="invalid_metrics_format", - message="Metrics must be a list", - details="The 'metrics' field must be an array of metric specifications", - suggestions=[ - "Wrap metrics in array: 'metrics': [{'name': 'sales', " - "'aggregate': 'SUM'}]", - ], - error_code="INVALID_METRICS_FORMAT", - ) - - return True, None - - @staticmethod - def _pre_validate_mixed_timeseries_config( - config: Dict[str, Any], - ) -> Tuple[bool, ChartGenerationError | None]: - """Pre-validate mixed timeseries configuration.""" - missing_fields = [] - - if "x" not in config: - missing_fields.append("'x' (X-axis temporal column)") - if "y" not in config: - missing_fields.append("'y' (primary Y-axis metrics)") - if "y_secondary" not in config: - missing_fields.append("'y_secondary' (secondary Y-axis metrics)") - - if missing_fields: - return False, ChartGenerationError( - error_type="missing_mixed_timeseries_fields", - message=f"Mixed timeseries chart missing required " - f"fields: {', '.join(missing_fields)}", - details="Mixed timeseries charts require an x-axis, primary metrics, " - "and secondary metrics", - suggestions=[ - "Add 'x' field: {'name': 'date_column'}", - "Add 'y' field: [{'name': 'revenue', 'aggregate': 'SUM'}]", - "Add 'y_secondary' field: [{'name': 'orders', " - "'aggregate': 'COUNT'}]", - "Optional: 'primary_kind' and 'secondary_kind' for chart types", - ], - error_code="MISSING_MIXED_TIMESERIES_FIELDS", + error_type="invalid_chart_type", + message=f"Chart type '{chart_type}' has no registered plugin", + details="Internal error: chart type is listed but has no plugin", + suggestions=["Use a supported chart_type"], + error_code="INVALID_CHART_TYPE", ) - for field_name in ["y", "y_secondary"]: - if not isinstance(config.get(field_name, []), list): - return False, ChartGenerationError( - error_type=f"invalid_{field_name}_format", - message=f"'{field_name}' must be a list of metrics", - details=f"The '{field_name}' field must be an array of metric " - "specifications", - suggestions=[ - f"Wrap in array: '{field_name}': " - "[{'name': 'col', 'aggregate': 'SUM'}]", - ], - error_code=f"INVALID_{field_name.upper()}_FORMAT", - ) - + if (error := plugin.pre_validate(config)) is not None: + return False, error return True, None @staticmethod @@ -537,89 +198,26 @@ def _enhance_validation_error( if err.get("type") == "union_tag_invalid" or "discriminator" in str( err.get("ctx", {}) ): - # This is the generic union error - provide better message - config = request_data.get("config", {}) - chart_type = config.get("chart_type", "unknown") + from superset.mcp_service.chart.registry import get_registry - if chart_type == "xy": - return ChartGenerationError( - error_type="xy_validation_error", - message="XY chart configuration validation failed", - details="The XY chart configuration is missing required " - "fields or has invalid structure", - suggestions=[ - "Ensure 'x' field exists with {'name': 'column_name'}", - "Ensure 'y' field is an array: [{'name': 'metric', " - "'aggregate': 'SUM'}]", - "Check that all column names are strings", - "Verify aggregate functions are valid: SUM, COUNT, AVG, " - "MIN, MAX", - ], - error_code="XY_VALIDATION_ERROR", - ) - elif chart_type == "table": - return ChartGenerationError( - error_type="table_validation_error", - message="Table chart configuration validation failed", - details="The table chart configuration is missing required " - "fields or has invalid structure", - suggestions=[ - "Ensure 'columns' field is an array of column " - "specifications", - "Each column needs {'name': 'column_name'}", - "Optional: add 'aggregate' for metrics", - "Example: 'columns': [{'name': 'product'}, {'name': " - "'sales', 'aggregate': 'SUM'}]", - ], - error_code="TABLE_VALIDATION_ERROR", - ) - elif chart_type == "handlebars": - return ChartGenerationError( - error_type="handlebars_validation_error", - message="Handlebars chart configuration validation failed", - details="The handlebars chart configuration is missing " - "required fields or has invalid structure", - suggestions=[ - "Ensure 'handlebars_template' is a non-empty string", - "For aggregate mode: add 'metrics' with aggregate " - "functions", - "For raw mode: set 'query_mode': 'raw' and add 'columns'", - "Example: {'chart_type': 'handlebars', " - "'handlebars_template': '', " - "'metrics': [{'name': 'sales', 'aggregate': 'SUM'}]}", - ], - error_code="HANDLEBARS_VALIDATION_ERROR", - ) - elif chart_type == "big_number": - return ChartGenerationError( - error_type="big_number_validation_error", - message="Big Number chart configuration validation failed", - details="The Big Number chart configuration is " - "missing required fields or has invalid " - "structure", - suggestions=[ - "Ensure 'metric' field has 'name' and 'aggregate'", - "Example: 'metric': {'name': 'revenue', " - "'aggregate': 'SUM'}", - "For trendline: add 'show_trendline': true " - "and 'temporal_column': 'date_col'", - "Without trendline: just provide the metric", - ], - error_code="BIG_NUMBER_VALIDATION_ERROR", - ) + chart_type = request_data.get("config", {}).get("chart_type", "") + plugin = get_registry().get(chart_type) + if plugin is not None: + hint = plugin.schema_error_hint() + if hint is not None: + return hint # Default enhanced error error_details = [] for err in errors[:3]: # Show first 3 errors loc = " -> ".join(str(location) for location in err.get("loc", [])) msg = err.get("msg", "Validation failed") - error_details.append(f"{loc}: {msg}") + error_details.append(f"{loc}: {msg}" if loc else msg) return ChartGenerationError( error_type="validation_error", message="Chart configuration validation failed", - details="; ".join(error_details), + details="; ".join(error_details) or "Invalid chart configuration structure", suggestions=[ "Check that all required fields are present", "Ensure field types match the schema", diff --git a/tests/unit_tests/mcp_service/chart/test_big_number_chart.py b/tests/unit_tests/mcp_service/chart/test_big_number_chart.py index 59e142333bdb..c832d7793d15 100644 --- a/tests/unit_tests/mcp_service/chart/test_big_number_chart.py +++ b/tests/unit_tests/mcp_service/chart/test_big_number_chart.py @@ -90,7 +90,7 @@ def test_saved_metric_passes_pre_validation(self) -> None: "chart_type": "big_number", "metric": {"name": "total_sales", "saved_metric": True}, } - is_valid, error = SchemaValidator._pre_validate_big_number_config(data) + is_valid, error = SchemaValidator._pre_validate_chart_type("big_number", data) assert is_valid is True assert error is None diff --git a/tests/unit_tests/mcp_service/chart/test_registry.py b/tests/unit_tests/mcp_service/chart/test_registry.py new file mode 100644 index 000000000000..0351b2d2bd33 --- /dev/null +++ b/tests/unit_tests/mcp_service/chart/test_registry.py @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for the chart type plugin registry.""" + +import pytest + +import superset.mcp_service.chart.registry as registry_module +from superset.mcp_service.chart.plugin import BaseChartPlugin +from superset.mcp_service.chart.registry import ( + _RegistryProxy, + all_types, + display_name_for_viz_type, + get, + get_registry, + is_registered, + register, +) + + +@pytest.fixture(autouse=True) +def _isolated_registry(monkeypatch): + """Run each test against a clean registry without touching the real one.""" + monkeypatch.setattr(registry_module, "_REGISTRY", {}) + monkeypatch.setattr(registry_module, "_plugins_loaded", True) + + +class _FakePlugin(BaseChartPlugin): + chart_type = "fake" + display_name = "Fake Chart" + native_viz_types = {"fake_viz": "Fake Viz"} + + +class _AnotherPlugin(BaseChartPlugin): + chart_type = "another" + display_name = "Another Chart" + native_viz_types = {"another_viz": "Another Viz"} + + +def test_register_adds_plugin(): + plugin = _FakePlugin() + register(plugin) + assert get("fake") is plugin + + +def test_get_returns_none_for_unknown(): + assert get("nonexistent") is None + + +def test_all_types_returns_registered_keys(): + register(_FakePlugin()) + register(_AnotherPlugin()) + types = all_types() + assert "fake" in types + assert "another" in types + + +def test_all_types_insertion_order(): + register(_FakePlugin()) + register(_AnotherPlugin()) + types = all_types() + assert types.index("fake") < types.index("another") + + +def test_is_registered_true_for_known(): + register(_FakePlugin()) + assert is_registered("fake") is True + + +def test_is_registered_false_for_unknown(): + assert is_registered("nonexistent") is False + + +def test_register_warns_on_duplicate(caplog): + register(_FakePlugin()) + with caplog.at_level("WARNING"): + register(_FakePlugin()) + assert "Overwriting" in caplog.text + + +def test_register_raises_for_empty_chart_type(): + class _BadPlugin(BaseChartPlugin): + chart_type = "" + + with pytest.raises(ValueError, match="non-empty chart_type"): + register(_BadPlugin()) + + +def test_display_name_for_viz_type_found(): + register(_FakePlugin()) + assert display_name_for_viz_type("fake_viz") == "Fake Viz" + + +def test_display_name_for_viz_type_not_found(): + register(_FakePlugin()) + assert display_name_for_viz_type("unknown_viz") is None + + +def test_display_name_searches_all_plugins(): + register(_FakePlugin()) + register(_AnotherPlugin()) + assert display_name_for_viz_type("another_viz") == "Another Viz" + + +def test_get_registry_returns_proxy(): + assert isinstance(get_registry(), _RegistryProxy) + + +def test_registry_proxy_get(): + plugin = _FakePlugin() + register(plugin) + assert get_registry().get("fake") is plugin + + +def test_registry_proxy_all_types(): + register(_FakePlugin()) + assert "fake" in get_registry().all_types() + + +def test_registry_proxy_is_registered(): + register(_FakePlugin()) + assert get_registry().is_registered("fake") is True + assert get_registry().is_registered("missing") is False + + +def test_registry_proxy_display_name_for_viz_type(): + register(_FakePlugin()) + assert get_registry().display_name_for_viz_type("fake_viz") == "Fake Viz" + assert get_registry().display_name_for_viz_type("unknown") is None diff --git a/tests/unit_tests/mcp_service/chart/tool/test_update_chart.py b/tests/unit_tests/mcp_service/chart/tool/test_update_chart.py old mode 100644 new mode 100755 index c504d8bca598..21e2286fb6fb --- a/tests/unit_tests/mcp_service/chart/tool/test_update_chart.py +++ b/tests/unit_tests/mcp_service/chart/tool/test_update_chart.py @@ -1175,6 +1175,11 @@ def _mock_chart_with_dataset(chart_id: int = 1) -> Mock: ) @patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock) @patch("superset.db.session") + @patch( + "superset.mcp_service.chart.validation.dataset_validator" + ".DatasetValidator.validate_against_dataset", + new=Mock(return_value=(True, None)), + ) @pytest.mark.asyncio async def test_preview_path_validation_failure_skips_cache( self, @@ -1238,6 +1243,11 @@ async def test_preview_path_validation_failure_skips_cache( ) @patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock) @patch("superset.db.session") + @patch( + "superset.mcp_service.chart.validation.dataset_validator" + ".DatasetValidator.validate_against_dataset", + new=Mock(return_value=(True, None)), + ) @pytest.mark.asyncio async def test_persist_path_validation_failure_skips_db_write( self, diff --git a/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py b/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py index dbebe268b4b8..a81f0864f262 100644 --- a/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py +++ b/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py @@ -117,83 +117,6 @@ def test_unknown_column_returns_original( assert result == "unknown_column" -class TestNormalizeXYConfig: - """Test _normalize_xy_config static method.""" - - def test_normalize_x_axis_column( - self, mock_dataset_context: DatasetContext - ) -> None: - """Test that x-axis column name is normalized.""" - config_dict: Dict[str, Any] = { - "chart_type": "xy", - "x": {"name": "orderdate"}, - "y": [{"name": "Sales", "aggregate": "SUM"}], - "kind": "line", - } - - DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context) - - assert config_dict["x"]["name"] == "OrderDate" - - def test_normalize_y_axis_columns( - self, mock_dataset_context: DatasetContext - ) -> None: - """Test that y-axis column names are normalized.""" - config_dict: Dict[str, Any] = { - "chart_type": "xy", - "x": {"name": "OrderDate"}, - "y": [ - {"name": "sales", "aggregate": "SUM"}, - {"name": "QUANTITY_ORDERED", "aggregate": "COUNT"}, - ], - "kind": "bar", - } - - DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context) - - assert config_dict["y"][0]["name"] == "Sales" - assert config_dict["y"][1]["name"] == "quantity_ordered" - - def test_normalize_group_by_column( - self, mock_dataset_context: DatasetContext - ) -> None: - """Test that group_by column name is normalized.""" - config_dict: Dict[str, Any] = { - "chart_type": "xy", - "x": {"name": "OrderDate"}, - "y": [{"name": "Sales", "aggregate": "SUM"}], - "kind": "line", - "group_by": [{"name": "productline"}], - } - - DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context) - - assert config_dict["group_by"][0]["name"] == "ProductLine" - - -class TestNormalizeTableConfig: - """Test _normalize_table_config static method.""" - - def test_normalize_table_columns( - self, mock_dataset_context: DatasetContext - ) -> None: - """Test that table column names are normalized.""" - config_dict: Dict[str, Any] = { - "chart_type": "table", - "columns": [ - {"name": "orderdate"}, - {"name": "PRODUCTLINE"}, - {"name": "sales", "aggregate": "SUM"}, - ], - } - - DatasetValidator._normalize_table_config(config_dict, mock_dataset_context) - - assert config_dict["columns"][0]["name"] == "OrderDate" - assert config_dict["columns"][1]["name"] == "ProductLine" - assert config_dict["columns"][2]["name"] == "Sales" - - class TestNormalizeFilters: """Test _normalize_filters static method.""" diff --git a/tests/unit_tests/mcp_service/chart/validation/test_runtime_validator.py b/tests/unit_tests/mcp_service/chart/validation/test_runtime_validator.py index c49677cb99f2..6aed0b112698 100644 --- a/tests/unit_tests/mcp_service/chart/validation/test_runtime_validator.py +++ b/tests/unit_tests/mcp_service/chart/validation/test_runtime_validator.py @@ -58,12 +58,12 @@ def test_validate_runtime_issues_non_blocking_with_format_warnings(self): x_axis=AxisConfig(format="$,.2f"), # Currency format for date - mismatch ) - # Mock the format validator to return warnings + # Mock the plugin runtime dispatcher to return format warnings with patch( "superset.mcp_service.chart.validation.runtime.RuntimeValidator." - "_validate_format_compatibility" - ) as mock_format: - mock_format.return_value = [ + "_validate_plugin_runtime" + ) as mock_plugin: + mock_plugin.return_value = [ "Currency format '$,.2f' may not display dates correctly" ] @@ -87,15 +87,14 @@ def test_validate_runtime_issues_non_blocking_with_cardinality_warnings(self): kind="bar", ) - # Mock the cardinality validator to return warnings + # Mock the plugin runtime dispatcher to return cardinality warnings with patch( "superset.mcp_service.chart.validation.runtime.RuntimeValidator." - "_validate_cardinality" - ) as mock_cardinality: - mock_cardinality.return_value = ( - ["High cardinality detected: 10000+ unique values"], - ["Consider using aggregation or filtering"], - ) + "_validate_plugin_runtime" + ) as mock_plugin: + mock_plugin.return_value = [ + "High cardinality detected: 10000+ unique values" + ] is_valid, warnings_metadata = RuntimeValidator.validate_runtime_issues( config, 1 @@ -148,26 +147,21 @@ def test_validate_runtime_issues_non_blocking_with_multiple_warnings(self): x_axis=AxisConfig(format="smart_date"), # Wrong format for user_id ) - # Mock all validators to return warnings + # Mock plugin runtime and chart type validators to return warnings with ( patch( "superset.mcp_service.chart.validation.runtime.RuntimeValidator." - "_validate_format_compatibility" - ) as mock_format, - patch( - "superset.mcp_service.chart.validation.runtime.RuntimeValidator." - "_validate_cardinality" - ) as mock_cardinality, + "_validate_plugin_runtime" + ) as mock_plugin, patch( "superset.mcp_service.chart.validation.runtime.RuntimeValidator." "_validate_chart_type" ) as mock_type, ): - mock_format.return_value = ["Format mismatch warning"] - mock_cardinality.return_value = ( - ["High cardinality warning"], - ["Cardinality suggestion"], - ) + mock_plugin.return_value = [ + "Format mismatch warning", + "High cardinality warning", + ] mock_type.return_value = ( ["Chart type warning"], ["Chart type suggestion"], @@ -197,13 +191,13 @@ def test_validate_runtime_issues_logs_warnings(self): with ( patch( "superset.mcp_service.chart.validation.runtime.RuntimeValidator." - "_validate_format_compatibility" - ) as mock_format, + "_validate_plugin_runtime" + ) as mock_plugin, patch( "superset.mcp_service.chart.validation.runtime.logger" ) as mock_logger, ): - mock_format.return_value = ["Test warning message"] + mock_plugin.return_value = ["Test warning message"] is_valid, warnings_metadata = RuntimeValidator.validate_runtime_issues( config, 1 @@ -217,7 +211,7 @@ def test_validate_runtime_issues_logs_warnings(self): assert "warnings" in warnings_metadata def test_validate_table_chart_skips_xy_validations(self): - """Test that table charts skip XY-specific validations.""" + """Test that table charts produce no XY-specific runtime warnings.""" config = TableChartConfig( chart_type="table", columns=[ @@ -226,28 +220,15 @@ def test_validate_table_chart_skips_xy_validations(self): ], ) - # These should not be called for table charts - with ( - patch( - "superset.mcp_service.chart.validation.runtime.RuntimeValidator." - "_validate_format_compatibility" - ) as mock_format, - patch( - "superset.mcp_service.chart.validation.runtime.RuntimeValidator." - "_validate_cardinality" - ) as mock_cardinality, - patch( - "superset.mcp_service.chart.validation.runtime.RuntimeValidator." - "_validate_chart_type" - ) as mock_chart_type, - ): - # Mock chart type validator to return no warnings + # Plugin runtime dispatches to TableChartPlugin which returns no warnings. + # Chart type suggester is also stubbed to return no warnings. + with patch( + "superset.mcp_service.chart.validation.runtime.RuntimeValidator." + "_validate_chart_type" + ) as mock_chart_type: mock_chart_type.return_value = ([], []) is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1) - # Format and cardinality validation should not be called for table charts - mock_format.assert_not_called() - mock_cardinality.assert_not_called() assert is_valid is True assert error is None