diff --git a/README.md b/README.md index 993b6006b..2adcbba0c 100644 --- a/README.md +++ b/README.md @@ -1125,7 +1125,7 @@ server = Server("example-server", lifespan=server_lifespan) @server.list_tools() -async def handle_list_tools() -> list[types.Tool]: +async def handle_list_tools(_) -> list[types.Tool]: """List available tools.""" return [ types.Tool( @@ -1207,7 +1207,7 @@ server = Server("example-server") @server.list_prompts() -async def handle_list_prompts() -> list[types.Prompt]: +async def handle_list_prompts(_) -> list[types.Prompt]: """List available prompts.""" return [ types.Prompt( @@ -1286,7 +1286,7 @@ server = Server("example-server") @server.list_tools() -async def list_tools() -> list[types.Tool]: +async def list_tools(_) -> list[types.Tool]: """List available tools with structured output schemas.""" return [ types.Tool( diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/server.py b/examples/servers/simple-prompt/mcp_simple_prompt/server.py index b562cc932..e62c66f42 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/server.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/server.py @@ -49,7 +49,7 @@ def main(port: int, transport: str) -> int: app = Server("mcp-simple-prompt") @app.list_prompts() - async def list_prompts() -> list[types.Prompt]: + async def list_prompts(_) -> list[types.Prompt]: return [ types.Prompt( name="simple", diff --git a/examples/servers/simple-resource/mcp_simple_resource/server.py b/examples/servers/simple-resource/mcp_simple_resource/server.py index cef29b851..43bcaba5b 100644 --- a/examples/servers/simple-resource/mcp_simple_resource/server.py +++ b/examples/servers/simple-resource/mcp_simple_resource/server.py @@ -32,7 +32,7 @@ def main(port: int, transport: str) -> int: app = Server("mcp-simple-resource") @app.list_resources() - async def list_resources() -> list[types.Resource]: + async def list_resources(_) -> list[types.Resource]: return [ types.Resource( uri=FileUrl(f"file:///{name}.txt"), diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py index 68f3ac6a6..b5da75c05 100644 --- a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py @@ -69,7 +69,7 @@ async def call_tool(name: str, arguments: dict) -> list[types.ContentBlock]: ] @app.list_tools() - async def list_tools() -> list[types.Tool]: + async def list_tools(_) -> list[types.Tool]: return [ types.Tool( name="start-notification-stream", diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index 9c25cc569..b69f8071b 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -87,7 +87,7 @@ async def call_tool(name: str, arguments: dict) -> list[types.ContentBlock]: ] @app.list_tools() - async def list_tools() -> list[types.Tool]: + async def list_tools(_) -> list[types.Tool]: return [ types.Tool( name="start-notification-stream", diff --git a/examples/servers/simple-tool/mcp_simple_tool/server.py b/examples/servers/simple-tool/mcp_simple_tool/server.py index bf3683c9e..d7ac21c8d 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/server.py +++ b/examples/servers/simple-tool/mcp_simple_tool/server.py @@ -37,7 +37,7 @@ async def fetch_tool(name: str, arguments: dict) -> list[types.ContentBlock]: return await fetch_website(arguments["url"]) @app.list_tools() - async def list_tools() -> list[types.Tool]: + async def list_tools(_) -> list[types.Tool]: return [ types.Tool( name="fetch", diff --git a/examples/servers/structured_output_lowlevel.py b/examples/servers/structured_output_lowlevel.py index 7f102ff8b..db57f958a 100644 --- a/examples/servers/structured_output_lowlevel.py +++ b/examples/servers/structured_output_lowlevel.py @@ -21,7 +21,7 @@ @server.list_tools() -async def list_tools() -> list[types.Tool]: +async def list_tools(_) -> list[types.Tool]: """List available tools with their schemas.""" return [ types.Tool( diff --git a/examples/snippets/servers/hierarchical_organization.py b/examples/snippets/servers/hierarchical_organization.py new file mode 100644 index 000000000..2e6cd49c7 --- /dev/null +++ b/examples/snippets/servers/hierarchical_organization.py @@ -0,0 +1,182 @@ +"""Example demonstrating hierarchical organization of tools, prompts, and resources using custom URIs. + +This example shows how to: +1. Register tools, prompts, and resources with hierarchical URIs +2. Create group discovery resources at well-known URIs +3. Filter items by URI paths for better organization +""" + +import json +from typing import cast + +from pydantic import AnyUrl + +from mcp.server.fastmcp import FastMCP +from mcp.types import ListFilters, TextContent, TextResourceContents + +# Create FastMCP server instance +mcp = FastMCP("hierarchical-example") + + +# Group discovery resources +@mcp.resource("mcp://groups/tools") +def get_tool_groups() -> str: + """Discover available tool groups.""" + return json.dumps( + { + "groups": [ + {"name": "math", "description": "Mathematical operations", "uri_paths": ["mcp://tools/math/"]}, + {"name": "string", "description": "String manipulation", "uri_paths": ["mcp://tools/string/"]}, + ] + }, + indent=2, + ) + + +@mcp.resource("mcp://groups/prompts") +def get_prompt_groups() -> str: + """Discover available prompt groups.""" + return json.dumps( + { + "groups": [ + {"name": "greetings", "description": "Greeting prompts", "uri_paths": ["mcp://prompts/greetings/"]}, + { + "name": "instructions", + "description": "Instructional prompts", + "uri_paths": ["mcp://prompts/instructions/"], + }, + ] + }, + indent=2, + ) + + +# Math tools organized under mcp://tools/math/ +@mcp.tool(uri="mcp://tools/math/add") +def add(a: float, b: float) -> float: + """Add two numbers.""" + return a + b + + +@mcp.tool(uri="mcp://tools/math/multiply") +def multiply(a: float, b: float) -> float: + """Multiply two numbers.""" + return a * b + + +# String tools organized under mcp://tools/string/ +@mcp.tool(uri="mcp://tools/string/reverse") +def reverse(text: str) -> str: + """Reverse a string.""" + return text[::-1] + + +@mcp.tool(uri="mcp://tools/string/upper") +def upper(text: str) -> str: + """Convert to uppercase.""" + return text.upper() + + +# Greeting prompts organized under mcp://prompts/greetings/ +@mcp.prompt(uri="mcp://prompts/greetings/hello") +def hello_prompt(name: str) -> str: + """Generate a hello greeting.""" + return f"Hello, {name}! How can I help you today?" + + +@mcp.prompt(uri="mcp://prompts/greetings/goodbye") +def goodbye_prompt(name: str) -> str: + """Generate a goodbye message.""" + return f"Goodbye, {name}! Have a great day!" + + +# Instruction prompts organized under mcp://prompts/instructions/ +@mcp.prompt(uri="mcp://prompts/instructions/setup") +def setup_prompt(tool: str) -> str: + """Generate setup instructions for a tool.""" + return ( + f"To set up {tool}, follow these steps:\n" + "1. Install the required dependencies\n" + "2. Configure the settings\n" + "3. Run the initialization script\n" + "4. Verify the installation" + ) + + +@mcp.prompt(uri="mcp://prompts/instructions/debug") +def debug_prompt(error: str) -> str: + """Generate debugging instructions for an error.""" + return ( + f"To debug '{error}':\n" + "1. Check the error logs\n" + "2. Verify input parameters\n" + "3. Enable verbose logging\n" + "4. Isolate the issue with minimal reproduction" + ) + + +if __name__ == "__main__": + # Example of testing the hierarchical organization + import asyncio + + from mcp.shared.memory import create_connected_server_and_client_session + + async def test_hierarchy(): + """Test the hierarchical organization.""" + async with create_connected_server_and_client_session(mcp._mcp_server) as client: + # 1. Discover tool groups and list tools in each group + print("\n=== Discovering Tool Groups ===") + result = await client.read_resource(uri=AnyUrl("mcp://groups/tools")) + tool_groups = json.loads(cast(TextResourceContents, result.contents[0]).text) + + for group in tool_groups["groups"]: + print(f"\n--- {group['name'].upper()} Tools ({group['description']}) ---") + # Use the URI paths from the group definition + group_tools = await client.list_tools( + filters=ListFilters(uri_paths=[AnyUrl(uri) for uri in group["uri_paths"]]) + ) + for tool in group_tools.tools: + print(f" - {tool.name}: {tool.description}") + + # 2. Call tools by name (still works!) + print("\n=== Calling Tools by Name ===") + result = await client.call_tool("add", {"a": 10, "b": 5}) + print(f"add(10, 5) = {cast(TextContent, result.content[0]).text}") + + result = await client.call_tool("reverse", {"text": "Hello"}) + print(f"reverse('Hello') = {cast(TextContent, result.content[0]).text}") + + # 3. Call tools by URI + print("\n=== Calling Tools by URI ===") + result = await client.call_tool("mcp://tools/math/multiply", {"a": 7, "b": 8}) + print( + f"Call mcp://tools/math/multiply with {{'a': 7, 'b': 8}} = {cast(TextContent, result.content[0]).text}" + ) + + result = await client.call_tool("mcp://tools/string/upper", {"text": "hello world"}) + print( + f"Call mcp://tools/string/upper with {{'text': 'hello world'}} = " + f"{cast(TextContent, result.content[0]).text}" + ) + + # 4. Discover prompt groups and list prompts in each group + print("\n=== Discovering Prompt Groups ===") + result = await client.read_resource(uri=AnyUrl("mcp://groups/prompts")) + prompt_groups = json.loads(cast(TextResourceContents, result.contents[0]).text) + + for group in prompt_groups["groups"]: + print(f"\n--- {group['name'].upper()} Prompts ({group['description']}) ---") + # Use the URI paths from the group definition + group_prompts = await client.list_prompts( + filters=ListFilters(uri_paths=[AnyUrl(uri) for uri in group["uri_paths"]]) + ) + for prompt in group_prompts.prompts: + print(f" - {prompt.name}: {prompt.description}") + + # 5. Use a prompt + print("\n=== Using a Prompt ===") + result = await client.get_prompt("hello_prompt", {"name": "Alice"}) + print(f"Prompt result: {cast(TextContent, result.messages[0].content).text}") + + # Run the test + asyncio.run(test_hierarchy()) diff --git a/examples/snippets/servers/lowlevel/basic.py b/examples/snippets/servers/lowlevel/basic.py index a5c4149df..bbdc1d63e 100644 --- a/examples/snippets/servers/lowlevel/basic.py +++ b/examples/snippets/servers/lowlevel/basic.py @@ -15,7 +15,7 @@ @server.list_prompts() -async def handle_list_prompts() -> list[types.Prompt]: +async def handle_list_prompts(_) -> list[types.Prompt]: """List available prompts.""" return [ types.Prompt( diff --git a/examples/snippets/servers/lowlevel/lifespan.py b/examples/snippets/servers/lowlevel/lifespan.py index 61a9fe78e..652ac4754 100644 --- a/examples/snippets/servers/lowlevel/lifespan.py +++ b/examples/snippets/servers/lowlevel/lifespan.py @@ -49,7 +49,7 @@ async def server_lifespan(_server: Server) -> AsyncIterator[dict]: @server.list_tools() -async def handle_list_tools() -> list[types.Tool]: +async def handle_list_tools(_) -> list[types.Tool]: """List available tools.""" return [ types.Tool( diff --git a/examples/snippets/servers/lowlevel/structured_output.py b/examples/snippets/servers/lowlevel/structured_output.py index 0237c9ab3..9e25cfd42 100644 --- a/examples/snippets/servers/lowlevel/structured_output.py +++ b/examples/snippets/servers/lowlevel/structured_output.py @@ -15,7 +15,7 @@ @server.list_tools() -async def list_tools() -> list[types.Tool]: +async def list_tools(_) -> list[types.Tool]: """List available tools with structured output schemas.""" return [ types.Tool( diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 1853ce7c1..6a8c6d027 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -221,25 +221,37 @@ async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResul types.EmptyResult, ) - async def list_resources(self, cursor: str | None = None) -> types.ListResourcesResult: + async def list_resources( + self, filters: types.ListFilters | None = None, cursor: str | None = None + ) -> types.ListResourcesResult: """Send a resources/list request.""" + params = None + if cursor is not None or filters is not None: + params = types.ListRequestParams(filters=filters, cursor=cursor) return await self.send_request( types.ClientRequest( types.ListResourcesRequest( method="resources/list", - params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, + params=params, ) ), types.ListResourcesResult, ) - async def list_resource_templates(self, cursor: str | None = None) -> types.ListResourceTemplatesResult: + async def list_resource_templates( + self, + filters: types.ListFilters | None = None, + cursor: str | None = None, + ) -> types.ListResourceTemplatesResult: """Send a resources/templates/list request.""" + params = None + if cursor is not None or filters is not None: + params = types.ListRequestParams(filters=filters, cursor=cursor) return await self.send_request( types.ClientRequest( types.ListResourceTemplatesRequest( method="resources/templates/list", - params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, + params=params, ) ), types.ListResourceTemplatesResult, @@ -332,13 +344,18 @@ async def _validate_tool_result(self, name: str, result: types.CallToolResult) - except SchemaError as e: raise RuntimeError(f"Invalid schema for tool {name}: {e}") - async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResult: + async def list_prompts( + self, filters: types.ListFilters | None = None, cursor: str | None = None + ) -> types.ListPromptsResult: """Send a prompts/list request.""" + params = None + if cursor is not None or filters is not None: + params = types.ListRequestParams(filters=filters, cursor=cursor) return await self.send_request( types.ClientRequest( types.ListPromptsRequest( method="prompts/list", - params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, + params=params, ) ), types.ListPromptsResult, @@ -381,13 +398,18 @@ async def complete( types.CompleteResult, ) - async def list_tools(self, cursor: str | None = None) -> types.ListToolsResult: + async def list_tools( + self, filters: types.ListFilters | None = None, cursor: str | None = None + ) -> types.ListToolsResult: """Send a tools/list request.""" + params = None + if cursor is not None or filters is not None: + params = types.ListRequestParams(filters=filters, cursor=cursor) result = await self.send_request( types.ClientRequest( types.ListToolsRequest( method="tools/list", - params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, + params=params, ) ), types.ListToolsResult, diff --git a/src/mcp/server/fastmcp/prompts/base.py b/src/mcp/server/fastmcp/prompts/base.py index b45cfc917..52d331b78 100644 --- a/src/mcp/server/fastmcp/prompts/base.py +++ b/src/mcp/server/fastmcp/prompts/base.py @@ -5,9 +5,9 @@ from typing import Any, Literal import pydantic_core -from pydantic import BaseModel, Field, TypeAdapter, validate_call +from pydantic import AnyUrl, BaseModel, Field, TypeAdapter, validate_call -from mcp.types import ContentBlock, TextContent +from mcp.types import PROMPT_SCHEME, ContentBlock, TextContent class Message(BaseModel): @@ -58,16 +58,24 @@ class Prompt(BaseModel): """A prompt template that can be rendered with parameters.""" name: str = Field(description="Name of the prompt") + uri: AnyUrl = Field(description="URI of the prompt") title: str | None = Field(None, description="Human-readable title of the prompt") description: str | None = Field(None, description="Description of what the prompt does") arguments: list[PromptArgument] | None = Field(None, description="Arguments that can be passed to the prompt") fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True) + def __init__(self, **data: Any) -> None: + """Initialize Prompt, generating URI from name if not provided.""" + if not data.get("uri", None): + data["uri"] = AnyUrl(f"{PROMPT_SCHEME}/{data['name']}") + super().__init__(**data) + @classmethod def from_function( cls, fn: Callable[..., PromptResult | Awaitable[PromptResult]], name: str | None = None, + uri: str | AnyUrl | None = None, title: str | None = None, description: str | None = None, ) -> "Prompt": @@ -105,6 +113,7 @@ def from_function( return cls( name=func_name, + uri=uri, title=title, description=description or fn.__doc__ or "", arguments=arguments, diff --git a/src/mcp/server/fastmcp/prompts/manager.py b/src/mcp/server/fastmcp/prompts/manager.py index 6b01d91cd..b804cd0cd 100644 --- a/src/mcp/server/fastmcp/prompts/manager.py +++ b/src/mcp/server/fastmcp/prompts/manager.py @@ -1,8 +1,11 @@ """Prompt management functionality.""" -from typing import Any +from typing import Any, overload + +from pydantic import AnyUrl from mcp.server.fastmcp.prompts.base import Message, Prompt +from mcp.server.fastmcp.uri_utils import filter_by_uri_paths, normalize_to_prompt_uri from mcp.server.fastmcp.utilities.logging import get_logger logger = get_logger(__name__) @@ -15,34 +18,77 @@ def __init__(self, warn_on_duplicate_prompts: bool = True): self._prompts: dict[str, Prompt] = {} self.warn_on_duplicate_prompts = warn_on_duplicate_prompts - def get_prompt(self, name: str) -> Prompt | None: + def _normalize_to_uri(self, name_or_uri: str) -> str: + """Convert name to URI if needed.""" + return normalize_to_prompt_uri(name_or_uri) + + @overload + def get_prompt(self, name_or_uri: str) -> Prompt | None: """Get prompt by name.""" - return self._prompts.get(name) + ... + + @overload + def get_prompt(self, name_or_uri: AnyUrl) -> Prompt | None: + """Get prompt by URI.""" + ... + + def get_prompt(self, name_or_uri: AnyUrl | str) -> Prompt | None: + """Get prompt by name or URI.""" + if isinstance(name_or_uri, AnyUrl): + return self._prompts.get(str(name_or_uri)) + + # Try as a direct URI first + if name_or_uri in self._prompts: + return self._prompts[name_or_uri] - def list_prompts(self) -> list[Prompt]: - """List all registered prompts.""" - return list(self._prompts.values()) + # Try to find a prompt by name + for prompt in self._prompts.values(): + if prompt.name == name_or_uri: + return prompt + + # Finally try normalizing to URI + uri = self._normalize_to_uri(name_or_uri) + return self._prompts.get(uri) + + def list_prompts(self, uri_paths: list[AnyUrl] | None = None) -> list[Prompt]: + """List all registered prompts, optionally filtered by URI paths.""" + prompts = list(self._prompts.values()) + if uri_paths: + prompts = filter_by_uri_paths(prompts, uri_paths) + logger.debug("Listing prompts", extra={"count": len(prompts), "uri_paths": uri_paths}) + return prompts def add_prompt( self, prompt: Prompt, ) -> Prompt: """Add a prompt to the manager.""" + logger.debug(f"Adding prompt: {prompt.name} with URI: {prompt.uri}") # Check for duplicates - existing = self._prompts.get(prompt.name) + existing = self._prompts.get(str(prompt.uri)) if existing: if self.warn_on_duplicate_prompts: - logger.warning(f"Prompt already exists: {prompt.name}") + logger.warning(f"Prompt already exists: {prompt.uri}") return existing - self._prompts[prompt.name] = prompt + self._prompts[str(prompt.uri)] = prompt return prompt - async def render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> list[Message]: + @overload + async def render_prompt(self, name_or_uri: str, arguments: dict[str, Any] | None = None) -> list[Message]: """Render a prompt by name with arguments.""" - prompt = self.get_prompt(name) + ... + + @overload + async def render_prompt(self, name_or_uri: AnyUrl, arguments: dict[str, Any] | None = None) -> list[Message]: + """Render a prompt by URI with arguments.""" + ... + + async def render_prompt(self, name_or_uri: AnyUrl | str, arguments: dict[str, Any] | None = None) -> list[Message]: + """Render a prompt by name or URI with arguments.""" + prompt = self.get_prompt(name_or_uri) if not prompt: - raise ValueError(f"Unknown prompt: {name}") + raise ValueError(f"Unknown prompt: {name_or_uri}") return await prompt.render(arguments) diff --git a/src/mcp/server/fastmcp/prompts/prompt_manager.py b/src/mcp/server/fastmcp/prompts/prompt_manager.py index 389e89624..92b329d7c 100644 --- a/src/mcp/server/fastmcp/prompts/prompt_manager.py +++ b/src/mcp/server/fastmcp/prompts/prompt_manager.py @@ -1,6 +1,9 @@ """Prompt management functionality.""" +from pydantic import AnyUrl + from mcp.server.fastmcp.prompts.base import Prompt +from mcp.server.fastmcp.uri_utils import filter_by_uri_paths, normalize_to_prompt_uri from mcp.server.fastmcp.utilities.logging import get_logger logger = get_logger(__name__) @@ -13,21 +16,30 @@ def __init__(self, warn_on_duplicate_prompts: bool = True): self._prompts: dict[str, Prompt] = {} self.warn_on_duplicate_prompts = warn_on_duplicate_prompts + def _normalize_to_uri(self, name_or_uri: str) -> str: + """Convert name to URI if needed.""" + return normalize_to_prompt_uri(name_or_uri) + def add_prompt(self, prompt: Prompt) -> Prompt: """Add a prompt to the manager.""" - logger.debug(f"Adding prompt: {prompt.name}") - existing = self._prompts.get(prompt.name) + logger.debug(f"Adding prompt: {prompt.name} with URI: {prompt.uri}") + existing = self._prompts.get(str(prompt.uri)) if existing: if self.warn_on_duplicate_prompts: - logger.warning(f"Prompt already exists: {prompt.name}") + logger.warning(f"Prompt already exists: {prompt.uri}") return existing - self._prompts[prompt.name] = prompt + self._prompts[str(prompt.uri)] = prompt return prompt def get_prompt(self, name: str) -> Prompt | None: - """Get prompt by name.""" - return self._prompts.get(name) - - def list_prompts(self) -> list[Prompt]: - """List all registered prompts.""" - return list(self._prompts.values()) + """Get prompt by name or URI.""" + uri = self._normalize_to_uri(name) + return self._prompts.get(uri) + + def list_prompts(self, uri_paths: list[AnyUrl] | None = None) -> list[Prompt]: + """List all registered prompts, optionally filtered by URI paths.""" + prompts = list(self._prompts.values()) + if uri_paths: + prompts = filter_by_uri_paths(prompts, uri_paths) + logger.debug("Listing prompts", extra={"count": len(prompts), "uri_paths": uri_paths}) + return prompts diff --git a/src/mcp/server/fastmcp/resources/resource_manager.py b/src/mcp/server/fastmcp/resources/resource_manager.py index 35e4ec04d..aaabf2a0f 100644 --- a/src/mcp/server/fastmcp/resources/resource_manager.py +++ b/src/mcp/server/fastmcp/resources/resource_manager.py @@ -7,6 +7,7 @@ from mcp.server.fastmcp.resources.base import Resource from mcp.server.fastmcp.resources.templates import ResourceTemplate +from mcp.server.fastmcp.uri_utils import filter_by_uri_paths from mcp.server.fastmcp.utilities.logging import get_logger logger = get_logger(__name__) @@ -86,12 +87,28 @@ async def get_resource(self, uri: AnyUrl | str) -> Resource | None: raise ValueError(f"Unknown resource: {uri}") - def list_resources(self) -> list[Resource]: - """List all registered resources.""" - logger.debug("Listing resources", extra={"count": len(self._resources)}) - return list(self._resources.values()) - - def list_templates(self) -> list[ResourceTemplate]: - """List all registered templates.""" - logger.debug("Listing templates", extra={"count": len(self._templates)}) - return list(self._templates.values()) + def list_resources(self, uri_paths: list[AnyUrl] | None = None) -> list[Resource]: + """List all registered resources, optionally filtered by URI paths.""" + resources = list(self._resources.values()) + if uri_paths: + resources = filter_by_uri_paths(resources, uri_paths) + logger.debug("Listing resources", extra={"count": len(resources), "uri_paths": uri_paths}) + return resources + + def list_templates(self, uri_paths: list[AnyUrl] | None = None) -> list[ResourceTemplate]: + """List all registered templates, optionally filtered by URI paths.""" + templates = list(self._templates.values()) + if uri_paths: + filtered: list[ResourceTemplate] = [] + for template in templates: + for prefix in uri_paths: + # Ensure prefix ends with / for proper path matching + prefix_str = str(prefix) + if not prefix_str.endswith("/"): + prefix_str = prefix_str + "/" + if template.matches_prefix(prefix_str): + filtered.append(template) + break + templates = filtered + logger.debug("Listing templates", extra={"count": len(templates), "uri_paths": uri_paths}) + return templates diff --git a/src/mcp/server/fastmcp/resources/templates.py b/src/mcp/server/fastmcp/resources/templates.py index b1c7b2711..54b04dde3 100644 --- a/src/mcp/server/fastmcp/resources/templates.py +++ b/src/mcp/server/fastmcp/resources/templates.py @@ -63,6 +63,44 @@ def matches(self, uri: str) -> dict[str, Any] | None: return match.groupdict() return None + def matches_prefix(self, prefix: str) -> bool: + """Check if this template could match URIs with the given prefix.""" + + # First, simple check: does the template itself start with the prefix? + if self.uri_template.startswith(prefix): + return True + + template_segments = self.uri_template.split("/") + prefix_segments = prefix.split("/") + + # Handle trailing slash - it creates an empty last segment + has_trailing_slash = prefix.endswith("/") and prefix_segments[-1] == "" + if has_trailing_slash: + # Remove the empty segment for comparison + prefix_segments = prefix_segments[:-1] + # Template must have more segments to generate something "under" this path + if len(template_segments) <= len(prefix_segments): + return False + else: + # Without trailing slash, prefix can't have more segments than template + if len(prefix_segments) > len(template_segments): + return False + + # Compare each segment + for i, prefix_seg in enumerate(prefix_segments): + template_seg = template_segments[i] + + # If template segment is a parameter, it can match any value + if template_seg.startswith("{") and template_seg.endswith("}"): + continue + + # If both are literals, they must match exactly + if template_seg != prefix_seg: + return False + + # All prefix segments matched + return True + async def create_resource(self, uri: str, params: dict[str, Any]) -> Resource: """Create a resource from the template with the given parameters.""" try: diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 924baaa9b..f442c96d7 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -21,6 +21,7 @@ from starlette.routing import Mount, Route from starlette.types import Receive, Scope, Send +from mcp import types from mcp.server.auth.middleware.auth_context import AuthContextMiddleware from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier @@ -266,14 +267,18 @@ def _setup_handlers(self) -> None: self._mcp_server.get_prompt()(self.get_prompt) self._mcp_server.list_resource_templates()(self.list_resource_templates) - async def list_tools(self) -> list[MCPTool]: - """List all available tools.""" - tools = self._tool_manager.list_tools() + async def list_tools(self, request: types.ListToolsRequest | None = None) -> list[MCPTool]: + """List all available tools, optionally filtered by URI paths.""" + uri_paths = None + if request and request.params and request.params.filters: + uri_paths = request.params.filters.uri_paths + tools = self._tool_manager.list_tools(uri_paths=uri_paths) return [ MCPTool( name=info.name, title=info.title, description=info.description, + uri=info.uri, inputSchema=info.parameters, outputSchema=info.output_schema, annotations=info.annotations, @@ -297,10 +302,12 @@ async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[Cont context = self.get_context() return await self._tool_manager.call_tool(name, arguments, context=context, convert_result=True) - async def list_resources(self) -> list[MCPResource]: - """List all available resources.""" - - resources = self._resource_manager.list_resources() + async def list_resources(self, request: types.ListResourcesRequest | None = None) -> list[MCPResource]: + """List all available resources, optionally filtered by URI paths.""" + uri_paths = None + if request and request.params and request.params.filters: + uri_paths = request.params.filters.uri_paths + resources = self._resource_manager.list_resources(uri_paths=uri_paths) return [ MCPResource( uri=resource.uri, @@ -312,8 +319,14 @@ async def list_resources(self) -> list[MCPResource]: for resource in resources ] - async def list_resource_templates(self) -> list[MCPResourceTemplate]: - templates = self._resource_manager.list_templates() + async def list_resource_templates( + self, request: types.ListResourceTemplatesRequest | None = None + ) -> list[MCPResourceTemplate]: + """List all available resource templates, optionally filtered by URI paths.""" + uri_paths = None + if request and request.params and request.params.filters: + uri_paths = request.params.filters.uri_paths + templates = self._resource_manager.list_templates(uri_paths=uri_paths) return [ MCPResourceTemplate( uriTemplate=template.uri_template, @@ -342,6 +355,7 @@ def add_tool( self, fn: AnyFunction, name: str | None = None, + uri: str | AnyUrl | None = None, title: str | None = None, description: str | None = None, annotations: ToolAnnotations | None = None, @@ -355,6 +369,7 @@ def add_tool( Args: fn: The function to register as a tool name: Optional name for the tool (defaults to function name) + uri: Optional URI for the tool (defaults to {TOOL_SCHEME}/{{name}}) title: Optional human-readable title for the tool description: Optional description of what the tool does annotations: Optional ToolAnnotations providing additional tool information @@ -366,6 +381,7 @@ def add_tool( self._tool_manager.add_tool( fn, name=name, + uri=uri, title=title, description=description, annotations=annotations, @@ -375,6 +391,7 @@ def add_tool( def tool( self, name: str | None = None, + uri: str | AnyUrl | None = None, title: str | None = None, description: str | None = None, annotations: ToolAnnotations | None = None, @@ -388,6 +405,7 @@ def tool( Args: name: Optional name for the tool (defaults to function name) + uri: Optional URI for the tool (defaults to {TOOL_SCHEME}/{{name}}) title: Optional human-readable title for the tool description: Optional description of what the tool does annotations: Optional ToolAnnotations providing additional tool information @@ -421,6 +439,7 @@ def decorator(fn: AnyFunction) -> AnyFunction: self.add_tool( fn, name=name, + uri=uri, title=title, description=description, annotations=annotations, @@ -557,12 +576,17 @@ def add_prompt(self, prompt: Prompt) -> None: self._prompt_manager.add_prompt(prompt) def prompt( - self, name: str | None = None, title: str | None = None, description: str | None = None + self, + name: str | None = None, + uri: str | AnyUrl | None = None, + title: str | None = None, + description: str | None = None, ) -> Callable[[AnyFunction], AnyFunction]: """Decorator to register a prompt. Args: name: Optional name for the prompt (defaults to function name) + uri: Optional URI for the prompt (defaults to {PROMPT_SCHEME}/{{name}}) title: Optional human-readable title for the prompt description: Optional description of what the prompt does @@ -601,7 +625,7 @@ async def analyze_file(path: str) -> list[Message]: ) def decorator(func: AnyFunction) -> AnyFunction: - prompt = Prompt.from_function(func, name=name, title=title, description=description) + prompt = Prompt.from_function(func, name=name, uri=uri, title=title, description=description) self.add_prompt(prompt) return func @@ -956,14 +980,18 @@ def streamable_http_app(self) -> Starlette: lifespan=lambda app: self.session_manager.run(), ) - async def list_prompts(self) -> list[MCPPrompt]: - """List all available prompts.""" - prompts = self._prompt_manager.list_prompts() + async def list_prompts(self, request: types.ListPromptsRequest | None = None) -> list[MCPPrompt]: + """List all available prompts, optionally filtered by URI paths.""" + uri_paths = None + if request and request.params and request.params.filters: + uri_paths = request.params.filters.uri_paths + prompts = self._prompt_manager.list_prompts(uri_paths=uri_paths) return [ MCPPrompt( name=prompt.name, title=prompt.title, description=prompt.description, + uri=prompt.uri, arguments=[ MCPPromptArgument( name=arg.name, diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index f50126081..79d539d2b 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -6,11 +6,11 @@ from functools import cached_property from typing import TYPE_CHECKING, Any, get_origin -from pydantic import BaseModel, Field +from pydantic import AnyUrl, BaseModel, Field from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata -from mcp.types import ToolAnnotations +from mcp.types import TOOL_SCHEME, ToolAnnotations if TYPE_CHECKING: from mcp.server.fastmcp.server import Context @@ -21,10 +21,11 @@ class Tool(BaseModel): """Internal tool registration info.""" - fn: Callable[..., Any] = Field(exclude=True) name: str = Field(description="Name of the tool") + uri: AnyUrl = Field(description="URI of the tool") title: str | None = Field(None, description="Human-readable title of the tool") description: str = Field(description="Description of what the tool does") + fn: Callable[..., Any] = Field(exclude=True) parameters: dict[str, Any] = Field(description="JSON schema for tool parameters") fn_metadata: FuncMetadata = Field( description="Metadata about the function including a pydantic model for tool arguments" @@ -33,6 +34,12 @@ class Tool(BaseModel): context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context") annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool") + def __init__(self, **data: Any) -> None: + """Initialize Tool, generating URI from name if not provided.""" + if not data.get("uri", None): + data["uri"] = AnyUrl(f"{TOOL_SCHEME}/{data['name']}") + super().__init__(**data) + @cached_property def output_schema(self) -> dict[str, Any] | None: return self.fn_metadata.output_schema @@ -42,6 +49,7 @@ def from_function( cls, fn: Callable[..., Any], name: str | None = None, + uri: str | AnyUrl | None = None, title: str | None = None, description: str | None = None, context_kwarg: str | None = None, @@ -78,6 +86,7 @@ def from_function( return cls( fn=fn, name=func_name, + uri=uri, title=title, description=func_doc, parameters=parameters, diff --git a/src/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/server/fastmcp/tools/tool_manager.py index bfa8b2382..8efa5acd8 100644 --- a/src/mcp/server/fastmcp/tools/tool_manager.py +++ b/src/mcp/server/fastmcp/tools/tool_manager.py @@ -1,10 +1,13 @@ from __future__ import annotations as _annotations from collections.abc import Callable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, overload + +from pydantic import AnyUrl from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools.base import Tool +from mcp.server.fastmcp.uri_utils import filter_by_uri_paths, normalize_to_tool_uri from mcp.server.fastmcp.utilities.logging import get_logger from mcp.shared.context import LifespanContextT, RequestT from mcp.types import ToolAnnotations @@ -28,24 +31,57 @@ def __init__( self._tools: dict[str, Tool] = {} if tools is not None: for tool in tools: - if warn_on_duplicate_tools and tool.name in self._tools: - logger.warning(f"Tool already exists: {tool.name}") - self._tools[tool.name] = tool + if warn_on_duplicate_tools and str(tool.uri) in self._tools: + logger.warning(f"Tool already exists: {tool.uri}") + self._tools[str(tool.uri)] = tool self.warn_on_duplicate_tools = warn_on_duplicate_tools - def get_tool(self, name: str) -> Tool | None: + def _normalize_to_uri(self, name_or_uri: str) -> str: + """Convert name to URI if needed.""" + return normalize_to_tool_uri(name_or_uri) + + @overload + def get_tool(self, name_or_uri: str) -> Tool | None: """Get tool by name.""" - return self._tools.get(name) + ... + + @overload + def get_tool(self, name_or_uri: AnyUrl) -> Tool | None: + """Get tool by URI.""" + ... + + def get_tool(self, name_or_uri: AnyUrl | str) -> Tool | None: + """Get tool by name or URI.""" + if isinstance(name_or_uri, AnyUrl): + return self._tools.get(str(name_or_uri)) + + # Try as a direct URI first + if name_or_uri in self._tools: + return self._tools[name_or_uri] + + # Try to find a tool by name + for tool in self._tools.values(): + if tool.name == name_or_uri: + return tool - def list_tools(self) -> list[Tool]: - """List all registered tools.""" - return list(self._tools.values()) + # Finally try normalizing to URI + uri = self._normalize_to_uri(name_or_uri) + return self._tools.get(uri) + + def list_tools(self, uri_paths: list[AnyUrl] | None = None) -> list[Tool]: + """List all registered tools, optionally filtered by URI paths.""" + tools = list(self._tools.values()) + if uri_paths: + tools = filter_by_uri_paths(tools, uri_paths) + logger.debug("Listing tools", extra={"count": len(tools), "uri_paths": uri_paths}) + return tools def add_tool( self, fn: Callable[..., Any], name: str | None = None, + uri: str | AnyUrl | None = None, title: str | None = None, description: str | None = None, annotations: ToolAnnotations | None = None, @@ -55,29 +91,52 @@ def add_tool( tool = Tool.from_function( fn, name=name, + uri=uri, title=title, description=description, annotations=annotations, structured_output=structured_output, ) - existing = self._tools.get(tool.name) + existing = self._tools.get(str(tool.uri)) if existing: if self.warn_on_duplicate_tools: - logger.warning(f"Tool already exists: {tool.name}") + logger.warning(f"Tool already exists: {tool.uri}") return existing - self._tools[tool.name] = tool + self._tools[str(tool.uri)] = tool return tool + @overload async def call_tool( self, - name: str, + name_or_uri: str, arguments: dict[str, Any], context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, convert_result: bool = False, ) -> Any: """Call a tool by name with arguments.""" - tool = self.get_tool(name) + ... + + @overload + async def call_tool( + self, + name_or_uri: AnyUrl, + arguments: dict[str, Any], + context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + convert_result: bool = False, + ) -> Any: + """Call a tool by URI with arguments.""" + ... + + async def call_tool( + self, + name_or_uri: AnyUrl | str, + arguments: dict[str, Any], + context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + convert_result: bool = False, + ) -> Any: + """Call a tool by name or URI with arguments.""" + tool = self.get_tool(name_or_uri) if not tool: - raise ToolError(f"Unknown tool: {name}") + raise ToolError(f"Unknown tool: {name_or_uri}") return await tool.run(arguments, context=context, convert_result=convert_result) diff --git a/src/mcp/server/fastmcp/uri_utils.py b/src/mcp/server/fastmcp/uri_utils.py new file mode 100644 index 000000000..ab6931323 --- /dev/null +++ b/src/mcp/server/fastmcp/uri_utils.py @@ -0,0 +1,72 @@ +"""Common URI utilities for FastMCP.""" + +from collections.abc import Sequence +from typing import Protocol, TypeVar, runtime_checkable + +from pydantic import AnyUrl + +from mcp.types import PROMPT_SCHEME, TOOL_SCHEME + +T = TypeVar("T", bound="HasUri") + + +def normalize_to_uri(name_or_uri: str, scheme: str) -> str: + """Convert name to URI if needed. + + Args: + name_or_uri: Either a name or a full URI + scheme: The URI scheme to use (e.g., TOOL_SCHEME or PROMPT_SCHEME) + + Returns: + A properly formatted URI + """ + if name_or_uri.startswith(scheme): + return name_or_uri + return f"{scheme}/{name_or_uri}" + + +def normalize_to_tool_uri(name_or_uri: str) -> str: + """Convert name to tool URI if needed.""" + return normalize_to_uri(name_or_uri, TOOL_SCHEME) + + +def normalize_to_prompt_uri(name_or_uri: str) -> str: + """Convert name to prompt URI if needed.""" + return normalize_to_uri(name_or_uri, PROMPT_SCHEME) + + +@runtime_checkable +class HasUri(Protocol): + """Protocol for objects that have a URI attribute.""" + + uri: AnyUrl + + +def filter_by_uri_paths(items: Sequence[T], uri_paths: Sequence[AnyUrl]) -> list[T]: + """Filter items by multiple URI path prefixes. + + Args: + items: List of items that have a 'uri' attribute + uri_paths: List of URI path prefixes to filter by. + + Returns: + Filtered list of items matching any of the provided prefixes + """ + + # Filter items where the URI matches any of the prefixes + filtered: list[T] = [] + for item in items: + uri = str(item.uri) + for prefix in uri_paths: + prefix_str = str(prefix) + if uri.startswith(prefix_str): + # If prefix ends with a separator, we already have a proper boundary + if prefix_str.endswith(("/", "?", "#")): + filtered.append(item) + break + # Otherwise check if it's an exact match or if the next character is a separator + elif len(uri) == len(prefix_str) or uri[len(prefix_str)] in ("/", "?", "#"): + filtered.append(item) + break + + return filtered diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index ab6a3d15c..e6c6d9ca7 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -231,11 +231,11 @@ def request_context( return request_ctx.get() def list_prompts(self): - def decorator(func: Callable[[], Awaitable[list[types.Prompt]]]): + def decorator(func: Callable[[types.ListPromptsRequest], Awaitable[list[types.Prompt]]]): logger.debug("Registering handler for PromptListRequest") - async def handler(_: Any): - prompts = await func() + async def handler(request: types.ListPromptsRequest): + prompts = await func(request) return types.ServerResult(types.ListPromptsResult(prompts=prompts)) self.request_handlers[types.ListPromptsRequest] = handler @@ -259,11 +259,11 @@ async def handler(req: types.GetPromptRequest): return decorator def list_resources(self): - def decorator(func: Callable[[], Awaitable[list[types.Resource]]]): + def decorator(func: Callable[[types.ListResourcesRequest], Awaitable[list[types.Resource]]]): logger.debug("Registering handler for ListResourcesRequest") - async def handler(_: Any): - resources = await func() + async def handler(request: types.ListResourcesRequest): + resources = await func(request) return types.ServerResult(types.ListResourcesResult(resources=resources)) self.request_handlers[types.ListResourcesRequest] = handler @@ -272,11 +272,11 @@ async def handler(_: Any): return decorator def list_resource_templates(self): - def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]): + def decorator(func: Callable[[types.ListResourceTemplatesRequest], Awaitable[list[types.ResourceTemplate]]]): logger.debug("Registering handler for ListResourceTemplatesRequest") - async def handler(_: Any): - templates = await func() + async def handler(request: types.ListResourceTemplatesRequest): + templates = await func(request) return types.ServerResult(types.ListResourceTemplatesResult(resourceTemplates=templates)) self.request_handlers[types.ListResourceTemplatesRequest] = handler @@ -382,11 +382,11 @@ async def handler(req: types.UnsubscribeRequest): return decorator def list_tools(self): - def decorator(func: Callable[[], Awaitable[list[types.Tool]]]): + def decorator(func: Callable[[types.ListToolsRequest], Awaitable[list[types.Tool]]]): logger.debug("Registering handler for ListToolsRequest") - async def handler(_: Any): - tools = await func() + async def handler(request: types.ListToolsRequest): + tools = await func(request) # Refresh the tool cache self._tool_cache.clear() for tool in tools: @@ -415,7 +415,7 @@ async def _get_cached_tool_definition(self, tool_name: str) -> types.Tool | None if tool_name not in self._tool_cache: if types.ListToolsRequest in self.request_handlers: logger.debug("Tool cache miss for %s, refreshing cache", tool_name) - await self.request_handlers[types.ListToolsRequest](None) + await self.request_handlers[types.ListToolsRequest](types.ListToolsRequest(method="tools/list")) tool = self._tool_cache.get(tool_name) if tool is None: diff --git a/src/mcp/types.py b/src/mcp/types.py index 98fefa080..d63d22244 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -1,10 +1,24 @@ from collections.abc import Callable from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar -from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel -from pydantic.networks import AnyUrl, UrlConstraints +from pydantic import AnyUrl, BaseModel, ConfigDict, Field, FileUrl, RootModel, model_validator +from pydantic.networks import UrlConstraints from typing_extensions import deprecated +# URI scheme constants +MCP_SCHEME = "mcp" +TOOL_SCHEME = f"{MCP_SCHEME}://tools" +PROMPT_SCHEME = f"{MCP_SCHEME}://prompts" +RESOURCE_SCHEME = f"{MCP_SCHEME}://resources" +GROUP_SCHEME = f"{MCP_SCHEME}://groups" + +# Well-known MCP group URIs +WELL_KNOWN_GROUP_URIS = { + f"{GROUP_SCHEME}/tools", + f"{GROUP_SCHEME}/prompts", + f"{GROUP_SCHEME}/resources", +} + """ Model Context Protocol bindings for Python @@ -55,7 +69,17 @@ class Meta(BaseModel): meta: Meta | None = Field(alias="_meta", default=None) -class PaginatedRequestParams(RequestParams): +class ListFilters(BaseModel): + """Filters for list operations.""" + + uri_paths: list[AnyUrl] | None = None + """Optional list of absolute URI path prefixes to filter results.""" + + +class ListRequestParams(RequestParams): + filters: ListFilters | None = None + """Optional filters to apply to the list results.""" + cursor: Cursor | None = None """ An opaque token representing the current pagination position. @@ -87,11 +111,11 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]): model_config = ConfigDict(extra="allow") -class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]): - """Base class for paginated requests, - matching the schema's PaginatedRequest interface.""" +class ListRequest(Request[ListRequestParams | None, MethodT], Generic[MethodT]): + """Base class for list requests, + matching the schema's ListRequest interface.""" - params: PaginatedRequestParams | None = None + params: ListRequestParams | None = None class Notification(BaseModel, Generic[NotificationParamsT, MethodT]): @@ -113,7 +137,7 @@ class Result(BaseModel): model_config = ConfigDict(extra="allow") -class PaginatedResult(Result): +class ListResult(Result): nextCursor: Cursor | None = None """ An opaque token representing the pagination position after the last returned result. @@ -394,7 +418,7 @@ class ProgressNotification(Notification[ProgressNotificationParams, Literal["not params: ProgressNotificationParams -class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]): +class ListResourcesRequest(ListRequest[Literal["resources/list"]]): """Sent from the client to request a list of resources the server has.""" method: Literal["resources/list"] @@ -430,6 +454,14 @@ class Resource(BaseMetadata): """ model_config = ConfigDict(extra="allow") + @model_validator(mode="after") + def validate_uri_scheme(self) -> "Resource": + """Ensure resource URI doesn't use reserved MCP scheme, except for well-known group URIs.""" + uri_str = str(self.uri) + if uri_str.startswith(f"{MCP_SCHEME}://") and uri_str not in WELL_KNOWN_GROUP_URIS: + raise ValueError(f"Resource URI cannot use reserved MCP scheme '{MCP_SCHEME}://', got: {self.uri}") + return self + class ResourceTemplate(BaseMetadata): """A template description for resources available on the server.""" @@ -455,19 +487,19 @@ class ResourceTemplate(BaseMetadata): model_config = ConfigDict(extra="allow") -class ListResourcesResult(PaginatedResult): +class ListResourcesResult(ListResult): """The server's response to a resources/list request from the client.""" resources: list[Resource] -class ListResourceTemplatesRequest(PaginatedRequest[Literal["resources/templates/list"]]): +class ListResourceTemplatesRequest(ListRequest[Literal["resources/templates/list"]]): """Sent from the client to request a list of resource templates the server has.""" method: Literal["resources/templates/list"] -class ListResourceTemplatesResult(PaginatedResult): +class ListResourceTemplatesResult(ListResult): """The server's response to a resources/templates/list request from the client.""" resourceTemplates: list[ResourceTemplate] @@ -603,7 +635,7 @@ class ResourceUpdatedNotification( params: ResourceUpdatedNotificationParams -class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]): +class ListPromptsRequest(ListRequest[Literal["prompts/list"]]): """Sent from the client to request a list of prompts and prompt templates.""" method: Literal["prompts/list"] @@ -624,6 +656,8 @@ class PromptArgument(BaseModel): class Prompt(BaseMetadata): """A prompt or prompt template that the server offers.""" + uri: Annotated[AnyUrl, UrlConstraints(allowed_schemes=[MCP_SCHEME], host_required=False)] + """URI for the prompt. Auto-generated if not provided.""" description: str | None = None """An optional description of what this prompt provides.""" arguments: list[PromptArgument] | None = None @@ -635,8 +669,22 @@ class Prompt(BaseMetadata): """ model_config = ConfigDict(extra="allow") + def __init__(self, **data: Any) -> None: + """Initialize prompt with auto-generated URI if not provided.""" + if "uri" not in data: + data["uri"] = AnyUrl(f"{PROMPT_SCHEME}/{data['name']}") + super().__init__(**data) + + @model_validator(mode="after") + def validate_prompt_uri(self) -> "Prompt": + """Validate that prompt URI starts with the correct prefix.""" + uri_str = str(self.uri) + if not uri_str.startswith(f"{PROMPT_SCHEME}/"): + raise ValueError(f"Prompt URI must start with {PROMPT_SCHEME}/") + return self + -class ListPromptsResult(PaginatedResult): +class ListPromptsResult(ListResult): """The server's response to a prompts/list request from the client.""" prompts: list[Prompt] @@ -786,7 +834,7 @@ class PromptListChangedNotification( params: NotificationParams | None = None -class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]): +class ListToolsRequest(ListRequest[Literal["tools/list"]]): """Sent from the client to request a list of tools the server has.""" method: Literal["tools/list"] @@ -843,6 +891,8 @@ class ToolAnnotations(BaseModel): class Tool(BaseMetadata): """Definition for a tool the client can call.""" + uri: Annotated[AnyUrl, UrlConstraints(allowed_schemes=[MCP_SCHEME], host_required=False)] + """URI for the tool. Auto-generated if not provided.""" description: str | None = None """A human-readable description of the tool.""" inputSchema: dict[str, Any] @@ -861,8 +911,22 @@ class Tool(BaseMetadata): """ model_config = ConfigDict(extra="allow") + def __init__(self, **data: Any) -> None: + """Initialize tool with auto-generated URI if not provided.""" + if "uri" not in data: + data["uri"] = AnyUrl(f"{TOOL_SCHEME}/{data['name']}") + super().__init__(**data) + + @model_validator(mode="after") + def validate_tool_uri(self) -> "Tool": + """Validate that tool URI starts with the correct prefix.""" + uri_str = str(self.uri) + if not uri_str.startswith(f"{TOOL_SCHEME}/"): + raise ValueError(f"Tool URI must start with {TOOL_SCHEME}/") + return self + -class ListToolsResult(PaginatedResult): +class ListToolsResult(ListResult): """The server's response to a tools/list request from the client.""" tools: list[Tool] diff --git a/tests/client/test_output_schema_validation.py b/tests/client/test_output_schema_validation.py index 242515b96..6a3a4d98b 100644 --- a/tests/client/test_output_schema_validation.py +++ b/tests/client/test_output_schema_validation.py @@ -42,7 +42,7 @@ async def test_tool_structured_output_client_side_validation_basemodel(self): } @server.list_tools() - async def list_tools(): + async def list_tools(_): return [ Tool( name="get_user", @@ -81,7 +81,7 @@ async def test_tool_structured_output_client_side_validation_primitive(self): } @server.list_tools() - async def list_tools(): + async def list_tools(_): return [ Tool( name="calculate", @@ -112,7 +112,7 @@ async def test_tool_structured_output_client_side_validation_dict_typed(self): output_schema = {"type": "object", "additionalProperties": {"type": "integer"}, "title": "get_scores_Output"} @server.list_tools() - async def list_tools(): + async def list_tools(_): return [ Tool( name="get_scores", @@ -147,7 +147,7 @@ async def test_tool_structured_output_client_side_validation_missing_required(se } @server.list_tools() - async def list_tools(): + async def list_tools(_): return [ Tool( name="get_person", @@ -175,7 +175,7 @@ async def test_tool_not_listed_warning(self, caplog): server = Server("test-server") @server.list_tools() - async def list_tools(): + async def list_tools(_): # Return empty list - tool is not listed return [] diff --git a/tests/issues/test_152_resource_mime_type.py b/tests/issues/test_152_resource_mime_type.py index a99e5a5c7..8960c1376 100644 --- a/tests/issues/test_152_resource_mime_type.py +++ b/tests/issues/test_152_resource_mime_type.py @@ -79,7 +79,7 @@ async def test_lowlevel_resource_mime_type(): ] @server.list_resources() - async def handle_list_resources(): + async def handle_list_resources(_): return test_resources @server.read_resource() diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index c3570a39c..8457aa9c1 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -31,7 +31,7 @@ async def test_notification_validation_error(tmp_path: Path): slow_request_lock = anyio.Event() @server.list_tools() - async def list_tools() -> list[types.Tool]: + async def list_tools(_) -> list[types.Tool]: return [ types.Tool( name="slow", diff --git a/tests/server/fastmcp/prompts/test_manager.py b/tests/server/fastmcp/prompts/test_manager.py index 82b234638..9c55afaf7 100644 --- a/tests/server/fastmcp/prompts/test_manager.py +++ b/tests/server/fastmcp/prompts/test_manager.py @@ -1,7 +1,9 @@ import pytest +from pydantic import AnyUrl from mcp.server.fastmcp.prompts.base import Prompt, TextContent, UserMessage from mcp.server.fastmcp.prompts.manager import PromptManager +from mcp.types import PROMPT_SCHEME class TestPromptManager: @@ -61,6 +63,87 @@ def fn2() -> str: assert len(prompts) == 2 assert prompts == [prompt1, prompt2] + def test_list_prompts_with_prefix(self): + """Test listing prompts with prefix filtering.""" + + def greeting_hello() -> str: + return "Hello!" + + def greeting_goodbye() -> str: + return "Goodbye!" + + def question_name() -> str: + return "What's your name?" + + def question_age() -> str: + return "How old are you?" + + manager = PromptManager() + + # Create prompts with custom URIs + hello_prompt = Prompt.from_function(greeting_hello) + hello_prompt.uri = AnyUrl(f"{PROMPT_SCHEME}/greeting/hello") + + goodbye_prompt = Prompt.from_function(greeting_goodbye) + goodbye_prompt.uri = AnyUrl(f"{PROMPT_SCHEME}/greeting/goodbye") + + name_prompt = Prompt.from_function(question_name) + name_prompt.uri = AnyUrl(f"{PROMPT_SCHEME}/question/name") + + age_prompt = Prompt.from_function(question_age) + age_prompt.uri = AnyUrl(f"{PROMPT_SCHEME}/question/age") + + # Add prompts directly to manager's internal storage + manager._prompts = { + str(hello_prompt.uri): hello_prompt, + str(goodbye_prompt.uri): goodbye_prompt, + str(name_prompt.uri): name_prompt, + str(age_prompt.uri): age_prompt, + } + + # Test listing all prompts + all_prompts = manager.list_prompts() + assert len(all_prompts) == 4 + + # Test uri_paths filtering - greeting prompts + greeting_prompts = manager.list_prompts(uri_paths=[AnyUrl(f"{PROMPT_SCHEME}/greeting/")]) + assert len(greeting_prompts) == 2 + assert all(str(p.uri).startswith(f"{PROMPT_SCHEME}/greeting/") for p in greeting_prompts) + assert hello_prompt in greeting_prompts + assert goodbye_prompt in greeting_prompts + + # Test uri_paths filtering - question prompts + question_prompts = manager.list_prompts(uri_paths=[AnyUrl(f"{PROMPT_SCHEME}/question/")]) + assert len(question_prompts) == 2 + assert all(str(p.uri).startswith(f"{PROMPT_SCHEME}/question/") for p in question_prompts) + assert name_prompt in question_prompts + assert age_prompt in question_prompts + + # Test exact URI match + hello_prompts = manager.list_prompts(uri_paths=[AnyUrl(f"{PROMPT_SCHEME}/greeting/hello")]) + assert len(hello_prompts) == 1 + assert hello_prompts[0] == hello_prompt + + # Test partial prefix doesn't match + no_partial = manager.list_prompts(uri_paths=[AnyUrl(f"{PROMPT_SCHEME}/greeting/h")]) + assert len(no_partial) == 0 # Won't match because next char is 'e' not a separator + + # Test no matches + no_matches = manager.list_prompts(uri_paths=[AnyUrl(f"{PROMPT_SCHEME}/nonexistent")]) + assert len(no_matches) == 0 + + # Test with trailing slash + greeting_prompts_slash = manager.list_prompts(uri_paths=[AnyUrl(f"{PROMPT_SCHEME}/greeting/")]) + assert len(greeting_prompts_slash) == 2 + assert greeting_prompts_slash == greeting_prompts + + # Test multiple uri_paths + greeting_and_question = manager.list_prompts( + uri_paths=[AnyUrl(f"{PROMPT_SCHEME}/greeting/"), AnyUrl(f"{PROMPT_SCHEME}/question/")] + ) + assert len(greeting_and_question) == 4 + assert all(p in greeting_and_question for p in all_prompts) + @pytest.mark.anyio async def test_render_prompt(self): """Test rendering a prompt.""" @@ -106,3 +189,186 @@ def fn(name: str) -> str: manager.add_prompt(prompt) with pytest.raises(ValueError, match="Missing required arguments"): await manager.render_prompt("fn") + + def test_get_prompt_by_uri(self): + """Test getting prompts by their URI.""" + + def greeting() -> str: + return "Hello!" + + def custom_prompt() -> str: + return "Custom message" + + manager = PromptManager() + + # Add prompt with default URI + manager.add_prompt(Prompt.from_function(greeting)) + + # Add prompt with custom URI + custom = Prompt.from_function(custom_prompt, uri=f"{PROMPT_SCHEME}/custom/messages/welcome") + manager.add_prompt(custom) + + # Get by name + prompt = manager.get_prompt("greeting") + assert prompt is not None + assert prompt.name == "greeting" + + # Get by default URI + prompt_by_uri = manager.get_prompt(f"{PROMPT_SCHEME}/greeting") + assert prompt_by_uri is not None + assert prompt_by_uri.name == "greeting" + assert prompt_by_uri == prompt + + # Get by custom URI + custom_by_uri = manager.get_prompt(f"{PROMPT_SCHEME}/custom/messages/welcome") + assert custom_by_uri is not None + assert custom_by_uri == custom + + # Custom URI prompt should also work with name + custom_by_name = manager.get_prompt("custom_prompt") + assert custom_by_name is not None + assert custom_by_name == custom + + @pytest.mark.anyio + async def test_render_prompt_by_uri(self): + """Test rendering prompts by their URI.""" + + def welcome(name: str) -> str: + return f"Welcome, {name}!" + + def farewell(name: str) -> str: + return f"Goodbye, {name}!" + + manager = PromptManager() + + # Add prompt with default URI + manager.add_prompt(Prompt.from_function(welcome)) + + # Add prompt with custom URI + farewell_prompt = Prompt.from_function(farewell, uri=f"{PROMPT_SCHEME}/custom/farewell") + manager.add_prompt(farewell_prompt) + + # Render by default URI + messages = await manager.render_prompt(f"{PROMPT_SCHEME}/welcome", arguments={"name": "Alice"}) + assert messages == [UserMessage(content=TextContent(type="text", text="Welcome, Alice!"))] + + # Render by custom URI + messages = await manager.render_prompt(f"{PROMPT_SCHEME}/custom/farewell", arguments={"name": "Bob"}) + assert messages == [UserMessage(content=TextContent(type="text", text="Goodbye, Bob!"))] + + # Should still work with name + messages = await manager.render_prompt("welcome", arguments={"name": "Charlie"}) + assert messages == [UserMessage(content=TextContent(type="text", text="Welcome, Charlie!"))] + + # Custom URI prompt should also work with name + messages = await manager.render_prompt("farewell", arguments={"name": "David"}) + assert messages == [UserMessage(content=TextContent(type="text", text="Goodbye, David!"))] + + def test_add_prompt_with_custom_uri(self): + """Test adding prompts with custom URI parameter.""" + + def greeting(name: str) -> str: + return f"Hello, {name}!" + + def question(topic: str) -> str: + return f"What do you think about {topic}?" + + manager = PromptManager() + + # Add prompt with custom hierarchical URI + prompt1 = Prompt.from_function(greeting, uri="mcp://prompts/greetings/hello") + added1 = manager.add_prompt(prompt1) + assert added1.name == "greeting" + assert str(added1.uri) == "mcp://prompts/greetings/hello" + + # Add prompt with AnyUrl + prompt2 = Prompt.from_function(question, uri=AnyUrl("mcp://prompts/questions/general")) + added2 = manager.add_prompt(prompt2) + assert added2.name == "question" + assert str(added2.uri) == "mcp://prompts/questions/general" + + # Verify prompts are stored by URI + assert str(prompt1.uri) in manager._prompts + assert str(prompt2.uri) in manager._prompts + + def test_get_prompt_by_name_with_custom_uri(self): + """Test getting prompts by name when they have custom URIs.""" + + def welcome(name: str) -> str: + return f"Welcome, {name}!" + + def goodbye(name: str) -> str: + return f"Goodbye, {name}!" + + manager = PromptManager() + + # Add prompts with custom URIs + welcome_prompt = Prompt.from_function(welcome, uri="mcp://prompts/greetings/welcome") + goodbye_prompt = Prompt.from_function(goodbye, uri="mcp://prompts/greetings/goodbye") + + manager.add_prompt(welcome_prompt) + manager.add_prompt(goodbye_prompt) + + # Should be able to get by name + prompt_by_name = manager.get_prompt("welcome") + assert prompt_by_name is not None + assert prompt_by_name == welcome_prompt + + # Should also work for the second prompt + prompt_by_name2 = manager.get_prompt("goodbye") + assert prompt_by_name2 is not None + assert prompt_by_name2 == goodbye_prompt + + # Should also be able to get by URI + prompt_by_uri = manager.get_prompt("mcp://prompts/greetings/welcome") + assert prompt_by_uri == welcome_prompt + + # Get by AnyUrl + prompt_by_anyurl = manager.get_prompt(AnyUrl("mcp://prompts/greetings/goodbye")) + assert prompt_by_anyurl == goodbye_prompt + + @pytest.mark.anyio + async def test_prompt_name_lookup_with_hierarchical_uri(self): + """Test name lookup works correctly with hierarchical URIs.""" + + def hello(name: str) -> str: + return f"Hello, {name}!" + + def askname() -> str: + return "What is your name?" + + def confirm(action: str) -> str: + return f"Are you sure you want to {action}?" + + manager = PromptManager() + + # Add prompts with hierarchical URIs + hello_prompt = Prompt.from_function(hello, uri="mcp://prompts/greetings/hello") + askname_prompt = Prompt.from_function(askname, uri="mcp://prompts/questions/askname") + confirm_prompt = Prompt.from_function(confirm, uri="mcp://prompts/confirmations/confirm") + + manager.add_prompt(hello_prompt) + manager.add_prompt(askname_prompt) + manager.add_prompt(confirm_prompt) + + # Test rendering by name + messages = await manager.render_prompt("hello", {"name": "Alice"}) + assert messages == [UserMessage(content=TextContent(type="text", text="Hello, Alice!"))] + + messages = await manager.render_prompt("askname") + assert messages == [UserMessage(content=TextContent(type="text", text="What is your name?"))] + + messages = await manager.render_prompt("confirm", {"action": "delete"}) + assert messages == [UserMessage(content=TextContent(type="text", text="Are you sure you want to delete?"))] + + # Test rendering by full URI + messages = await manager.render_prompt("mcp://prompts/greetings/hello", {"name": "Bob"}) + assert messages == [UserMessage(content=TextContent(type="text", text="Hello, Bob!"))] + + messages = await manager.render_prompt(AnyUrl("mcp://prompts/questions/askname")) + assert messages == [UserMessage(content=TextContent(type="text", text="What is your name?"))] + + # Verify that the standard normalization doesn't work + # (since prompts are at custom URIs, not standard ones) + prompt = manager.get_prompt(f"{PROMPT_SCHEME}/hello") + assert prompt is None # Should not find it at the standard URI diff --git a/tests/server/fastmcp/resources/test_resource_manager.py b/tests/server/fastmcp/resources/test_resource_manager.py index 4423e5315..c566d6196 100644 --- a/tests/server/fastmcp/resources/test_resource_manager.py +++ b/tests/server/fastmcp/resources/test_resource_manager.py @@ -139,3 +139,108 @@ def test_list_resources(self, temp_file: Path): resources = manager.list_resources() assert len(resources) == 2 assert resources == [resource1, resource2] + + def test_list_resources_with_prefix(self, temp_file: Path): + """Test listing resources with prefix filtering.""" + manager = ResourceManager() + + # Add resources with different URIs + resource1 = FileResource( + uri=FileUrl("file:///data/images/test.jpg"), + name="test_image", + path=temp_file, + ) + resource2 = FileResource( + uri=FileUrl("file:///data/docs/test.txt"), + name="test_doc", + path=temp_file, + ) + resource3 = FileResource( + uri=FileUrl("file:///other/test.txt"), + name="other_test", + path=temp_file, + ) + + manager.add_resource(resource1) + manager.add_resource(resource2) + manager.add_resource(resource3) + + # Test uri_paths filtering + data_resources = manager.list_resources(uri_paths=[AnyUrl("file:///data/")]) + assert len(data_resources) == 2 + assert resource1 in data_resources + assert resource2 in data_resources + + # More specific prefix + image_resources = manager.list_resources(uri_paths=[AnyUrl("file:///data/images/")]) + assert len(image_resources) == 1 + assert resource1 in image_resources + + # No matches + no_matches = manager.list_resources(uri_paths=[AnyUrl("file:///nonexistent/")]) + assert len(no_matches) == 0 + + # Multiple uri_paths + multi_resources = manager.list_resources(uri_paths=[AnyUrl("file:///data/"), AnyUrl("file:///other/")]) + assert len(multi_resources) == 3 + assert all(r in multi_resources for r in [resource1, resource2, resource3]) + + def test_list_templates_with_prefix(self): + """Test listing templates with prefix filtering.""" + manager = ResourceManager() + + # Add templates with different URI patterns + def user_func(user_id: str) -> str: + return f"User {user_id}" + + def post_func(user_id: str, post_id: str) -> str: + return f"User {user_id} Post {post_id}" + + def product_func(product_id: str) -> str: + return f"Product {product_id}" + + template1 = manager.add_template(user_func, uri_template="http://api.com/users/{user_id}", name="user_template") + template2 = manager.add_template( + post_func, uri_template="http://api.com/users/{user_id}/posts/{post_id}", name="post_template" + ) + template3 = manager.add_template( + product_func, uri_template="http://api.com/products/{product_id}", name="product_template" + ) + + # Test listing all templates + all_templates = manager.list_templates() + assert len(all_templates) == 3 + + # Test uri_paths filtering - matches both user templates + user_templates = manager.list_templates(uri_paths=[AnyUrl("http://api.com/users/")]) + assert len(user_templates) == 2 + assert template1 in user_templates + assert template2 in user_templates + + # Test partial materialization - only matches post template + # The template users/{user_id} generates "users/123" not "users/123/" + # But users/{user_id}/posts/{post_id} can generate "users/123/posts/456" + user_123_templates = manager.list_templates(uri_paths=[AnyUrl("http://api.com/users/123/")]) + assert len(user_123_templates) == 1 + assert template2 in user_123_templates # users/{user_id}/posts/{post_id} matches + + # Without trailing slash, it gets added automatically so only posts template matches + user_123_no_slash = manager.list_templates(uri_paths=[AnyUrl("http://api.com/users/123")]) + assert len(user_123_no_slash) == 1 + assert template2 in user_123_no_slash # Only posts template has path after users/123/ + + # Test product prefix + product_templates = manager.list_templates(uri_paths=[AnyUrl("http://api.com/products/")]) + assert len(product_templates) == 1 + assert template3 in product_templates + + # No matches + no_matches = manager.list_templates(uri_paths=[AnyUrl("http://api.com/orders/")]) + assert len(no_matches) == 0 + + # Multiple uri_paths + users_and_products = manager.list_templates( + uri_paths=[AnyUrl("http://api.com/users/"), AnyUrl("http://api.com/products/")] + ) + assert len(users_and_products) == 3 + assert all(t in users_and_products for t in all_templates) diff --git a/tests/server/fastmcp/resources/test_resource_template.py b/tests/server/fastmcp/resources/test_resource_template.py index f47244361..bcf040813 100644 --- a/tests/server/fastmcp/resources/test_resource_template.py +++ b/tests/server/fastmcp/resources/test_resource_template.py @@ -186,3 +186,117 @@ def get_data(value: str) -> CustomData: assert isinstance(resource, FunctionResource) content = await resource.read() assert content == '"hello"' + + def test_matches_prefix_exact_template(self): + """Test that templates match when prefix matches template exactly.""" + + def dummy_func() -> str: + return "data" + + template = ResourceTemplate.from_function( + dummy_func, uri_template="http://api.example.com/users/{user_id}", name="test" + ) + + # Exact prefix of template + assert template.matches_prefix("http://api.example.com/users/") + assert template.matches_prefix("http://api.example.com/users") + assert template.matches_prefix("http://api.example.com/") + assert template.matches_prefix("http://") + + def test_matches_prefix_partial_materialization(self): + """Test matching with partially materialized parameters.""" + + def dummy_func(user_id: str, post_id: str) -> str: + return f"User {user_id} Post {post_id}" + + template = ResourceTemplate.from_function( + dummy_func, uri_template="http://api.example.com/users/{user_id}/posts/{post_id}", name="test" + ) + + # Partial materialization - user_id replaced with value + assert template.matches_prefix("http://api.example.com/users/123/") + assert template.matches_prefix("http://api.example.com/users/123/posts/") + assert template.matches_prefix("http://api.example.com/users/alice/posts/") + + # Without trailing slash + assert template.matches_prefix("http://api.example.com/users/123") + assert template.matches_prefix("http://api.example.com/users/123/posts") + + def test_matches_prefix_no_match_different_structure(self): + """Test that templates don't match when structure differs.""" + + def dummy_func(user_id: str) -> str: + return f"User {user_id}" + + template = ResourceTemplate.from_function( + dummy_func, uri_template="http://api.example.com/users/{user_id}", name="test" + ) + + # Different path structure + assert not template.matches_prefix("http://api.example.com/products/") + assert not template.matches_prefix("http://api.example.com/users/123/invalid/") + assert not template.matches_prefix("http://different.com/users/") + + def test_matches_prefix_complex_nested(self): + """Test matching with complex nested templates.""" + + def dummy_func(org_id: str, team_id: str, user_id: str) -> str: + return f"Org {org_id} Team {team_id} User {user_id}" + + template = ResourceTemplate.from_function( + dummy_func, uri_template="http://api.example.com/orgs/{org_id}/teams/{team_id}/users/{user_id}", name="test" + ) + + # Various levels of partial materialization + assert template.matches_prefix("http://api.example.com/orgs/") + assert template.matches_prefix("http://api.example.com/orgs/acme/") + assert template.matches_prefix("http://api.example.com/orgs/acme/teams/") + assert template.matches_prefix("http://api.example.com/orgs/acme/teams/dev/") + assert template.matches_prefix("http://api.example.com/orgs/acme/teams/dev/users/") + + def test_matches_prefix_file_uri(self): + """Test matching with file:// URI templates.""" + + def dummy_func(category: str, filename: str) -> str: + return f"File {category}/{filename}" + + template = ResourceTemplate.from_function( + dummy_func, uri_template="file:///data/{category}/{filename}", name="test" + ) + + assert template.matches_prefix("file:///data/") + assert template.matches_prefix("file:///data/images/") + assert template.matches_prefix("file:///data/docs/") + assert not template.matches_prefix("file:///other/") + + def test_matches_prefix_trailing_slash_semantics(self): + """Test that trailing slashes have semantic meaning.""" + + def dummy_func(id: str) -> str: + return f"Item {id}" + + template = ResourceTemplate.from_function( + dummy_func, uri_template="http://api.example.com/items/{id}", name="test" + ) + + # Prefix without trailing slash matches (looking for items or under items) + assert template.matches_prefix("http://api.example.com/items") + assert template.matches_prefix("http://api.example.com/items/123") + + # Prefix with trailing slash only matches if template generates something under it + assert template.matches_prefix("http://api.example.com/items/") # template generates items/X + assert not template.matches_prefix("http://api.example.com/items/123/") # template can't generate items/123/... + + def test_matches_prefix_longer_than_template(self): + """Test that prefixes longer than template don't match.""" + + def dummy_func(id: str) -> str: + return f"Item {id}" + + template = ResourceTemplate.from_function( + dummy_func, uri_template="http://api.example.com/items/{id}", name="test" + ) + + # Prefix has more segments than template + assert not template.matches_prefix("http://api.example.com/items/123/extra/") + assert not template.matches_prefix("http://api.example.com/items/123/extra/more/") diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index a9e0d182a..f69bc720b 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -16,6 +16,8 @@ create_connected_server_and_client_session as client_session, ) from mcp.types import ( + PROMPT_SCHEME, + TOOL_SCHEME, AudioContent, BlobResourceContents, ContentBlock, @@ -217,6 +219,9 @@ async def test_list_tools(self): async with client_session(mcp._mcp_server) as client: tools = await client.list_tools() assert len(tools.tools) == 1 + # Verify URI is generated + tool = tools.tools[0] + assert str(tool.uri) == f"{TOOL_SCHEME}/tool_fn" @pytest.mark.anyio async def test_call_tool(self): @@ -957,6 +962,8 @@ def fn(name: str, optional: str = "default") -> str: assert len(result.prompts) == 1 prompt = result.prompts[0] assert prompt.name == "fn" + # Verify URI is generated + assert str(prompt.uri) == f"{PROMPT_SCHEME}/fn" assert prompt.arguments is not None assert len(prompt.arguments) == 2 assert prompt.arguments[0].name == "name" diff --git a/tests/server/fastmcp/test_title.py b/tests/server/fastmcp/test_title.py index a94f6671d..1dc3ccaae 100644 --- a/tests/server/fastmcp/test_title.py +++ b/tests/server/fastmcp/test_title.py @@ -187,11 +187,18 @@ async def test_get_display_name_utility(): tool_with_title = Tool(name="test_tool", title="Test Tool", inputSchema={}) assert get_display_name(tool_with_title) == "Test Tool" - tool_with_annotations = Tool(name="test_tool", inputSchema={}, annotations=ToolAnnotations(title="Annotated Tool")) + tool_with_annotations = Tool( + name="test_tool", + inputSchema={}, + annotations=ToolAnnotations(title="Annotated Tool"), + ) assert get_display_name(tool_with_annotations) == "Annotated Tool" tool_with_both = Tool( - name="test_tool", title="Primary Title", inputSchema={}, annotations=ToolAnnotations(title="Secondary Title") + name="test_tool", + title="Primary Title", + inputSchema={}, + annotations=ToolAnnotations(title="Secondary Title"), ) assert get_display_name(tool_with_both) == "Primary Title" diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index 27e16cc8e..1973ed7bf 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -4,7 +4,7 @@ from typing import Any, TypedDict import pytest -from pydantic import BaseModel +from pydantic import AnyUrl, BaseModel from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp.exceptions import ToolError @@ -12,7 +12,7 @@ from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata from mcp.server.session import ServerSessionT from mcp.shared.context import LifespanContextT, RequestT -from mcp.types import TextContent, ToolAnnotations +from mcp.types import TOOL_SCHEME, TextContent, ToolAnnotations class TestAddTools: @@ -62,7 +62,7 @@ class AddArguments(ArgModelBase): # warn on duplicate tools with caplog.at_level(logging.WARNING): manager = ToolManager(True, tools=[original_tool, original_tool]) - assert "Tool already exists: sum" in caplog.text + assert f"Tool already exists: {TOOL_SCHEME}/sum" in caplog.text @pytest.mark.anyio async def test_async_function(self): @@ -163,7 +163,7 @@ def f(x: int) -> int: manager.add_tool(f) with caplog.at_level(logging.WARNING): manager.add_tool(f) - assert "Tool already exists: f" in caplog.text + assert f"Tool already exists: {TOOL_SCHEME}/f" in caplog.text def test_disable_warn_on_duplicate_tools(self, caplog): """Test disabling warning on duplicate tools.""" @@ -178,6 +178,189 @@ def f(x: int) -> int: manager.add_tool(f) assert "Tool already exists: f" not in caplog.text + def test_list_tools_with_prefix(self): + """Test listing tools with prefix filtering.""" + manager = ToolManager() + + # Add tools with different URI prefixes + def math_add(a: int, b: int) -> int: + return a + b + + def math_multiply(a: int, b: int) -> int: + return a * b + + def string_concat(a: str, b: str) -> str: + return a + b + + def string_upper(text: str) -> str: + return text.upper() + + # Add tools with custom URIs + math_add_tool = Tool.from_function(math_add) + math_add_tool.uri = AnyUrl(f"{TOOL_SCHEME}/math/add") + + math_multiply_tool = Tool.from_function(math_multiply) + math_multiply_tool.uri = AnyUrl(f"{TOOL_SCHEME}/math/multiply") + + string_concat_tool = Tool.from_function(string_concat) + string_concat_tool.uri = AnyUrl(f"{TOOL_SCHEME}/string/concat") + + string_upper_tool = Tool.from_function(string_upper) + string_upper_tool.uri = AnyUrl(f"{TOOL_SCHEME}/string/upper") + + manager._tools = { + str(math_add_tool.uri): math_add_tool, + str(math_multiply_tool.uri): math_multiply_tool, + str(string_concat_tool.uri): string_concat_tool, + str(string_upper_tool.uri): string_upper_tool, + } + + # Test listing all tools + all_tools = manager.list_tools() + assert len(all_tools) == 4 + + # Test uri_paths filtering - math tools + math_tools = manager.list_tools(uri_paths=[AnyUrl(f"{TOOL_SCHEME}/math/")]) + assert len(math_tools) == 2 + assert all(str(t.uri).startswith(f"{TOOL_SCHEME}/math/") for t in math_tools) + assert math_add_tool in math_tools + assert math_multiply_tool in math_tools + + # Test uri_paths filtering - string tools + string_tools = manager.list_tools(uri_paths=[AnyUrl(f"{TOOL_SCHEME}/string/")]) + assert len(string_tools) == 2 + assert all(str(t.uri).startswith(f"{TOOL_SCHEME}/string/") for t in string_tools) + assert string_concat_tool in string_tools + assert string_upper_tool in string_tools + + # Test exact URI match + add_tools = manager.list_tools(uri_paths=[AnyUrl(f"{TOOL_SCHEME}/math/add")]) + assert len(add_tools) == 1 + assert add_tools[0] == math_add_tool + + # Test partial prefix doesn't match + no_partial = manager.list_tools(uri_paths=[AnyUrl(f"{TOOL_SCHEME}/math/a")]) + assert len(no_partial) == 0 # Won't match because next char is 'd' not a separator + + # Test no matches + no_matches = manager.list_tools(uri_paths=[AnyUrl(f"{TOOL_SCHEME}/nonexistent")]) + assert len(no_matches) == 0 + + # Test with trailing slash + math_tools_slash = manager.list_tools(uri_paths=[AnyUrl(f"{TOOL_SCHEME}/math/")]) + assert len(math_tools_slash) == 2 + assert math_tools_slash == math_tools + + # Test multiple uri_paths + math_and_string = manager.list_tools( + uri_paths=[AnyUrl(f"{TOOL_SCHEME}/math/"), AnyUrl(f"{TOOL_SCHEME}/string/")] + ) + assert len(math_and_string) == 4 + assert all(t in math_and_string for t in all_tools) + + def test_add_tool_with_custom_uri(self): + """Test adding tools with custom URI parameter.""" + + def math_add(a: int, b: int) -> int: + return a + b + + def string_concat(a: str, b: str) -> str: + return a + b + + manager = ToolManager() + + # Add tool with custom hierarchical URI + tool1 = manager.add_tool(math_add, uri="mcp://tools/math/add") + assert tool1.name == "math_add" + assert str(tool1.uri) == "mcp://tools/math/add" + + # Add tool with AnyUrl + tool2 = manager.add_tool(string_concat, uri=AnyUrl("mcp://tools/string/concat")) + assert tool2.name == "string_concat" + assert str(tool2.uri) == "mcp://tools/string/concat" + + # Verify tools are stored by URI + assert str(tool1.uri) in manager._tools + assert str(tool2.uri) in manager._tools + + def test_get_tool_by_name_with_custom_uri(self): + """Test getting tools by name when they have custom URIs.""" + + def calculator(x: int, y: int) -> int: + return x + y + + def formatter(text: str) -> str: + return text.upper() + + manager = ToolManager() + + # Add tools with custom URIs + calc_tool = manager.add_tool(calculator, uri="mcp://tools/utils/calculator") + format_tool = manager.add_tool(formatter, uri="mcp://tools/text/formatter") + + # Should be able to get by name + tool_by_name = manager.get_tool("calculator") + assert tool_by_name is not None + assert tool_by_name == calc_tool + + # Should also work for the second tool + tool_by_name2 = manager.get_tool("formatter") + assert tool_by_name2 is not None + assert tool_by_name2 == format_tool + + # Should also be able to get by URI + tool_by_uri = manager.get_tool("mcp://tools/utils/calculator") + assert tool_by_uri == calc_tool + + # Get by AnyUrl + tool_by_anyurl = manager.get_tool(AnyUrl("mcp://tools/text/formatter")) + assert tool_by_anyurl == format_tool + + @pytest.mark.anyio + async def test_tool_name_lookup_with_hierarchical_uri(self): + """Test name lookup works correctly with hierarchical URIs.""" + + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + def multiply(a: int, b: int) -> int: + """Multiply two numbers.""" + return a * b + + def concat(a: str, b: str) -> str: + """Concatenate strings.""" + return a + b + + manager = ToolManager() + + # Add tools with hierarchical URIs + _ = manager.add_tool(add, uri="mcp://tools/math/add") + _ = manager.add_tool(multiply, uri="mcp://tools/math/multiply") + _ = manager.add_tool(concat, uri="mcp://tools/string/concat") + + # Test calling by name + result = await manager.call_tool("add", {"a": 5, "b": 3}) + assert result == 8 + + result = await manager.call_tool("multiply", {"a": 4, "b": 6}) + assert result == 24 + + result = await manager.call_tool("concat", {"a": "Hello, ", "b": "World!"}) + assert result == "Hello, World!" + + # Test calling by full URI + result = await manager.call_tool("mcp://tools/math/add", {"a": 10, "b": 20}) + assert result == 30 + + result = await manager.call_tool(AnyUrl("mcp://tools/string/concat"), {"a": "Foo", "b": "Bar"}) + assert result == "FooBar" + + # Verify that the standard normalization doesn't work + # (since tools are at custom URIs, not standard ones) + tool = manager.get_tool(f"{TOOL_SCHEME}/add") + assert tool is None # Should not find it at the standard URI + class TestCallTools: @pytest.mark.anyio @@ -258,6 +441,62 @@ async def test_call_unknown_tool(self): with pytest.raises(ToolError): await manager.call_tool("unknown", {"a": 1}) + @pytest.mark.anyio + async def test_call_tool_by_uri(self): + """Test calling tools by their URI instead of name.""" + + def math_add(a: int, b: int) -> int: + return a + b + + def math_multiply(a: int, b: int) -> int: + return a * b + + manager = ToolManager() + + # Add tool with default URI + manager.add_tool(math_add) + + # Add tool with custom URI + manager.add_tool(math_multiply, uri=f"{TOOL_SCHEME}/custom/math/multiply") + + # Call by default URI (TOOL_SCHEME/function_name) + result = await manager.call_tool(f"{TOOL_SCHEME}/math_add", {"a": 5, "b": 3}) + assert result == 8 + + # Call by custom URI + result = await manager.call_tool(f"{TOOL_SCHEME}/custom/math/multiply", {"a": 4, "b": 7}) + assert result == 28 + + # Should still work with name + result = await manager.call_tool("math_add", {"a": 2, "b": 2}) + assert result == 4 + + # Custom URI tool should also work with name + result = await manager.call_tool("math_multiply", {"a": 3, "b": 3}) + assert result == 9 + + def test_get_tool_by_uri(self): + """Test getting tools by their URI.""" + + def calculator() -> str: + return "calculator" + + manager = ToolManager() + + # Add tool with default URI + manager.add_tool(calculator) + + # Get by name + tool = manager.get_tool("calculator") + assert tool is not None + assert tool.name == "calculator" + + # Get by URI + tool_by_uri = manager.get_tool(f"{TOOL_SCHEME}/calculator") + assert tool_by_uri is not None + assert tool_by_uri.name == "calculator" + assert tool_by_uri == tool + @pytest.mark.anyio async def test_call_tool_with_list_int_input(self): def sum_vals(vals: list[int]) -> int: diff --git a/tests/server/fastmcp/test_uri_utils.py b/tests/server/fastmcp/test_uri_utils.py new file mode 100644 index 000000000..9a8d90121 --- /dev/null +++ b/tests/server/fastmcp/test_uri_utils.py @@ -0,0 +1,204 @@ +"""Tests for URI utility functions.""" + +from pydantic import AnyUrl + +from mcp.server.fastmcp.uri_utils import ( + filter_by_uri_paths, + normalize_to_prompt_uri, + normalize_to_tool_uri, + normalize_to_uri, +) +from mcp.types import PROMPT_SCHEME, TOOL_SCHEME + + +class TestNormalizeToUri: + """Test the generic normalize_to_uri function.""" + + def test_normalize_name_to_uri(self): + """Test converting a name to URI.""" + result = normalize_to_uri("test_name", TOOL_SCHEME) + assert result == f"{TOOL_SCHEME}/test_name" + + def test_normalize_already_uri(self): + """Test that URIs are returned unchanged.""" + uri = f"{TOOL_SCHEME}/existing_uri" + result = normalize_to_uri(uri, TOOL_SCHEME) + assert result == uri + + def test_normalize_with_different_scheme(self): + """Test normalizing with different schemes.""" + result = normalize_to_uri("test", PROMPT_SCHEME) + assert result == f"{PROMPT_SCHEME}/test" + + def test_normalize_empty_name(self): + """Test normalizing empty string.""" + result = normalize_to_uri("", TOOL_SCHEME) + assert result == f"{TOOL_SCHEME}/" + + def test_normalize_special_characters(self): + """Test normalizing names with special characters.""" + result = normalize_to_uri("test-name_123", TOOL_SCHEME) + assert result == f"{TOOL_SCHEME}/test-name_123" + + +class TestNormalizeToToolUri: + """Test the tool-specific URI normalization.""" + + def test_normalize_tool_name(self): + """Test converting tool name to URI.""" + result = normalize_to_tool_uri("calculator") + assert result == f"{TOOL_SCHEME}/calculator" + + def test_normalize_existing_tool_uri(self): + """Test that tool URIs are returned unchanged.""" + uri = f"{TOOL_SCHEME}/existing_tool" + result = normalize_to_tool_uri(uri) + assert result == uri + + def test_normalize_tool_with_path(self): + """Test normalizing tool names that look like paths.""" + result = normalize_to_tool_uri("math/calculator") + assert result == f"{TOOL_SCHEME}/math/calculator" + + +class TestNormalizeToPromptUri: + """Test the prompt-specific URI normalization.""" + + def test_normalize_prompt_name(self): + """Test converting prompt name to URI.""" + result = normalize_to_prompt_uri("greeting") + assert result == f"{PROMPT_SCHEME}/greeting" + + def test_normalize_existing_prompt_uri(self): + """Test that prompt URIs are returned unchanged.""" + uri = f"{PROMPT_SCHEME}/existing_prompt" + result = normalize_to_prompt_uri(uri) + assert result == uri + + def test_normalize_prompt_with_path(self): + """Test normalizing prompt names that look like paths.""" + result = normalize_to_prompt_uri("templates/greeting") + assert result == f"{PROMPT_SCHEME}/templates/greeting" + + +class MockUriItem: + """Mock item with a URI for testing.""" + + def __init__(self, uri: str): + self.uri = AnyUrl(uri) + + +class TestFilterByUriPaths: + """Test the URI paths filtering function.""" + + def test_filter_empty_paths(self): + """Test that empty paths list returns empty result.""" + + class MockItem: + def __init__(self, uri: str): + self.uri = AnyUrl(uri) + + items = [MockItem("http://example.com/item1"), MockItem("http://example.com/item2")] + result = filter_by_uri_paths(items, []) + assert result == [] + + def test_filter_single_path(self): + """Test filtering with a single path.""" + items = [ + MockUriItem(f"{TOOL_SCHEME}/math/add"), + MockUriItem(f"{TOOL_SCHEME}/math/subtract"), + MockUriItem(f"{TOOL_SCHEME}/string/concat"), + ] + result = filter_by_uri_paths(items, [AnyUrl(f"{TOOL_SCHEME}/math")]) + assert len(result) == 2 + assert any(str(item.uri) == f"{TOOL_SCHEME}/math/add" for item in result) + assert any(str(item.uri) == f"{TOOL_SCHEME}/math/subtract" for item in result) + + def test_filter_multiple_paths(self): + """Test filtering with multiple paths.""" + items = [ + MockUriItem(f"{TOOL_SCHEME}/math/add"), + MockUriItem(f"{TOOL_SCHEME}/math/subtract"), + MockUriItem(f"{TOOL_SCHEME}/string/concat"), + MockUriItem(f"{PROMPT_SCHEME}/greet/hello"), + ] + result = filter_by_uri_paths(items, [AnyUrl(f"{TOOL_SCHEME}/math"), AnyUrl(f"{PROMPT_SCHEME}/greet")]) + assert len(result) == 3 + assert any(str(item.uri) == f"{TOOL_SCHEME}/math/add" for item in result) + assert any(str(item.uri) == f"{TOOL_SCHEME}/math/subtract" for item in result) + assert any(str(item.uri) == f"{PROMPT_SCHEME}/greet/hello" for item in result) + assert not any(str(item.uri) == f"{TOOL_SCHEME}/string/concat" for item in result) + + def test_filter_paths_without_slash(self): + """Test that paths without trailing slash only match at boundaries.""" + items = [ + MockUriItem(f"{TOOL_SCHEME}/math/add"), + MockUriItem(f"{TOOL_SCHEME}/math/subtract"), + MockUriItem(f"{TOOL_SCHEME}/string/concat"), + MockUriItem(f"{TOOL_SCHEME}/mathematic"), + ] + result = filter_by_uri_paths(items, [AnyUrl(f"{TOOL_SCHEME}/math"), AnyUrl(f"{TOOL_SCHEME}/string")]) + assert len(result) == 3 + assert any(str(item.uri) == f"{TOOL_SCHEME}/math/add" for item in result) + assert any(str(item.uri) == f"{TOOL_SCHEME}/math/subtract" for item in result) + assert any(str(item.uri) == f"{TOOL_SCHEME}/string/concat" for item in result) + assert not any(str(item.uri) == f"{TOOL_SCHEME}/mathematic" for item in result) + + def test_filter_with_trailing_slashes(self): + """Test filtering when paths have trailing slashes.""" + items = [ + MockUriItem(f"{PROMPT_SCHEME}/greet/hello"), + MockUriItem(f"{PROMPT_SCHEME}/greet/goodbye"), + MockUriItem(f"{PROMPT_SCHEME}/chat/start"), + ] + result = filter_by_uri_paths(items, [AnyUrl(f"{PROMPT_SCHEME}/greet/"), AnyUrl(f"{PROMPT_SCHEME}/chat/")]) + assert len(result) == 3 + + def test_filter_overlapping_paths(self): + """Test filtering with overlapping paths.""" + items = [ + MockUriItem(f"{TOOL_SCHEME}/math"), + MockUriItem(f"{TOOL_SCHEME}/math/add"), + MockUriItem(f"{TOOL_SCHEME}/math/advanced/multiply"), + ] + result = filter_by_uri_paths(items, [AnyUrl(f"{TOOL_SCHEME}/math"), AnyUrl(f"{TOOL_SCHEME}/math/advanced")]) + assert len(result) == 3 # All items match + + def test_filter_no_matches(self): + """Test filtering when no items match any path.""" + items = [MockUriItem(f"{TOOL_SCHEME}/math/add"), MockUriItem(f"{TOOL_SCHEME}/math/subtract")] + result = filter_by_uri_paths(items, [AnyUrl(f"{TOOL_SCHEME}/string"), AnyUrl(f"{PROMPT_SCHEME}/greet")]) + assert result == [] + + def test_filter_with_objects(self): + """Test filtering objects with URI attributes.""" + resources = [ + MockUriItem(f"{TOOL_SCHEME}/math/add"), + MockUriItem(f"{TOOL_SCHEME}/string/concat"), + MockUriItem(f"{PROMPT_SCHEME}/greet/hello"), + ] + + result = filter_by_uri_paths(resources, [AnyUrl(f"{TOOL_SCHEME}/math"), AnyUrl(f"{PROMPT_SCHEME}/greet")]) + assert len(result) == 2 + assert str(result[0].uri) == f"{TOOL_SCHEME}/math/add" + assert str(result[1].uri) == f"{PROMPT_SCHEME}/greet/hello" + + def test_filter_case_sensitive(self): + """Test that filtering is case sensitive.""" + items = [MockUriItem(f"{TOOL_SCHEME}/Math/add"), MockUriItem(f"{TOOL_SCHEME}/math/add")] + result = filter_by_uri_paths(items, [AnyUrl(f"{TOOL_SCHEME}/math")]) + assert len(result) == 1 + assert any(str(item.uri) == f"{TOOL_SCHEME}/math/add" for item in result) + + def test_filter_exact_path_match(self): + """Test that exact path matches work correctly.""" + items = [ + MockUriItem(f"{TOOL_SCHEME}/test"), + MockUriItem(f"{TOOL_SCHEME}/test/sub"), + MockUriItem(f"{TOOL_SCHEME}/testing"), + ] + result = filter_by_uri_paths(items, [AnyUrl(f"{TOOL_SCHEME}/test")]) + assert len(result) == 2 + assert any(str(item.uri) == f"{TOOL_SCHEME}/test" for item in result) + assert any(str(item.uri) == f"{TOOL_SCHEME}/test/sub" for item in result) + assert not any(str(item.uri) == f"{TOOL_SCHEME}/testing" for item in result) diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 44b9a924d..a8377d054 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -31,7 +31,7 @@ async def test_server_remains_functional_after_cancel(): first_request_id = None @server.list_tools() - async def handle_list_tools() -> list[Tool]: + async def handle_list_tools(_) -> list[Tool]: return [ Tool( name="test_tool", diff --git a/tests/server/test_lowlevel_input_validation.py b/tests/server/test_lowlevel_input_validation.py index 250159733..4f729c5fe 100644 --- a/tests/server/test_lowlevel_input_validation.py +++ b/tests/server/test_lowlevel_input_validation.py @@ -35,7 +35,7 @@ async def run_tool_test( server = Server("test") @server.list_tools() - async def list_tools(): + async def list_tools(_): return tools @server.call_tool() diff --git a/tests/server/test_lowlevel_output_validation.py b/tests/server/test_lowlevel_output_validation.py index 39f0d970d..126cd3991 100644 --- a/tests/server/test_lowlevel_output_validation.py +++ b/tests/server/test_lowlevel_output_validation.py @@ -35,7 +35,7 @@ async def run_tool_test( server = Server("test") @server.list_tools() - async def list_tools(): + async def list_tools(_): return tools @server.call_tool() diff --git a/tests/server/test_lowlevel_tool_annotations.py b/tests/server/test_lowlevel_tool_annotations.py index 2eb3b7ddb..fd70eaaec 100644 --- a/tests/server/test_lowlevel_tool_annotations.py +++ b/tests/server/test_lowlevel_tool_annotations.py @@ -20,7 +20,7 @@ async def test_lowlevel_server_tool_annotations(): # Create a tool with annotations @server.list_tools() - async def list_tools(): + async def list_tools(_): return [ Tool( name="echo", diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 154c3a368..d89b6ecd7 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -90,7 +90,7 @@ async def test_server_capabilities(): # Add a prompts handler @server.list_prompts() - async def list_prompts(): + async def list_prompts(_): return [] caps = server.get_capabilities(notification_options, experimental_capabilities) @@ -100,7 +100,7 @@ async def list_prompts(): # Add a resources handler @server.list_resources() - async def list_resources(): + async def list_resources(_): return [] caps = server.get_capabilities(notification_options, experimental_capabilities) diff --git a/tests/shared/test_memory.py b/tests/shared/test_memory.py index a0c32f556..8ccb7fb7f 100644 --- a/tests/shared/test_memory.py +++ b/tests/shared/test_memory.py @@ -18,7 +18,7 @@ def mcp_server() -> Server: server = Server(name="test_server") @server.list_resources() - async def handle_list_resources(): + async def handle_list_resources(_): return [ Resource( uri=AnyUrl("memory://test"), diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index 93cc712b4..bc1bcd531 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -76,7 +76,7 @@ async def handle_progress( # Register list tool handler @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: + async def handle_list_tools(_) -> list[types.Tool]: return [ types.Tool( name="test_tool", diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 864e0d1b4..55eac9b9d 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -72,7 +72,7 @@ async def handle_call_tool(name: str, arguments: dict | None) -> list: # Register the tool so it shows up in list_tools @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: + async def handle_list_tools(_) -> list[types.Tool]: return [ types.Tool( name="slow_tool", diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 39ae13524..aec9d1ee9 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -64,7 +64,7 @@ async def handle_read_resource(uri: AnyUrl) -> str | bytes: raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found")) @self.list_tools() - async def handle_list_tools() -> list[Tool]: + async def handle_list_tools(_) -> list[Tool]: return [ Tool( name="test_tool", @@ -323,7 +323,7 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: return [TextContent(type="text", text=f"Called {name}")] @self.list_tools() - async def handle_list_tools() -> list[Tool]: + async def handle_list_tools(_) -> list[Tool]: return [ Tool( name="echo_headers", diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3fea54f0b..56ecf59fe 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -138,7 +138,7 @@ async def handle_read_resource(uri: AnyUrl) -> str | bytes: raise ValueError(f"Unknown resource: {uri}") @self.list_tools() - async def handle_list_tools() -> list[Tool]: + async def handle_list_tools(_) -> list[Tool]: return [ Tool( name="test_tool", @@ -1277,7 +1277,7 @@ def __init__(self): super().__init__("ContextAwareServer") @self.list_tools() - async def handle_list_tools() -> list[Tool]: + async def handle_list_tools(_) -> list[Tool]: return [ Tool( name="echo_headers", diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 5081f1d53..8fd680ac1 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -57,7 +57,7 @@ async def handle_read_resource(uri: AnyUrl) -> str | bytes: raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found")) @self.list_tools() - async def handle_list_tools() -> list[Tool]: + async def handle_list_tools(_) -> list[Tool]: return [ Tool( name="test_tool", diff --git a/tests/test_types.py b/tests/test_types.py index a39d33412..88f23aefd 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1,10 +1,17 @@ import pytest +from pydantic import AnyUrl from mcp.types import ( LATEST_PROTOCOL_VERSION, + PROMPT_SCHEME, + TOOL_SCHEME, ClientRequest, + Implementation, JSONRPCMessage, JSONRPCRequest, + Prompt, + Resource, + Tool, ) @@ -30,3 +37,54 @@ async def test_jsonrpc_request(): assert request.root.method == "initialize" assert request.root.params is not None assert request.root.params["protocolVersion"] == LATEST_PROTOCOL_VERSION + + +def test_implementation_no_uri(): + """Test that Implementation doesn't have URI field.""" + impl = Implementation(name="test-server", version="1.0.0") + assert impl.name == "test-server" + assert impl.version == "1.0.0" + assert not hasattr(impl, "uri") + + +def test_resource_uri(): + """Test that Resource requires URI and validates scheme.""" + # Resource should require URI + with pytest.raises(ValueError): + Resource(name="test") # pyright: ignore[reportCallIssue] + + # This should work + resource = Resource(name="test", uri=AnyUrl("file://test.txt")) + assert resource.name == "test" + assert str(resource.uri) == "file://test.txt/" # AnyUrl adds trailing slash + + # Should reject MCP scheme + with pytest.raises(ValueError, match="reserved MCP scheme"): + Resource(name="test", uri=AnyUrl(f"{TOOL_SCHEME}/test")) + + with pytest.raises(ValueError, match="reserved MCP scheme"): + Resource(name="test", uri=AnyUrl(f"{PROMPT_SCHEME}/test")) + + +def test_tool_uri_validation(): + """Test that Tool requires URI with tool scheme.""" + # Tool requires URI with TOOL_SCHEME + tool = Tool(name="calculator", inputSchema={"type": "object"}, uri=f"{TOOL_SCHEME}/calculator") + assert tool.name == "calculator" + assert str(tool.uri) == f"{TOOL_SCHEME}/calculator" + + # Should reject non-tool schemes + with pytest.raises(ValueError): + Tool(name="calculator", inputSchema={"type": "object"}, uri="custom://calc") + + +def test_prompt_uri_validation(): + """Test that Prompt requires URI with prompt scheme.""" + # Prompt requires URI with PROMPT_SCHEME + prompt = Prompt(name="greeting", uri=f"{PROMPT_SCHEME}/greeting") + assert prompt.name == "greeting" + assert str(prompt.uri) == f"{PROMPT_SCHEME}/greeting" + + # Should reject non-prompt schemes + with pytest.raises(ValueError): + Prompt(name="greeting", uri="custom://greet")