Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ async def _validate_request_headers(self, request: Request, send: Send) -> bool:

async def _validate_session(self, request: Request, send: Send) -> bool:
"""Validate the session ID in the request."""
if not self.mcp_session_id: # pragma: no cover
if not self.mcp_session_id:
# If we're not using session IDs, return True
return True

Expand Down Expand Up @@ -842,7 +842,7 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)

# If no protocol version provided, assume default version
if protocol_version is None: # pragma: no cover
if protocol_version is None:
protocol_version = DEFAULT_NEGOTIATED_VERSION

# Check if the protocol version is supported
Expand Down
32 changes: 29 additions & 3 deletions src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def __init__(
self._session_creation_lock = anyio.Lock()
self._server_instances: dict[str, StreamableHTTPServerTransport] = {}

# Track in-flight stateless transports for graceful shutdown
self._stateless_transports: set[StreamableHTTPServerTransport] = set()

# The task group will be set during lifespan
self._task_group = None
# Thread-safe tracking of run() calls
Expand Down Expand Up @@ -130,11 +133,28 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
yield # Let the application run
finally:
logger.info("StreamableHTTP session manager shutting down")

# Terminate all active transports before cancelling the task
# group. This closes their in-memory streams, which lets
# EventSourceResponse send a final ``more_body=False`` chunk
# — a clean HTTP close instead of a connection reset.
for transport in list(self._server_instances.values()):
try:
await transport.terminate()
except Exception: # pragma: no cover
logger.debug("Error terminating transport during shutdown", exc_info=True)
for transport in list(self._stateless_transports):
try:
await transport.terminate()
except Exception: # pragma: no cover
logger.debug("Error terminating stateless transport during shutdown", exc_info=True)

# Cancel task group to stop all spawned tasks
tg.cancel_scope.cancel()
self._task_group = None
# Clear any remaining server instances
self._server_instances.clear()
self._stateless_transports.clear()

async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Process ASGI request with proper session handling and transport setup.
Expand All @@ -161,6 +181,9 @@ async def _handle_stateless_request(self, scope: Scope, receive: Receive, send:
security_settings=self.security_settings,
)

# Track for graceful shutdown
self._stateless_transports.add(http_transport)

# Start server in a new task
async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED):
async with http_transport.connect() as streams:
Expand All @@ -173,16 +196,19 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA
self.app.create_initialization_options(),
stateless=True,
)
except Exception: # pragma: no cover
except Exception: # pragma: lax no cover
logger.exception("Stateless session crashed")

# Assert task group is not None for type checking
assert self._task_group is not None
# Start the server task
await self._task_group.start(run_stateless_server)

# Handle the HTTP request and return the response
await http_transport.handle_request(scope, receive, send)
try:
# Handle the HTTP request and return the response
await http_transport.handle_request(scope, receive, send)
finally:
self._stateless_transports.discard(http_transport)

# Terminate the transport after the request is handled
await http_transport.terminate()
Expand Down
166 changes: 164 additions & 2 deletions tests/server/test_streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
import pytest
from starlette.types import Message

from mcp import Client
from mcp import Client, types
from mcp.client.streamable_http import streamable_http_client
from mcp.server import Server, ServerRequestContext, streamable_http_manager
from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager
from mcp.types import INVALID_REQUEST, ListToolsResult, PaginatedRequestParams


Expand Down Expand Up @@ -413,3 +413,165 @@ def test_session_idle_timeout_rejects_non_positive():
def test_session_idle_timeout_rejects_stateless():
with pytest.raises(RuntimeError, match="not supported in stateless"):
StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True)


MCP_HEADERS = {
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
}

_INITIALIZE_REQUEST = {
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-03-26",
"capabilities": {},
"clientInfo": {"name": "test", "version": "0.1"},
},
}

_INITIALIZED_NOTIFICATION = {
"jsonrpc": "2.0",
"method": "notifications/initialized",
}

_TOOL_CALL_REQUEST = {
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": {"name": "slow_tool", "arguments": {"message": "hello"}},
}


def _make_slow_tool_server() -> tuple[Server, anyio.Event]:
"""Create an MCP server with a tool that blocks forever, returning
the server and an event that fires when the tool starts executing."""
tool_started = anyio.Event()

async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult:
tool_started.set()
await anyio.sleep_forever()
return types.CallToolResult( # pragma: no cover
content=[types.TextContent(type="text", text="never reached")]
)

