diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index 1fdb4f43ab6d..045f13de99a7 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -22,7 +22,7 @@ from __future__ import annotations import difflib -from datetime import datetime, timezone +from datetime import datetime from typing import Annotated, Any, Dict, List, Literal, Protocol import humanize @@ -50,7 +50,7 @@ OwnedByMeMixin, QueryCacheControl, ) -from superset.mcp_service.common.error_schemas import ChartGenerationError +from superset.mcp_service.common.error_schemas import ChartGenerationError, MCPBaseError from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE from superset.mcp_service.privacy import filter_user_directory_fields from superset.mcp_service.system.schemas import ( @@ -183,16 +183,8 @@ def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any return data -class ChartError(BaseModel): - error: str = Field(..., description="Error message") - error_type: str = Field(..., description="Type of error") - timestamp: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), - description="Error timestamp", - ) - model_config = ConfigDict(ser_json_timedelta="iso8601") - - @field_validator("error") +class ChartError(MCPBaseError): + @field_validator("message") @classmethod def sanitize_error_for_llm_context(cls, value: str) -> str: """Wrap error text before it is exposed to LLM context.""" diff --git a/superset/mcp_service/common/error_schemas.py b/superset/mcp_service/common/error_schemas.py index ec0274cc0be8..f3b1b5f7c6ff 100644 --- a/superset/mcp_service/common/error_schemas.py +++ b/superset/mcp_service/common/error_schemas.py @@ -19,9 +19,53 @@ Enhanced error schemas for MCP chart generation with contextual information """ +from __future__ import annotations + +from datetime import datetime, timezone from typing import Any, Dict, List -from pydantic import BaseModel, Field +from pydantic import BaseModel, computed_field, ConfigDict, Field, model_validator + + +class MCPBaseError(BaseModel): + """Base error shape for all MCP tool responses. + + Provides a consistent set of fields that every error response includes, + allowing LLM clients to handle errors uniformly regardless of which tool + produced them. + """ + + error_type: str = Field( + ..., description="Type of error (validation, execution, etc.)" + ) + message: str = Field(..., description="Human-readable error message") + timestamp: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + description="Error timestamp", + ) + details: str | None = Field(None, description="Detailed error explanation") + suggestions: list[str] = Field( + default_factory=list, description="Actionable suggestions to fix the error" + ) + error_code: str | None = Field( + None, description="Unique error code for support reference" + ) + + model_config = ConfigDict(ser_json_timedelta="iso8601") + + @model_validator(mode="before") + @classmethod + def _compat_error_to_message(cls, data: Any) -> Any: + """Allow construction with error= kwarg for backward compatibility.""" + if isinstance(data, dict) and "error" in data and "message" not in data: + data["message"] = data.pop("error") + return data + + @computed_field # type: ignore[prop-decorator] + @property + def error(self) -> str: + """Backward-compatible field: mirrors 'message' in serialized output.""" + return self.message class ColumnSuggestion(BaseModel): @@ -67,13 +111,9 @@ class DatasetContext(BaseModel): ) -class ChartGenerationError(BaseModel): +class ChartGenerationError(MCPBaseError): """Enhanced error response for chart generation failures""" - error_type: str = Field( - ..., description="Type of error (validation, execution, etc.)" - ) - message: str = Field(..., description="High-level error message") details: str = Field(..., description="Detailed error explanation") validation_errors: List[ValidationError] = Field( default_factory=list, description="Specific field validation errors" @@ -84,15 +124,9 @@ class ChartGenerationError(BaseModel): query_info: Dict[str, Any] | None = Field( None, description="Query execution details" ) - suggestions: List[str] = Field( - default_factory=list, description="Actionable suggestions to fix the error" - ) help_url: str | None = Field( None, description="URL to documentation for this error type" ) - error_code: str | None = Field( - None, description="Unique error code for support reference" - ) class ChartGenerationResponse(BaseModel):