diff --git a/examples/mcp/streamablehttp_example/main.py b/examples/mcp/streamablehttp_example/main.py index cc95e798b..44b0c1eff 100644 --- a/examples/mcp/streamablehttp_example/main.py +++ b/examples/mcp/streamablehttp_example/main.py @@ -10,10 +10,10 @@ from agents.model_settings import ModelSettings -async def run(mcp_server: MCPServer): +async def run(mcp_server: MCPServer, instructions): agent = Agent( name="Assistant", - instructions="Use the tools to answer the questions.", + instructions=instructions, mcp_servers=[mcp_server], model_settings=ModelSettings(tool_choice="required"), ) @@ -46,8 +46,14 @@ async def main(): ) as server: trace_id = gen_trace_id() with trace(workflow_name="Streamable HTTP Example", trace_id=trace_id): + # List available prompts + prompts = await server.list_prompts() + print(f"Prompts list -> {prompts}") + system_prompt = await server.get_prompt("system_prompt") + instructions = system_prompt.messages[0].content.text + print(f"instructions -> {instructions}") print(f"View trace: https://platform.openai.com/traces/trace?trace_id={trace_id}\n") - await run(server) + await run(server, instructions) if __name__ == "__main__": diff --git a/examples/mcp/streamablehttp_example/server.py b/examples/mcp/streamablehttp_example/server.py index d8f839652..3d979fccf 100644 --- a/examples/mcp/streamablehttp_example/server.py +++ b/examples/mcp/streamablehttp_example/server.py @@ -29,5 +29,10 @@ def get_current_weather(city: str) -> str: return response.text +@mcp.prompt() +def system_prompt(): + return "Use the tools to answer the questions." + + if __name__ == "__main__": mcp.run(transport="streamable-http") diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index f6c2b58ef..31986b283 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -13,7 +13,7 @@ from mcp.client.sse import sse_client from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client from mcp.shared.message import SessionMessage -from mcp.types import CallToolResult, InitializeResult +from mcp.types import CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult from typing_extensions import NotRequired, TypedDict from ..exceptions import UserError @@ -63,6 +63,18 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C """Invoke a tool on the server.""" pass + @abc.abstractmethod + async def list_prompts(self) -> ListPromptsResult | None: + """List the prompts available on the server.""" + pass + + @abc.abstractmethod + async def get_prompt( + self, name: str, arguments: dict[str, str] | None = None + ) -> GetPromptResult | None: + """Returns an existing prompt from the server.""" + pass + class _MCPServerWithClientSession(MCPServer, abc.ABC): """Base class for MCP servers that use a `ClientSession` to communicate with the server.""" @@ -261,6 +273,23 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C return await self.session.call_tool(tool_name, arguments) + async def list_prompts( + self, + ) -> ListPromptsResult | None: + """List the prompts available on the server.""" + if not self.session: + raise UserError("Server not initialized. Make sure you call `connect()` first.") + + return await self.session.list_prompts() + + async def get_prompt( + self, name: str, arguments: dict[str, str] | None = None + ) -> GetPromptResult | None: + if not self.session: + raise UserError("Server not initialized. Make sure you call `connect()` first.") + + return await self.session.get_prompt(name, arguments) + async def cleanup(self): """Cleanup the server.""" async with self._cleanup_lock: diff --git a/tests/mcp/helpers.py b/tests/mcp/helpers.py index e0d8a813d..5a2365607 100644 --- a/tests/mcp/helpers.py +++ b/tests/mcp/helpers.py @@ -4,7 +4,7 @@ from typing import Any from mcp import Tool as MCPTool -from mcp.types import CallToolResult, TextContent +from mcp.types import CallToolResult, GetPromptResult, ListPromptsResult, TextContent from agents.mcp import MCPServer from agents.mcp.server import _MCPServerWithClientSession @@ -57,10 +57,12 @@ def name(self) -> str: class FakeMCPServer(MCPServer): def __init__( self, + prompts: ListPromptsResult | None = None, tools: list[MCPTool] | None = None, tool_filter: ToolFilter = None, server_name: str = "fake_mcp_server", ): + self.prompts = prompts self.tools: list[MCPTool] = tools or [] self.tool_calls: list[str] = [] self.tool_results: list[str] = [] @@ -94,6 +96,16 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C content=[TextContent(text=self.tool_results[-1], type="text")], ) + async def list_prompts( + self, + ) -> ListPromptsResult | None: + return self.prompts + + async def get_prompt( + self, name: str, arguments: dict[str, str] | None = None + ) -> GetPromptResult | None: + return None + @property def name(self) -> str: return self._server_name