Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 133 additions & 1 deletion cecli/mcp/server.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,32 @@
import asyncio
import logging
import os
import random
import webbrowser
from contextlib import AsyncExitStack
from enum import Enum, auto
from urllib.parse import urlparse

MIN_KEEPALIVE_INTERVAL = 5
MAX_KEEPALIVE_INTERVAL = 300
FAILED_PING_THRESHOLD = 3


class ConnectionState(Enum):
CONNECTED = auto()
UNHEALTHY = auto()
DISCONNECTED = auto()


import httpx

Check failure on line 21 in cecli/mcp/server.py

View workflow job for this annotation

GitHub Actions / pre-commit

E402 module level import not at top of file
from mcp import ClientSession, StdioServerParameters

Check failure on line 22 in cecli/mcp/server.py

View workflow job for this annotation

GitHub Actions / pre-commit

E402 module level import not at top of file
from mcp.client.auth import OAuthClientProvider

Check failure on line 23 in cecli/mcp/server.py

View workflow job for this annotation

GitHub Actions / pre-commit

E402 module level import not at top of file
from mcp.client.sse import sse_client

Check failure on line 24 in cecli/mcp/server.py

View workflow job for this annotation

GitHub Actions / pre-commit

E402 module level import not at top of file
from mcp.client.stdio import stdio_client

Check failure on line 25 in cecli/mcp/server.py

View workflow job for this annotation

GitHub Actions / pre-commit

E402 module level import not at top of file
from mcp.client.streamable_http import streamable_http_client

Check failure on line 26 in cecli/mcp/server.py

View workflow job for this annotation

GitHub Actions / pre-commit

E402 module level import not at top of file
from mcp.shared.auth import OAuthClientMetadata

Check failure on line 27 in cecli/mcp/server.py

View workflow job for this annotation

GitHub Actions / pre-commit

E402 module level import not at top of file

from .oauth import (

Check failure on line 29 in cecli/mcp/server.py

View workflow job for this annotation

GitHub Actions / pre-commit

E402 module level import not at top of file
FileBasedTokenStorage,
create_oauth_callback_server,
get_mcp_oauth_token,
Expand Down Expand Up @@ -111,6 +124,13 @@
class HttpBasedMcpServer(McpServer):
"""Base class for HTTP-based MCP servers (HTTP streaming and SSE)."""

def __init__(self, server_config, io=None, verbose=False):
super().__init__(server_config, io, verbose)
self._state: ConnectionState = ConnectionState.CONNECTED
self._failed_pings: int = 0
self._keepalive_task: asyncio.Task | None = None
self._http_client: httpx.AsyncClient | None = None

async def _create_oauth_provider(self):
"""Create an OAuthClientProvider using the MCP SDK."""
parsed = urlparse(self.config.get("url"))
Expand Down Expand Up @@ -214,6 +234,7 @@
timeout=30,
)
)
self._http_client = http_client

