Skip to content
Open
19 changes: 19 additions & 0 deletions docs/my-website/docs/mcp.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,25 @@ mcp_servers:
extra_headers: ["custom_key", "x-custom-header"] # These headers will be forwarded from client
```

### Static Headers

Sometimes your MCP server needs specific headers on every request. Maybe it's an API key, maybe it's a custom header the server expects. Instead of configuring auth, you can just set them directly.

```yaml title="config.yaml" showLineNumbers
mcp_servers:
my_mcp_server:
url: "https://my-mcp-server.com/mcp"
static_headers:
X-API-Key: "abc123"
X-Custom-Header: "some-value"
```

These headers get sent with every request to the server. That's it.

**When to use this:**
- Your server needs custom headers that don't fit the standard auth patterns
- You want full control over exactly what headers are sent
- You're debugging and need to quickly add headers without changing auth configuration

### MCP Aliases

Expand Down
5 changes: 2 additions & 3 deletions litellm/experimental_mcp_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,9 @@ async def list_tools(self) -> List[MCPTool]:
await self.disconnect()
raise
except Exception as e:
verbose_logger.warning(f"MCP client list_tools failed: {str(e)}")
verbose_logger.debug(f"MCP client list_tools failed: {str(e)}")
await self.disconnect()
# Return empty list instead of raising to allow graceful degradation
return []
raise e

async def call_tool(
self, call_tool_request_params: MCPCallToolRequestParams
Expand Down
155 changes: 90 additions & 65 deletions litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def load_servers_from_config(
disallowed_tools=server_config.get("disallowed_tools", None),
allowed_params=server_config.get("allowed_params", None),
access_groups=server_config.get("access_groups", None),
static_headers=server_config.get("static_headers", None),
)
self.config_mcp_servers[server_id] = new_server

Expand Down Expand Up @@ -632,6 +633,11 @@ async def _get_tools_from_server(
client = None

try:
if server.static_headers:
if extra_headers is None:
extra_headers = {}
extra_headers.update(server.static_headers)

client = self._create_mcp_client(
server=server,
mcp_auth_header=mcp_auth_header,
Expand All @@ -654,10 +660,10 @@ async def _get_tools_from_server(
return prefixed_or_original_tools

except Exception as e:
verbose_logger.warning(
verbose_logger.debug(
f"Failed to get tools from server {server.name}: {str(e)}"
)
return []
raise e
finally:
if client:
try:
Expand Down Expand Up @@ -686,14 +692,14 @@ async def _list_tools_task():
tools = await client.list_tools()
verbose_logger.debug(f"Tools from {server_name}: {tools}")
return tools
except asyncio.CancelledError:
verbose_logger.warning(f"Client operation cancelled for {server_name}")
return []
except asyncio.CancelledError as e:
verbose_logger.debug(f"Client operation cancelled for {server_name}")
raise e
except Exception as e:
verbose_logger.warning(
verbose_logger.debug(
f"Client operation failed for {server_name}: {str(e)}"
)
return []
raise e
finally:
try:
await client.disconnect()
Expand All @@ -702,22 +708,22 @@ async def _list_tools_task():

try:
return await asyncio.wait_for(_list_tools_task(), timeout=30.0)
except asyncio.TimeoutError:
verbose_logger.warning(f"Timeout while listing tools from {server_name}")
return []
except asyncio.CancelledError:
verbose_logger.warning(
except asyncio.TimeoutError as e:
verbose_logger.debug(f"Timeout while listing tools from {server_name}")
raise e
except asyncio.CancelledError as e:
verbose_logger.debug(
f"Task cancelled while listing tools from {server_name}"
)
return []
raise e
except ConnectionError as e:
verbose_logger.warning(
verbose_logger.debug(
f"Connection error while listing tools from {server_name}: {str(e)}"
)
return []
raise e
except Exception as e:
verbose_logger.warning(f"Error listing tools from {server_name}: {str(e)}")
return []
verbose_logger.debug(f"Error listing tools from {server_name}: {str(e)}")
raise e

def _create_prefixed_tools(
self, tools: List[MCPTool], server: MCPServer, add_prefix: bool = True
Expand Down Expand Up @@ -1097,14 +1103,7 @@ async def call_tool(
raise ValueError(f"Tool {name} not found")

# Validate that the server from prefix matches the actual server (if prefix was used)
if server_name_from_prefix:
expected_prefix = get_server_prefix(mcp_server)
if normalize_server_name(server_name_from_prefix) != normalize_server_name(
expected_prefix
):
raise ValueError(
f"Tool {name} server prefix mismatch: expected {expected_prefix}, got {server_name_from_prefix}"
)
self._validate_server_prefix_match(name, server_name_from_prefix, mcp_server)

#########################################################
# Pre MCP Tool Call Hook
Expand All @@ -1121,6 +1120,39 @@ async def call_tool(
server=mcp_server,
)

# Get server-specific auth header if available
server_auth_header: Optional[Union[Dict[str, str], str]] = None
if mcp_server_auth_headers and mcp_server.alias:
server_auth_header = mcp_server_auth_headers.get(mcp_server.alias)
elif mcp_server_auth_headers and mcp_server.server_name:
server_auth_header = mcp_server_auth_headers.get(mcp_server.server_name)

# Fall back to deprecated mcp_auth_header if no server-specific header found
if server_auth_header is None:
server_auth_header = mcp_auth_header

# oauth2 headers
extra_headers: Optional[Dict[str, str]] = None
if mcp_server.auth_type == MCPAuth.oauth2:
extra_headers = oauth2_headers

if mcp_server.extra_headers and raw_headers:
if extra_headers is None:
extra_headers = {}
for header in mcp_server.extra_headers:
if header in raw_headers:
extra_headers[header] = raw_headers[header]

if mcp_server.static_headers:
if extra_headers is None:
extra_headers = {}
extra_headers.update(mcp_server.static_headers)

client = self._create_mcp_client(
server=mcp_server,
mcp_auth_header=server_auth_header,
extra_headers=extra_headers,
)
# Prepare tasks for during hooks
tasks = []
if proxy_logging_obj:
Expand All @@ -1147,46 +1179,13 @@ async def call_tool(
else:
# For regular MCP servers, use the MCP client
# Get server-specific auth header if available
server_auth_header: Optional[Union[Dict[str, str], str]] = None
if mcp_server_auth_headers and mcp_server.alias:
server_auth_header = mcp_server_auth_headers.get(mcp_server.alias)
elif mcp_server_auth_headers and mcp_server.server_name:
server_auth_header = mcp_server_auth_headers.get(mcp_server.server_name)

# Fall back to deprecated mcp_auth_header if no server-specific header found
if server_auth_header is None:
server_auth_header = mcp_auth_header

# oauth2 headers
extra_headers: Optional[Dict[str, str]] = None
if mcp_server.auth_type == MCPAuth.oauth2:
extra_headers = oauth2_headers

if mcp_server.extra_headers and raw_headers:
if extra_headers is None:
extra_headers = {}
for header in mcp_server.extra_headers:
if header in raw_headers:
extra_headers[header] = raw_headers[header]

client = self._create_mcp_client(
server=mcp_server,
mcp_auth_header=server_auth_header,
extra_headers=extra_headers,
)

call_tool_params = MCPCallToolRequestParams(
name=original_tool_name,
arguments=arguments,
)

async def _call_tool_via_client(client, params):
async with client:
return await client.call_tool(params)

tasks.append(
asyncio.create_task(_call_tool_via_client(client, call_tool_params))
)
async with client:
# Use the original tool name (without prefix) for the actual call
call_tool_params = MCPCallToolRequestParams(
name=original_tool_name,
arguments=arguments,
)
tasks.append(asyncio.create_task(client.call_tool(call_tool_params)))

try:
mcp_responses = await asyncio.gather(*tasks)
Expand Down Expand Up @@ -1270,6 +1269,32 @@ def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPServer]:

return None

def _validate_server_prefix_match(
self,
tool_name: str,
server_name_from_prefix: Optional[str],
mcp_server: MCPServer,
) -> None:
"""
Validate that the server prefix from the tool name matches the actual server.

Args:
tool_name: Original tool name provided
server_name_from_prefix: Server name extracted from tool name prefix (if any)
mcp_server: The MCP server that was found for this tool

Raises:
ValueError: If the server prefix doesn't match the expected server
"""
if server_name_from_prefix:
expected_prefix = get_server_prefix(mcp_server)
if normalize_server_name(server_name_from_prefix) != normalize_server_name(
expected_prefix
):
raise ValueError(
f"Tool {tool_name} server prefix mismatch: expected {expected_prefix}, got {server_name_from_prefix}"
)

async def _add_mcp_servers_from_db_to_in_memory_registry(self):
from litellm.proxy._experimental.mcp_server.db import get_all_mcp_servers
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
Expand Down
9 changes: 4 additions & 5 deletions litellm/proxy/_experimental/mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,8 @@ async def list_tools() -> List[MCPTool]:
)
return tools
except Exception as e:
verbose_logger.exception(f"Error in list_tools endpoint: {str(e)}")
# Return empty list instead of failing completely
# This prevents the HTTP stream from failing and allows the client to get a response
return []
verbose_logger.debug(f"Error in list_tools endpoint: {str(e)}")
raise e

@server.call_tool()
async def mcp_server_tool_call(
Expand Down Expand Up @@ -514,10 +512,11 @@ async def _get_tools_from_mcp_servers(
f"Successfully fetched {len(tools)} tools from server {server.name}, {len(filtered_tools)} after filtering"
)
except Exception as e:
verbose_logger.exception(
verbose_logger.debug(
f"Error getting tools from server {server.name}: {str(e)}"
)
# Continue with other servers instead of failing completely
raise e

verbose_logger.info(
f"Successfully fetched {len(all_tools)} tools total from all MCP servers"
Expand Down
5 changes: 4 additions & 1 deletion litellm/types/mcp_server/mcp_server_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class MCPServer(BaseModel):
authentication_token: Optional[str] = None
mcp_info: Optional[MCPInfo] = None
extra_headers: Optional[List[str]] = (
None # allow admin to specify which headers to forward to the MCP server
None # allow admin to specify which headers to forward from client to the MCP server
)
allowed_tools: Optional[List[str]] = None
disallowed_tools: Optional[List[str]] = None
Expand All @@ -40,4 +40,7 @@ class MCPServer(BaseModel):
args: Optional[List[str]] = None
env: Optional[Dict[str, str]] = None
access_groups: Optional[List[str]] = None
static_headers: Optional[Dict[str, str]] = (
None # static headers to forward to the MCP server
)
model_config = ConfigDict(arbitrary_types_allowed=True)
Loading
Loading