Skip to content

fix: resolve URL path truncation in SSE transport for proxied servers #1211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
86 changes: 64 additions & 22 deletions src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
SSE Server Transport Module
This module implements a Server-Sent Events (SSE) transport layer for MCP servers.
Endpoints are specified as relative paths. This aligns with common client URL
construction patterns (for example, `urllib.parse.urljoin`) and works correctly
when applications are deployed behind proxies or at subpaths.
Example usage:
```
# Create an SSE transport at an endpoint
sse = SseServerTransport("/messages/")
```python
# Recommended: provide a relative path segment (no scheme/host/query/fragment).
# Using "messages/" works well with clients that build absolute URLs using
# `urllib.parse.urljoin`, including in proxied/subpath deployments.
sse = SseServerTransport("messages/")
# Create Starlette routes for SSE and message handling
routes = [
Expand All @@ -30,6 +35,17 @@ async def handle_sse(request):
uvicorn.run(starlette_app, host="127.0.0.1", port=port)
```
Path behavior examples inside the server (final path emitted to clients):
- root_path="" and endpoint="messages/" -> "/messages/"
- root_path="/api" and endpoint="messages/" -> "/api/messages/"
Note: When clients use `urllib.parse.urljoin(base, path)`, joining a segment that
starts with "/" replaces the base path. Providing a relative segment like
`"messages/?id=1"` preserves the base path as intended.
For servers behind proxies or mounted at subpaths, prefer a relative path without
leading slash (e.g., "messages/") to ensure correct joining with `urljoin`.
Note: The handle_sse function must return a Response to avoid a "TypeError: 'NoneType'
object is not callable" error when client disconnects. The example above returns
an empty Response() after the SSE connection ends to fix this.
Expand Down Expand Up @@ -83,8 +99,10 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings |
messages to the relative path given.
Args:
endpoint: A relative path where messages should be posted
(e.g., "/messages/").
endpoint: Relative path segment where messages should be posted
(e.g., "messages/"). Avoid scheme/host/query/fragment. When
clients construct absolute URLs using `urllib.parse.urljoin`,
relative segments preserve any existing base path.
security_settings: Optional security settings for DNS rebinding protection.
Note:
Expand All @@ -96,28 +114,60 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings |
3. Portability: The same endpoint configuration works across different
environments (development, staging, production)
The endpoint path handling preserves the provided relative path and is
suitable for deployments under proxies or subpaths.
Raises:
ValueError: If the endpoint is a full URL instead of a relative path
"""

super().__init__()

# Validate that endpoint is a relative path and not a full URL
# Validate that endpoint is a relative path and not a full URL.
if "://" in endpoint or endpoint.startswith("//") or "?" in endpoint or "#" in endpoint:
raise ValueError(
f"Given endpoint: {endpoint} is not a relative path (e.g., '/messages/'), "
"expecting a relative path (e.g., '/messages/')."
f"Given endpoint: {endpoint} is not a relative path (e.g., 'messages/'), "
"expecting a relative path with no scheme/host/query/fragment."
)

# Ensure endpoint starts with a forward slash
if not endpoint.startswith("/"):
endpoint = "/" + endpoint

# Store the endpoint as provided to retain relative-path semantics and make
# client URL construction predictable across deployment topologies.
self._endpoint = endpoint
self._read_stream_writers = {}
self._security = TransportSecurityMiddleware(security_settings)
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")

def _build_message_path(self, root_path: str) -> str:
"""
Helper method to properly construct the message path
Constructs the message path relative to the app's mount point and the
provided `root_path`. The stored endpoint is treated as path-absolute if
it starts with "/", otherwise as a relative segment.
Args:
root_path: The root path from ASGI scope (e.g., "" or "/api_prefix")
Returns:
The properly constructed path for client message posting
"""
# Clean up the root path
clean_root_path = root_path.rstrip("/")

# If endpoint starts with "/", treat it as path-absolute from the app mount;
# otherwise, treat it as relative to `root_path`.
if self._endpoint.startswith("/"):
# Path-absolute within the app mount - just concatenate
full_path = clean_root_path + self._endpoint
else:
# Relative path - ensure proper joining
if clean_root_path:
full_path = clean_root_path + "/" + self._endpoint
else:
full_path = "/" + self._endpoint

return full_path

@asynccontextmanager
async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http":
Expand Down Expand Up @@ -145,17 +195,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
self._read_stream_writers[session_id] = read_stream_writer
logger.debug(f"Created new session with ID: {session_id}")

# Determine the full path for the message endpoint to be sent to the client.
# scope['root_path'] is the prefix where the current Starlette app
# instance is mounted.
# e.g., "" if top-level, or "/api_prefix" if mounted under "/api_prefix".
# Use the new helper method for proper path construction
root_path = scope.get("root_path", "")

# self._endpoint is the path *within* this app, e.g., "/messages".
# Concatenating them gives the full absolute path from the server root.
# e.g., "" + "/messages" -> "/messages"
# e.g., "/api_prefix" + "/messages" -> "/api_prefix/messages"
full_message_path_for_client = root_path.rstrip("/") + self._endpoint
full_message_path_for_client = self._build_message_path(root_path)

# This is the URI (path + query) the client will use to POST messages.
client_post_uri_data = f"{quote(full_message_path_for_client)}?session_id={session_id.hex}"
Expand Down
44 changes: 44 additions & 0 deletions tests/server/test_sse_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,47 @@ async def test_sse_security_post_valid_content_type(server_port: int):
finally:
process.terminate()
process.join()


@pytest.mark.anyio
async def test_endpoint_validation_rejects_absolute_urls():
"""Validate endpoint format: relative path segments only.

Context on URL joining (urllib.parse.urljoin):
- Joining a segment starting with "/" resets to the host root:
urljoin("http://host/app/sse", "/messages") -> "http://host/messages"
- Joining a relative segment appends relative to the base:
urljoin("http://host/hello/world", "messages") -> "http://host/hello/messages"
urljoin("http://host/hello/world/", "messages") -> "http://host/hello/world/messages"

This test ensures the transport accepts relative path segments (e.g., "messages/"),
rejects full URLs or paths containing query/fragment components, and stores accepted
values verbatim (no normalization). Both leading-slash and non-leading-slash forms
are permitted because the server handles construction relative to its mount path.
"""
# Reject: fully-qualified URLs and segments that include query/fragment
invalid_endpoints = [
"http://example.com/messages/",
"https://example.com/messages/",
"//example.com/messages/",
"/messages/?query=test",
"/messages/#fragment",
]

for invalid_endpoint in invalid_endpoints:
with pytest.raises(ValueError, match="is not a relative path"):
SseServerTransport(invalid_endpoint)

# Accept: relative path forms; endpoint is stored as provided (no normalization)
valid_endpoints_and_expected = [
("/messages/", "/messages/"), # Leading-slash path segment
("messages/", "messages/"), # Non-leading-slash path segment
("/api/v1/messages/", "/api/v1/messages/"),
("api/v1/messages/", "api/v1/messages/"),
]

for valid_endpoint, expected_stored_value in valid_endpoints_and_expected:
# Should not raise an exception
transport = SseServerTransport(valid_endpoint)
# Endpoint should be stored exactly as provided (no normalization)
assert transport._endpoint == expected_stored_value
22 changes: 16 additions & 6 deletions tests/shared/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,9 +488,9 @@ def test_sse_message_id_coercion():
@pytest.mark.parametrize(
"endpoint, expected_result",
[
# Valid endpoints - should normalize and work
# Accept: relative path forms; endpoint is stored verbatim (no normalization)
("/messages/", "/messages/"),
("messages/", "/messages/"),
("messages/", "messages/"),
("/", "/"),
# Invalid endpoints - should raise ValueError
("http://example.com/messages/", ValueError),
Expand All @@ -501,13 +501,23 @@ def test_sse_message_id_coercion():
],
)
def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]):
"""Test that SseServerTransport properly validates and normalizes endpoints."""
if isinstance(expected_result, type):
"""Validate relative endpoint semantics and storage.

Context on URL joining (urllib.parse.urljoin):
- Joining a segment starting with "/" resets to the host root:
urljoin("http://host/hello/world", "/messages") -> "http://host/messages"
- Joining a relative segment appends relative to the base:
urljoin("http://host/hello/world", "messages") -> "http://host/hello/messages"
urljoin("http://host/hello/world/", "messages/") -> "http://host/hello/world/messages/"

The transport validates that endpoints are relative path segments (no scheme/host/query/fragment)
and stores accepted values exactly as provided.
"""
if isinstance(expected_result, type) and issubclass(expected_result, Exception):
# Test invalid endpoints that should raise an exception
with pytest.raises(expected_result, match="is not a relative path.*expecting a relative path"):
SseServerTransport(endpoint)
else:
# Test valid endpoints that should normalize correctly
# Endpoint should be stored exactly as provided (no normalization)
sse = SseServerTransport(endpoint)
assert sse._endpoint == expected_result
assert sse._endpoint.startswith("/")
Loading