Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions superset/mcp_service/chart/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,10 @@ class ColumnRef(BaseModel):
...,
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
# No regex pattern: sanitize_name() already blocks XSS/SQL injection;
# many valid column names (digit-prefixed, locale chars, etc.) would
# be rejected by a strict pattern while posing no security risk.
# Use get_dataset_info to find exact column names.
validation_alias=AliasChoices("name", "column_name"),
)
label: str | None = Field(None, max_length=500)
Expand Down Expand Up @@ -743,7 +746,10 @@ class FilterConfig(BaseModel):
...,
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
# No regex pattern: sanitize_name() already blocks XSS/SQL injection;
# many valid column names (digit-prefixed, locale chars, etc.) would
# be rejected by a strict pattern while posing no security risk.
# Use get_dataset_info to find exact column names.
validation_alias=AliasChoices("column", "col"),
)
Comment thread
aminghadersohi marked this conversation as resolved.
op: Literal[
Expand Down Expand Up @@ -1082,7 +1088,10 @@ class BigNumberChartConfig(UnknownFieldCheckMixin):
),
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
# No regex pattern: sanitize_name() already blocks XSS/SQL injection;
# many valid column names (digit-prefixed, locale chars, etc.) would
# be rejected by a strict pattern while posing no security risk.
# Use get_dataset_info to find exact column names.
)
Comment thread
aminghadersohi marked this conversation as resolved.
time_grain: TimeGrain | None = Field(
None,
Expand Down
80 changes: 80 additions & 0 deletions superset/mcp_service/chart/tool/generate_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,86 @@ async def generate_chart( # noqa: C901
}
```


Example usage for Pie chart:
```json
{
"dataset_id": 123,
"config": {
"chart_type": "pie",
"dimension": {"name": "product_category"},
Comment thread
aminghadersohi marked this conversation as resolved.
"metric": {"name": "revenue", "aggregate": "SUM"},
"donut": false
}
}
```

Example usage for Big Number (no trendline):
```json
{
"dataset_id": 123,
"config": {
"chart_type": "big_number",
"metric": {"name": "total_sales", "aggregate": "SUM"}
}
}
```

Example usage for Big Number with trendline:
```json
{
"dataset_id": 123,
"config": {
"chart_type": "big_number",
"metric": {"name": "revenue", "aggregate": "SUM"},
"temporal_column": "order_date",
"time_grain": "P1M",
"show_trendline": true
}
}
```

Example usage for Pivot Table:
```json
{
"dataset_id": 123,
"config": {
"chart_type": "pivot_table",
"rows": [{"name": "region"}],
"columns": [{"name": "product_category"}],
"metrics": [{"name": "revenue", "aggregate": "SUM"}]
}
}
```

Example usage for Mixed Timeseries:
```json
{
"dataset_id": 123,
"config": {
"chart_type": "mixed_timeseries",
"x": {"name": "order_date"},
"y": [{"name": "revenue", "aggregate": "SUM"}],
"primary_kind": "line",
"y_secondary": [{"name": "order_count", "aggregate": "COUNT"}],
"secondary_kind": "bar"
}
}
```

Example usage for Handlebars:
```json
{
"dataset_id": 123,
"config": {
"chart_type": "handlebars",
"handlebars_template": "{{#each data}}{{this.name}}{{/each}}",
"groupby": [{"name": "product"}],
"metrics": [{"name": "revenue", "aggregate": "SUM"}]
}
}
```

Example usage for Table chart:
```json
{
Expand Down
61 changes: 50 additions & 11 deletions superset/mcp_service/chart/validation/schema_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,38 @@ def _pre_validate_mixed_timeseries_config(

return True, None

@staticmethod
def _format_single_error(err: Dict[str, Any]) -> tuple[str, str]:
"""Return (detail_message, optional_suggestion) for one pydantic error."""
loc_parts = [str(p) for p in err.get("loc", [])]
loc = " -> ".join(loc_parts)
msg = err.get("msg", "Validation failed")
err_type = err.get("type", "")
ctx = err.get("ctx", {}) or {}
field = loc_parts[-1] if loc_parts else "field"

if err_type == "string_pattern_mismatch":
return (
f"'{field}' value contains disallowed characters. "
"Column names must not contain HTML, script tags, or SQL "
"injection patterns. Use the exact column name from your dataset.",
"Use get_dataset_info to find exact column names",
)
if err_type == "literal_error":
expected = ctx.get("expected", "")
return f"'{field}' has an invalid value. Expected one of: {expected}", ""
if err_type == "missing":
return (
f"Required field '{field}' is missing",
"Check the chart_type examples in the tool description",
)
if err_type == "value_error":
return (
f"{loc}: {msg}",
"Use get_dataset_info to verify column names and types",
)
return f"{loc}: {msg}", ""

@staticmethod
def _enhance_validation_error(
error: PydanticValidationError, request_data: Dict[str, Any]
Expand Down Expand Up @@ -609,22 +641,29 @@ def _enhance_validation_error(
error_code="BIG_NUMBER_VALIDATION_ERROR",
)

# Default enhanced error
# Default enhanced error: build actionable per-field messages
error_details = []
for err in errors[:3]: # Show first 3 errors
loc = " -> ".join(str(location) for location in err.get("loc", []))
msg = err.get("msg", "Validation failed")
error_details.append(f"{loc}: {msg}")
extra_suggestions: list[str] = []
for err in errors[:5]: # Surface up to 5 errors
detail, suggestion = SchemaValidator._format_single_error(err)
error_details.append(detail)
if suggestion:
extra_suggestions.append(suggestion)

return ChartGenerationError(
error_type="validation_error",
message="Chart configuration validation failed",
details="; ".join(error_details),
suggestions=[
"Check that all required fields are present",
"Ensure field types match the schema",
"Use get_dataset_info to verify column names",
"Refer to the API documentation for field requirements",
],
suggestions=list(
dict.fromkeys(
[
"Check that all required fields are present",
"Ensure field types match the schema",
"Use get_dataset_info to verify column names",
"Refer to the API documentation for field requirements",
]
+ extra_suggestions
)
),
error_code="VALIDATION_ERROR",
)
82 changes: 82 additions & 0 deletions tests/unit_tests/mcp_service/chart/test_chart_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,3 +778,85 @@ def test_client_warnings_discarded_even_when_server_also_warns(self) -> None:
assert len(req.sanitization_warnings) == 1
assert "chart_name" in req.sanitization_warnings[0]
assert "injected" not in req.sanitization_warnings[0]


class TestColumnRefNameRelaxedPattern:
"""ColumnRef.name no longer enforces a strict regex pattern.

Many valid database column names were previously rejected:
- Names starting with a digit (e.g. "1Q_revenue")
- Names with locale-specific characters
The field_validator sanitize_name() still blocks XSS and SQL injection.
"""

def test_digit_prefixed_name_accepted(self) -> None:
"""Column names starting with a digit must now be accepted."""
col = ColumnRef(name="1Q_revenue")
assert col.name == "1Q_revenue"

def test_name_with_hyphen_accepted(self) -> None:
col = ColumnRef(name="order-date")
assert col.name == "order-date"

def test_name_with_dot_accepted(self) -> None:
col = ColumnRef(name="schema.column")
assert col.name == "schema.column"

def test_name_with_spaces_accepted(self) -> None:
col = ColumnRef(name="Total Revenue")
assert col.name == "Total Revenue"

def test_xss_attempt_blocked(self) -> None:
"""sanitize_name() still blocks XSS even without the regex."""
with pytest.raises(ValidationError):
ColumnRef(name="<script>alert(1)</script>")

def test_sql_keyword_blocked(self) -> None:
"""check_sql_keywords=True still blocks pure SQL statements."""
with pytest.raises(ValidationError):
ColumnRef(name="1; DROP TABLE users; --")

def test_empty_name_blocked(self) -> None:
with pytest.raises(ValidationError):
ColumnRef(name="")

def test_table_chart_with_digit_prefixed_column(self) -> None:
"""End-to-end: digit-prefixed column passes through GenerateChartRequest."""
req = GenerateChartRequest(
dataset_id=1,
config={
"chart_type": "table",
"columns": [
{"name": "1Q_revenue"},
{"name": "product_name"},
],
},
)
assert req.config.chart_type == "table"

def test_xy_chart_with_hyphenated_column(self) -> None:
req = GenerateChartRequest(
dataset_id=1,
config={
"chart_type": "xy",
"x": {"name": "order-date"},
"y": [{"name": "1Q-revenue", "aggregate": "SUM"}],
},
)
assert req.config.chart_type == "xy"


class TestFilterConfigColumnRelaxedPattern:
"""FilterConfig.column no longer enforces a strict regex pattern."""

def test_digit_prefixed_filter_column_accepted(self) -> None:
from superset.mcp_service.chart.schemas import FilterConfig

f = FilterConfig(column="1Q_flag", op="=", value="active")
assert f.column == "1Q_flag"

def test_hyphenated_filter_column_accepted(self) -> None:
from superset.mcp_service.chart.schemas import FilterConfig

f = FilterConfig(column="order-status", op="=", value="shipped")
assert f.column == "order-status"
Loading