From 0191dd152c1ccaa8e543e640cb38d17cbefa85f9 Mon Sep 17 00:00:00 2001 From: Sid Murching Date: Mon, 21 Jul 2025 14:04:05 -0700 Subject: [PATCH 1/5] Update unit test Signed-off-by: Sid Murching --- .gitignore | 1 + tests/client/test_auth.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index e9fdca176..7c80ae585 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,4 @@ cython_debug/ .vscode/ .windsurfrules **/CLAUDE.local.md +.idea diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 46208d69c..f44383478 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -208,7 +208,7 @@ async def callback_handler() -> tuple[str, str | None]: return "test_auth_code", "test_state" provider = OAuthClientProvider( - server_url="https://api.example.com", + server_url="https://api.example.com/api/2.0/mcp", client_metadata=client_metadata, storage=mock_storage, redirect_handler=redirect_handler, From 3eeb0f22c77f1dd9991822398f4e81e116f3d367 Mon Sep 17 00:00:00 2001 From: Sid Murching Date: Mon, 21 Jul 2025 14:06:29 -0700 Subject: [PATCH 2/5] empty From f2a9b8285544e155639540756ee5fdc800c40502 Mon Sep 17 00:00:00 2001 From: Sid Murching Date: Mon, 21 Jul 2025 14:19:50 -0700 Subject: [PATCH 3/5] Include path to resource in oauth-protected-resource request Signed-off-by: Sid Murching --- src/mcp/client/auth.py | 13 ++++++++++--- tests/client/test_auth.py | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index b00db7b9b..33269d43c 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -233,9 +233,16 @@ async def _discover_protected_resource(self, init_response: httpx.Response) -> h url = self._extract_resource_metadata_from_www_auth(init_response) if not url: - # Fallback to well-known discovery - auth_base_url = self.context.get_authorization_base_url(self.context.server_url) - url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource") + # Fallback to well-known discovery with path component included + parsed = urlparse(self.context.server_url) + auth_base_url = f"{parsed.scheme}://{parsed.netloc}" + + if parsed.path and parsed.path != "/": + # Include path component in the well-known URL + path_component = parsed.path.rstrip("/") + url = urljoin(auth_base_url, f"/.well-known/oauth-protected-resource{path_component}") + else: + url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource") return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index f44383478..49a3fb633 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -222,7 +222,7 @@ async def callback_handler() -> tuple[str, str | None]: request = await provider._discover_protected_resource(init_response) assert request.method == "GET" - assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource" + assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource/api/2.0/mcp" assert "mcp-protocol-version" in request.headers # Test with WWW-Authenticate header From a39c24dbce52418e41cbfb5b1501b504c6e297ce Mon Sep 17 00:00:00 2001 From: Sid Murching Date: Mon, 21 Jul 2025 14:35:42 -0700 Subject: [PATCH 4/5] Simplify Signed-off-by: Sid Murching --- src/mcp/client/auth.py | 59 ++++++++++++++++++++++++--------------- tests/client/test_auth.py | 35 +++++++++++------------ 2 files changed, 54 insertions(+), 40 deletions(-) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 33269d43c..d7794e79c 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -108,10 +108,6 @@ class OAuthContext: # State lock: anyio.Lock = field(default_factory=anyio.Lock) - # Discovery state for fallback support - discovery_base_url: str | None = None - discovery_pathname: str | None = None - def get_authorization_base_url(self, server_url: str) -> str: """Extract base URL by removing path component.""" parsed = urlparse(server_url) @@ -228,23 +224,23 @@ def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response return None - async def _discover_protected_resource(self, init_response: httpx.Response) -> httpx.Request: - # RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response - url = self._extract_resource_metadata_from_www_auth(init_response) - - if not url: - # Fallback to well-known discovery with path component included - parsed = urlparse(self.context.server_url) - auth_base_url = f"{parsed.scheme}://{parsed.netloc}" + def _get_protected_resource_discovery_urls(self) -> list[str]: + """Generate ordered list of URLs for protected resource discovery attempts.""" + urls: list[str] = [] + parsed = urlparse(self.context.server_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" - if parsed.path and parsed.path != "/": - # Include path component in the well-known URL - path_component = parsed.path.rstrip("/") - url = urljoin(auth_base_url, f"/.well-known/oauth-protected-resource{path_component}") - else: - url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource") + if parsed.path and parsed.path != "/": + # Try path-specific endpoint first + path_component = parsed.path.rstrip("/") + urls.append(urljoin(base_url, f"/.well-known/oauth-protected-resource{path_component}")) + # Then fallback to base endpoint + urls.append(urljoin(base_url, "/.well-known/oauth-protected-resource")) + else: + # No path, just use base endpoint + urls.append(urljoin(base_url, "/.well-known/oauth-protected-resource")) - return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) + return urls async def _handle_protected_resource_response(self, response: httpx.Response) -> None: """Handle discovery response.""" @@ -517,9 +513,28 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. try: # OAuth flow must be inline due to generator constraints # Step 1: Discover protected resource metadata (RFC9728 with WWW-Authenticate support) - discovery_request = await self._discover_protected_resource(response) - discovery_response = yield discovery_request - await self._handle_protected_resource_response(discovery_response) + # Check if WWW-Authenticate provides resource_metadata URL first + www_auth_url = self._extract_resource_metadata_from_www_auth(response) + if www_auth_url: + discovery_request = httpx.Request( + "GET", www_auth_url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION} + ) + discovery_response = yield discovery_request + await self._handle_protected_resource_response(discovery_response) + else: + # Try well-known discovery URLs with fallback + discovery_urls = self._get_protected_resource_discovery_urls() + for url in discovery_urls: + discovery_request = httpx.Request( + "GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION} + ) + discovery_response = yield discovery_request + + if discovery_response.status_code == 200: + await self._handle_protected_resource_response(discovery_response) + break # Success, stop trying other URLs + elif discovery_response.status_code != 404: + break # Non-404 error, stop trying # Step 2: Discover OAuth metadata (with fallback for legacy servers) discovery_urls = self._get_discovery_urls() diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 49a3fb633..07651731c 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -198,8 +198,8 @@ class TestOAuthFlow: """Test OAuth flow methods.""" @pytest.mark.anyio - async def test_discover_protected_resource_request(self, client_metadata, mock_storage): - """Test protected resource discovery request building maintains backward compatibility.""" + async def test_protected_resource_discovery_urls(self, client_metadata, mock_storage): + """Test protected resource discovery URL generation with fallback.""" async def redirect_handler(url: str) -> None: pass @@ -207,6 +207,7 @@ async def redirect_handler(url: str) -> None: async def callback_handler() -> tuple[str, str | None]: return "test_auth_code", "test_state" + # Test with path component provider = OAuthClientProvider( server_url="https://api.example.com/api/2.0/mcp", client_metadata=client_metadata, @@ -215,25 +216,23 @@ async def callback_handler() -> tuple[str, str | None]: callback_handler=callback_handler, ) - # Test without WWW-Authenticate (fallback) - init_response = httpx.Response( - status_code=401, headers={}, request=httpx.Request("GET", "https://request-api.example.com") - ) - - request = await provider._discover_protected_resource(init_response) - assert request.method == "GET" - assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource/api/2.0/mcp" - assert "mcp-protocol-version" in request.headers + urls = provider._get_protected_resource_discovery_urls() + assert urls == [ + "https://api.example.com/.well-known/oauth-protected-resource/api/2.0/mcp", + "https://api.example.com/.well-known/oauth-protected-resource", + ] - # Test with WWW-Authenticate header - init_response.headers["WWW-Authenticate"] = ( - 'Bearer resource_metadata="https://prm.example.com/.well-known/oauth-protected-resource/path"' + # Test without path component + provider = OAuthClientProvider( + server_url="https://api.example.com", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, ) - request = await provider._discover_protected_resource(init_response) - assert request.method == "GET" - assert str(request.url) == "https://prm.example.com/.well-known/oauth-protected-resource/path" - assert "mcp-protocol-version" in request.headers + urls = provider._get_protected_resource_discovery_urls() + assert urls == ["https://api.example.com/.well-known/oauth-protected-resource"] @pytest.mark.anyio def test_create_oauth_metadata_request(self, oauth_provider): From 2fe913483edf6d178f1449241d3ff6f497168340 Mon Sep 17 00:00:00 2001 From: Sid Murching Date: Mon, 21 Jul 2025 14:49:15 -0700 Subject: [PATCH 5/5] Add tests Signed-off-by: Sid Murching --- tests/client/test_auth.py | 179 +++++++++++++++++++++++++++++++++++++- 1 file changed, 175 insertions(+), 4 deletions(-) diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 07651731c..86caae84f 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -198,8 +198,8 @@ class TestOAuthFlow: """Test OAuth flow methods.""" @pytest.mark.anyio - async def test_protected_resource_discovery_urls(self, client_metadata, mock_storage): - """Test protected resource discovery URL generation with fallback.""" + async def test_protected_resource_discovery_urls_generation(self, client_metadata, mock_storage): + """Test that discovery URL generation works correctly for different server URLs.""" async def redirect_handler(url: str) -> None: pass @@ -207,7 +207,7 @@ async def redirect_handler(url: str) -> None: async def callback_handler() -> tuple[str, str | None]: return "test_auth_code", "test_state" - # Test with path component + # Test with path component - should have both path-specific and base endpoints provider = OAuthClientProvider( server_url="https://api.example.com/api/2.0/mcp", client_metadata=client_metadata, @@ -222,7 +222,7 @@ async def callback_handler() -> tuple[str, str | None]: "https://api.example.com/.well-known/oauth-protected-resource", ] - # Test without path component + # Test without path component - should only have base endpoint provider = OAuthClientProvider( server_url="https://api.example.com", client_metadata=client_metadata, @@ -594,6 +594,177 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider, mock_storage): assert oauth_provider.context.current_tokens.access_token == "new_access_token" assert oauth_provider.context.token_expiry_time is not None + @pytest.mark.anyio + async def test_auth_flow_protected_resource_fallback(self, client_metadata, mock_storage): + """Test that the OAuth flow correctly implements fallback from path-specific to base endpoint.""" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/api/2.0/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + test_request = httpx.Request("GET", "https://api.example.com/api/2.0/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # Step 1: Original request without auth + request = await auth_flow.__anext__() + assert "Authorization" not in request.headers + + # Step 2: 401 triggers protected resource discovery - should try path-specific first + response = httpx.Response(401, request=test_request) + path_discovery_request = await auth_flow.asend(response) + assert ( + str(path_discovery_request.url) + == "https://api.example.com/.well-known/oauth-protected-resource/api/2.0/mcp" + ) + + # Step 3: Path-specific fails with 404 - should trigger fallback + path_404_response = httpx.Response(404, request=path_discovery_request) + base_discovery_request = await auth_flow.asend(path_404_response) + assert str(base_discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource" + + # Step 4: Base endpoint succeeds - should store metadata and continue to OAuth discovery + successful_response = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com", "authorization_servers": ["https://api.example.com"]}', + request=base_discovery_request, + ) + + # Verify the fallback worked and metadata was stored + await auth_flow.asend(successful_response) + assert provider.context.protected_resource_metadata is not None + assert str(provider.context.protected_resource_metadata.resource) == "https://api.example.com/" + + # Clean up the generator + try: + await auth_flow.aclose() + except Exception: + pass + + @pytest.mark.anyio + async def test_auth_flow_www_authenticate_no_fallback(self, client_metadata, mock_storage): + """Test that WWW-Authenticate header skips fallback logic entirely.""" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/api/2.0/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + test_request = httpx.Request("GET", "https://api.example.com/api/2.0/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # Step 1: Original request without auth + request = await auth_flow.__anext__() + assert "Authorization" not in request.headers + + # Step 2: 401 with WWW-Authenticate should use that URL directly + response = httpx.Response( + 401, + headers={ + "WWW-Authenticate": 'Bearer resource_metadata="https://custom.example.com/.well-known/oauth-protected-resource"' + }, + request=test_request, + ) + + www_auth_request = await auth_flow.asend(response) + assert str(www_auth_request.url) == "https://custom.example.com/.well-known/oauth-protected-resource" + + # Step 3: Should proceed directly to OAuth metadata discovery (no fallback attempted) + successful_response = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com/api/2.0/mcp", "authorization_servers": ["https://api.example.com"]}', + request=www_auth_request, + ) + + await auth_flow.asend(successful_response) + assert provider.context.protected_resource_metadata is not None + + # Clean up the generator + try: + await auth_flow.aclose() + except Exception: + pass + + @pytest.mark.anyio + async def test_auth_flow_no_fallback_on_success(self, client_metadata, mock_storage): + """Test that first successful discovery response stops the fallback process.""" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/api/2.0/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + test_request = httpx.Request("GET", "https://api.example.com/api/2.0/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # Step 1: Original request without auth + request = await auth_flow.__anext__() + assert "Authorization" not in request.headers + + # Step 2: 401 triggers path-specific discovery + response = httpx.Response(401, request=test_request) + path_discovery_request = await auth_flow.asend(response) + assert ( + str(path_discovery_request.url) + == "https://api.example.com/.well-known/oauth-protected-resource/api/2.0/mcp" + ) + + # Step 3: Path-specific succeeds - should skip fallback and go to OAuth discovery + successful_response = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com/api/2.0/mcp", "authorization_servers": ["https://api.example.com"]}', + request=path_discovery_request, + ) + + await auth_flow.asend(successful_response) + assert provider.context.protected_resource_metadata is not None + assert str(provider.context.protected_resource_metadata.resource) == "https://api.example.com/api/2.0/mcp" + + # Clean up the generator + try: + await auth_flow.aclose() + except Exception: + pass + @pytest.mark.parametrize( (