diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index cbba0fecf..68628029a 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -1,13 +1,14 @@ -"""Directed Acyclic Graph (DAG) Multi-Agent Pattern Implementation. +"""Directed Graph Multi-Agent Pattern Implementation. -This module provides a deterministic DAG-based agent orchestration system where +This module provides a deterministic graph-based agent orchestration system where agents or MultiAgentBase instances (like Swarm or Graph) are nodes in a graph, executed according to edge dependencies, with output from one node passed as input to connected nodes. Key Features: - Agents and MultiAgentBase instances (Swarm, Graph, etc.) as graph nodes -- Deterministic execution order based on DAG structure +- Deterministic execution order based on graph structure +- Cycles are permitted only if at least one edge is conditional. - Output propagation along edges - Topological sort for execution ordering - Clear dependency management @@ -253,18 +254,21 @@ def has_cycle_from(node_id: str) -> bool: colors[node_id] = GRAY # Check all outgoing edges for cycles for edge in self.edges: - if edge.from_node.node_id == node_id and has_cycle_from(edge.to_node.node_id): + if not edge.condition and edge.from_node.node_id == node_id and has_cycle_from(edge.to_node.node_id): return True colors[node_id] = BLACK return False # Check for cycles from each unvisited node if any(colors[node_id] == WHITE and has_cycle_from(node_id) for node_id in self.nodes): - raise ValueError("Graph contains cycles - must be a directed acyclic graph") + raise ValueError( + "Graph contains unconditional cycles — it must either be a Directed Acyclic Graph (DAG) " + "or contain at least one conditional edge within each cycle." + ) class Graph(MultiAgentBase): - """Directed Acyclic Graph multi-agent orchestration.""" + """Directed graph multi-agent orchestration.""" def __init__(self, nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_points: set[GraphNode]) -> None: """Initialize Graph.""" @@ -332,45 +336,35 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: async def _execute_graph(self) -> None: """Unified execution flow with conditional routing.""" - ready_nodes = list(self.entry_points) + ready_nodes = set(self.entry_points) while ready_nodes: - current_batch = ready_nodes.copy() - ready_nodes.clear() - # Execute current batch of ready nodes concurrently - tasks = [ - asyncio.create_task(self._execute_node(node)) - for node in current_batch - if node not in self.state.completed_nodes - ] + tasks = [asyncio.create_task(self._execute_node(node)) for node in ready_nodes] for task in tasks: await task # Find newly ready nodes after batch execution - ready_nodes.extend(self._find_newly_ready_nodes()) + ready_nodes = self._find_newly_ready_nodes(ready_nodes) - def _find_newly_ready_nodes(self) -> list["GraphNode"]: + def _find_newly_ready_nodes(self, executed_nodes: set["GraphNode"]) -> set["GraphNode"]: """Find nodes that became ready after the last execution.""" newly_ready = [] for _node_id, node in self.nodes.items(): if ( - node not in self.state.completed_nodes + node.dependencies & executed_nodes and node not in self.state.failed_nodes and self._is_node_ready_with_conditions(node) ): newly_ready.append(node) - return newly_ready + return set(newly_ready) def _is_node_ready_with_conditions(self, node: GraphNode) -> bool: """Check if a node is ready considering conditional edges.""" # Get incoming edges to this node incoming_edges = [edge for edge in self.edges if edge.to_node == node] - if not incoming_edges: - return node in self.entry_points - # Check if at least one incoming edge condition is satisfied for edge in incoming_edges: if edge.from_node in self.state.completed_nodes: diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index cb74f515c..2cb58ec14 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1,3 +1,4 @@ +import re from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest @@ -378,9 +379,26 @@ def test_graph_builder_validation(): builder.add_edge("c", "a") # Creates cycle builder.set_entry_point("a") - with pytest.raises(ValueError, match="Graph contains cycles"): + with pytest.raises( + ValueError, + match=re.escape( + "Graph contains unconditional cycles — it must either be a Directed Acyclic Graph (DAG) " + "or contain at least one conditional edge within each cycle." + ), + ): builder.build() + # Test cycle detection with back edge condition + builder = GraphBuilder() + builder.add_node(agent1, "a") + builder.add_node(agent2, "b") + builder.add_node(create_mock_agent("agent3"), "c") + builder.add_edge("a", "b") + builder.add_edge("b", "c") + builder.add_edge("c", "a", condition=lambda _: True) # Creates cycle + builder.set_entry_point("a") + builder.build() + # Test auto-detection of entry points builder = GraphBuilder() builder.add_node(agent1, "entry")