Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions fastapi_mcp/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import httpx
from contextlib import asynccontextmanager
from typing import Dict, Optional, Any, List, Union, Literal, Sequence
from typing_extensions import Annotated, Doc

Expand Down Expand Up @@ -355,6 +356,34 @@ def mount_http(
self._setup_auth()
self._http_transport = http_transport # Store reference

# Wrap the main FastAPI app's lifespan to include the MCP session
# manager lifecycle. We cannot use add_event_handler() here because
# when the user has already passed lifespan= to FastAPI(), Starlette
# stores that function as `lifespan_context` and completely bypasses
# on_startup / on_shutdown handlers. Wrapping lifespan_context works
# for both cases:
# • No custom lifespan: wraps Starlette's _DefaultLifespan, which
# still calls on_startup / on_shutdown handlers as expected.
# • Custom lifespan: wraps the user's context manager so the
# session manager starts after user resources are initialised
# and stops before they are torn down.
_main_router = self.fastapi.router
_original_lifespan = _main_router.lifespan_context
_transport = http_transport # capture for closure

@asynccontextmanager
async def _mcp_lifespan(app: Any):
# Run user startup first so their resources (DB, cache …) are
# available when the MCP session manager begins accepting clients.
async with _original_lifespan(app) as state:
await _transport.startup()
try:
yield state
finally:
await _transport.shutdown()

_main_router.lifespan_context = _mcp_lifespan

# HACK: If we got a router and not a FastAPI instance, we need to re-include the router so that
# FastAPI will pick up the new routes we added. The problem with this approach is that we assume
# that the router is a sub-router of self.fastapi, which may not always be the case.
Expand Down
42 changes: 30 additions & 12 deletions fastapi_mcp/transport/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,39 @@ def __init__(
self._manager_started = False
self._startup_lock = asyncio.Lock()

async def startup(self) -> None:
"""
Start the session manager.

Call this from a FastAPI ``on_startup`` event handler (or lifespan) so
that the StreamableHTTP session manager is ready before the first
request arrives and the ASGI lifespan protocol is handled correctly.
"""
await self._ensure_session_manager_started()

async def shutdown(self) -> None:
"""
Stop the session manager.

Call this from a FastAPI ``on_shutdown`` event handler (or lifespan)
so that in-flight sessions are drained and the background task is
cancelled cleanly.
"""
if self._manager_task and not self._manager_task.done():
self._manager_task.cancel()
try:
await self._manager_task
except asyncio.CancelledError:
pass
self._manager_started = False
self._session_manager = None

async def _ensure_session_manager_started(self) -> None:
"""
Ensure the session manager is started.

This is called lazily on the first request to start the session manager
if it hasn't been started yet.
Called eagerly from ``startup()`` and also as a lazy fallback on the
first request in case the startup hook was not registered.
"""
if self._manager_started:
return
Expand Down Expand Up @@ -84,7 +111,7 @@ async def handle_fastapi_request(self, request: Request) -> Response:
This converts FastAPI's Request/Response to ASGI scope/receive/send
and then converts the result back to a FastAPI Response.
"""
# Ensure session manager is started
# Ensure session manager is started (lazy fallback if startup hook wasn't used)
await self._ensure_session_manager_started()

if not self._session_manager:
Expand Down Expand Up @@ -125,12 +152,3 @@ async def send_callback(message):
logger.exception("Error in StreamableHTTPSessionManager")
raise HTTPException(status_code=500, detail="Internal server error")

async def shutdown(self) -> None:
"""Clean up the session manager and background task."""
if self._manager_task and not self._manager_task.done():
self._manager_task.cancel()
try:
await self._manager_task
except asyncio.CancelledError:
pass
self._manager_started = False
288 changes: 288 additions & 0 deletions tests/test_lifespan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
"""
Tests for lifespan context manager support with mount_http().

Regression tests for:
https://github.com/tadata-org/fastapi_mcp/issues/256

When mount_http() is used, the FastAPI app's lifespan startup/shutdown events
must fire correctly and not be suppressed by the ASGI lifespan protocol.
"""

import asyncio
from contextlib import asynccontextmanager

import pytest
import httpx
from fastapi import FastAPI
from fastapi_mcp import FastApiMCP
import mcp.types as types


# ---------------------------------------------------------------------------
# ASGI lifespan helper
# ---------------------------------------------------------------------------

@asynccontextmanager
async def run_asgi_lifespan(app: FastAPI):
"""
Manually drive the ASGI lifespan protocol for an app in tests.

httpx.ASGITransport only sends 'http' scoped requests — it never sends
the 'lifespan' scope. This helper fills that gap by manually sending
lifespan.startup / lifespan.shutdown events to the ASGI app, which is
exactly what uvicorn does in production.
"""
receive_q: asyncio.Queue = asyncio.Queue()
send_q: asyncio.Queue = asyncio.Queue()

async def receive():
return await receive_q.get()

async def send(message):
await send_q.put(message)

scope = {"type": "lifespan", "asgi": {"version": "3.0"}, "state": {}}
lifespan_task = asyncio.create_task(app(scope, receive, send))

# Trigger startup
await receive_q.put({"type": "lifespan.startup"})
startup_msg = await asyncio.wait_for(send_q.get(), timeout=5.0)
assert startup_msg["type"] == "lifespan.startup.complete", (
f"Expected lifespan.startup.complete, got: {startup_msg}"
)

try:
yield
finally:
# Trigger shutdown
await receive_q.put({"type": "lifespan.shutdown"})
try:
await asyncio.wait_for(send_q.get(), timeout=3.0)
except asyncio.TimeoutError:
pass

# Clean up the background lifespan task
if not lifespan_task.done():
lifespan_task.cancel()
try:
await asyncio.wait_for(lifespan_task, timeout=2.0)
except (asyncio.CancelledError, asyncio.TimeoutError, Exception):
pass


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _make_lifespan_app():
"""Return (app, events) where events records startup/shutdown calls."""
events: list[str] = []

@asynccontextmanager
async def lifespan(app: FastAPI):
events.append("startup")
yield
events.append("shutdown")

app = FastAPI(title="Lifespan Test App", lifespan=lifespan)

@app.get("/health", operation_id="health_check")
async def health():
return {"status": "healthy"}

return app, events


# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------


@pytest.mark.anyio
async def test_lifespan_startup_fires_with_mount_http():
"""on_startup events must run when mount_http() is used (issue #256)."""
app, events = _make_lifespan_app()
mcp = FastApiMCP(app)
mcp.mount_http()

async with run_asgi_lifespan(app):
assert "startup" in events, "lifespan startup event was not triggered"

assert "shutdown" in events, "lifespan shutdown event was not triggered"


@pytest.mark.anyio
async def test_lifespan_order():
"""startup must happen before shutdown and both must occur exactly once."""
app, events = _make_lifespan_app()
mcp = FastApiMCP(app)
mcp.mount_http()

async with run_asgi_lifespan(app):
pass # startup already done by now

assert events == ["startup", "shutdown"], f"Unexpected event order: {events}"


@pytest.mark.anyio
async def test_lifespan_resource_available_during_request():
"""Resources initialised in the lifespan must be accessible during tool calls."""
db_state: dict = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
db_state["connected"] = True
yield
db_state["connected"] = False

app = FastAPI(title="DB App", lifespan=lifespan)

@app.get("/db-status", operation_id="db_status")
async def db_status():
return {"connected": db_state.get("connected", False)}

mcp = FastApiMCP(app)
mcp.mount_http()

async with run_asgi_lifespan(app):
assert db_state.get("connected") is True

# Make an HTTP request while lifespan is active
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app),
base_url="http://testserver",
) as client:
response = await client.get("/db-status")
assert response.status_code == 200
assert response.json()["connected"] is True

# After shutdown the resource must be cleaned up
assert db_state.get("connected") is False


@pytest.mark.anyio
async def test_mcp_endpoint_reachable_after_lifespan_startup():
"""The /mcp endpoint must respond after lifespan startup completes."""
app, events = _make_lifespan_app()
mcp = FastApiMCP(app)
mcp.mount_http()

async with run_asgi_lifespan(app):
assert "startup" in events

async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app),
base_url="http://testserver",
) as client:
response = await client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "initialize",
"id": 1,
"params": {
"protocolVersion": types.LATEST_PROTOCOL_VERSION,
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.0.0"},
},
},
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
},
)

