diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index 1fdb4f43ab6d..0fb1ca1c1e1b 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -669,7 +669,10 @@ class ColumnRef(BaseModel): ..., min_length=1, max_length=255, - pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$", + # No regex pattern: sanitize_name() already blocks XSS/SQL injection; + # many valid column names (digit-prefixed, locale chars, etc.) would + # be rejected by a strict pattern while posing no security risk. + # Use get_dataset_info to find exact column names. validation_alias=AliasChoices("name", "column_name"), ) label: str | None = Field(None, max_length=500) @@ -743,7 +746,10 @@ class FilterConfig(BaseModel): ..., min_length=1, max_length=255, - pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$", + # No regex pattern: sanitize_column() already blocks XSS/SQL injection; + # many valid column names (digit-prefixed, locale chars, etc.) would + # be rejected by a strict pattern while posing no security risk. + # Use get_dataset_info to find exact column names. validation_alias=AliasChoices("column", "col"), ) op: Literal[ @@ -775,7 +781,9 @@ def sanitize_column(cls, v: str) -> str: """Sanitize filter column name to prevent injection attacks.""" # sanitize_user_input raises ValueError when allow_empty=False (default) # so the return value is guaranteed to be a non-None str - return sanitize_user_input(v, "Filter column", max_length=255) # type: ignore[return-value] + return sanitize_user_input( # type: ignore[return-value] + v, "Filter column", max_length=255, check_sql_keywords=True + ) @field_validator("value") @classmethod @@ -1082,8 +1090,19 @@ class BigNumberChartConfig(UnknownFieldCheckMixin): ), min_length=1, max_length=255, - pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$", + # No regex pattern — see field description above. ) + + @field_validator("temporal_column") + @classmethod + def sanitize_temporal_column(cls, v: str | None) -> str | None: + """Sanitize temporal column name to prevent XSS and SQL injection.""" + if v is None: + return None + return sanitize_user_input( + v, "Temporal column", max_length=255, check_sql_keywords=True + ) + time_grain: TimeGrain | None = Field( None, description=( diff --git a/superset/mcp_service/chart/tool/generate_chart.py b/superset/mcp_service/chart/tool/generate_chart.py index 646ac4d4c2a7..697698bbda03 100644 --- a/superset/mcp_service/chart/tool/generate_chart.py +++ b/superset/mcp_service/chart/tool/generate_chart.py @@ -175,7 +175,8 @@ async def generate_chart( # noqa: C901 - Set save_chart=True to permanently save the chart - LLM clients MUST display returned chart URL to users - Use numeric dataset ID or UUID (NOT schema.table_name format) - - MUST include chart_type in config (either 'xy' or 'table') + - MUST include chart_type in config (one of: 'xy', 'table', 'pie', + 'big_number', 'pivot_table', 'mixed_timeseries', 'handlebars') IMPORTANT: The 'chart_type' field in the config is a DISCRIMINATOR that determines which chart configuration schema to use. It MUST be included and MUST match the @@ -200,6 +201,86 @@ async def generate_chart( # noqa: C901 } ``` + + Example usage for Pie chart: + ```json + { + "dataset_id": 123, + "config": { + "chart_type": "pie", + "dimension": {"name": "product_category"}, + "metric": {"name": "revenue", "aggregate": "SUM"}, + "donut": false + } + } + ``` + + Example usage for Big Number (no trendline): + ```json + { + "dataset_id": 123, + "config": { + "chart_type": "big_number", + "metric": {"name": "total_sales", "aggregate": "SUM"} + } + } + ``` + + Example usage for Big Number with trendline: + ```json + { + "dataset_id": 123, + "config": { + "chart_type": "big_number", + "metric": {"name": "revenue", "aggregate": "SUM"}, + "temporal_column": "order_date", + "time_grain": "P1M", + "show_trendline": true + } + } + ``` + + Example usage for Pivot Table: + ```json + { + "dataset_id": 123, + "config": { + "chart_type": "pivot_table", + "rows": [{"name": "region"}], + "columns": [{"name": "product_category"}], + "metrics": [{"name": "revenue", "aggregate": "SUM"}] + } + } + ``` + + Example usage for Mixed Timeseries: + ```json + { + "dataset_id": 123, + "config": { + "chart_type": "mixed_timeseries", + "x": {"name": "order_date"}, + "y": [{"name": "revenue", "aggregate": "SUM"}], + "primary_kind": "line", + "y_secondary": [{"name": "order_count", "aggregate": "COUNT"}], + "secondary_kind": "bar" + } + } + ``` + + Example usage for Handlebars: + ```json + { + "dataset_id": 123, + "config": { + "chart_type": "handlebars", + "handlebars_template": "{{#each data}}{{this.name}}{{/each}}", + "groupby": [{"name": "product"}], + "metrics": [{"name": "revenue", "aggregate": "SUM"}] + } + } + ``` + Example usage for Table chart: ```json { diff --git a/superset/mcp_service/chart/validation/schema_validator.py b/superset/mcp_service/chart/validation/schema_validator.py index 7cae450ff599..8ae2b2aad86a 100644 --- a/superset/mcp_service/chart/validation/schema_validator.py +++ b/superset/mcp_service/chart/validation/schema_validator.py @@ -525,6 +525,37 @@ def _pre_validate_mixed_timeseries_config( return True, None + @staticmethod + def _format_single_error(err: Dict[str, Any]) -> tuple[str, str]: + """Return (detail_message, optional_suggestion) for one pydantic error.""" + loc_parts = [str(p) for p in err.get("loc", [])] + loc = " -> ".join(loc_parts) + msg = err.get("msg", "Validation failed") + err_type = err.get("type", "") + field = loc_parts[-1] if loc_parts else "field" + + if err_type == "string_pattern_mismatch": + return ( + f"'{field}' value does not match the required pattern. " + "Use the exact value from your dataset.", + "Use get_dataset_info to find exact column names and values", + ) + if err_type == "literal_error": + # Preserve the pydantic message ("Input should be ...") which is + # already human-readable; just prefix with the field name for context. + return f"'{field}': {msg}", "" + if err_type == "missing": + return ( + f"Required field '{field}' is missing", + "Check the chart_type examples in the tool description", + ) + if err_type == "value_error": + return ( + f"{loc}: {msg}", + "Use get_dataset_info to verify column names and types", + ) + return f"{loc}: {msg}", "" + @staticmethod def _enhance_validation_error( error: PydanticValidationError, request_data: Dict[str, Any] @@ -609,22 +640,29 @@ def _enhance_validation_error( error_code="BIG_NUMBER_VALIDATION_ERROR", ) - # Default enhanced error + # Default enhanced error: build actionable per-field messages error_details = [] - for err in errors[:3]: # Show first 3 errors - loc = " -> ".join(str(location) for location in err.get("loc", [])) - msg = err.get("msg", "Validation failed") - error_details.append(f"{loc}: {msg}") + extra_suggestions: list[str] = [] + for err in errors[:5]: # Surface up to 5 errors + detail, suggestion = SchemaValidator._format_single_error(err) + error_details.append(detail) + if suggestion: + extra_suggestions.append(suggestion) return ChartGenerationError( error_type="validation_error", message="Chart configuration validation failed", details="; ".join(error_details), - suggestions=[ - "Check that all required fields are present", - "Ensure field types match the schema", - "Use get_dataset_info to verify column names", - "Refer to the API documentation for field requirements", - ], + suggestions=list( + dict.fromkeys( + [ + "Check that all required fields are present", + "Ensure field types match the schema", + "Use get_dataset_info to verify column names", + "Refer to the API documentation for field requirements", + ] + + extra_suggestions + ) + ), error_code="VALIDATION_ERROR", ) diff --git a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py index 1da9d1bc3a74..a401a6f47eb8 100644 --- a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py +++ b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py @@ -778,3 +778,104 @@ def test_client_warnings_discarded_even_when_server_also_warns(self) -> None: assert len(req.sanitization_warnings) == 1 assert "chart_name" in req.sanitization_warnings[0] assert "injected" not in req.sanitization_warnings[0] + + +class TestColumnRefNameRelaxedPattern: + """ColumnRef.name no longer enforces a strict regex pattern. + + Many valid database column names were previously rejected: + - Names starting with a digit (e.g. "1Q_revenue") + - Names with locale-specific characters + The field_validator sanitize_name() still blocks XSS and SQL injection. + """ + + def test_digit_prefixed_name_accepted(self) -> None: + """Column names starting with a digit must now be accepted.""" + col = ColumnRef(name="1Q_revenue") + assert col.name == "1Q_revenue" + + def test_name_with_hyphen_accepted(self) -> None: + col = ColumnRef(name="order-date") + assert col.name == "order-date" + + def test_name_with_dot_accepted(self) -> None: + col = ColumnRef(name="schema.column") + assert col.name == "schema.column" + + def test_name_with_spaces_accepted(self) -> None: + col = ColumnRef(name="Total Revenue") + assert col.name == "Total Revenue" + + def test_script_tag_neutralized(self) -> None: + """sanitize_name() neutralizes script-tag XSS via nh3. + Depending on nh3 version, nh3 either strips the entire script element + including its content (leaving empty → ValidationError) or strips only + the tag delimiters (leaving 'alert(1)'). Either way, no raw ") + assert "