transport = await self.exit_stack.enter_async_context(
self._create_transport(url, http_client=http_client)
Expand All @@ -224,6 +245,7 @@
session = await self.exit_stack.enter_async_context(ClientSession(read, write))
await session.initialize()
self.session = session
await self.start_keepalive()

if oauth_provider.context.oauth_metadata:
token_endpoint = oauth_provider._get_token_endpoint()
Expand All @@ -241,10 +263,119 @@
await self.disconnect()
raise

async def disconnect(self):
async def start_keepalive(self):
"""Start the background keepalive loop if configured."""
interval = self.config.get("keepalive_interval")
if interval is None:
return

try:
interval = int(interval)
if not (MIN_KEEPALIVE_INTERVAL <= interval <= MAX_KEEPALIVE_INTERVAL):
if self.verbose and self.io:
self.io.tool_warning(
f"Keepalive interval {interval} out of range ({MIN_KEEPALIVE_INTERVAL}-{MAX_KEEPALIVE_INTERVAL}). Ignoring."

Check failure on line 277 in cecli/mcp/server.py

View workflow job for this annotation

GitHub Actions / pre-commit

E501 line too long (132 > 120 characters)
)
return
except (ValueError, TypeError):
if self.verbose and self.io:
self.io.tool_warning(f"Invalid keepalive interval {interval}. Must be an integer.")
return

if self._keepalive_task and not self._keepalive_task.done():
self._keepalive_task.cancel()

self._keepalive_task = asyncio.create_task(self._keepalive_loop(interval))
if self.verbose and self.io:
self.io.tool_output(f"Started keepalive loop for {self.name} (interval: {interval}s)")

async def _keepalive_loop(self, interval: int):
"""Background loop that sends periodic heartbeats to the MCP server."""
try:
while True:
# Jitter: ±10% to prevent timing analysis
jitter = interval * 0.1 * (2 * random.random() - 1)
await asyncio.sleep(interval + jitter)

if not self._http_client:
continue

try:
url = self.config.get("url")
headers = self.config.get("headers", {})

# Use OPTIONS request as a lightweight heartbeat
response = await self._http_client.options(url, headers=headers)
if response.status_code == 200:
self._state = ConnectionState.CONNECTED
self._failed_pings = 0
else:
raise httpx.HTTPStatusError(
f"Unexpected status {response.status_code}",
request=response.request,
response=response,
)
except Exception as e:

Check failure on line 318 in cecli/mcp/server.py

View workflow job for this annotation

GitHub Actions / pre-commit

F841 local variable 'e' is assigned to but never used
self._failed_pings += 1
if self._failed_pings >= FAILED_PING_THRESHOLD:
self._state = ConnectionState.DISCONNECTED
if self.verbose and self.io:
self.io.tool_warning(
f"MCP server {self.name} disconnected after {self._failed_pings} failed pings. Attempting reconnect..."
)
await self.reconnect()
else:
self._state = ConnectionState.UNHEALTHY
if self.verbose and self.io:
self.io.tool_output(
f"MCP server {self.name} unhealthy (ping {self._failed_pings}/{FAILED_PING_THRESHOLD})"
)
except asyncio.CancelledError:
pass
except Exception as e:
logging.error(f"Keepalive loop for {self.name} crashed: {e}")

async def reconnect(self):
"""Attempt to reconnect to the server using exponential backoff."""
initial_delay = 1
multiplier = 2
max_delay = 300

attempt = 0
while self._state == ConnectionState.DISCONNECTED:
delay = min(initial_delay * (multiplier**attempt), max_delay)
# Jitter: ±20%
jitter = delay * 0.2 * (2 * random.random() - 1)
await asyncio.sleep(delay + jitter)

try:
if self.verbose and self.io:
self.io.tool_output(
f"Attempting to reconnect to {self.name} (attempt {attempt + 1})..."
)

# Clean up old session/client without cancelling the keepalive task
await self.disconnect(cancel_keepalive=False)
await self.connect()

self._state = ConnectionState.CONNECTED
self._failed_pings = 0
if self.verbose and self.io:
self.io.tool_output(f"Successfully reconnected to {self.name}")
break
except Exception as e:
attempt += 1
if self.verbose and self.io:
self.io.tool_warning(
f"Reconnection attempt {attempt} failed for {self.name}: {e}"
)

async def disconnect(self, cancel_keepalive: bool = True):
"""Disconnect from the MCP server and clean up resources."""
async with self._cleanup_lock:
try:
if cancel_keepalive and self._keepalive_task:
self._keepalive_task.cancel()
if hasattr(self, "_oauth_shutdown"):
self._oauth_shutdown()
await self.exit_stack.aclose()
Expand All @@ -256,6 +387,7 @@
logging.error(f"Error during cleanup of server {self.name}: {e}")
finally:
self.session = None
self._http_client = None


class HttpStreamingServer(HttpBasedMcpServer):
Expand Down
104 changes: 104 additions & 0 deletions tests/mcp/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import asyncio
import random
from typing import Any, AsyncGenerator, Dict
from unittest.mock import AsyncMock, MagicMock

import pytest

from cecli.mcp.server import HttpBasedMcpServer, HttpStreamingServer
from tests.mcp.mock_server import MockMcpServer


@pytest.fixture
def mock_mcp_server() -> MockMcpServer:
"""Fixture providing a mock MCP server instance."""
server = MockMcpServer()
return server


@pytest.fixture
async def running_mock_server(mock_mcp_server) -> AsyncGenerator[MockMcpServer, None]:
"""Fixture providing a running mock MCP server."""
url = await mock_mcp_server.start()
yield mock_mcp_server
await mock_mcp_server.stop()


@pytest.fixture
def http_server_config(running_mock_server) -> Dict[str, Any]:
"""Fixture providing a basic HTTP server configuration."""
return {
"name": "test-server",
"url": running_mock_server,
"type": "http",
"keepalive_interval": 1, # 1 second for fast tests
"headers": {},
"enabled": True,
}


@pytest.fixture
def http_streaming_server_config(running_mock_server) -> Dict[str, Any]:
"""Fixture providing an HTTP streaming server configuration."""
return {
"name": "test-streaming-server",
"url": running_mock_server,
"type": "streamable_http",
"keepalive_interval": 1,
"headers": {},
"enabled": True,
}


@pytest.fixture
def mock_io():
"""Fixture providing a mock IO object."""
io = MagicMock()
io.tool_output = MagicMock()
io.tool_error = MagicMock()
io.tool_warning = MagicMock()
return io


@pytest.fixture
def http_based_server(http_server_config, mock_io) -> HttpBasedMcpServer:
"""Fixture providing an HttpBasedMcpServer instance."""
return HttpBasedMcpServer(http_server_config, io=mock_io)


@pytest.fixture
def http_streaming_server(http_streaming_server_config, mock_io) -> HttpStreamingServer:
"""Fixture providing an HttpStreamingServer instance."""
return HttpStreamingServer(http_streaming_server_config, io=mock_io)


# Test utilities for inspecting internal state
class ServerStateInspector:
"""Utility class to inspect internal state of HttpBasedMcpServer for testing."""

@staticmethod
def get_state(server: HttpBasedMcpServer):
"""Get the connection state of the server."""
return server._state

@staticmethod
def get_failed_pings(server: HttpBasedMcpServer):
"""Get the number of failed pings."""
return server._failed_pings

@staticmethod
def get_keepalive_task(server: HttpBasedMcpServer):
"""Get the keepalive task."""
return server._keepalive_task

@staticmethod
def is_keepalive_running(server: HttpBasedMcpServer):
"""Check if the keepalive task is running."""
task = server._keepalive_task
return task is not None and not task.done()


@pytest.fixture
def server_inspector():
"""Fixture providing a server state inspector."""
return ServerStateInspector()
Loading
Loading