diff --git a/a2a_client.py b/a2a_client.py new file mode 100644 index 00000000..e8683aa3 --- /dev/null +++ b/a2a_client.py @@ -0,0 +1,58 @@ +import logging +from typing import Any +from uuid import uuid4 + +import httpx +from a2a.client import A2ACardResolver, A2AClient +from a2a.types import ( + AgentCard, + MessageSendParams, + SendMessageRequest, +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +PUBLIC_AGENT_CARD_PATH = "/.well-known/agent.json" +BASE_URL = "http://localhost:9000" + + +async def main() -> None: + async with httpx.AsyncClient() as httpx_client: + # Initialize A2ACardResolver + resolver = A2ACardResolver( + httpx_client=httpx_client, + base_url=BASE_URL, + ) + + # Fetch Public Agent Card and Initialize Client + agent_card: AgentCard | None = None + + try: + logger.info("Attempting to fetch public agent card from: {} {}", BASE_URL, PUBLIC_AGENT_CARD_PATH) + agent_card = await resolver.get_agent_card() # Fetches from default public path + logger.info("Successfully fetched public agent card:") + logger.info(agent_card.model_dump_json(indent=2, exclude_none=True)) + except Exception as e: + logger.exception("Critical error fetching public agent card") + raise RuntimeError("Failed to fetch the public agent card. Cannot continue.") from e + + client = A2AClient(httpx_client=httpx_client, agent_card=agent_card) + logger.info("A2AClient initialized.") + + send_message_payload: dict[str, Any] = { + "message": { + "role": "user", + "parts": [{"kind": "text", "text": "how much is 10 USD in INR?"}], + "messageId": uuid4().hex, + }, + } + request = SendMessageRequest(id=str(uuid4()), params=MessageSendParams(**send_message_payload)) + + response = await client.send_message(request) + print(response.model_dump(mode="json", exclude_none=True)) + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/a2a_server.py b/a2a_server.py new file mode 100644 index 00000000..4e55b3ae --- /dev/null +++ b/a2a_server.py @@ -0,0 +1,19 @@ +import logging +import sys + +from strands import Agent +from strands.multiagent.a2a import A2AAgent + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], + force=True, +) + +# Log that we're starting +logging.info("Starting A2A server with root logger") + +strands_agent = Agent(model="us.anthropic.claude-3-haiku-20240307-v1:0", callback_handler=None) +strands_a2a_agent = A2AAgent(agent=strands_agent, name="Hello World Agent", description="Just a hello world agent") +strands_a2a_agent.serve() diff --git a/pyproject.toml b/pyproject.toml index 835def0f..35bd8f31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,6 +86,7 @@ a2a = [ "httpx>=0.28.1", "fastapi>=0.115.12", "starlette>=0.46.2", + "protobuf==6.31.1", ] [tool.hatch.version] @@ -117,7 +118,7 @@ lint-fix = [ [tool.hatch.envs.hatch-test] features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"] -extra-dependencies = [ +extra-dependencies = "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", "pytest-asyncio>=0.26.0,<0.27.0", diff --git a/src/strands/multiagent/__init__.py b/src/strands/multiagent/__init__.py new file mode 100644 index 00000000..1cef1425 --- /dev/null +++ b/src/strands/multiagent/__init__.py @@ -0,0 +1,13 @@ +"""Multiagent capabilities for Strands Agents. + +This module provides support for multiagent systems, including agent-to-agent (A2A) +communication protocols and coordination mechanisms. + +Submodules: + a2a: Implementation of the Agent-to-Agent (A2A) protocol, which enables + standardized communication between agents. +""" + +from . import a2a + +__all__ = ["a2a"] diff --git a/src/strands/multiagent/a2a/__init__.py b/src/strands/multiagent/a2a/__init__.py new file mode 100644 index 00000000..c5425618 --- /dev/null +++ b/src/strands/multiagent/a2a/__init__.py @@ -0,0 +1,14 @@ +"""Agent-to-Agent (A2A) communication protocol implementation for Strands Agents. + +This module provides classes and utilities for enabling Strands Agents to communicate +with other agents using the Agent-to-Agent (A2A) protocol. + +Docs: https://google-a2a.github.io/A2A/latest/ + +Classes: + A2AAgent: A wrapper that adapts a Strands Agent to be A2A-compatible. +""" + +from .agent import A2AAgent + +__all__ = ["A2AAgent"] diff --git a/src/strands/multiagent/a2a/agent.py b/src/strands/multiagent/a2a/agent.py new file mode 100644 index 00000000..56ba1016 --- /dev/null +++ b/src/strands/multiagent/a2a/agent.py @@ -0,0 +1,139 @@ +"""A2A-compatible wrapper for Strands Agent. + +This module provides the A2AAgent class, which adapts a Strands Agent to the A2A protocol, +allowing it to be used in A2A-compatible systems. +""" + +import logging +from typing import Any, Literal + +import uvicorn +from a2a.server.apps import A2AFastAPIApplication, A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore +from a2a.types import AgentCapabilities, AgentCard, AgentSkill +from fastapi import FastAPI +from starlette.applications import Starlette + +from ...agent.agent import Agent as SAAgent +from .executor import StrandsA2AExecutor + +log = logging.getLogger(__name__) + + +class A2AAgent: + """A2A-compatible wrapper for Strands Agent.""" + + def __init__( + self, + agent: SAAgent, + *, + name: str, + description: str, + host: str = "localhost", + port: int = 9000, + version: str = "0.0.1", + ): + """Initialize an A2A-compatible agent from a Strands agent. + + Args: + agent: The Strands Agent to wrap with A2A compatibility. + name: The name of the agent, used in the AgentCard. + description: A description of the agent's capabilities, used in the AgentCard. + host: The hostname or IP address to bind the A2A server to. Defaults to "localhost". + port: The port to bind the A2A server to. Defaults to 9000. + version: The version of the agent. Defaults to "0.0.1". + """ + self.name = name + self.description = description + self.host = host + self.port = port + self.http_url = f"http://{self.host}:{self.port}/" + self.version = version + self.strands_agent = agent + self.capabilities = AgentCapabilities() + self.request_handler = DefaultRequestHandler( + agent_executor=StrandsA2AExecutor(self.strands_agent), + task_store=InMemoryTaskStore(), + ) + + @property + def public_agent_card(self) -> AgentCard: + """Get the public AgentCard for this agent. + + The AgentCard contains metadata about the agent, including its name, + description, URL, version, skills, and capabilities. This information + is used by other agents and systems to discover and interact with this agent. + + Returns: + AgentCard: The public agent card containing metadata about this agent. + """ + return AgentCard( + name=self.name, + description=self.description, + url=self.http_url, + version=self.version, + skills=self.agent_skills, + defaultInputModes=["text"], + defaultOutputModes=["text"], + capabilities=self.capabilities, + ) + + @property + def agent_skills(self) -> list[AgentSkill]: + """Get the list of skills this agent provides. + + Skills represent specific capabilities that the agent can perform. + Strands agent tools are adapted to A2A skills. + + Returns: + list[AgentSkill]: A list of skills this agent provides. + """ + return [] + + def to_starlette_app(self) -> Starlette: + """Create a Starlette application for serving this agent via HTTP. + + This method creates a Starlette application that can be used to serve + the agent via HTTP using the A2A protocol. + + Returns: + Starlette: A Starlette application configured to serve this agent. + """ + starlette_app = A2AStarletteApplication(agent_card=self.public_agent_card, http_handler=self.request_handler) + return starlette_app.build() + + def to_fastapi_app(self) -> FastAPI: + """Create a FastAPI application for serving this agent via HTTP. + + This method creates a FastAPI application that can be used to serve + the agent via HTTP using the A2A protocol. + + Returns: + FastAPI: A FastAPI application configured to serve this agent. + """ + fastapi_app = A2AFastAPIApplication(agent_card=self.public_agent_card, http_handler=self.request_handler) + return fastapi_app.build() + + def serve(self, app_type: Literal["fastapi", "starlette"] = "starlette", **kwargs: Any) -> None: + """Start the A2A server with the specified application type. + + This method starts an HTTP server that exposes the agent via the A2A protocol. + The server can be implemented using either FastAPI or Starlette, depending on + the specified app_type. + + Args: + app_type: The type of application to serve, either "fastapi" or "starlette". + Defaults to "starlette". + **kwargs: Additional keyword arguments to pass to uvicorn.run. + """ + try: + log.info("Starting Strands agent A2A server...") + if app_type == "fastapi": + uvicorn.run(self.to_fastapi_app(), host=self.host, port=self.port, **kwargs) + else: + uvicorn.run(self.to_starlette_app(), host=self.host, port=self.port, **kwargs) + except KeyboardInterrupt: + log.warning("Server shutdown requested (KeyboardInterrupt).") + finally: + log.info("Strands agent A2A server has shutdown.") diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py new file mode 100644 index 00000000..b7a7af09 --- /dev/null +++ b/src/strands/multiagent/a2a/executor.py @@ -0,0 +1,67 @@ +"""Strands Agent executor for the A2A protocol. + +This module provides the StrandsA2AExecutor class, which adapts a Strands Agent +to be used as an executor in the A2A protocol. It handles the execution of agent +requests and the conversion of Strands Agent responses to A2A events. +""" + +import logging + +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.events import EventQueue +from a2a.types import UnsupportedOperationError +from a2a.utils import new_agent_text_message +from a2a.utils.errors import ServerError + +from ...agent.agent import Agent as SAAgent +from ...agent.agent_result import AgentResult as SAAgentResult + +log = logging.getLogger(__name__) + + +class StrandsA2AExecutor(AgentExecutor): + """Executor that adapts a Strands Agent to the A2A protocol.""" + + def __init__(self, agent: SAAgent): + """Initialize a StrandsA2AExecutor. + + Args: + agent: The Strands Agent to adapt to the A2A protocol. + """ + self.agent = agent + + async def execute( + self, + context: RequestContext, + event_queue: EventQueue, + ) -> None: + """Execute a request using the Strands Agent and send the response as A2A events. + + This method executes the user's input using the Strands Agent and converts + the agent's response to A2A events, which are then sent to the event queue. + + Args: + context: The A2A request context, containing the user's input and other metadata. + event_queue: The A2A event queue, used to send response events. + """ + result: SAAgentResult = self.agent(context.get_user_input()) + if result.message and "content" in result.message: + for content_block in result.message["content"]: + if "text" in content_block: + await event_queue.enqueue_event(new_agent_text_message(content_block["text"])) + + async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: + """Cancel an ongoing execution. + + This method is called when a request is cancelled. Currently, cancellation + is not supported, so this method raises an UnsupportedOperationError. + + Args: + context: The A2A request context. + event_queue: The A2A event queue. + + Raises: + ServerError: Always raised with an UnsupportedOperationError, as cancellation + is not currently supported. + """ + raise ServerError(error=UnsupportedOperationError()) diff --git a/tests/multiagent/__init__.py b/tests/multiagent/__init__.py new file mode 100644 index 00000000..b43bae53 --- /dev/null +++ b/tests/multiagent/__init__.py @@ -0,0 +1 @@ +"""Tests for the multiagent module.""" diff --git a/tests/multiagent/a2a/__init__.py b/tests/multiagent/a2a/__init__.py new file mode 100644 index 00000000..ea8e5990 --- /dev/null +++ b/tests/multiagent/a2a/__init__.py @@ -0,0 +1 @@ +"""Tests for the A2A implementation.""" diff --git a/tests/multiagent/a2a/conftest.py b/tests/multiagent/a2a/conftest.py new file mode 100644 index 00000000..0500db93 --- /dev/null +++ b/tests/multiagent/a2a/conftest.py @@ -0,0 +1,12 @@ +"""Pytest configuration for A2A tests.""" + +from unittest.mock import patch + +import pytest + + +@pytest.fixture(autouse=True) +def mock_uvicorn(): + """Mock uvicorn.run to prevent actual server startup during tests.""" + with patch("uvicorn.run") as mock: + yield mock diff --git a/tests/multiagent/a2a/test_agent.py b/tests/multiagent/a2a/test_agent.py new file mode 100644 index 00000000..4580bc1b --- /dev/null +++ b/tests/multiagent/a2a/test_agent.py @@ -0,0 +1,70 @@ +"""Tests for the A2AAgent class.""" + +import pytest +from a2a.types import AgentCard +from fastapi import FastAPI +from starlette.applications import Starlette + +from strands import Agent +from strands.multiagent.a2a import A2AAgent + + +@pytest.fixture +def strands_agent(): + """Create a Strands agent for testing.""" + return Agent() + + +@pytest.fixture +def a2a_agent(strands_agent): + """Create an A2A agent for testing.""" + return A2AAgent( + agent=strands_agent, + name="Test Agent", + description="A test agent", + host="localhost", + port=9000, + ) + + +def test_a2a_agent_initialization(a2a_agent, strands_agent): + """Test that the A2AAgent initializes correctly.""" + assert a2a_agent.name == "Test Agent" + assert a2a_agent.description == "A test agent" + assert a2a_agent.host == "localhost" + assert a2a_agent.port == 9000 + assert a2a_agent.http_url == "http://localhost:9000/" + assert a2a_agent.version == "0.0.1" + assert a2a_agent.strands_agent == strands_agent + + +def test_public_agent_card(a2a_agent): + """Test that the public agent card is created correctly.""" + card = a2a_agent.public_agent_card + assert isinstance(card, AgentCard) + assert card.name == "Test Agent" + assert card.description == "A test agent" + assert card.url == "http://localhost:9000/" + assert card.version == "0.0.1" + assert card.defaultInputModes == ["text"] + assert card.defaultOutputModes == ["text"] + assert len(card.skills) == 0 # No skills defined yet + + +def test_agent_skills(a2a_agent): + """Test that agent skills are returned correctly.""" + skills = a2a_agent.agent_skills + assert isinstance(skills, list) + assert len(skills) == 0 # No skills defined yet + + +def test_to_starlette_app(a2a_agent): + """Test that a Starlette app is created correctly.""" + app = a2a_agent.to_starlette_app() + assert isinstance(app, Starlette) + + +def test_to_fastapi_app(a2a_agent): + """Test that a FastAPI app is created correctly.""" + app = a2a_agent.to_fastapi_app() + assert isinstance(app, FastAPI) diff --git a/tests/multiagent/a2a/test_executor.py b/tests/multiagent/a2a/test_executor.py new file mode 100644 index 00000000..7e7c8ba5 --- /dev/null +++ b/tests/multiagent/a2a/test_executor.py @@ -0,0 +1,99 @@ +"""Tests for the StrandsA2AExecutor class.""" + +import pytest +from a2a.types import UnsupportedOperationError +from a2a.utils.errors import ServerError + +from strands.agent.agent_result import AgentResult +from strands.multiagent.a2a.executor import StrandsA2AExecutor +from strands.telemetry.metrics import EventLoopMetrics + + +class MockAgent: + """Mock Strands Agent for testing.""" + + def __init__(self, response_text="Test response"): + """Initialize the mock agent with a predefined response.""" + self.response_text = response_text + self.called_with = None + + def __call__(self, input_text): + """Mock the agent call method.""" + self.called_with = input_text + return AgentResult( + stop_reason="end_turn", + message={"content": [{"text": self.response_text}]}, + metrics=EventLoopMetrics(), + state={}, + ) + + +class MockEventQueue: + """Mock EventQueue for testing.""" + + def __init__(self): + """Initialize the mock event queue.""" + self.events = [] + + async def enqueue_event(self, event): + """Mock the enqueue_event method.""" + self.events.append(event) + return None + + +class MockRequestContext: + """Mock RequestContext for testing.""" + + def __init__(self, user_input="Test input"): + """Initialize the mock request context.""" + self.user_input = user_input + + def get_user_input(self): + """Mock the get_user_input method.""" + return self.user_input + + +@pytest.fixture +def mock_agent(): + """Create a mock Strands agent for testing.""" + return MockAgent() + + +@pytest.fixture +def executor(mock_agent): + """Create a StrandsA2AExecutor for testing.""" + return StrandsA2AExecutor(mock_agent) + + +@pytest.fixture +def event_queue(): + """Create a mock event queue for testing.""" + return MockEventQueue() + + +@pytest.fixture +def request_context(): + """Create a mock request context for testing.""" + return MockRequestContext() + + +@pytest.mark.asyncio +async def test_execute(executor, event_queue, request_context): + """Test that the execute method works correctly.""" + await executor.execute(request_context, event_queue) + + # Check that the agent was called with the correct input + assert executor.agent.called_with == "Test input" + + # Check that an event was enqueued (we can't check the content directly) + assert len(event_queue.events) == 1 + + +@pytest.mark.asyncio +async def test_cancel(executor, event_queue, request_context): + """Test that the cancel method raises the expected error.""" + with pytest.raises(ServerError) as excinfo: + await executor.cancel(request_context, event_queue) + + # Check that the error contains an UnsupportedOperationError + assert isinstance(excinfo.value.error, UnsupportedOperationError)