async def handle_list_tools(
ctx: ServerRequestContext, params: PaginatedRequestParams | None
) -> ListToolsResult: # pragma: no cover
return ListToolsResult(
tools=[
types.Tool(
name="slow_tool",
description="A tool that blocks forever",
input_schema={"type": "object", "properties": {"message": {"type": "string"}}},
)
]
)

app = Server("test-graceful-shutdown", on_call_tool=handle_call_tool, on_list_tools=handle_list_tools)
return app, tool_started


class SSECloseTracker:
"""ASGI middleware that tracks whether SSE responses close cleanly.

In HTTP, a clean close means sending a final empty chunk (``0\\r\\n\\r\\n``).
At the ASGI protocol level this corresponds to a
``{"type": "http.response.body", "more_body": False}`` message.

Without graceful drain, the server task is cancelled but nothing closes
the stateless transport's streams — the SSE response hangs indefinitely
and never sends the final body. A reverse proxy (e.g. nginx) would log
"upstream prematurely closed connection while reading upstream".
"""

def __init__(self, app: StreamableHTTPASGIApp) -> None:
self.app = app
self.sse_streams_opened = 0
self.sse_streams_closed_cleanly = 0

async def __call__(self, scope: dict[str, Any], receive: Any, send: Any) -> None:
is_sse = False

async def tracking_send(message: dict[str, Any]) -> None:
nonlocal is_sse
if message["type"] == "http.response.start":
for name, value in message.get("headers", []):
if name == b"content-type" and b"text/event-stream" in value:
is_sse = True
self.sse_streams_opened += 1
break
elif message["type"] == "http.response.body" and is_sse:
if not message.get("more_body", False):
self.sse_streams_closed_cleanly += 1
await send(message)

await self.app(scope, receive, tracking_send)


@pytest.mark.anyio
async def test_graceful_shutdown_closes_sse_streams_cleanly():
"""Verify that shutting down the session manager closes in-flight SSE
streams with a proper ``more_body=False`` ASGI message.

This is the ASGI equivalent of sending the final HTTP chunk — the signal
that reverse proxies like nginx use to distinguish a clean close from a
connection reset ("upstream prematurely closed connection").

Without the graceful-drain fix, stateless transports are not tracked by
the session manager. On shutdown nothing calls ``terminate()`` on them,
so SSE responses hang indefinitely and never send the final body. With
the fix, ``run()``'s finally block iterates ``_stateless_transports`` and
terminates each one, closing the underlying memory streams and letting
``EventSourceResponse`` complete normally.
"""
app, tool_started = _make_slow_tool_server()
manager = StreamableHTTPSessionManager(app=app, stateless=True)

tracker = SSECloseTracker(StreamableHTTPASGIApp(manager))

manager_ready = anyio.Event()

with anyio.fail_after(10):
async with anyio.create_task_group() as tg:

async def run_lifespan_and_shutdown() -> None:
async with manager.run():
manager_ready.set()
with anyio.fail_after(5):
await tool_started.wait()
# manager.run() exits — graceful shutdown runs here

async def make_requests() -> None:
with anyio.fail_after(5):
await manager_ready.wait()
async with (
httpx.ASGITransport(tracker, raise_app_exceptions=False) as transport,
httpx.AsyncClient(transport=transport, base_url="http://testserver") as client,
):
# Initialize
resp = await client.post("/mcp/", json=_INITIALIZE_REQUEST, headers=MCP_HEADERS)
resp.raise_for_status()

# Send initialized notification
resp = await client.post("/mcp/", json=_INITIALIZED_NOTIFICATION, headers=MCP_HEADERS)
assert resp.status_code == 202

# Send slow tool call — returns an SSE stream that blocks
# until shutdown terminates it
await client.post(
"/mcp/",
json=_TOOL_CALL_REQUEST,
headers=MCP_HEADERS,
timeout=httpx.Timeout(10, connect=5),
)

tg.start_soon(run_lifespan_and_shutdown)
tg.start_soon(make_requests)

assert tracker.sse_streams_opened > 0, "Test should have opened at least one SSE stream"
assert tracker.sse_streams_closed_cleanly == tracker.sse_streams_opened, (
f"All {tracker.sse_streams_opened} SSE stream(s) should have closed with "
f"more_body=False, but only {tracker.sse_streams_closed_cleanly} did"
)
Loading