Skip to content
27 changes: 23 additions & 4 deletions superset/mcp_service/chart/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,10 @@ class ColumnRef(BaseModel):
...,
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
# No regex pattern: sanitize_name() already blocks XSS/SQL injection;
# many valid column names (digit-prefixed, locale chars, etc.) would
# be rejected by a strict pattern while posing no security risk.
# Use get_dataset_info to find exact column names.
validation_alias=AliasChoices("name", "column_name"),
)
label: str | None = Field(None, max_length=500)
Expand Down Expand Up @@ -743,7 +746,10 @@ class FilterConfig(BaseModel):
...,
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
# No regex pattern: sanitize_name() already blocks XSS/SQL injection;
# many valid column names (digit-prefixed, locale chars, etc.) would
# be rejected by a strict pattern while posing no security risk.
# Use get_dataset_info to find exact column names.
validation_alias=AliasChoices("column", "col"),
)
op: Literal[
Expand Down Expand Up @@ -775,7 +781,9 @@ def sanitize_column(cls, v: str) -> str:
"""Sanitize filter column name to prevent injection attacks."""
# sanitize_user_input raises ValueError when allow_empty=False (default)
# so the return value is guaranteed to be a non-None str
return sanitize_user_input(v, "Filter column", max_length=255) # type: ignore[return-value]
return sanitize_user_input( # type: ignore[return-value]
v, "Filter column", max_length=255, check_sql_keywords=True
)

@field_validator("value")
@classmethod
Expand Down Expand Up @@ -1082,8 +1090,19 @@ class BigNumberChartConfig(UnknownFieldCheckMixin):
),
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
# No regex pattern — see field description above.
)

@field_validator("temporal_column")
@classmethod
def sanitize_temporal_column(cls, v: str | None) -> str | None:
"""Sanitize temporal column name to prevent XSS and SQL injection."""
if v is None:
return None
return sanitize_user_input(
v, "Temporal column", max_length=255, check_sql_keywords=True
)

