diff --git a/docs/my-website/docs/mcp.md b/docs/my-website/docs/mcp.md index 6333f7f09131..d143c02ac5c1 100644 --- a/docs/my-website/docs/mcp.md +++ b/docs/my-website/docs/mcp.md @@ -219,6 +219,25 @@ mcp_servers: extra_headers: ["custom_key", "x-custom-header"] # These headers will be forwarded from client ``` +### Static Headers + +Sometimes your MCP server needs specific headers on every request. Maybe it's an API key, maybe it's a custom header the server expects. Instead of configuring auth, you can just set them directly. + +```yaml title="config.yaml" showLineNumbers +mcp_servers: + my_mcp_server: + url: "https://my-mcp-server.com/mcp" + static_headers: + X-API-Key: "abc123" + X-Custom-Header: "some-value" +``` + +These headers get sent with every request to the server. That's it. + +**When to use this:** +- Your server needs custom headers that don't fit the standard auth patterns +- You want full control over exactly what headers are sent +- You're debugging and need to quickly add headers without changing auth configuration ### MCP Aliases diff --git a/litellm/experimental_mcp_client/client.py b/litellm/experimental_mcp_client/client.py index b10ddc9e8126..5a31749d46ac 100644 --- a/litellm/experimental_mcp_client/client.py +++ b/litellm/experimental_mcp_client/client.py @@ -279,10 +279,9 @@ async def list_tools(self) -> List[MCPTool]: await self.disconnect() raise except Exception as e: - verbose_logger.warning(f"MCP client list_tools failed: {str(e)}") + verbose_logger.debug(f"MCP client list_tools failed: {str(e)}") await self.disconnect() - # Return empty list instead of raising to allow graceful degradation - return [] + raise e async def call_tool( self, call_tool_request_params: MCPCallToolRequestParams diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index 28581c778c6a..28655663a592 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -218,6 +218,7 @@ def load_servers_from_config( disallowed_tools=server_config.get("disallowed_tools", None), allowed_params=server_config.get("allowed_params", None), access_groups=server_config.get("access_groups", None), + static_headers=server_config.get("static_headers", None), ) self.config_mcp_servers[server_id] = new_server @@ -632,6 +633,11 @@ async def _get_tools_from_server( client = None try: + if server.static_headers: + if extra_headers is None: + extra_headers = {} + extra_headers.update(server.static_headers) + client = self._create_mcp_client( server=server, mcp_auth_header=mcp_auth_header, @@ -654,10 +660,10 @@ async def _get_tools_from_server( return prefixed_or_original_tools except Exception as e: - verbose_logger.warning( + verbose_logger.debug( f"Failed to get tools from server {server.name}: {str(e)}" ) - return [] + raise e finally: if client: try: @@ -686,14 +692,14 @@ async def _list_tools_task(): tools = await client.list_tools() verbose_logger.debug(f"Tools from {server_name}: {tools}") return tools - except asyncio.CancelledError: - verbose_logger.warning(f"Client operation cancelled for {server_name}") - return [] + except asyncio.CancelledError as e: + verbose_logger.debug(f"Client operation cancelled for {server_name}") + raise e except Exception as e: - verbose_logger.warning( + verbose_logger.debug( f"Client operation failed for {server_name}: {str(e)}" ) - return [] + raise e finally: try: await client.disconnect() @@ -702,22 +708,22 @@ async def _list_tools_task(): try: return await asyncio.wait_for(_list_tools_task(), timeout=30.0) - except asyncio.TimeoutError: - verbose_logger.warning(f"Timeout while listing tools from {server_name}") - return [] - except asyncio.CancelledError: - verbose_logger.warning( + except asyncio.TimeoutError as e: + verbose_logger.debug(f"Timeout while listing tools from {server_name}") + raise e + except asyncio.CancelledError as e: + verbose_logger.debug( f"Task cancelled while listing tools from {server_name}" ) - return [] + raise e except ConnectionError as e: - verbose_logger.warning( + verbose_logger.debug( f"Connection error while listing tools from {server_name}: {str(e)}" ) - return [] + raise e except Exception as e: - verbose_logger.warning(f"Error listing tools from {server_name}: {str(e)}") - return [] + verbose_logger.debug(f"Error listing tools from {server_name}: {str(e)}") + raise e def _create_prefixed_tools( self, tools: List[MCPTool], server: MCPServer, add_prefix: bool = True @@ -1097,14 +1103,7 @@ async def call_tool( raise ValueError(f"Tool {name} not found") # Validate that the server from prefix matches the actual server (if prefix was used) - if server_name_from_prefix: - expected_prefix = get_server_prefix(mcp_server) - if normalize_server_name(server_name_from_prefix) != normalize_server_name( - expected_prefix - ): - raise ValueError( - f"Tool {name} server prefix mismatch: expected {expected_prefix}, got {server_name_from_prefix}" - ) + self._validate_server_prefix_match(name, server_name_from_prefix, mcp_server) ######################################################### # Pre MCP Tool Call Hook @@ -1121,6 +1120,39 @@ async def call_tool( server=mcp_server, ) + # Get server-specific auth header if available + server_auth_header: Optional[Union[Dict[str, str], str]] = None + if mcp_server_auth_headers and mcp_server.alias: + server_auth_header = mcp_server_auth_headers.get(mcp_server.alias) + elif mcp_server_auth_headers and mcp_server.server_name: + server_auth_header = mcp_server_auth_headers.get(mcp_server.server_name) + + # Fall back to deprecated mcp_auth_header if no server-specific header found + if server_auth_header is None: + server_auth_header = mcp_auth_header + + # oauth2 headers + extra_headers: Optional[Dict[str, str]] = None + if mcp_server.auth_type == MCPAuth.oauth2: + extra_headers = oauth2_headers + + if mcp_server.extra_headers and raw_headers: + if extra_headers is None: + extra_headers = {} + for header in mcp_server.extra_headers: + if header in raw_headers: + extra_headers[header] = raw_headers[header] + + if mcp_server.static_headers: + if extra_headers is None: + extra_headers = {} + extra_headers.update(mcp_server.static_headers) + + client = self._create_mcp_client( + server=mcp_server, + mcp_auth_header=server_auth_header, + extra_headers=extra_headers, + ) # Prepare tasks for during hooks tasks = [] if proxy_logging_obj: @@ -1147,46 +1179,13 @@ async def call_tool( else: # For regular MCP servers, use the MCP client # Get server-specific auth header if available - server_auth_header: Optional[Union[Dict[str, str], str]] = None - if mcp_server_auth_headers and mcp_server.alias: - server_auth_header = mcp_server_auth_headers.get(mcp_server.alias) - elif mcp_server_auth_headers and mcp_server.server_name: - server_auth_header = mcp_server_auth_headers.get(mcp_server.server_name) - - # Fall back to deprecated mcp_auth_header if no server-specific header found - if server_auth_header is None: - server_auth_header = mcp_auth_header - - # oauth2 headers - extra_headers: Optional[Dict[str, str]] = None - if mcp_server.auth_type == MCPAuth.oauth2: - extra_headers = oauth2_headers - - if mcp_server.extra_headers and raw_headers: - if extra_headers is None: - extra_headers = {} - for header in mcp_server.extra_headers: - if header in raw_headers: - extra_headers[header] = raw_headers[header] - - client = self._create_mcp_client( - server=mcp_server, - mcp_auth_header=server_auth_header, - extra_headers=extra_headers, - ) - - call_tool_params = MCPCallToolRequestParams( - name=original_tool_name, - arguments=arguments, - ) - - async def _call_tool_via_client(client, params): - async with client: - return await client.call_tool(params) - - tasks.append( - asyncio.create_task(_call_tool_via_client(client, call_tool_params)) - ) + async with client: + # Use the original tool name (without prefix) for the actual call + call_tool_params = MCPCallToolRequestParams( + name=original_tool_name, + arguments=arguments, + ) + tasks.append(asyncio.create_task(client.call_tool(call_tool_params))) try: mcp_responses = await asyncio.gather(*tasks) @@ -1270,6 +1269,32 @@ def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPServer]: return None + def _validate_server_prefix_match( + self, + tool_name: str, + server_name_from_prefix: Optional[str], + mcp_server: MCPServer, + ) -> None: + """ + Validate that the server prefix from the tool name matches the actual server. + + Args: + tool_name: Original tool name provided + server_name_from_prefix: Server name extracted from tool name prefix (if any) + mcp_server: The MCP server that was found for this tool + + Raises: + ValueError: If the server prefix doesn't match the expected server + """ + if server_name_from_prefix: + expected_prefix = get_server_prefix(mcp_server) + if normalize_server_name(server_name_from_prefix) != normalize_server_name( + expected_prefix + ): + raise ValueError( + f"Tool {tool_name} server prefix mismatch: expected {expected_prefix}, got {server_name_from_prefix}" + ) + async def _add_mcp_servers_from_db_to_in_memory_registry(self): from litellm.proxy._experimental.mcp_server.db import get_all_mcp_servers from litellm.proxy.management_endpoints.mcp_management_endpoints import ( diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 77d6abfed62b..d5879289ae16 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -206,10 +206,8 @@ async def list_tools() -> List[MCPTool]: ) return tools except Exception as e: - verbose_logger.exception(f"Error in list_tools endpoint: {str(e)}") - # Return empty list instead of failing completely - # This prevents the HTTP stream from failing and allows the client to get a response - return [] + verbose_logger.debug(f"Error in list_tools endpoint: {str(e)}") + raise e @server.call_tool() async def mcp_server_tool_call( @@ -514,10 +512,11 @@ async def _get_tools_from_mcp_servers( f"Successfully fetched {len(tools)} tools from server {server.name}, {len(filtered_tools)} after filtering" ) except Exception as e: - verbose_logger.exception( + verbose_logger.debug( f"Error getting tools from server {server.name}: {str(e)}" ) # Continue with other servers instead of failing completely + raise e verbose_logger.info( f"Successfully fetched {len(all_tools)} tools total from all MCP servers" diff --git a/litellm/types/mcp_server/mcp_server_manager.py b/litellm/types/mcp_server/mcp_server_manager.py index 03bee94eb8b9..b0ddc017d174 100644 --- a/litellm/types/mcp_server/mcp_server_manager.py +++ b/litellm/types/mcp_server/mcp_server_manager.py @@ -22,7 +22,7 @@ class MCPServer(BaseModel): authentication_token: Optional[str] = None mcp_info: Optional[MCPInfo] = None extra_headers: Optional[List[str]] = ( - None # allow admin to specify which headers to forward to the MCP server + None # allow admin to specify which headers to forward from client to the MCP server ) allowed_tools: Optional[List[str]] = None disallowed_tools: Optional[List[str]] = None @@ -40,4 +40,7 @@ class MCPServer(BaseModel): args: Optional[List[str]] = None env: Optional[Dict[str, str]] = None access_groups: Optional[List[str]] = None + static_headers: Optional[Dict[str, str]] = ( + None # static headers to forward to the MCP server + ) model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py index 769785b92145..85f842d25683 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py @@ -69,172 +69,6 @@ async def mock_call_mcp_tool(*args, **kwargs): assert body["arguments"] == tool_arguments -@pytest.mark.asyncio -async def test_get_tools_from_mcp_servers_continues_when_one_server_fails(): - """Test that _get_tools_from_mcp_servers continues when one server fails""" - try: - from litellm.proxy._experimental.mcp_server.server import ( - _get_tools_from_mcp_servers, - set_auth_context, - ) - except ImportError: - pytest.skip("MCP server not available") - - # Mock user auth - user_api_key_auth = UserAPIKeyAuth(api_key="test_key", user_id="test_user") - set_auth_context(user_api_key_auth) - - # Mock servers - working_server = MagicMock() - working_server.name = "working_server" - working_server.alias = "working" - working_server.allowed_tools = None - working_server.disallowed_tools = None - - failing_server = MagicMock() - failing_server.name = "failing_server" - failing_server.alias = "failing" - failing_server.allowed_tools = None - failing_server.disallowed_tools = None - - # Mock global_mcp_server_manager - mock_manager = MagicMock() - mock_manager.get_allowed_mcp_servers = AsyncMock( - return_value=["working_server", "failing_server"] - ) - mock_manager.get_mcp_server_by_id = lambda server_id: ( - working_server if server_id == "working_server" else failing_server - ) - - async def mock_get_tools_from_server( - server, mcp_auth_header=None, extra_headers=None, add_prefix=True - ): - if server.name == "working_server": - # Working server returns tools - tool1 = MagicMock() - tool1.name = "working_tool_1" - tool1.description = "Working tool 1" - tool1.inputSchema = {} - return [tool1] - else: - # Failing server raises an exception - raise Exception("Server connection failed") - - mock_manager._get_tools_from_server = mock_get_tools_from_server - - with patch( - "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", - mock_manager, - ): - with patch( - "litellm.proxy._experimental.mcp_server.server.verbose_logger", - ) as mock_logger: - # Test with server-specific auth headers - mcp_server_auth_headers = { - "working": "Bearer working-token", - "failing": "Bearer failing-token", - } - - result = await _get_tools_from_mcp_servers( - user_api_key_auth=user_api_key_auth, - mcp_auth_header=None, - mcp_servers=None, - mcp_server_auth_headers=mcp_server_auth_headers, - ) - - # Verify that tools from the working server are returned - assert len(result) == 1 - assert result[0].name == "working_tool_1" - - # Verify failure logging - mock_logger.exception.assert_any_call( - "Error getting tools from server failing_server: Server connection failed" - ) - - # Verify success logging - mock_logger.info.assert_any_call( - "Successfully fetched 1 tools total from all MCP servers" - ) - - -@pytest.mark.asyncio -async def test_get_tools_from_mcp_servers_handles_all_servers_failing(): - """Test that _get_tools_from_mcp_servers handles all servers failing gracefully""" - try: - from litellm.proxy._experimental.mcp_server.server import ( - _get_tools_from_mcp_servers, - set_auth_context, - ) - except ImportError: - pytest.skip("MCP server not available") - - # Mock user auth - user_api_key_auth = UserAPIKeyAuth(api_key="test_key", user_id="test_user") - set_auth_context(user_api_key_auth) - - # Mock servers - failing_server1 = MagicMock() - failing_server1.name = "failing_server1" - failing_server1.alias = "failing1" - - failing_server2 = MagicMock() - failing_server2.name = "failing_server2" - failing_server2.alias = "failing2" - - # Mock global_mcp_server_manager - mock_manager = MagicMock() - mock_manager.get_allowed_mcp_servers = AsyncMock( - return_value=["failing_server1", "failing_server2"] - ) - mock_manager.get_mcp_server_by_id = lambda server_id: ( - failing_server1 if server_id == "failing_server1" else failing_server2 - ) - - async def mock_get_tools_from_server( - server, mcp_auth_header=None, extra_headers=None, add_prefix=True - ): - # All servers fail - raise Exception(f"Server {server.name} connection failed") - - mock_manager._get_tools_from_server = mock_get_tools_from_server - - with patch( - "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", - mock_manager, - ): - with patch( - "litellm.proxy._experimental.mcp_server.server.verbose_logger", - ) as mock_logger: - # Test with server-specific auth headers - mcp_server_auth_headers = { - "failing1": "Bearer failing1-token", - "failing2": "Bearer failing2-token", - } - - result = await _get_tools_from_mcp_servers( - user_api_key_auth=user_api_key_auth, - mcp_auth_header=None, - mcp_servers=None, - mcp_server_auth_headers=mcp_server_auth_headers, - ) - - # Verify that empty list is returned - assert len(result) == 0 - - # Verify failure logging for both servers - mock_logger.exception.assert_any_call( - "Error getting tools from server failing_server1: Server failing_server1 connection failed" - ) - mock_logger.exception.assert_any_call( - "Error getting tools from server failing_server2: Server failing_server2 connection failed" - ) - - # Verify total logging - mock_logger.info.assert_any_call( - "Successfully fetched 0 tools total from all MCP servers" - ) - - @pytest.mark.asyncio async def test_mcp_server_tool_call_body_with_none_arguments(): """Test that proxy_server_request body handles None arguments correctly"""