diff --git a/fastapi_mcp/openapi/utils.py b/fastapi_mcp/openapi/utils.py index 1821d57..5493288 100644 --- a/fastapi_mcp/openapi/utils.py +++ b/fastapi_mcp/openapi/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, FrozenSet def get_single_param_type_from_schema(param_schema: Dict[str, Any]) -> str: @@ -16,17 +16,25 @@ def get_single_param_type_from_schema(param_schema: Dict[str, Any]) -> str: return param_schema.get("type", "string") -def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dict[str, Any]) -> Dict[str, Any]: +def resolve_schema_references( + schema_part: Dict[str, Any], + reference_schema: Dict[str, Any], + _visited: FrozenSet[str] | None = None, +) -> Dict[str, Any]: """ Resolve schema references in OpenAPI schemas. Args: schema_part: The part of the schema being processed that may contain references reference_schema: The complete schema used to resolve references from + _visited: Frozenset of $ref paths currently being resolved; used to detect cycles Returns: The schema with references resolved """ + if _visited is None: + _visited = frozenset() + # Make a copy to avoid modifying the input schema schema_part = schema_part.copy() @@ -35,6 +43,9 @@ def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dic ref_path = schema_part["$ref"] # Standard OpenAPI references are in the format "#/components/schemas/ModelName" if ref_path.startswith("#/components/schemas/"): + # Circular reference detected — leave $ref in place to avoid infinite recursion + if ref_path in _visited: + return schema_part model_name = ref_path.split("/")[-1] if "components" in reference_schema and "schemas" in reference_schema["components"]: if model_name in reference_schema["components"]["schemas"]: @@ -43,15 +54,18 @@ def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dic # Remove the $ref key and merge with the original schema schema_part.pop("$ref") schema_part.update(ref_schema) + # Mark this ref as being resolved for the current expansion path + _visited = _visited | {ref_path} # Recursively resolve references in all dictionary values for key, value in schema_part.items(): if isinstance(value, dict): - schema_part[key] = resolve_schema_references(value, reference_schema) + schema_part[key] = resolve_schema_references(value, reference_schema, _visited) elif isinstance(value, list): # Only process list items that are dictionaries since only they can contain refs schema_part[key] = [ - resolve_schema_references(item, reference_schema) if isinstance(item, dict) else item for item in value + resolve_schema_references(item, reference_schema, _visited) if isinstance(item, dict) else item + for item in value ] return schema_part diff --git a/tests/test_openapi_conversion.py b/tests/test_openapi_conversion.py index aefe643..a5310f4 100644 --- a/tests/test_openapi_conversion.py +++ b/tests/test_openapi_conversion.py @@ -1,6 +1,8 @@ from fastapi import FastAPI from fastapi.openapi.utils import get_openapi import mcp.types as types +from pydantic import BaseModel +from typing import List, Literal, Union from fastapi_mcp.openapi.convert import convert_openapi_to_mcp_tools from fastapi_mcp.openapi.utils import ( @@ -422,3 +424,51 @@ def test_body_params_edge_cases(complex_fastapi_app: FastAPI): if "items" in properties: item_props = properties["items"]["items"]["properties"] assert "total" in item_props + + +def test_self_referencing_schema_does_not_raise_recursion_error(): + """ + Regression test for https://github.com/tadata-org/fastapi_mcp/issues/287. + resolve_schema_references() must not raise RecursionError when the OpenAPI + schema contains a self-referencing (recursive) Pydantic model. + """ + + class ComparisonFilter(BaseModel): + type: Literal["eq", "ne", "gt", "lt"] + key: str + value: str + + class CompoundFilter(BaseModel): + type: Literal["and", "or"] + filters: List[Union[ComparisonFilter, "CompoundFilter"]] + + CompoundFilter.model_rebuild() + + class SearchRequest(BaseModel): + query: str + filters: CompoundFilter | None = None + + app = FastAPI() + + @app.post("/search") + async def search(request: SearchRequest): + return {"status": "ok"} + + openapi_schema = get_openapi( + title=app.title, + version=app.version, + openapi_version=app.openapi_version, + description=app.description, + routes=app.routes, + ) + + # Must not raise RecursionError + tools, operation_map = convert_openapi_to_mcp_tools(openapi_schema) + + assert len(tools) == 1 + assert "search_search_post" in operation_map + + tool = tools[0] + assert isinstance(tool, types.Tool) + assert tool.inputSchema is not None + assert "query" in tool.inputSchema["properties"]