assert response.status_code == 200
result = response.json()
assert result["jsonrpc"] == "2.0"
assert "result" in result

assert "shutdown" in events


@pytest.mark.anyio
async def test_no_lifespan_still_works():
"""Apps without a lifespan context manager must continue to work fine."""
app = FastAPI(title="No Lifespan App")

@app.get("/ping", operation_id="ping")
async def ping():
return {"pong": True}

mcp = FastApiMCP(app)
mcp.mount_http()

async with run_asgi_lifespan(app):
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app),
base_url="http://testserver",
) as client:
response = await client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "initialize",
"id": 1,
"params": {
"protocolVersion": types.LATEST_PROTOCOL_VERSION,
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.0.0"},
},
},
headers={
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
},
)
assert response.status_code == 200


@pytest.mark.anyio
async def test_session_manager_starts_after_user_startup():
"""
The MCP session manager must start AFTER the user's startup so that tools
can access resources initialised in the lifespan.
"""
order: list[str] = []

@asynccontextmanager
async def lifespan(app: FastAPI):
order.append("user_startup")
yield
order.append("user_shutdown")

app = FastAPI(title="Order App", lifespan=lifespan)

@app.get("/noop", operation_id="noop")
async def noop():
return {}

mcp = FastApiMCP(app)
mcp.mount_http()

# Patch startup/shutdown to record order
original_startup = mcp._http_transport.startup
original_shutdown = mcp._http_transport.shutdown

async def patched_startup():
order.append("mcp_startup")
await original_startup()

async def patched_shutdown():
order.append("mcp_shutdown")
await original_shutdown()

mcp._http_transport.startup = patched_startup
mcp._http_transport.shutdown = patched_shutdown

async with run_asgi_lifespan(app):
pass

# user startup must come before MCP startup
assert order.index("user_startup") < order.index("mcp_startup"), (
f"Expected user_startup before mcp_startup, got: {order}"
)
# MCP shutdown must come before user shutdown
assert order.index("mcp_shutdown") < order.index("user_shutdown"), (
f"Expected mcp_shutdown before user_shutdown, got: {order}"
)