time_grain: TimeGrain | None = Field(
None,
description=(
Expand Down
83 changes: 82 additions & 1 deletion superset/mcp_service/chart/tool/generate_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ async def generate_chart( # noqa: C901
- Set save_chart=True to permanently save the chart
- LLM clients MUST display returned chart URL to users
- Use numeric dataset ID or UUID (NOT schema.table_name format)
- MUST include chart_type in config (either 'xy' or 'table')
- MUST include chart_type in config (one of: 'xy', 'table', 'pie',
'big_number', 'pivot_table', 'mixed_timeseries', 'handlebars')

IMPORTANT: The 'chart_type' field in the config is a DISCRIMINATOR that determines
which chart configuration schema to use. It MUST be included and MUST match the
Expand All @@ -200,6 +201,86 @@ async def generate_chart( # noqa: C901
}
```


Example usage for Pie chart:
```json
{
"dataset_id": 123,
"config": {
"chart_type": "pie",
"dimension": {"name": "product_category"},
"metric": {"name": "revenue", "aggregate": "SUM"},
"donut": false
}
}
```

Example usage for Big Number (no trendline):
```json
{
"dataset_id": 123,
"config": {
"chart_type": "big_number",
"metric": {"name": "total_sales", "aggregate": "SUM"}
}
}
```

Example usage for Big Number with trendline:
```json
{
"dataset_id": 123,
"config": {
"chart_type": "big_number",
"metric": {"name": "revenue", "aggregate": "SUM"},
"temporal_column": "order_date",
"time_grain": "P1M",
"show_trendline": true
}
}
```

Example usage for Pivot Table:
```json
{
"dataset_id": 123,
"config": {
"chart_type": "pivot_table",
"rows": [{"name": "region"}],
"columns": [{"name": "product_category"}],
"metrics": [{"name": "revenue", "aggregate": "SUM"}]
}
}
```

Example usage for Mixed Timeseries:
```json
{
"dataset_id": 123,
"config": {
"chart_type": "mixed_timeseries",
"x": {"name": "order_date"},
"y": [{"name": "revenue", "aggregate": "SUM"}],
"primary_kind": "line",
"y_secondary": [{"name": "order_count", "aggregate": "COUNT"}],
"secondary_kind": "bar"
}
}
```

Example usage for Handlebars:
```json
{
"dataset_id": 123,
"config": {
"chart_type": "handlebars",
"handlebars_template": "{{#each data}}{{this.name}}{{/each}}",
"groupby": [{"name": "product"}],
"metrics": [{"name": "revenue", "aggregate": "SUM"}]
}
}
```

Example usage for Table chart:
```json
{
Expand Down
61 changes: 50 additions & 11 deletions superset/mcp_service/chart/validation/schema_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,38 @@ def _pre_validate_mixed_timeseries_config(

return True, None

@staticmethod
def _format_single_error(err: Dict[str, Any]) -> tuple[str, str]:
"""Return (detail_message, optional_suggestion) for one pydantic error."""
loc_parts = [str(p) for p in err.get("loc", [])]
loc = " -> ".join(loc_parts)
msg = err.get("msg", "Validation failed")
err_type = err.get("type", "")
field = loc_parts[-1] if loc_parts else "field"

if err_type == "string_pattern_mismatch":
return (
f"'{field}' value contains disallowed characters. "
"Column names must not contain HTML, script tags, or SQL "
"injection patterns. Use the exact column name from your dataset.",
"Use get_dataset_info to find exact column names",
)
if err_type == "literal_error":
# Preserve the pydantic message ("Input should be ...") which is
# already human-readable; just prefix with the field name for context.
return f"'{field}': {msg}", ""
if err_type == "missing":
return (
f"Required field '{field}' is missing",
"Check the chart_type examples in the tool description",
)
if err_type == "value_error":
return (
f"{loc}: {msg}",
"Use get_dataset_info to verify column names and types",
)
return f"{loc}: {msg}", ""

@staticmethod
def _enhance_validation_error(
error: PydanticValidationError, request_data: Dict[str, Any]
Expand Down Expand Up @@ -609,22 +641,29 @@ def _enhance_validation_error(
error_code="BIG_NUMBER_VALIDATION_ERROR",
)

# Default enhanced error
# Default enhanced error: build actionable per-field messages
error_details = []
for err in errors[:3]: # Show first 3 errors
loc = " -> ".join(str(location) for location in err.get("loc", []))
msg = err.get("msg", "Validation failed")
error_details.append(f"{loc}: {msg}")
extra_suggestions: list[str] = []
for err in errors[:5]: # Surface up to 5 errors
detail, suggestion = SchemaValidator._format_single_error(err)
error_details.append(detail)
if suggestion:
extra_suggestions.append(suggestion)

return ChartGenerationError(
error_type="validation_error",
message="Chart configuration validation failed",
details="; ".join(error_details),
suggestions=[
"Check that all required fields are present",
"Ensure field types match the schema",
"Use get_dataset_info to verify column names",
"Refer to the API documentation for field requirements",
],
suggestions=list(
dict.fromkeys(
[
"Check that all required fields are present",
"Ensure field types match the schema",
"Use get_dataset_info to verify column names",
"Refer to the API documentation for field requirements",
]
+ extra_suggestions
)
),
error_code="VALIDATION_ERROR",
)
101 changes: 101 additions & 0 deletions tests/unit_tests/mcp_service/chart/test_chart_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,3 +778,104 @@ def test_client_warnings_discarded_even_when_server_also_warns(self) -> None:
assert len(req.sanitization_warnings) == 1
assert "chart_name" in req.sanitization_warnings[0]
assert "injected" not in req.sanitization_warnings[0]


class TestColumnRefNameRelaxedPattern:
"""ColumnRef.name no longer enforces a strict regex pattern.

Many valid database column names were previously rejected:
- Names starting with a digit (e.g. "1Q_revenue")
- Names with locale-specific characters
The field_validator sanitize_name() still blocks XSS and SQL injection.
"""

def test_digit_prefixed_name_accepted(self) -> None:
"""Column names starting with a digit must now be accepted."""
col = ColumnRef(name="1Q_revenue")
assert col.name == "1Q_revenue"

def test_name_with_hyphen_accepted(self) -> None:
col = ColumnRef(name="order-date")
assert col.name == "order-date"

def test_name_with_dot_accepted(self) -> None:
col = ColumnRef(name="schema.column")
assert col.name == "schema.column"

def test_name_with_spaces_accepted(self) -> None:
col = ColumnRef(name="Total Revenue")
assert col.name == "Total Revenue"

def test_script_tag_neutralized(self) -> None:
"""sanitize_name() neutralizes script-tag XSS via nh3.
Depending on nh3 version, nh3 either strips the entire script element
including its content (leaving empty → ValidationError) or strips only
the tag delimiters (leaving 'alert(1)'). Either way, no raw <script>
tag is stored in the column name."""
try:
col = ColumnRef(name="<script>alert(1)</script>")
assert "<script>" not in col.name
except ValidationError:
pass # nh3 stripped entire element; empty-value guard raised

def test_event_handler_injection_blocked(self) -> None:
"""sanitize_name() rejects event-handler injection patterns (on...=)."""
with pytest.raises(ValidationError):
ColumnRef(name="col onclick=alert(1)")

def test_sql_keyword_blocked(self) -> None:
"""check_sql_keywords=True still blocks pure SQL statements."""
with pytest.raises(ValidationError):
ColumnRef(name="1; DROP TABLE users; --")

def test_empty_name_blocked(self) -> None:
with pytest.raises(ValidationError):
ColumnRef(name="")

def test_table_chart_with_digit_prefixed_column(self) -> None:
"""End-to-end: digit-prefixed column passes through GenerateChartRequest."""
req = GenerateChartRequest(
dataset_id=1,
config={
"chart_type": "table",
"columns": [
{"name": "1Q_revenue"},
{"name": "product_name"},
],
},
)
assert req.config.chart_type == "table"

def test_xy_chart_with_hyphenated_column(self) -> None:
req = GenerateChartRequest(
dataset_id=1,
config={
"chart_type": "xy",
"x": {"name": "order-date"},
"y": [{"name": "1Q-revenue", "aggregate": "SUM"}],
},
)
assert req.config.chart_type == "xy"


class TestFilterConfigColumnRelaxedPattern:
"""FilterConfig.column no longer enforces a strict regex pattern."""

def test_digit_prefixed_filter_column_accepted(self) -> None:
from superset.mcp_service.chart.schemas import FilterConfig

f = FilterConfig(column="1Q_flag", op="=", value="active")
assert f.column == "1Q_flag"

def test_hyphenated_filter_column_accepted(self) -> None:
from superset.mcp_service.chart.schemas import FilterConfig

f = FilterConfig(column="order-status", op="=", value="shipped")
assert f.column == "order-status"

def test_sql_injection_in_filter_column_blocked(self) -> None:
"""FilterConfig.sanitize_column uses check_sql_keywords=True."""
from superset.mcp_service.chart.schemas import FilterConfig

with pytest.raises(ValidationError):
FilterConfig(column="col; DROP TABLE users; --", op="=", value="x")
Loading