From 9db3583a9dbbd171025796618e07724a9e89cc8b Mon Sep 17 00:00:00 2001 From: Bhargav <7ksb24@gmail.com> Date: Fri, 29 May 2026 12:09:58 +0530 Subject: [PATCH] fix: trigger lifespan events when using mount_http() (closes #256) When FastAPI(lifespan=...) is used, Starlette replaces the router's lifespan_context with the user's function and completely bypasses any on_startup / on_shutdown handlers registered via add_event_handler(). This caused the ASGI lifespan protocol to appear unsupported to uvicorn, silently skipping DB connections, caches, and background task setup. Fix: wrap self.fastapi.router.lifespan_context instead of using add_event_handler(). The wrapper runs the user's startup first (so their resources are available to MCP tools), then starts the StreamableHTTP session manager, and shuts them down in reverse order. Also add explicit startup() / shutdown() methods to FastApiHttpSessionManager so callers can integrate the session manager into their own lifecycle management if needed. Tests: 6 new tests in tests/test_lifespan.py covering - lifespan startup/shutdown fires with mount_http() - correct startup-before-shutdown ordering - user resources available during HTTP requests - MCP /mcp endpoint reachable post-startup - apps without lifespan continue working - session manager starts after user startup --- fastapi_mcp/server.py | 29 ++++ fastapi_mcp/transport/http.py | 42 +++-- tests/test_lifespan.py | 288 ++++++++++++++++++++++++++++++++++ 3 files changed, 347 insertions(+), 12 deletions(-) create mode 100644 tests/test_lifespan.py diff --git a/fastapi_mcp/server.py b/fastapi_mcp/server.py index bb75106..0409f67 100644 --- a/fastapi_mcp/server.py +++ b/fastapi_mcp/server.py @@ -1,5 +1,6 @@ import json import httpx +from contextlib import asynccontextmanager from typing import Dict, Optional, Any, List, Union, Literal, Sequence from typing_extensions import Annotated, Doc @@ -355,6 +356,34 @@ def mount_http( self._setup_auth() self._http_transport = http_transport # Store reference + # Wrap the main FastAPI app's lifespan to include the MCP session + # manager lifecycle. We cannot use add_event_handler() here because + # when the user has already passed lifespan= to FastAPI(), Starlette + # stores that function as `lifespan_context` and completely bypasses + # on_startup / on_shutdown handlers. Wrapping lifespan_context works + # for both cases: + # • No custom lifespan: wraps Starlette's _DefaultLifespan, which + # still calls on_startup / on_shutdown handlers as expected. + # • Custom lifespan: wraps the user's context manager so the + # session manager starts after user resources are initialised + # and stops before they are torn down. + _main_router = self.fastapi.router + _original_lifespan = _main_router.lifespan_context + _transport = http_transport # capture for closure + + @asynccontextmanager + async def _mcp_lifespan(app: Any): + # Run user startup first so their resources (DB, cache …) are + # available when the MCP session manager begins accepting clients. + async with _original_lifespan(app) as state: + await _transport.startup() + try: + yield state + finally: + await _transport.shutdown() + + _main_router.lifespan_context = _mcp_lifespan + # HACK: If we got a router and not a FastAPI instance, we need to re-include the router so that # FastAPI will pick up the new routes we added. The problem with this approach is that we assume # that the router is a sub-router of self.fastapi, which may not always be the case. diff --git a/fastapi_mcp/transport/http.py b/fastapi_mcp/transport/http.py index 47af6f0..6fe9759 100644 --- a/fastapi_mcp/transport/http.py +++ b/fastapi_mcp/transport/http.py @@ -30,12 +30,39 @@ def __init__( self._manager_started = False self._startup_lock = asyncio.Lock() + async def startup(self) -> None: + """ + Start the session manager. + + Call this from a FastAPI ``on_startup`` event handler (or lifespan) so + that the StreamableHTTP session manager is ready before the first + request arrives and the ASGI lifespan protocol is handled correctly. + """ + await self._ensure_session_manager_started() + + async def shutdown(self) -> None: + """ + Stop the session manager. + + Call this from a FastAPI ``on_shutdown`` event handler (or lifespan) + so that in-flight sessions are drained and the background task is + cancelled cleanly. + """ + if self._manager_task and not self._manager_task.done(): + self._manager_task.cancel() + try: + await self._manager_task + except asyncio.CancelledError: + pass + self._manager_started = False + self._session_manager = None + async def _ensure_session_manager_started(self) -> None: """ Ensure the session manager is started. - This is called lazily on the first request to start the session manager - if it hasn't been started yet. + Called eagerly from ``startup()`` and also as a lazy fallback on the + first request in case the startup hook was not registered. """ if self._manager_started: return @@ -84,7 +111,7 @@ async def handle_fastapi_request(self, request: Request) -> Response: This converts FastAPI's Request/Response to ASGI scope/receive/send and then converts the result back to a FastAPI Response. """ - # Ensure session manager is started + # Ensure session manager is started (lazy fallback if startup hook wasn't used) await self._ensure_session_manager_started() if not self._session_manager: @@ -125,12 +152,3 @@ async def send_callback(message): logger.exception("Error in StreamableHTTPSessionManager") raise HTTPException(status_code=500, detail="Internal server error") - async def shutdown(self) -> None: - """Clean up the session manager and background task.""" - if self._manager_task and not self._manager_task.done(): - self._manager_task.cancel() - try: - await self._manager_task - except asyncio.CancelledError: - pass - self._manager_started = False diff --git a/tests/test_lifespan.py b/tests/test_lifespan.py new file mode 100644 index 0000000..8a8545e --- /dev/null +++ b/tests/test_lifespan.py @@ -0,0 +1,288 @@ +""" +Tests for lifespan context manager support with mount_http(). + +Regression tests for: + https://github.com/tadata-org/fastapi_mcp/issues/256 + +When mount_http() is used, the FastAPI app's lifespan startup/shutdown events +must fire correctly and not be suppressed by the ASGI lifespan protocol. +""" + +import asyncio +from contextlib import asynccontextmanager + +import pytest +import httpx +from fastapi import FastAPI +from fastapi_mcp import FastApiMCP +import mcp.types as types + + +# --------------------------------------------------------------------------- +# ASGI lifespan helper +# --------------------------------------------------------------------------- + +@asynccontextmanager +async def run_asgi_lifespan(app: FastAPI): + """ + Manually drive the ASGI lifespan protocol for an app in tests. + + httpx.ASGITransport only sends 'http' scoped requests — it never sends + the 'lifespan' scope. This helper fills that gap by manually sending + lifespan.startup / lifespan.shutdown events to the ASGI app, which is + exactly what uvicorn does in production. + """ + receive_q: asyncio.Queue = asyncio.Queue() + send_q: asyncio.Queue = asyncio.Queue() + + async def receive(): + return await receive_q.get() + + async def send(message): + await send_q.put(message) + + scope = {"type": "lifespan", "asgi": {"version": "3.0"}, "state": {}} + lifespan_task = asyncio.create_task(app(scope, receive, send)) + + # Trigger startup + await receive_q.put({"type": "lifespan.startup"}) + startup_msg = await asyncio.wait_for(send_q.get(), timeout=5.0) + assert startup_msg["type"] == "lifespan.startup.complete", ( + f"Expected lifespan.startup.complete, got: {startup_msg}" + ) + + try: + yield + finally: + # Trigger shutdown + await receive_q.put({"type": "lifespan.shutdown"}) + try: + await asyncio.wait_for(send_q.get(), timeout=3.0) + except asyncio.TimeoutError: + pass + + # Clean up the background lifespan task + if not lifespan_task.done(): + lifespan_task.cancel() + try: + await asyncio.wait_for(lifespan_task, timeout=2.0) + except (asyncio.CancelledError, asyncio.TimeoutError, Exception): + pass + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_lifespan_app(): + """Return (app, events) where events records startup/shutdown calls.""" + events: list[str] = [] + + @asynccontextmanager + async def lifespan(app: FastAPI): + events.append("startup") + yield + events.append("shutdown") + + app = FastAPI(title="Lifespan Test App", lifespan=lifespan) + + @app.get("/health", operation_id="health_check") + async def health(): + return {"status": "healthy"} + + return app, events + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.anyio +async def test_lifespan_startup_fires_with_mount_http(): + """on_startup events must run when mount_http() is used (issue #256).""" + app, events = _make_lifespan_app() + mcp = FastApiMCP(app) + mcp.mount_http() + + async with run_asgi_lifespan(app): + assert "startup" in events, "lifespan startup event was not triggered" + + assert "shutdown" in events, "lifespan shutdown event was not triggered" + + +@pytest.mark.anyio +async def test_lifespan_order(): + """startup must happen before shutdown and both must occur exactly once.""" + app, events = _make_lifespan_app() + mcp = FastApiMCP(app) + mcp.mount_http() + + async with run_asgi_lifespan(app): + pass # startup already done by now + + assert events == ["startup", "shutdown"], f"Unexpected event order: {events}" + + +@pytest.mark.anyio +async def test_lifespan_resource_available_during_request(): + """Resources initialised in the lifespan must be accessible during tool calls.""" + db_state: dict = {} + + @asynccontextmanager + async def lifespan(app: FastAPI): + db_state["connected"] = True + yield + db_state["connected"] = False + + app = FastAPI(title="DB App", lifespan=lifespan) + + @app.get("/db-status", operation_id="db_status") + async def db_status(): + return {"connected": db_state.get("connected", False)} + + mcp = FastApiMCP(app) + mcp.mount_http() + + async with run_asgi_lifespan(app): + assert db_state.get("connected") is True + + # Make an HTTP request while lifespan is active + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://testserver", + ) as client: + response = await client.get("/db-status") + assert response.status_code == 200 + assert response.json()["connected"] is True + + # After shutdown the resource must be cleaned up + assert db_state.get("connected") is False + + +@pytest.mark.anyio +async def test_mcp_endpoint_reachable_after_lifespan_startup(): + """The /mcp endpoint must respond after lifespan startup completes.""" + app, events = _make_lifespan_app() + mcp = FastApiMCP(app) + mcp.mount_http() + + async with run_asgi_lifespan(app): + assert "startup" in events + + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://testserver", + ) as client: + response = await client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "method": "initialize", + "id": 1, + "params": { + "protocolVersion": types.LATEST_PROTOCOL_VERSION, + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0.0"}, + }, + }, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + ) + + assert response.status_code == 200 + result = response.json() + assert result["jsonrpc"] == "2.0" + assert "result" in result + + assert "shutdown" in events + + +@pytest.mark.anyio +async def test_no_lifespan_still_works(): + """Apps without a lifespan context manager must continue to work fine.""" + app = FastAPI(title="No Lifespan App") + + @app.get("/ping", operation_id="ping") + async def ping(): + return {"pong": True} + + mcp = FastApiMCP(app) + mcp.mount_http() + + async with run_asgi_lifespan(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://testserver", + ) as client: + response = await client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "method": "initialize", + "id": 1, + "params": { + "protocolVersion": types.LATEST_PROTOCOL_VERSION, + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0.0"}, + }, + }, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + ) + assert response.status_code == 200 + + +@pytest.mark.anyio +async def test_session_manager_starts_after_user_startup(): + """ + The MCP session manager must start AFTER the user's startup so that tools + can access resources initialised in the lifespan. + """ + order: list[str] = [] + + @asynccontextmanager + async def lifespan(app: FastAPI): + order.append("user_startup") + yield + order.append("user_shutdown") + + app = FastAPI(title="Order App", lifespan=lifespan) + + @app.get("/noop", operation_id="noop") + async def noop(): + return {} + + mcp = FastApiMCP(app) + mcp.mount_http() + + # Patch startup/shutdown to record order + original_startup = mcp._http_transport.startup + original_shutdown = mcp._http_transport.shutdown + + async def patched_startup(): + order.append("mcp_startup") + await original_startup() + + async def patched_shutdown(): + order.append("mcp_shutdown") + await original_shutdown() + + mcp._http_transport.startup = patched_startup + mcp._http_transport.shutdown = patched_shutdown + + async with run_asgi_lifespan(app): + pass + + # user startup must come before MCP startup + assert order.index("user_startup") < order.index("mcp_startup"), ( + f"Expected user_startup before mcp_startup, got: {order}" + ) + # MCP shutdown must come before user shutdown + assert order.index("mcp_shutdown") < order.index("user_shutdown"), ( + f"Expected mcp_shutdown before user_shutdown, got: {order}" + )