diff --git a/.gitignore b/.gitignore index e9fdca176..7c80ae585 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,4 @@ cython_debug/ .vscode/ .windsurfrules **/CLAUDE.local.md +.idea diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index b00db7b9b..d7794e79c 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -108,10 +108,6 @@ class OAuthContext: # State lock: anyio.Lock = field(default_factory=anyio.Lock) - # Discovery state for fallback support - discovery_base_url: str | None = None - discovery_pathname: str | None = None - def get_authorization_base_url(self, server_url: str) -> str: """Extract base URL by removing path component.""" parsed = urlparse(server_url) @@ -228,16 +224,23 @@ def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response return None - async def _discover_protected_resource(self, init_response: httpx.Response) -> httpx.Request: - # RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response - url = self._extract_resource_metadata_from_www_auth(init_response) + def _get_protected_resource_discovery_urls(self) -> list[str]: + """Generate ordered list of URLs for protected resource discovery attempts.""" + urls: list[str] = [] + parsed = urlparse(self.context.server_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" - if not url: - # Fallback to well-known discovery - auth_base_url = self.context.get_authorization_base_url(self.context.server_url) - url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource") + if parsed.path and parsed.path != "/": + # Try path-specific endpoint first + path_component = parsed.path.rstrip("/") + urls.append(urljoin(base_url, f"/.well-known/oauth-protected-resource{path_component}")) + # Then fallback to base endpoint + urls.append(urljoin(base_url, "/.well-known/oauth-protected-resource")) + else: + # No path, just use base endpoint + urls.append(urljoin(base_url, "/.well-known/oauth-protected-resource")) - return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) + return urls async def _handle_protected_resource_response(self, response: httpx.Response) -> None: """Handle discovery response.""" @@ -510,9 +513,28 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. try: # OAuth flow must be inline due to generator constraints # Step 1: Discover protected resource metadata (RFC9728 with WWW-Authenticate support) - discovery_request = await self._discover_protected_resource(response) - discovery_response = yield discovery_request - await self._handle_protected_resource_response(discovery_response) + # Check if WWW-Authenticate provides resource_metadata URL first + www_auth_url = self._extract_resource_metadata_from_www_auth(response) + if www_auth_url: + discovery_request = httpx.Request( + "GET", www_auth_url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION} + ) + discovery_response = yield discovery_request + await self._handle_protected_resource_response(discovery_response) + else: + # Try well-known discovery URLs with fallback + discovery_urls = self._get_protected_resource_discovery_urls() + for url in discovery_urls: + discovery_request = httpx.Request( + "GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION} + ) + discovery_response = yield discovery_request + + if discovery_response.status_code == 200: + await self._handle_protected_resource_response(discovery_response) + break # Success, stop trying other URLs + elif discovery_response.status_code != 404: + break # Non-404 error, stop trying # Step 2: Discover OAuth metadata (with fallback for legacy servers) discovery_urls = self._get_discovery_urls() diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 46208d69c..86caae84f 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -198,8 +198,8 @@ class TestOAuthFlow: """Test OAuth flow methods.""" @pytest.mark.anyio - async def test_discover_protected_resource_request(self, client_metadata, mock_storage): - """Test protected resource discovery request building maintains backward compatibility.""" + async def test_protected_resource_discovery_urls_generation(self, client_metadata, mock_storage): + """Test that discovery URL generation works correctly for different server URLs.""" async def redirect_handler(url: str) -> None: pass @@ -207,33 +207,32 @@ async def redirect_handler(url: str) -> None: async def callback_handler() -> tuple[str, str | None]: return "test_auth_code", "test_state" + # Test with path component - should have both path-specific and base endpoints provider = OAuthClientProvider( - server_url="https://api.example.com", + server_url="https://api.example.com/api/2.0/mcp", client_metadata=client_metadata, storage=mock_storage, redirect_handler=redirect_handler, callback_handler=callback_handler, ) - # Test without WWW-Authenticate (fallback) - init_response = httpx.Response( - status_code=401, headers={}, request=httpx.Request("GET", "https://request-api.example.com") - ) - - request = await provider._discover_protected_resource(init_response) - assert request.method == "GET" - assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource" - assert "mcp-protocol-version" in request.headers + urls = provider._get_protected_resource_discovery_urls() + assert urls == [ + "https://api.example.com/.well-known/oauth-protected-resource/api/2.0/mcp", + "https://api.example.com/.well-known/oauth-protected-resource", + ] - # Test with WWW-Authenticate header - init_response.headers["WWW-Authenticate"] = ( - 'Bearer resource_metadata="https://prm.example.com/.well-known/oauth-protected-resource/path"' + # Test without path component - should only have base endpoint + provider = OAuthClientProvider( + server_url="https://api.example.com", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, ) - request = await provider._discover_protected_resource(init_response) - assert request.method == "GET" - assert str(request.url) == "https://prm.example.com/.well-known/oauth-protected-resource/path" - assert "mcp-protocol-version" in request.headers + urls = provider._get_protected_resource_discovery_urls() + assert urls == ["https://api.example.com/.well-known/oauth-protected-resource"] @pytest.mark.anyio def test_create_oauth_metadata_request(self, oauth_provider): @@ -595,6 +594,177 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider, mock_storage): assert oauth_provider.context.current_tokens.access_token == "new_access_token" assert oauth_provider.context.token_expiry_time is not None + @pytest.mark.anyio + async def test_auth_flow_protected_resource_fallback(self, client_metadata, mock_storage): + """Test that the OAuth flow correctly implements fallback from path-specific to base endpoint.""" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/api/2.0/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + test_request = httpx.Request("GET", "https://api.example.com/api/2.0/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # Step 1: Original request without auth + request = await auth_flow.__anext__() + assert "Authorization" not in request.headers + + # Step 2: 401 triggers protected resource discovery - should try path-specific first + response = httpx.Response(401, request=test_request) + path_discovery_request = await auth_flow.asend(response) + assert ( + str(path_discovery_request.url) + == "https://api.example.com/.well-known/oauth-protected-resource/api/2.0/mcp" + ) + + # Step 3: Path-specific fails with 404 - should trigger fallback + path_404_response = httpx.Response(404, request=path_discovery_request) + base_discovery_request = await auth_flow.asend(path_404_response) + assert str(base_discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource" + + # Step 4: Base endpoint succeeds - should store metadata and continue to OAuth discovery + successful_response = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com", "authorization_servers": ["https://api.example.com"]}', + request=base_discovery_request, + ) + + # Verify the fallback worked and metadata was stored + await auth_flow.asend(successful_response) + assert provider.context.protected_resource_metadata is not None + assert str(provider.context.protected_resource_metadata.resource) == "https://api.example.com/" + + # Clean up the generator + try: + await auth_flow.aclose() + except Exception: + pass + + @pytest.mark.anyio + async def test_auth_flow_www_authenticate_no_fallback(self, client_metadata, mock_storage): + """Test that WWW-Authenticate header skips fallback logic entirely.""" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/api/2.0/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + test_request = httpx.Request("GET", "https://api.example.com/api/2.0/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # Step 1: Original request without auth + request = await auth_flow.__anext__() + assert "Authorization" not in request.headers + + # Step 2: 401 with WWW-Authenticate should use that URL directly + response = httpx.Response( + 401, + headers={ + "WWW-Authenticate": 'Bearer resource_metadata="https://custom.example.com/.well-known/oauth-protected-resource"' + }, + request=test_request, + ) + + www_auth_request = await auth_flow.asend(response) + assert str(www_auth_request.url) == "https://custom.example.com/.well-known/oauth-protected-resource" + + # Step 3: Should proceed directly to OAuth metadata discovery (no fallback attempted) + successful_response = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com/api/2.0/mcp", "authorization_servers": ["https://api.example.com"]}', + request=www_auth_request, + ) + + await auth_flow.asend(successful_response) + assert provider.context.protected_resource_metadata is not None + + # Clean up the generator + try: + await auth_flow.aclose() + except Exception: + pass + + @pytest.mark.anyio + async def test_auth_flow_no_fallback_on_success(self, client_metadata, mock_storage): + """Test that first successful discovery response stops the fallback process.""" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/api/2.0/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + test_request = httpx.Request("GET", "https://api.example.com/api/2.0/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # Step 1: Original request without auth + request = await auth_flow.__anext__() + assert "Authorization" not in request.headers + + # Step 2: 401 triggers path-specific discovery + response = httpx.Response(401, request=test_request) + path_discovery_request = await auth_flow.asend(response) + assert ( + str(path_discovery_request.url) + == "https://api.example.com/.well-known/oauth-protected-resource/api/2.0/mcp" + ) + + # Step 3: Path-specific succeeds - should skip fallback and go to OAuth discovery + successful_response = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com/api/2.0/mcp", "authorization_servers": ["https://api.example.com"]}', + request=path_discovery_request, + ) + + await auth_flow.asend(successful_response) + assert provider.context.protected_resource_metadata is not None + assert str(provider.context.protected_resource_metadata.resource) == "https://api.example.com/api/2.0/mcp" + + # Clean up the generator + try: + await auth_flow.aclose() + except Exception: + pass + @pytest.mark.parametrize( (