Skip to content

Add get_prompt and list_prompts to MCPServer #1013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions examples/mcp/streamablehttp_example/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from agents.model_settings import ModelSettings


async def run(mcp_server: MCPServer):
async def run(mcp_server: MCPServer, instructions):
agent = Agent(
name="Assistant",
instructions="Use the tools to answer the questions.",
instructions=instructions,
mcp_servers=[mcp_server],
model_settings=ModelSettings(tool_choice="required"),
)
Expand Down Expand Up @@ -46,8 +46,14 @@ async def main():
) as server:
trace_id = gen_trace_id()
with trace(workflow_name="Streamable HTTP Example", trace_id=trace_id):
# List available prompts
prompts = await server.list_prompts()
print(f"Prompts list -> {prompts}")
system_prompt = await server.get_prompt("system_prompt")
instructions = system_prompt.messages[0].content.text
print(f"instructions -> {instructions}")
print(f"View trace: https://platform.openai.com/traces/trace?trace_id={trace_id}\n")
await run(server)
await run(server, instructions)


if __name__ == "__main__":
Expand Down
5 changes: 5 additions & 0 deletions examples/mcp/streamablehttp_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,10 @@ def get_current_weather(city: str) -> str:
return response.text


@mcp.prompt()
def system_prompt():
return "Use the tools to answer the questions."


if __name__ == "__main__":
mcp.run(transport="streamable-http")
31 changes: 30 additions & 1 deletion src/agents/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from mcp.client.sse import sse_client
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
from mcp.shared.message import SessionMessage
from mcp.types import CallToolResult, InitializeResult
from mcp.types import CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult
from typing_extensions import NotRequired, TypedDict

from ..exceptions import UserError
Expand Down Expand Up @@ -63,6 +63,18 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C
"""Invoke a tool on the server."""
pass

@abc.abstractmethod
async def list_prompts(self) -> ListPromptsResult | None:
"""List the prompts available on the server."""
pass

@abc.abstractmethod
async def get_prompt(
self, name: str, arguments: dict[str, str] | None = None
) -> GetPromptResult | None:
"""Returns an existing prompt from the server."""
pass


class _MCPServerWithClientSession(MCPServer, abc.ABC):
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
Expand Down Expand Up @@ -261,6 +273,23 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C

return await self.session.call_tool(tool_name, arguments)

async def list_prompts(
self,
) -> ListPromptsResult | None:
"""List the prompts available on the server."""
if not self.session:
raise UserError("Server not initialized. Make sure you call `connect()` first.")

return await self.session.list_prompts()

async def get_prompt(
self, name: str, arguments: dict[str, str] | None = None
) -> GetPromptResult | None:
if not self.session:
raise UserError("Server not initialized. Make sure you call `connect()` first.")

return await self.session.get_prompt(name, arguments)

async def cleanup(self):
"""Cleanup the server."""
async with self._cleanup_lock:
Expand Down
14 changes: 13 additions & 1 deletion tests/mcp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any

from mcp import Tool as MCPTool
from mcp.types import CallToolResult, TextContent
from mcp.types import CallToolResult, GetPromptResult, ListPromptsResult, TextContent

from agents.mcp import MCPServer
from agents.mcp.server import _MCPServerWithClientSession
Expand Down Expand Up @@ -57,10 +57,12 @@ def name(self) -> str:
class FakeMCPServer(MCPServer):
def __init__(
self,
prompts: ListPromptsResult | None = None,
tools: list[MCPTool] | None = None,
tool_filter: ToolFilter = None,
server_name: str = "fake_mcp_server",
):
self.prompts = prompts
self.tools: list[MCPTool] = tools or []
self.tool_calls: list[str] = []
self.tool_results: list[str] = []
Expand Down Expand Up @@ -94,6 +96,16 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C
content=[TextContent(text=self.tool_results[-1], type="text")],
)

async def list_prompts(
self,
) -> ListPromptsResult | None:
return self.prompts

async def get_prompt(
self, name: str, arguments: dict[str, str] | None = None
) -> GetPromptResult | None:
return None

@property
def name(self) -> str:
return self._server_name