From 97a2f31c475347c8014a32fa93057057762e0cc2 Mon Sep 17 00:00:00 2001 From: ashish-dahal Date: Sun, 6 Apr 2025 04:57:23 +0300 Subject: [PATCH 1/9] Enhance agent visualization to prevent infinite recursion for agents with recursive handoffs --- src/agents/extensions/visualization.py | 60 +++++++++--- tests/test_visualization.py | 130 +++++++++++-------------- uv.lock | 2 +- 3 files changed, 104 insertions(+), 88 deletions(-) diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index 5fb35062..826380fb 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Set import graphviz # type: ignore @@ -31,16 +31,23 @@ def get_main_graph(agent: Agent) -> str: return "".join(parts) -def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: +def get_all_nodes( + agent: Agent, parent: Optional[Agent] = None, visited: Optional[Set[int]] = None +) -> str: """ Recursively generates the nodes for the given agent and its handoffs in DOT format. Args: agent (Agent): The agent for which the nodes are to be generated. + parent (Agent, optional): The parent agent. Defaults to None. + visited (Set[int], optional): Set of already visited agent IDs to prevent infinite recursion. Returns: str: The DOT format string representing the nodes. """ + if visited is None: + visited = set() + parts = [] # Start and end the graph @@ -63,53 +70,80 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: f"fillcolor=lightgreen, width=0.5, height=0.3];" ) + # Add current agent's ID to visited set + visited.add(id(agent)) + for handoff in agent.handoffs: if isinstance(handoff, Handoff): parts.append( f'"{handoff.agent_name}" [label="{handoff.agent_name}", ' - f"shape=box, style=filled, style=rounded, " + f"shape=box, style=filled, " f"fillcolor=lightyellow, width=1.5, height=0.8];" ) if isinstance(handoff, Agent): parts.append( f'"{handoff.name}" [label="{handoff.name}", ' - f"shape=box, style=filled, style=rounded, " + f"shape=box, style=filled, " f"fillcolor=lightyellow, width=1.5, height=0.8];" ) - parts.append(get_all_nodes(handoff)) + # Only recursively add nodes if we haven't visited this agent before + if id(handoff) not in visited: + parts.append(get_all_nodes(handoff, agent, visited)) return "".join(parts) -def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: +def get_all_edges( + agent: Agent, parent: Optional[Agent] = None, visited: Optional[Set[int]] = None +) -> str: """ Recursively generates the edges for the given agent and its handoffs in DOT format. Args: agent (Agent): The agent for which the edges are to be generated. parent (Agent, optional): The parent agent. Defaults to None. + visited (Set[int], optional): Set of already visited agent IDs to prevent infinite recursion. Returns: str: The DOT format string representing the edges. """ + if visited is None: + visited = set() + parts = [] if not parent: parts.append(f'"__start__" -> "{agent.name}";') for tool in agent.tools: - parts.append(f""" + parts.append( + f""" "{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5]; - "{tool.name}" -> "{agent.name}" [style=dotted, penwidth=1.5];""") + "{tool.name}" -> "{agent.name}" [style=dotted, penwidth=1.5];""" + ) + + # Add current agent's ID to visited set + visited.add(id(agent)) for handoff in agent.handoffs: if isinstance(handoff, Handoff): - parts.append(f""" - "{agent.name}" -> "{handoff.agent_name}";""") + parts.append( + f""" + "{agent.name}" -> "{handoff.agent_name}";""" + ) if isinstance(handoff, Agent): - parts.append(f""" - "{agent.name}" -> "{handoff.name}";""") - parts.append(get_all_edges(handoff, agent)) + # Check for loops + if id(handoff) in visited: + parts.append( + f""" + "{agent.name}" -> "{handoff.name}";""" + ) + else: + parts.append( + f""" + "{agent.name}" -> "{handoff.name}";""" + ) + parts.append(get_all_edges(handoff, agent, visited)) if not agent.handoffs and not isinstance(agent, Tool): # type: ignore parts.append(f'"{agent.name}" -> "__end__";') diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 6aa86774..fca24c84 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -12,6 +12,14 @@ ) from agents.handoffs import Handoff +# Common test graph elements +START_NODE = '"__start__" [label="__start__", shape=ellipse, style=filled, fillcolor=lightblue, width=0.5, height=0.3];' +END_NODE = '"__end__" [label="__end__", shape=ellipse, style=filled, fillcolor=lightblue, width=0.5, height=0.3];' +AGENT_NODE = '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' +TOOL1_NODE = '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' +TOOL2_NODE = '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' +HANDOFF_NODE = '"Handoff1" [label="Handoff1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' + @pytest.fixture def mock_agent(): @@ -31,71 +39,46 @@ def mock_agent(): return agent +@pytest.fixture +def mock_recursive_agents(): + agent1 = Mock(spec=Agent) + agent1.name = "Agent1" + agent1.tools = [] + agent2 = Mock(spec=Agent) + agent2.name = "Agent2" + agent2.tools = [] + agent1.handoffs = [agent2] + agent2.handoffs = [agent1] + return agent1 + + def test_get_main_graph(mock_agent): result = get_main_graph(mock_agent) - print(result) assert "digraph G" in result assert "graph [splines=true];" in result assert 'node [fontname="Arial"];' in result assert "edge [penwidth=1.5];" in result - assert ( - '"__start__" [label="__start__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" in result - ) - assert ( - '"__end__" [label="__end__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" in result - ) - assert ( - '"Agent1" [label="Agent1", shape=box, style=filled, ' - "fillcolor=lightyellow, width=1.5, height=0.8];" in result - ) - assert ( - '"Tool1" [label="Tool1", shape=ellipse, style=filled, ' - "fillcolor=lightgreen, width=0.5, height=0.3];" in result - ) - assert ( - '"Tool2" [label="Tool2", shape=ellipse, style=filled, ' - "fillcolor=lightgreen, width=0.5, height=0.3];" in result - ) - assert ( - '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, ' - "fillcolor=lightyellow, width=1.5, height=0.8];" in result - ) + assert START_NODE in result + assert END_NODE in result + assert AGENT_NODE in result + assert TOOL1_NODE in result + assert TOOL2_NODE in result + assert HANDOFF_NODE in result def test_get_all_nodes(mock_agent): result = get_all_nodes(mock_agent) - assert ( - '"__start__" [label="__start__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" in result - ) - assert ( - '"__end__" [label="__end__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" in result - ) - assert ( - '"Agent1" [label="Agent1", shape=box, style=filled, ' - "fillcolor=lightyellow, width=1.5, height=0.8];" in result - ) - assert ( - '"Tool1" [label="Tool1", shape=ellipse, style=filled, ' - "fillcolor=lightgreen, width=0.5, height=0.3];" in result - ) - assert ( - '"Tool2" [label="Tool2", shape=ellipse, style=filled, ' - "fillcolor=lightgreen, width=0.5, height=0.3];" in result - ) - assert ( - '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, ' - "fillcolor=lightyellow, width=1.5, height=0.8];" in result - ) + assert START_NODE in result + assert END_NODE in result + assert AGENT_NODE in result + assert TOOL1_NODE in result + assert TOOL2_NODE in result + assert HANDOFF_NODE in result def test_get_all_edges(mock_agent): result = get_all_edges(mock_agent) assert '"__start__" -> "Agent1";' in result - assert '"Agent1" -> "__end__";' assert '"Agent1" -> "Tool1" [style=dotted, penwidth=1.5];' in result assert '"Tool1" -> "Agent1" [style=dotted, penwidth=1.5];' in result assert '"Agent1" -> "Tool2" [style=dotted, penwidth=1.5];' in result @@ -106,31 +89,30 @@ def test_get_all_edges(mock_agent): def test_draw_graph(mock_agent): graph = draw_graph(mock_agent) assert isinstance(graph, graphviz.Source) - assert "digraph G" in graph.source - assert "graph [splines=true];" in graph.source - assert 'node [fontname="Arial"];' in graph.source - assert "edge [penwidth=1.5];" in graph.source - assert ( - '"__start__" [label="__start__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" in graph.source - ) - assert ( - '"__end__" [label="__end__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" in graph.source - ) - assert ( - '"Agent1" [label="Agent1", shape=box, style=filled, ' - "fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source - ) - assert ( - '"Tool1" [label="Tool1", shape=ellipse, style=filled, ' - "fillcolor=lightgreen, width=0.5, height=0.3];" in graph.source - ) + source = graph.source + assert "digraph G" in source + assert "graph [splines=true];" in source + assert 'node [fontname="Arial"];' in source + assert "edge [penwidth=1.5];" in source + assert START_NODE in source + assert END_NODE in source + assert AGENT_NODE in source + assert TOOL1_NODE in source + assert TOOL2_NODE in source + assert HANDOFF_NODE in source + + +def test_recursive_handoff_loop(mock_recursive_agents): + agent1 = mock_recursive_agents + dot = get_main_graph(agent1) + assert ( - '"Tool2" [label="Tool2", shape=ellipse, style=filled, ' - "fillcolor=lightgreen, width=0.5, height=0.3];" in graph.source + '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' + in dot ) assert ( - '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, ' - "fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source + '"Agent2" [label="Agent2", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' + in dot ) + assert '"Agent1" -> "Agent2";' in dot + assert '"Agent2" -> "Agent1";' in dot diff --git a/uv.lock b/uv.lock index e443c009..a6018eeb 100644 --- a/uv.lock +++ b/uv.lock @@ -1087,7 +1087,7 @@ wheels = [ [[package]] name = "openai-agents" -version = "0.0.7" +version = "0.0.8" source = { editable = "." } dependencies = [ { name = "griffe" }, From 8c99e321d9f7dc1f3487a6b533f6906a771d15ba Mon Sep 17 00:00:00 2001 From: ashish-dahal Date: Sun, 6 Apr 2025 16:37:07 +0300 Subject: [PATCH 2/9] Build graph representation and add graph views for mermaid --- src/agents/extensions/visualization.py | 516 ++++++++++++++++++------- 1 file changed, 377 insertions(+), 139 deletions(-) diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index 826380fb..472e4a9b 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -1,171 +1,409 @@ -from typing import Optional, Set - -import graphviz # type: ignore +from typing import Optional, Set, Dict, List, TypeVar, Generic +from dataclasses import dataclass +from enum import Enum +import warnings +import abc +import graphviz +import base64 +import requests from agents import Agent from agents.handoffs import Handoff from agents.tool import Tool -def get_main_graph(agent: Agent) -> str: - """ - Generates the main graph structure in DOT format for the given agent. +class NodeType(Enum): + START = "start" + END = "end" + AGENT = "agent" + TOOL = "tool" + HANDOFF = "handoff" - Args: - agent (Agent): The agent for which the graph is to be generated. - Returns: - str: The DOT format string representing the graph. - """ - parts = [ +class EdgeType(Enum): + HANDOFF = "handoff" + TOOL = "tool" + + +@dataclass(frozen=True) +class Node: + id: str + label: str + type: NodeType + + +@dataclass(frozen=True) +class Edge: + source: str + target: str + type: EdgeType + + +class Graph: + def __init__(self): + self.nodes: Dict[str, Node] = {} + self.edges: List[Edge] = [] + + def add_node(self, node: Node) -> None: + self.nodes[node.id] = node + + def add_edge(self, edge: Edge) -> None: + """Add an edge to the graph. + + Args: + edge (Edge): The edge to add. + + Raises: + ValueError: If the source or target node does not exist in the graph. """ + if edge.source not in self.nodes: + raise ValueError(f"Source node '{edge.source}' does not exist in the graph") + if edge.target not in self.nodes: + raise ValueError(f"Target node '{edge.target}' does not exist in the graph") + self.edges.append(edge) + + def has_node(self, node_id: str) -> bool: + """Check if a node exists in the graph. + + Args: + node_id (str): The ID of the node to check. + + Returns: + bool: True if the node exists, False otherwise. + """ + return node_id in self.nodes + + def get_node(self, node_id: str) -> Optional[Node]: + """Get a node from the graph. + + Args: + node_id (str): The ID of the node to get. + + Returns: + Optional[Node]: The node if it exists, None otherwise. + """ + return self.nodes.get(node_id) + + +class GraphBuilder: + def __init__(self): + self._visited: Set[int] = set() + + def build_from_agent(self, agent: Agent) -> Graph: + """Build a graph from an agent. + + Args: + agent (Agent): The agent to build the graph from. + + Returns: + Graph: The built graph. + """ + self._visited.clear() + graph = Graph() + + # Add start and end nodes + graph.add_node(Node("__start__", "__start__", NodeType.START)) + graph.add_node(Node("__end__", "__end__", NodeType.END)) + + self._add_agent_nodes_and_edges(agent, None, graph) + return graph + + def _add_agent_nodes_and_edges( + self, + agent: Agent, + parent: Optional[Agent], + graph: Graph, + ) -> None: + # Add agent node + graph.add_node(Node(agent.name, agent.name, NodeType.AGENT)) + + # Connect start node if root agent + if not parent: + graph.add_edge(Edge("__start__", agent.name, EdgeType.HANDOFF)) + + # Add tool nodes and edges + for tool in agent.tools: + graph.add_node(Node(tool.name, tool.name, NodeType.TOOL)) + graph.add_edge(Edge(agent.name, tool.name, EdgeType.TOOL)) + graph.add_edge(Edge(tool.name, agent.name, EdgeType.TOOL)) + + # Add current agent's ID to visited set + self._visited.add(id(agent)) + + # Process handoffs + has_handoffs = False + for handoff in agent.handoffs: + has_handoffs = True + if isinstance(handoff, Handoff): + graph.add_node(Node(handoff.agent_name, handoff.agent_name, NodeType.HANDOFF)) + graph.add_edge(Edge(agent.name, handoff.agent_name, EdgeType.HANDOFF)) + elif isinstance(handoff, Agent): + graph.add_node(Node(handoff.name, handoff.name, NodeType.AGENT)) + graph.add_edge(Edge(agent.name, handoff.name, EdgeType.HANDOFF)) + if id(handoff) not in self._visited: + self._add_agent_nodes_and_edges(handoff, agent, graph) + + # Connect to end node if no handoffs + if not has_handoffs and not isinstance(agent, Tool): + graph.add_edge(Edge(agent.name, "__end__", EdgeType.HANDOFF)) + + +T = TypeVar('T') + +class GraphRenderer(Generic[T], abc.ABC): + """Abstract base class for graph renderers.""" + + @abc.abstractmethod + def render(self, graph: Graph) -> T: + """Render the graph in the specific format. + + Args: + graph (Graph): The graph to render. + + Returns: + T: The rendered graph in the format specific to the renderer. + """ + pass + + @abc.abstractmethod + def save(self, rendered: T, filename: str) -> None: + """Save the rendered graph to a file. + + Args: + rendered (T): The rendered graph returned by render(). + filename (str): The name of the file to save the graph as. + """ + pass + + +class GraphvizRenderer(GraphRenderer[str]): + """Renderer that outputs graphs in Graphviz DOT format.""" + + def render(self, graph: Graph) -> str: + parts = [ + """ digraph G { graph [splines=true]; node [fontname="Arial"]; edge [penwidth=1.5]; """ - ] - parts.append(get_all_nodes(agent)) - parts.append(get_all_edges(agent)) - parts.append("}") - return "".join(parts) + ] + # Add nodes + for node in graph.nodes.values(): + parts.append(self._render_node(node)) -def get_all_nodes( - agent: Agent, parent: Optional[Agent] = None, visited: Optional[Set[int]] = None -) -> str: - """ - Recursively generates the nodes for the given agent and its handoffs in DOT format. + # Add edges + for edge in graph.edges: + parts.append(self._render_edge(edge)) - Args: - agent (Agent): The agent for which the nodes are to be generated. - parent (Agent, optional): The parent agent. Defaults to None. - visited (Set[int], optional): Set of already visited agent IDs to prevent infinite recursion. + parts.append("}") + return "".join(parts) + def save(self, rendered: str, filename: str) -> None: + """Save the rendered graph as a PNG file using graphviz. + + Args: + rendered (str): The DOT format string. + filename (str): The name of the file to save the graph as. + """ + graphviz.Source(rendered).render(filename, format="png") + + def _render_node(self, node: Node) -> str: + style_map = { + NodeType.START: 'shape=ellipse, style=filled, fillcolor=lightblue, width=0.5, height=0.3', + NodeType.END: 'shape=ellipse, style=filled, fillcolor=lightblue, width=0.5, height=0.3', + NodeType.AGENT: 'shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8', + NodeType.TOOL: 'shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3', + NodeType.HANDOFF: 'shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8', + } + return f'"{node.id}" [label="{node.label}", {style_map[node.type]}];' + + def _render_edge(self, edge: Edge) -> str: + if edge.type == EdgeType.TOOL: + return f'"{edge.source}" -> "{edge.target}" [style=dotted, penwidth=1.5];' + return f'"{edge.source}" -> "{edge.target}";' + + +class MermaidRenderer(GraphRenderer[str]): + """Renderer that outputs graphs in Mermaid flowchart syntax.""" + + def render(self, graph: Graph) -> str: + parts = ["graph TD\n"] + + # Add nodes with styles + for node in graph.nodes.values(): + parts.append(self._render_node(node)) + + # Add edges + for edge in graph.edges: + parts.append(self._render_edge(edge)) + + return "".join(parts) + + def save(self, rendered: str, filename: str) -> None: + """Save the rendered graph as a PNG file using mermaid.ink API. + + Args: + rendered (str): The Mermaid syntax string. + filename (str): The name of the file to save the graph as. + """ + # Encode the graph to base64 + graphbytes = rendered.encode("utf8") + base64_bytes = base64.urlsafe_b64encode(graphbytes) + base64_string = base64_bytes.decode("ascii") + + # Get the image from mermaid.ink + response = requests.get(f'https://mermaid.ink/img/{base64_string}') + response.raise_for_status() + + # Save the image directly from response content + with open(f"{filename}.png", "wb") as f: + f.write(response.content) + + def _render_node(self, node: Node) -> str: + # Map node types to Mermaid shapes + style_map = { + NodeType.START: ["(", ")", "lightblue"], + NodeType.END: ["(", ")", "lightblue"], + NodeType.AGENT: ["[", "]", "lightyellow"], + NodeType.TOOL: ["((", "))", "lightgreen"], + NodeType.HANDOFF: ["[", "]", "lightyellow"], + } + + start, end, color = style_map[node.type] + node_id = self._sanitize_id(node.id) + # Use sanitized ID and original label + return f"{node_id}{start}{node.label}{end}\nstyle {node_id} fill:{color}\n" + + def _render_edge(self, edge: Edge) -> str: + source = self._sanitize_id(edge.source) + target = self._sanitize_id(edge.target) + if edge.type == EdgeType.TOOL: + return f"{source} -.-> {target}\n" + return f"{source} --> {target}\n" + + def _sanitize_id(self, id: str) -> str: + """Sanitize node IDs to work with Mermaid's stricter ID requirements.""" + return id.replace(" ", "_").replace("-", "_") + + +class GraphView: + def __init__(self, rendered_graph: str, renderer: GraphRenderer, filename: Optional[str] = None): + self.rendered_graph = rendered_graph + self.renderer = renderer + self.filename = filename + + def view(self) -> None: + """Opens the rendered graph in a separate window.""" + import tempfile + import os + import webbrowser + + if self.filename: + webbrowser.open(f"file://{os.path.abspath(self.filename)}.png") + else: + temp_dir = tempfile.gettempdir() + temp_path = os.path.join(temp_dir, next(tempfile._get_candidate_names())) + self.renderer.save(self.rendered_graph, temp_path) + webbrowser.open(f"file://{os.path.abspath(temp_path)}.png") + + +def draw_graph(agent: Agent, filename: Optional[str] = None, renderer: str = "graphviz") -> GraphView: + """ + Draws the graph for the given agent using the specified renderer. + + Args: + agent (Agent): The agent for which the graph is to be drawn. + filename (str, optional): The name of the file to save the graph as PNG. Defaults to None. + renderer (str, optional): The renderer to use. Must be one of: "graphviz" (offline), + "mermaid" (requires internet). Defaults to "graphviz". + Returns: - str: The DOT format string representing the nodes. + GraphView: A view object that can be used to display the graph. + + Raises: + ValueError: If the specified renderer is not supported. + requests.RequestException: If using mermaid renderer and unable to connect to mermaid.ink API. """ - if visited is None: - visited = set() + builder = GraphBuilder() + graph = builder.build_from_agent(agent) + + if renderer == "graphviz": + renderer_instance = GraphvizRenderer() + elif renderer == "mermaid": + renderer_instance = MermaidRenderer() + else: + raise ValueError(f"Unsupported renderer: {renderer}") + + rendered = renderer_instance.render(graph) + + if filename: + filename = filename.rsplit('.', 1)[0] + renderer_instance.save(rendered, filename) + + return GraphView(rendered, renderer_instance, filename) - parts = [] - # Start and end the graph - parts.append( - '"__start__" [label="__start__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" - '"__end__" [label="__end__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" +def get_main_graph(agent: Agent) -> str: + """ + Generates the main graph structure in DOT format for the given agent. + + Args: + agent (Agent): The agent for which the graph is to be generated. + + Returns: + str: The DOT format string representing the graph. + + Deprecated: + This function is deprecated. Use GraphBuilder and GraphvizRenderer instead. + """ + warnings.warn( + "get_main_graph is deprecated. Use GraphBuilder and GraphvizRenderer instead.", + DeprecationWarning, + stacklevel=2, ) - # Ensure parent agent node is colored - if not parent: - parts.append( - f'"{agent.name}" [label="{agent.name}", shape=box, style=filled, ' - "fillcolor=lightyellow, width=1.5, height=0.8];" - ) - - for tool in agent.tools: - parts.append( - f'"{tool.name}" [label="{tool.name}", shape=ellipse, style=filled, ' - f"fillcolor=lightgreen, width=0.5, height=0.3];" - ) - - # Add current agent's ID to visited set - visited.add(id(agent)) - - for handoff in agent.handoffs: - if isinstance(handoff, Handoff): - parts.append( - f'"{handoff.agent_name}" [label="{handoff.agent_name}", ' - f"shape=box, style=filled, " - f"fillcolor=lightyellow, width=1.5, height=0.8];" - ) - if isinstance(handoff, Agent): - parts.append( - f'"{handoff.name}" [label="{handoff.name}", ' - f"shape=box, style=filled, " - f"fillcolor=lightyellow, width=1.5, height=0.8];" - ) - # Only recursively add nodes if we haven't visited this agent before - if id(handoff) not in visited: - parts.append(get_all_nodes(handoff, agent, visited)) - - return "".join(parts) + builder = GraphBuilder() + renderer = GraphvizRenderer() + graph = builder.build_from_agent(agent) + return renderer.render(graph) -def get_all_edges( +def get_all_nodes( agent: Agent, parent: Optional[Agent] = None, visited: Optional[Set[int]] = None ) -> str: """ - Recursively generates the edges for the given agent and its handoffs in DOT format. - - Args: - agent (Agent): The agent for which the edges are to be generated. - parent (Agent, optional): The parent agent. Defaults to None. - visited (Set[int], optional): Set of already visited agent IDs to prevent infinite recursion. - - Returns: - str: The DOT format string representing the edges. - """ - if visited is None: - visited = set() - - parts = [] - - if not parent: - parts.append(f'"__start__" -> "{agent.name}";') - - for tool in agent.tools: - parts.append( - f""" - "{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5]; - "{tool.name}" -> "{agent.name}" [style=dotted, penwidth=1.5];""" - ) - - # Add current agent's ID to visited set - visited.add(id(agent)) - - for handoff in agent.handoffs: - if isinstance(handoff, Handoff): - parts.append( - f""" - "{agent.name}" -> "{handoff.agent_name}";""" - ) - if isinstance(handoff, Agent): - # Check for loops - if id(handoff) in visited: - parts.append( - f""" - "{agent.name}" -> "{handoff.name}";""" - ) - else: - parts.append( - f""" - "{agent.name}" -> "{handoff.name}";""" - ) - parts.append(get_all_edges(handoff, agent, visited)) - - if not agent.handoffs and not isinstance(agent, Tool): # type: ignore - parts.append(f'"{agent.name}" -> "__end__";') - - return "".join(parts) - - -def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source: + Recursively generates the nodes for the given agent and its handoffs in DOT format. + + Deprecated: + This function is deprecated. Use GraphBuilder and GraphvizRenderer instead. """ - Draws the graph for the given agent and optionally saves it as a PNG file. + warnings.warn( + "get_all_nodes is deprecated. Use GraphBuilder and GraphvizRenderer instead.", + DeprecationWarning, + stacklevel=2, + ) + builder = GraphBuilder() + renderer = GraphvizRenderer() + graph = builder.build_from_agent(agent) + return "\n".join(renderer._render_node(node) for node in graph.nodes.values()) - Args: - agent (Agent): The agent for which the graph is to be drawn. - filename (str): The name of the file to save the graph as a PNG. - Returns: - graphviz.Source: The graphviz Source object representing the graph. +def get_all_edges( + agent: Agent, parent: Optional[Agent] = None, visited: Optional[Set[int]] = None +) -> str: """ - dot_code = get_main_graph(agent) - graph = graphviz.Source(dot_code) - - if filename: - graph.render(filename, format="png") - - return graph + Recursively generates the edges for the given agent and its handoffs in DOT format. + + Deprecated: + This function is deprecated. Use GraphBuilder and GraphvizRenderer instead. + """ + warnings.warn( + "get_all_edges is deprecated. Use GraphBuilder and GraphvizRenderer instead.", + DeprecationWarning, + stacklevel=2, + ) + builder = GraphBuilder() + renderer = GraphvizRenderer() + graph = builder.build_from_agent(agent) + return "\n".join(renderer._render_edge(edge) for edge in graph.edges) From 110f7417214c3c0cf30c36f45c665b18c866728a Mon Sep 17 00:00:00 2001 From: ashish-dahal Date: Sun, 6 Apr 2025 16:37:37 +0300 Subject: [PATCH 3/9] Add tests for graph representation and mermaid graph generation --- tests/test_visualization.py | 292 +++++++++++++++++++++++++++++++++--- 1 file changed, 270 insertions(+), 22 deletions(-) diff --git a/tests/test_visualization.py b/tests/test_visualization.py index fca24c84..440cd26a 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -1,4 +1,6 @@ -from unittest.mock import Mock +from unittest.mock import Mock, patch +import io +import base64 import graphviz # type: ignore import pytest @@ -9,9 +11,20 @@ get_all_edges, get_all_nodes, get_main_graph, + Graph, + GraphBuilder, + GraphvizRenderer, + GraphRenderer, + Node, + Edge, + NodeType, + EdgeType, + MermaidRenderer, + GraphView ) from agents.handoffs import Handoff + # Common test graph elements START_NODE = '"__start__" [label="__start__", shape=ellipse, style=filled, fillcolor=lightblue, width=0.5, height=0.3];' END_NODE = '"__end__" [label="__end__", shape=ellipse, style=filled, fillcolor=lightblue, width=0.5, height=0.3];' @@ -52,8 +65,208 @@ def mock_recursive_agents(): return agent1 +# Tests for the new graph abstraction +def test_graph_builder(mock_agent): + builder = GraphBuilder() + graph = builder.build_from_agent(mock_agent) + + # Check nodes + assert "__start__" in graph.nodes + assert "__end__" in graph.nodes + assert "Agent1" in graph.nodes + assert "Tool1" in graph.nodes + assert "Tool2" in graph.nodes + assert "Handoff1" in graph.nodes + + # Check node types + assert graph.nodes["__start__"].type == NodeType.START + assert graph.nodes["__end__"].type == NodeType.END + assert graph.nodes["Agent1"].type == NodeType.AGENT + assert graph.nodes["Tool1"].type == NodeType.TOOL + assert graph.nodes["Tool2"].type == NodeType.TOOL + assert graph.nodes["Handoff1"].type == NodeType.HANDOFF + + # Check edges + start_to_agent = Edge("__start__", "Agent1", EdgeType.HANDOFF) + agent_to_tool1 = Edge("Agent1", "Tool1", EdgeType.TOOL) + tool1_to_agent = Edge("Tool1", "Agent1", EdgeType.TOOL) + agent_to_tool2 = Edge("Agent1", "Tool2", EdgeType.TOOL) + tool2_to_agent = Edge("Tool2", "Agent1", EdgeType.TOOL) + agent_to_handoff = Edge("Agent1", "Handoff1", EdgeType.HANDOFF) + + assert any(e.source == start_to_agent.source and e.target == start_to_agent.target for e in graph.edges) + assert any(e.source == agent_to_tool1.source and e.target == agent_to_tool1.target for e in graph.edges) + assert any(e.source == tool1_to_agent.source and e.target == tool1_to_agent.target for e in graph.edges) + assert any(e.source == agent_to_tool2.source and e.target == agent_to_tool2.target for e in graph.edges) + assert any(e.source == tool2_to_agent.source and e.target == tool2_to_agent.target for e in graph.edges) + assert any(e.source == agent_to_handoff.source and e.target == agent_to_handoff.target for e in graph.edges) + + +def test_graphviz_renderer(mock_agent): + builder = GraphBuilder() + graph = builder.build_from_agent(mock_agent) + renderer = GraphvizRenderer() + dot_code = renderer.render(graph) + + assert "digraph G" in dot_code + assert "graph [splines=true];" in dot_code + assert 'node [fontname="Arial"];' in dot_code + assert "edge [penwidth=1.5];" in dot_code + assert START_NODE in dot_code + assert END_NODE in dot_code + assert AGENT_NODE in dot_code + assert TOOL1_NODE in dot_code + assert TOOL2_NODE in dot_code + assert HANDOFF_NODE in dot_code + + +def test_recursive_graph_builder(mock_recursive_agents): + builder = GraphBuilder() + graph = builder.build_from_agent(mock_recursive_agents) + + # Check nodes + assert "Agent1" in graph.nodes + assert "Agent2" in graph.nodes + assert graph.nodes["Agent1"].type == NodeType.AGENT + assert graph.nodes["Agent2"].type == NodeType.AGENT + + # Check edges + agent1_to_agent2 = Edge("Agent1", "Agent2", EdgeType.HANDOFF) + agent2_to_agent1 = Edge("Agent2", "Agent1", EdgeType.HANDOFF) + + assert any(e.source == agent1_to_agent2.source and e.target == agent1_to_agent2.target for e in graph.edges) + assert any(e.source == agent2_to_agent1.source and e.target == agent2_to_agent1.target for e in graph.edges) + + +def test_graph_validation(): + graph = Graph() + + # Test adding valid nodes and edges + node1 = Node("1", "Node 1", NodeType.AGENT) + node2 = Node("2", "Node 2", NodeType.TOOL) + graph.add_node(node1) + graph.add_node(node2) + + valid_edge = Edge("1", "2", EdgeType.TOOL) + graph.add_edge(valid_edge) + + # Test adding edge with non-existent source + invalid_edge1 = Edge("3", "2", EdgeType.TOOL) + with pytest.raises(ValueError, match="Source node '3' does not exist in the graph"): + graph.add_edge(invalid_edge1) + + # Test adding edge with non-existent target + invalid_edge2 = Edge("1", "3", EdgeType.TOOL) + with pytest.raises(ValueError, match="Target node '3' does not exist in the graph"): + graph.add_edge(invalid_edge2) + + # Test helper methods + assert graph.has_node("1") + assert graph.has_node("2") + assert not graph.has_node("3") + + assert graph.get_node("1") == node1 + assert graph.get_node("2") == node2 + assert graph.get_node("3") is None + + +def test_node_immutability(): + node = Node("1", "Node 1", NodeType.AGENT) + with pytest.raises(Exception): # dataclasses.FrozenInstanceError + node.id = "2" + with pytest.raises(Exception): # dataclasses.FrozenInstanceError + node.label = "Node 2" + with pytest.raises(Exception): # dataclasses.FrozenInstanceError + node.type = NodeType.TOOL + + +def test_edge_immutability(): + edge = Edge("1", "2", EdgeType.TOOL) + with pytest.raises(Exception): # dataclasses.FrozenInstanceError + edge.source = "3" + with pytest.raises(Exception): # dataclasses.FrozenInstanceError + edge.target = "3" + with pytest.raises(Exception): # dataclasses.FrozenInstanceError + edge.type = EdgeType.HANDOFF + + +def test_draw_graph_with_invalid_renderer(mock_agent): + with pytest.raises(ValueError, match=f"Unsupported renderer: invalid"): + draw_graph(mock_agent, renderer="invalid") + + +def test_draw_graph_default_renderer(mock_agent): + result = draw_graph(mock_agent) + assert isinstance(result, GraphView) + assert "digraph G" in result.rendered_graph + + +def test_draw_graph_with_filename(mock_agent, tmp_path): + filename = tmp_path / "test_graph" + result = draw_graph(mock_agent, filename=str(filename)) + assert isinstance(result, GraphView) + assert "digraph G" in result.rendered_graph + assert (tmp_path / "test_graph.png").exists() + + +def test_draw_graph_with_graphviz(mock_agent): + result = draw_graph(mock_agent, renderer="graphviz") + assert isinstance(result, GraphView) + assert "digraph G" in result.rendered_graph + assert "graph [splines=true];" in result.rendered_graph + assert 'node [fontname="Arial"];' in result.rendered_graph + assert "edge [penwidth=1.5];" in result.rendered_graph + assert START_NODE in result.rendered_graph + assert END_NODE in result.rendered_graph + assert AGENT_NODE in result.rendered_graph + assert TOOL1_NODE in result.rendered_graph + assert TOOL2_NODE in result.rendered_graph + assert HANDOFF_NODE in result.rendered_graph + + +def test_draw_graph_with_mermaid(mock_agent): + result = draw_graph(mock_agent, renderer="mermaid") + assert isinstance(result, GraphView) + assert "graph TD" in result.rendered_graph + assert "__start__(__start__)" in result.rendered_graph + assert "style __start__ fill:lightblue" in result.rendered_graph + assert "Agent1[Agent1]" in result.rendered_graph + assert "style Agent1 fill:lightyellow" in result.rendered_graph + + +def test_draw_graph_with_filename_graphviz(mock_agent, tmp_path): + filename = tmp_path / "test_graph" + result = draw_graph(mock_agent, filename=str(filename), renderer="graphviz") + assert isinstance(result, GraphView) + assert "digraph G" in result.rendered_graph + assert (tmp_path / "test_graph.png").exists() + + +def test_draw_graph_with_filename_mermaid(mock_agent, tmp_path): + filename = tmp_path / "test_graph" + mock_response = Mock() + mock_response.content = b"mock image data" + mock_response.raise_for_status = Mock() + + with patch("requests.get", return_value=mock_response): + result = draw_graph(mock_agent, filename=str(filename), renderer="mermaid") + assert isinstance(result, GraphView) + assert "graph TD" in result.rendered_graph + assert (tmp_path / "test_graph.png").exists() + with open(tmp_path / "test_graph.png", "rb") as f: + assert f.read() == b"mock image data" + + +def test_draw_graph(mock_agent): + result = draw_graph(mock_agent) + assert isinstance(result, GraphView) + assert "digraph G" in result.rendered_graph + + +# Legacy function tests def test_get_main_graph(mock_agent): - result = get_main_graph(mock_agent) + with pytest.warns(DeprecationWarning): + result = get_main_graph(mock_agent) assert "digraph G" in result assert "graph [splines=true];" in result assert 'node [fontname="Arial"];' in result @@ -67,7 +280,8 @@ def test_get_main_graph(mock_agent): def test_get_all_nodes(mock_agent): - result = get_all_nodes(mock_agent) + with pytest.warns(DeprecationWarning): + result = get_all_nodes(mock_agent) assert START_NODE in result assert END_NODE in result assert AGENT_NODE in result @@ -77,7 +291,8 @@ def test_get_all_nodes(mock_agent): def test_get_all_edges(mock_agent): - result = get_all_edges(mock_agent) + with pytest.warns(DeprecationWarning): + result = get_all_edges(mock_agent) assert '"__start__" -> "Agent1";' in result assert '"Agent1" -> "Tool1" [style=dotted, penwidth=1.5];' in result assert '"Tool1" -> "Agent1" [style=dotted, penwidth=1.5];' in result @@ -86,25 +301,9 @@ def test_get_all_edges(mock_agent): assert '"Agent1" -> "Handoff1";' in result -def test_draw_graph(mock_agent): - graph = draw_graph(mock_agent) - assert isinstance(graph, graphviz.Source) - source = graph.source - assert "digraph G" in source - assert "graph [splines=true];" in source - assert 'node [fontname="Arial"];' in source - assert "edge [penwidth=1.5];" in source - assert START_NODE in source - assert END_NODE in source - assert AGENT_NODE in source - assert TOOL1_NODE in source - assert TOOL2_NODE in source - assert HANDOFF_NODE in source - - def test_recursive_handoff_loop(mock_recursive_agents): - agent1 = mock_recursive_agents - dot = get_main_graph(agent1) + with pytest.warns(DeprecationWarning): + dot = get_main_graph(mock_recursive_agents) assert ( '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' @@ -116,3 +315,52 @@ def test_recursive_handoff_loop(mock_recursive_agents): ) assert '"Agent1" -> "Agent2";' in dot assert '"Agent2" -> "Agent1";' in dot + + +def test_mermaid_renderer(mock_agent): + builder = GraphBuilder() + graph = builder.build_from_agent(mock_agent) + renderer = MermaidRenderer() + mermaid_code = renderer.render(graph) + + # Test flowchart header + assert "graph TD" in mermaid_code + + # Test node rendering + assert "__start__(__start__)" in mermaid_code + assert "style __start__ fill:lightblue" in mermaid_code + assert "__end__(__end__)" in mermaid_code + assert "style __end__ fill:lightblue" in mermaid_code + assert "Agent1[Agent1]" in mermaid_code + assert "style Agent1 fill:lightyellow" in mermaid_code + assert "Tool1((Tool1))" in mermaid_code + assert "style Tool1 fill:lightgreen" in mermaid_code + assert "Tool2((Tool2))" in mermaid_code + assert "style Tool2 fill:lightgreen" in mermaid_code + assert "Handoff1[Handoff1]" in mermaid_code + assert "style Handoff1 fill:lightyellow" in mermaid_code + + # Test edge rendering + assert "__start__ --> Agent1" in mermaid_code + assert "Agent1 -.-> Tool1" in mermaid_code + assert "Tool1 -.-> Agent1" in mermaid_code + assert "Agent1 -.-> Tool2" in mermaid_code + assert "Tool2 -.-> Agent1" in mermaid_code + assert "Agent1 --> Handoff1" in mermaid_code + + +def test_mermaid_renderer_save(mock_agent, tmp_path): + renderer = MermaidRenderer() + graph = GraphBuilder().build_from_agent(mock_agent) + rendered = renderer.render(graph) + filename = tmp_path / "test_graph" + + mock_response = Mock() + mock_response.content = b"mock image data" + mock_response.raise_for_status = Mock() + + with patch("requests.get", return_value=mock_response): + renderer.save(rendered, str(filename)) + assert (tmp_path / "test_graph.png").exists() + with open(tmp_path / "test_graph.png", "rb") as f: + assert f.read() == b"mock image data" From 5f2e60fa9cf3977ce7185425af91765cf37605bc Mon Sep 17 00:00:00 2001 From: ashish-dahal Date: Sun, 6 Apr 2025 16:37:48 +0300 Subject: [PATCH 4/9] Update agent visualization documentation to include Mermaid rendering option --- docs/visualization.md | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/docs/visualization.md b/docs/visualization.md index 00f3126d..00816f3b 100644 --- a/docs/visualization.md +++ b/docs/visualization.md @@ -1,6 +1,10 @@ # Agent Visualization -Agent visualization allows you to generate a structured graphical representation of agents and their relationships using **Graphviz**. This is useful for understanding how agents, tools, and handoffs interact within an application. +Agent visualization allows you to generate a structured graphical representation of agents and their relationships. Two rendering options are available: +- **Graphviz** (offline): Default renderer that generates graphs locally +- **Mermaid** (online): Alternative renderer that uses mermaid.ink API + +This is useful for understanding how agents, tools, and handoffs interact within an application. ## Installation @@ -18,6 +22,15 @@ You can generate an agent visualization using the `draw_graph` function. This fu - **Tools** are represented as green ellipses. - **Handoffs** are directed edges from one agent to another. +The renderer can be specified using the `renderer` parameter: +```python +# Using Graphviz (default) +draw_graph(agent, renderer="graphviz") + +# Using Mermaid API +draw_graph(agent, renderer="mermaid") +``` + ### Example Usage ```python @@ -82,5 +95,3 @@ draw_graph(triage_agent, filename="agent_graph.png") ``` This will generate `agent_graph.png` in the working directory. - - From 78c56bfc14ff220faef8ccdcc2097be5f30e0452 Mon Sep 17 00:00:00 2001 From: ashish-dahal Date: Sun, 6 Apr 2025 20:56:49 +0300 Subject: [PATCH 5/9] Fix linting issues --- src/agents/extensions/visualization.py | 139 ++++++++++++++----------- tests/test_visualization.py | 139 +++++++++++++++---------- 2 files changed, 164 insertions(+), 114 deletions(-) diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index 472e4a9b..c2f3139f 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -1,10 +1,11 @@ -from typing import Optional, Set, Dict, List, TypeVar, Generic +import abc +import base64 +import warnings from dataclasses import dataclass from enum import Enum -import warnings -import abc +from typing import Generic, Optional, TypeVar + import graphviz -import base64 import requests from agents import Agent @@ -41,18 +42,18 @@ class Edge: class Graph: def __init__(self): - self.nodes: Dict[str, Node] = {} - self.edges: List[Edge] = [] + self.nodes: dict[str, Node] = {} + self.edges: list[Edge] = [] def add_node(self, node: Node) -> None: self.nodes[node.id] = node def add_edge(self, edge: Edge) -> None: """Add an edge to the graph. - + Args: edge (Edge): The edge to add. - + Raises: ValueError: If the source or target node does not exist in the graph. """ @@ -64,10 +65,10 @@ def add_edge(self, edge: Edge) -> None: def has_node(self, node_id: str) -> bool: """Check if a node exists in the graph. - + Args: node_id (str): The ID of the node to check. - + Returns: bool: True if the node exists, False otherwise. """ @@ -75,10 +76,10 @@ def has_node(self, node_id: str) -> bool: def get_node(self, node_id: str) -> Optional[Node]: """Get a node from the graph. - + Args: node_id (str): The ID of the node to get. - + Returns: Optional[Node]: The node if it exists, None otherwise. """ @@ -87,14 +88,14 @@ def get_node(self, node_id: str) -> Optional[Node]: class GraphBuilder: def __init__(self): - self._visited: Set[int] = set() - + self._visited: set[int] = set() + def build_from_agent(self, agent: Agent) -> Graph: """Build a graph from an agent. - + Args: agent (Agent): The agent to build the graph from. - + Returns: Graph: The built graph. """ @@ -148,27 +149,28 @@ def _add_agent_nodes_and_edges( graph.add_edge(Edge(agent.name, "__end__", EdgeType.HANDOFF)) -T = TypeVar('T') +T = TypeVar("T") + class GraphRenderer(Generic[T], abc.ABC): """Abstract base class for graph renderers.""" - + @abc.abstractmethod def render(self, graph: Graph) -> T: """Render the graph in the specific format. - + Args: graph (Graph): The graph to render. - + Returns: T: The rendered graph in the format specific to the renderer. """ pass - + @abc.abstractmethod def save(self, rendered: T, filename: str) -> None: """Save the rendered graph to a file. - + Args: rendered (T): The rendered graph returned by render(). filename (str): The name of the file to save the graph as. @@ -178,7 +180,7 @@ def save(self, rendered: T, filename: str) -> None: class GraphvizRenderer(GraphRenderer[str]): """Renderer that outputs graphs in Graphviz DOT format.""" - + def render(self, graph: Graph) -> str: parts = [ """ @@ -202,7 +204,7 @@ def render(self, graph: Graph) -> str: def save(self, rendered: str, filename: str) -> None: """Save the rendered graph as a PNG file using graphviz. - + Args: rendered (str): The DOT format string. filename (str): The name of the file to save the graph as. @@ -211,11 +213,21 @@ def save(self, rendered: str, filename: str) -> None: def _render_node(self, node: Node) -> str: style_map = { - NodeType.START: 'shape=ellipse, style=filled, fillcolor=lightblue, width=0.5, height=0.3', - NodeType.END: 'shape=ellipse, style=filled, fillcolor=lightblue, width=0.5, height=0.3', - NodeType.AGENT: 'shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8', - NodeType.TOOL: 'shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3', - NodeType.HANDOFF: 'shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8', + NodeType.START: ( + "shape=ellipse, style=filled, fillcolor=lightblue, width=0.5, height=0.3" + ), + NodeType.END: ( + "shape=ellipse, style=filled, fillcolor=lightblue, width=0.5, height=0.3" + ), + NodeType.AGENT: ( + "shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8" + ), + NodeType.TOOL: ( + "shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3" + ), + NodeType.HANDOFF: ( + "shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8" + ), } return f'"{node.id}" [label="{node.label}", {style_map[node.type]}];' @@ -227,23 +239,23 @@ def _render_edge(self, edge: Edge) -> str: class MermaidRenderer(GraphRenderer[str]): """Renderer that outputs graphs in Mermaid flowchart syntax.""" - + def render(self, graph: Graph) -> str: parts = ["graph TD\n"] - + # Add nodes with styles for node in graph.nodes.values(): parts.append(self._render_node(node)) - + # Add edges for edge in graph.edges: parts.append(self._render_edge(edge)) - + return "".join(parts) def save(self, rendered: str, filename: str) -> None: """Save the rendered graph as a PNG file using mermaid.ink API. - + Args: rendered (str): The Mermaid syntax string. filename (str): The name of the file to save the graph as. @@ -252,11 +264,11 @@ def save(self, rendered: str, filename: str) -> None: graphbytes = rendered.encode("utf8") base64_bytes = base64.urlsafe_b64encode(graphbytes) base64_string = base64_bytes.decode("ascii") - + # Get the image from mermaid.ink - response = requests.get(f'https://mermaid.ink/img/{base64_string}') + response = requests.get(f"https://mermaid.ink/img/{base64_string}") response.raise_for_status() - + # Save the image directly from response content with open(f"{filename}.png", "wb") as f: f.write(response.content) @@ -270,36 +282,38 @@ def _render_node(self, node: Node) -> str: NodeType.TOOL: ["((", "))", "lightgreen"], NodeType.HANDOFF: ["[", "]", "lightyellow"], } - + start, end, color = style_map[node.type] node_id = self._sanitize_id(node.id) # Use sanitized ID and original label return f"{node_id}{start}{node.label}{end}\nstyle {node_id} fill:{color}\n" - + def _render_edge(self, edge: Edge) -> str: source = self._sanitize_id(edge.source) target = self._sanitize_id(edge.target) if edge.type == EdgeType.TOOL: return f"{source} -.-> {target}\n" return f"{source} --> {target}\n" - + def _sanitize_id(self, id: str) -> str: """Sanitize node IDs to work with Mermaid's stricter ID requirements.""" return id.replace(" ", "_").replace("-", "_") class GraphView: - def __init__(self, rendered_graph: str, renderer: GraphRenderer, filename: Optional[str] = None): + def __init__( + self, rendered_graph: str, renderer: GraphRenderer, filename: Optional[str] = None + ): self.rendered_graph = rendered_graph self.renderer = renderer self.filename = filename def view(self) -> None: """Opens the rendered graph in a separate window.""" - import tempfile import os + import tempfile import webbrowser - + if self.filename: webbrowser.open(f"file://{os.path.abspath(self.filename)}.png") else: @@ -309,52 +323,55 @@ def view(self) -> None: webbrowser.open(f"file://{os.path.abspath(temp_path)}.png") -def draw_graph(agent: Agent, filename: Optional[str] = None, renderer: str = "graphviz") -> GraphView: +def draw_graph( + agent: Agent, filename: Optional[str] = None, renderer: str = "graphviz" +) -> GraphView: """ Draws the graph for the given agent using the specified renderer. - + Args: agent (Agent): The agent for which the graph is to be drawn. filename (str, optional): The name of the file to save the graph as PNG. Defaults to None. - renderer (str, optional): The renderer to use. Must be one of: "graphviz" (offline), + renderer (str, optional): The renderer to use. Must be one of: "graphviz" (offline), "mermaid" (requires internet). Defaults to "graphviz". - + Returns: GraphView: A view object that can be used to display the graph. - + Raises: ValueError: If the specified renderer is not supported. - requests.RequestException: If using mermaid renderer and unable to connect to mermaid.ink API. + requests.RequestException: If using mermaid renderer and unable to connect + to mermaid.ink API. """ builder = GraphBuilder() graph = builder.build_from_agent(agent) - + if renderer == "graphviz": renderer_instance = GraphvizRenderer() elif renderer == "mermaid": renderer_instance = MermaidRenderer() else: raise ValueError(f"Unsupported renderer: {renderer}") - + rendered = renderer_instance.render(graph) - + if filename: - filename = filename.rsplit('.', 1)[0] + filename = filename.rsplit(".", 1)[0] renderer_instance.save(rendered, filename) - + return GraphView(rendered, renderer_instance, filename) def get_main_graph(agent: Agent) -> str: """ Generates the main graph structure in DOT format for the given agent. - + Args: agent (Agent): The agent for which the graph is to be generated. - + Returns: str: The DOT format string representing the graph. - + Deprecated: This function is deprecated. Use GraphBuilder and GraphvizRenderer instead. """ @@ -370,11 +387,11 @@ def get_main_graph(agent: Agent) -> str: def get_all_nodes( - agent: Agent, parent: Optional[Agent] = None, visited: Optional[Set[int]] = None + agent: Agent, parent: Optional[Agent] = None, visited: Optional[set[int]] = None ) -> str: """ Recursively generates the nodes for the given agent and its handoffs in DOT format. - + Deprecated: This function is deprecated. Use GraphBuilder and GraphvizRenderer instead. """ @@ -390,11 +407,11 @@ def get_all_nodes( def get_all_edges( - agent: Agent, parent: Optional[Agent] = None, visited: Optional[Set[int]] = None + agent: Agent, parent: Optional[Agent] = None, visited: Optional[set[int]] = None ) -> str: """ Recursively generates the edges for the given agent and its handoffs in DOT format. - + Deprecated: This function is deprecated. Use GraphBuilder and GraphvizRenderer instead. """ diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 440cd26a..e80a20de 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -1,37 +1,51 @@ +import dataclasses from unittest.mock import Mock, patch -import io -import base64 -import graphviz # type: ignore import pytest from agents import Agent from agents.extensions.visualization import ( - draw_graph, - get_all_edges, - get_all_nodes, - get_main_graph, + Edge, + EdgeType, Graph, GraphBuilder, + GraphView, GraphvizRenderer, - GraphRenderer, + MermaidRenderer, Node, - Edge, NodeType, - EdgeType, - MermaidRenderer, - GraphView + draw_graph, + get_all_edges, + get_all_nodes, + get_main_graph, ) from agents.handoffs import Handoff - # Common test graph elements -START_NODE = '"__start__" [label="__start__", shape=ellipse, style=filled, fillcolor=lightblue, width=0.5, height=0.3];' -END_NODE = '"__end__" [label="__end__", shape=ellipse, style=filled, fillcolor=lightblue, width=0.5, height=0.3];' -AGENT_NODE = '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' -TOOL1_NODE = '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' -TOOL2_NODE = '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' -HANDOFF_NODE = '"Handoff1" [label="Handoff1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' +START_NODE = ( + '"__start__" [label="__start__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" +) +END_NODE = ( + '"__end__" [label="__end__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" +) +AGENT_NODE = ( + '"Agent1" [label="Agent1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" +) +TOOL1_NODE = ( + '"Tool1" [label="Tool1", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" +) +TOOL2_NODE = ( + '"Tool2" [label="Tool2", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" +) +HANDOFF_NODE = ( + '"Handoff1" [label="Handoff1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" +) @pytest.fixture @@ -69,7 +83,7 @@ def mock_recursive_agents(): def test_graph_builder(mock_agent): builder = GraphBuilder() graph = builder.build_from_agent(mock_agent) - + # Check nodes assert "__start__" in graph.nodes assert "__end__" in graph.nodes @@ -77,7 +91,7 @@ def test_graph_builder(mock_agent): assert "Tool1" in graph.nodes assert "Tool2" in graph.nodes assert "Handoff1" in graph.nodes - + # Check node types assert graph.nodes["__start__"].type == NodeType.START assert graph.nodes["__end__"].type == NodeType.END @@ -94,12 +108,25 @@ def test_graph_builder(mock_agent): tool2_to_agent = Edge("Tool2", "Agent1", EdgeType.TOOL) agent_to_handoff = Edge("Agent1", "Handoff1", EdgeType.HANDOFF) - assert any(e.source == start_to_agent.source and e.target == start_to_agent.target for e in graph.edges) - assert any(e.source == agent_to_tool1.source and e.target == agent_to_tool1.target for e in graph.edges) - assert any(e.source == tool1_to_agent.source and e.target == tool1_to_agent.target for e in graph.edges) - assert any(e.source == agent_to_tool2.source and e.target == agent_to_tool2.target for e in graph.edges) - assert any(e.source == tool2_to_agent.source and e.target == tool2_to_agent.target for e in graph.edges) - assert any(e.source == agent_to_handoff.source and e.target == agent_to_handoff.target for e in graph.edges) + assert any( + e.source == start_to_agent.source and e.target == start_to_agent.target for e in graph.edges + ) + assert any( + e.source == agent_to_tool1.source and e.target == agent_to_tool1.target for e in graph.edges + ) + assert any( + e.source == tool1_to_agent.source and e.target == tool1_to_agent.target for e in graph.edges + ) + assert any( + e.source == agent_to_tool2.source and e.target == agent_to_tool2.target for e in graph.edges + ) + assert any( + e.source == tool2_to_agent.source and e.target == tool2_to_agent.target for e in graph.edges + ) + assert any( + e.source == agent_to_handoff.source and e.target == agent_to_handoff.target + for e in graph.edges + ) def test_graphviz_renderer(mock_agent): @@ -107,7 +134,7 @@ def test_graphviz_renderer(mock_agent): graph = builder.build_from_agent(mock_agent) renderer = GraphvizRenderer() dot_code = renderer.render(graph) - + assert "digraph G" in dot_code assert "graph [splines=true];" in dot_code assert 'node [fontname="Arial"];' in dot_code @@ -123,48 +150,54 @@ def test_graphviz_renderer(mock_agent): def test_recursive_graph_builder(mock_recursive_agents): builder = GraphBuilder() graph = builder.build_from_agent(mock_recursive_agents) - + # Check nodes assert "Agent1" in graph.nodes assert "Agent2" in graph.nodes assert graph.nodes["Agent1"].type == NodeType.AGENT assert graph.nodes["Agent2"].type == NodeType.AGENT - + # Check edges agent1_to_agent2 = Edge("Agent1", "Agent2", EdgeType.HANDOFF) agent2_to_agent1 = Edge("Agent2", "Agent1", EdgeType.HANDOFF) - - assert any(e.source == agent1_to_agent2.source and e.target == agent1_to_agent2.target for e in graph.edges) - assert any(e.source == agent2_to_agent1.source and e.target == agent2_to_agent1.target for e in graph.edges) + + assert any( + e.source == agent1_to_agent2.source and e.target == agent1_to_agent2.target + for e in graph.edges + ) + assert any( + e.source == agent2_to_agent1.source and e.target == agent2_to_agent1.target + for e in graph.edges + ) def test_graph_validation(): graph = Graph() - + # Test adding valid nodes and edges node1 = Node("1", "Node 1", NodeType.AGENT) node2 = Node("2", "Node 2", NodeType.TOOL) graph.add_node(node1) graph.add_node(node2) - + valid_edge = Edge("1", "2", EdgeType.TOOL) graph.add_edge(valid_edge) - + # Test adding edge with non-existent source invalid_edge1 = Edge("3", "2", EdgeType.TOOL) with pytest.raises(ValueError, match="Source node '3' does not exist in the graph"): graph.add_edge(invalid_edge1) - + # Test adding edge with non-existent target invalid_edge2 = Edge("1", "3", EdgeType.TOOL) with pytest.raises(ValueError, match="Target node '3' does not exist in the graph"): graph.add_edge(invalid_edge2) - + # Test helper methods assert graph.has_node("1") assert graph.has_node("2") assert not graph.has_node("3") - + assert graph.get_node("1") == node1 assert graph.get_node("2") == node2 assert graph.get_node("3") is None @@ -172,26 +205,26 @@ def test_graph_validation(): def test_node_immutability(): node = Node("1", "Node 1", NodeType.AGENT) - with pytest.raises(Exception): # dataclasses.FrozenInstanceError + with pytest.raises(dataclasses.FrozenInstanceError): node.id = "2" - with pytest.raises(Exception): # dataclasses.FrozenInstanceError + with pytest.raises(dataclasses.FrozenInstanceError): node.label = "Node 2" - with pytest.raises(Exception): # dataclasses.FrozenInstanceError + with pytest.raises(dataclasses.FrozenInstanceError): node.type = NodeType.TOOL def test_edge_immutability(): edge = Edge("1", "2", EdgeType.TOOL) - with pytest.raises(Exception): # dataclasses.FrozenInstanceError + with pytest.raises(dataclasses.FrozenInstanceError): edge.source = "3" - with pytest.raises(Exception): # dataclasses.FrozenInstanceError + with pytest.raises(dataclasses.FrozenInstanceError): edge.target = "3" - with pytest.raises(Exception): # dataclasses.FrozenInstanceError + with pytest.raises(dataclasses.FrozenInstanceError): edge.type = EdgeType.HANDOFF def test_draw_graph_with_invalid_renderer(mock_agent): - with pytest.raises(ValueError, match=f"Unsupported renderer: invalid"): + with pytest.raises(ValueError, match="Unsupported renderer: invalid"): draw_graph(mock_agent, renderer="invalid") @@ -306,12 +339,12 @@ def test_recursive_handoff_loop(mock_recursive_agents): dot = get_main_graph(mock_recursive_agents) assert ( - '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' - in dot + '"Agent1" [label="Agent1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" in dot ) assert ( - '"Agent2" [label="Agent2", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' - in dot + '"Agent2" [label="Agent2", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" in dot ) assert '"Agent1" -> "Agent2";' in dot assert '"Agent2" -> "Agent1";' in dot @@ -322,10 +355,10 @@ def test_mermaid_renderer(mock_agent): graph = builder.build_from_agent(mock_agent) renderer = MermaidRenderer() mermaid_code = renderer.render(graph) - + # Test flowchart header assert "graph TD" in mermaid_code - + # Test node rendering assert "__start__(__start__)" in mermaid_code assert "style __start__ fill:lightblue" in mermaid_code @@ -339,7 +372,7 @@ def test_mermaid_renderer(mock_agent): assert "style Tool2 fill:lightgreen" in mermaid_code assert "Handoff1[Handoff1]" in mermaid_code assert "style Handoff1 fill:lightyellow" in mermaid_code - + # Test edge rendering assert "__start__ --> Agent1" in mermaid_code assert "Agent1 -.-> Tool1" in mermaid_code From 73ebe9a64f50884d3800a16a901c9b941a9198e8 Mon Sep 17 00:00:00 2001 From: ashish-dahal Date: Tue, 8 Apr 2025 23:45:36 +0300 Subject: [PATCH 6/9] Refactor draw_graph function signature to use Literal for renderer type and improve type hints --- src/agents/extensions/visualization.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index c2f3139f..679876e4 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -3,7 +3,7 @@ import warnings from dataclasses import dataclass from enum import Enum -from typing import Generic, Optional, TypeVar +from typing import Generic, Literal, Optional, TypeVar import graphviz import requests @@ -324,24 +324,23 @@ def view(self) -> None: def draw_graph( - agent: Agent, filename: Optional[str] = None, renderer: str = "graphviz" + agent: Agent, + filename: str | None = None, + renderer: Literal["graphviz", "mermaid"] = "graphviz", ) -> GraphView: """ Draws the graph for the given agent using the specified renderer. Args: agent (Agent): The agent for which the graph is to be drawn. - filename (str, optional): The name of the file to save the graph as PNG. Defaults to None. - renderer (str, optional): The renderer to use. Must be one of: "graphviz" (offline), - "mermaid" (requires internet). Defaults to "graphviz". + filename (str | None): The name of the file to save the graph as PNG. Defaults to None. + renderer (Literal["graphviz", "mermaid"]): The renderer to use. Defaults to "graphviz". Returns: GraphView: A view object that can be used to display the graph. Raises: ValueError: If the specified renderer is not supported. - requests.RequestException: If using mermaid renderer and unable to connect - to mermaid.ink API. """ builder = GraphBuilder() graph = builder.build_from_agent(agent) From 57e9c18e65042460a04e52d4528abecae4283f4f Mon Sep 17 00:00:00 2001 From: ashish-dahal Date: Tue, 8 Apr 2025 23:49:58 +0300 Subject: [PATCH 7/9] Handle None case for agent in _add_agent_nodes_and_edges method --- src/agents/extensions/visualization.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index 679876e4..b808d6b4 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -111,10 +111,13 @@ def build_from_agent(self, agent: Agent) -> Graph: def _add_agent_nodes_and_edges( self, - agent: Agent, + agent: Agent | None, parent: Optional[Agent], graph: Graph, ) -> None: + if agent is None: + return + # Add agent node graph.add_node(Node(agent.name, agent.name, NodeType.AGENT)) From 3658d3a8c8434ef286b34b470b08dbbbab2fc18f Mon Sep 17 00:00:00 2001 From: ashish-dahal Date: Wed, 9 Apr 2025 00:14:59 +0300 Subject: [PATCH 8/9] Update Edge class to use Node instances for source and target --- src/agents/extensions/visualization.py | 48 +++++++++++++++-------- tests/test_visualization.py | 53 +++++++++++++++++--------- 2 files changed, 66 insertions(+), 35 deletions(-) diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index b808d6b4..fc8aa8bb 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -35,8 +35,8 @@ class Node: @dataclass(frozen=True) class Edge: - source: str - target: str + source: Node + target: Node type: EdgeType @@ -57,10 +57,10 @@ def add_edge(self, edge: Edge) -> None: Raises: ValueError: If the source or target node does not exist in the graph. """ - if edge.source not in self.nodes: - raise ValueError(f"Source node '{edge.source}' does not exist in the graph") - if edge.target not in self.nodes: - raise ValueError(f"Target node '{edge.target}' does not exist in the graph") + if edge.source.id not in self.nodes: + raise ValueError(f"Source node '{edge.source.id}' does not exist in the graph") + if edge.target.id not in self.nodes: + raise ValueError(f"Target node '{edge.target.id}' does not exist in the graph") self.edges.append(edge) def has_node(self, node_id: str) -> bool: @@ -123,13 +123,19 @@ def _add_agent_nodes_and_edges( # Connect start node if root agent if not parent: - graph.add_edge(Edge("__start__", agent.name, EdgeType.HANDOFF)) + graph.add_edge( + Edge(graph.get_node("__start__"), graph.get_node(agent.name), EdgeType.HANDOFF) + ) # Add tool nodes and edges for tool in agent.tools: graph.add_node(Node(tool.name, tool.name, NodeType.TOOL)) - graph.add_edge(Edge(agent.name, tool.name, EdgeType.TOOL)) - graph.add_edge(Edge(tool.name, agent.name, EdgeType.TOOL)) + graph.add_edge( + Edge(graph.get_node(agent.name), graph.get_node(tool.name), EdgeType.TOOL) + ) + graph.add_edge( + Edge(graph.get_node(tool.name), graph.get_node(agent.name), EdgeType.TOOL) + ) # Add current agent's ID to visited set self._visited.add(id(agent)) @@ -140,16 +146,26 @@ def _add_agent_nodes_and_edges( has_handoffs = True if isinstance(handoff, Handoff): graph.add_node(Node(handoff.agent_name, handoff.agent_name, NodeType.HANDOFF)) - graph.add_edge(Edge(agent.name, handoff.agent_name, EdgeType.HANDOFF)) + graph.add_edge( + Edge( + graph.get_node(agent.name), + graph.get_node(handoff.agent_name), + EdgeType.HANDOFF, + ) + ) elif isinstance(handoff, Agent): graph.add_node(Node(handoff.name, handoff.name, NodeType.AGENT)) - graph.add_edge(Edge(agent.name, handoff.name, EdgeType.HANDOFF)) + graph.add_edge( + Edge(graph.get_node(agent.name), graph.get_node(handoff.name), EdgeType.HANDOFF) + ) if id(handoff) not in self._visited: self._add_agent_nodes_and_edges(handoff, agent, graph) # Connect to end node if no handoffs if not has_handoffs and not isinstance(agent, Tool): - graph.add_edge(Edge(agent.name, "__end__", EdgeType.HANDOFF)) + graph.add_edge( + Edge(graph.get_node(agent.name), graph.get_node("__end__"), EdgeType.HANDOFF) + ) T = TypeVar("T") @@ -236,8 +252,8 @@ def _render_node(self, node: Node) -> str: def _render_edge(self, edge: Edge) -> str: if edge.type == EdgeType.TOOL: - return f'"{edge.source}" -> "{edge.target}" [style=dotted, penwidth=1.5];' - return f'"{edge.source}" -> "{edge.target}";' + return f'"{edge.source.id}" -> "{edge.target.id}" [style=dotted, penwidth=1.5];' + return f'"{edge.source.id}" -> "{edge.target.id}";' class MermaidRenderer(GraphRenderer[str]): @@ -292,8 +308,8 @@ def _render_node(self, node: Node) -> str: return f"{node_id}{start}{node.label}{end}\nstyle {node_id} fill:{color}\n" def _render_edge(self, edge: Edge) -> str: - source = self._sanitize_id(edge.source) - target = self._sanitize_id(edge.target) + source = self._sanitize_id(edge.source.id) + target = self._sanitize_id(edge.target.id) if edge.type == EdgeType.TOOL: return f"{source} -.-> {target}\n" return f"{source} --> {target}\n" diff --git a/tests/test_visualization.py b/tests/test_visualization.py index e80a20de..544006da 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -101,30 +101,41 @@ def test_graph_builder(mock_agent): assert graph.nodes["Handoff1"].type == NodeType.HANDOFF # Check edges - start_to_agent = Edge("__start__", "Agent1", EdgeType.HANDOFF) - agent_to_tool1 = Edge("Agent1", "Tool1", EdgeType.TOOL) - tool1_to_agent = Edge("Tool1", "Agent1", EdgeType.TOOL) - agent_to_tool2 = Edge("Agent1", "Tool2", EdgeType.TOOL) - tool2_to_agent = Edge("Tool2", "Agent1", EdgeType.TOOL) - agent_to_handoff = Edge("Agent1", "Handoff1", EdgeType.HANDOFF) + start_node = graph.nodes["__start__"] + agent_node = graph.nodes["Agent1"] + tool1_node = graph.nodes["Tool1"] + tool2_node = graph.nodes["Tool2"] + handoff_node = graph.nodes["Handoff1"] + + start_to_agent = Edge(start_node, agent_node, EdgeType.HANDOFF) + agent_to_tool1 = Edge(agent_node, tool1_node, EdgeType.TOOL) + tool1_to_agent = Edge(tool1_node, agent_node, EdgeType.TOOL) + agent_to_tool2 = Edge(agent_node, tool2_node, EdgeType.TOOL) + tool2_to_agent = Edge(tool2_node, agent_node, EdgeType.TOOL) + agent_to_handoff = Edge(agent_node, handoff_node, EdgeType.HANDOFF) assert any( - e.source == start_to_agent.source and e.target == start_to_agent.target for e in graph.edges + e.source.id == start_to_agent.source.id and e.target.id == start_to_agent.target.id + for e in graph.edges ) assert any( - e.source == agent_to_tool1.source and e.target == agent_to_tool1.target for e in graph.edges + e.source.id == agent_to_tool1.source.id and e.target.id == agent_to_tool1.target.id + for e in graph.edges ) assert any( - e.source == tool1_to_agent.source and e.target == tool1_to_agent.target for e in graph.edges + e.source.id == tool1_to_agent.source.id and e.target.id == tool1_to_agent.target.id + for e in graph.edges ) assert any( - e.source == agent_to_tool2.source and e.target == agent_to_tool2.target for e in graph.edges + e.source.id == agent_to_tool2.source.id and e.target.id == agent_to_tool2.target.id + for e in graph.edges ) assert any( - e.source == tool2_to_agent.source and e.target == tool2_to_agent.target for e in graph.edges + e.source.id == tool2_to_agent.source.id and e.target.id == tool2_to_agent.target.id + for e in graph.edges ) assert any( - e.source == agent_to_handoff.source and e.target == agent_to_handoff.target + e.source.id == agent_to_handoff.source.id and e.target.id == agent_to_handoff.target.id for e in graph.edges ) @@ -158,15 +169,18 @@ def test_recursive_graph_builder(mock_recursive_agents): assert graph.nodes["Agent2"].type == NodeType.AGENT # Check edges - agent1_to_agent2 = Edge("Agent1", "Agent2", EdgeType.HANDOFF) - agent2_to_agent1 = Edge("Agent2", "Agent1", EdgeType.HANDOFF) + agent1_node = graph.nodes["Agent1"] + agent2_node = graph.nodes["Agent2"] + + agent1_to_agent2 = Edge(agent1_node, agent2_node, EdgeType.HANDOFF) + agent2_to_agent1 = Edge(agent2_node, agent1_node, EdgeType.HANDOFF) assert any( - e.source == agent1_to_agent2.source and e.target == agent1_to_agent2.target + e.source.id == agent1_to_agent2.source.id and e.target.id == agent1_to_agent2.target.id for e in graph.edges ) assert any( - e.source == agent2_to_agent1.source and e.target == agent2_to_agent1.target + e.source.id == agent2_to_agent1.source.id and e.target.id == agent2_to_agent1.target.id for e in graph.edges ) @@ -180,16 +194,17 @@ def test_graph_validation(): graph.add_node(node1) graph.add_node(node2) - valid_edge = Edge("1", "2", EdgeType.TOOL) + valid_edge = Edge(node1, node2, EdgeType.TOOL) graph.add_edge(valid_edge) # Test adding edge with non-existent source - invalid_edge1 = Edge("3", "2", EdgeType.TOOL) + node3 = Node("3", "Node 3", NodeType.TOOL) + invalid_edge1 = Edge(node3, node2, EdgeType.TOOL) with pytest.raises(ValueError, match="Source node '3' does not exist in the graph"): graph.add_edge(invalid_edge1) # Test adding edge with non-existent target - invalid_edge2 = Edge("1", "3", EdgeType.TOOL) + invalid_edge2 = Edge(node1, node3, EdgeType.TOOL) with pytest.raises(ValueError, match="Target node '3' does not exist in the graph"): graph.add_edge(invalid_edge2) From 72adcc9c80a569dff0f95288efa8817a50019299 Mon Sep 17 00:00:00 2001 From: ashish-dahal Date: Wed, 9 Apr 2025 04:09:57 +0300 Subject: [PATCH 9/9] Refactor GraphBuilder to use instance equality checks instead of agent name --- src/agents/extensions/visualization.py | 62 +++-- tests/test_visualization.py | 316 +++++++++++++++---------- 2 files changed, 218 insertions(+), 160 deletions(-) diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index fc8aa8bb..c66e2499 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -10,7 +10,6 @@ from agents import Agent from agents.handoffs import Handoff -from agents.tool import Tool class NodeType(Enum): @@ -118,54 +117,44 @@ def _add_agent_nodes_and_edges( if agent is None: return + start_node = graph.get_node("__start__") + end_node = graph.get_node("__end__") + # Add agent node - graph.add_node(Node(agent.name, agent.name, NodeType.AGENT)) + agent_id = str(id(agent)) + agent_node = Node(agent_id, agent.name, NodeType.AGENT) + graph.add_node(agent_node) + self._visited.add(agent_id) # Connect start node if root agent if not parent: - graph.add_edge( - Edge(graph.get_node("__start__"), graph.get_node(agent.name), EdgeType.HANDOFF) - ) + graph.add_edge(Edge(start_node, agent_node, EdgeType.HANDOFF)) # Add tool nodes and edges for tool in agent.tools: - graph.add_node(Node(tool.name, tool.name, NodeType.TOOL)) - graph.add_edge( - Edge(graph.get_node(agent.name), graph.get_node(tool.name), EdgeType.TOOL) - ) - graph.add_edge( - Edge(graph.get_node(tool.name), graph.get_node(agent.name), EdgeType.TOOL) - ) - - # Add current agent's ID to visited set - self._visited.add(id(agent)) + tool_id = str(id(tool)) + tool_node = Node(tool_id, tool.name, NodeType.TOOL) + graph.add_node(tool_node) + graph.add_edge(Edge(agent_node, tool_node, EdgeType.TOOL)) + graph.add_edge(Edge(tool_node, agent_node, EdgeType.TOOL)) # Process handoffs - has_handoffs = False for handoff in agent.handoffs: - has_handoffs = True + handoff_id = str(id(handoff)) if isinstance(handoff, Handoff): - graph.add_node(Node(handoff.agent_name, handoff.agent_name, NodeType.HANDOFF)) - graph.add_edge( - Edge( - graph.get_node(agent.name), - graph.get_node(handoff.agent_name), - EdgeType.HANDOFF, - ) - ) + handoff_node = Node(handoff_id, handoff.agent_name, NodeType.HANDOFF) + graph.add_node(handoff_node) + graph.add_edge(Edge(agent_node, handoff_node, EdgeType.HANDOFF)) elif isinstance(handoff, Agent): - graph.add_node(Node(handoff.name, handoff.name, NodeType.AGENT)) - graph.add_edge( - Edge(graph.get_node(agent.name), graph.get_node(handoff.name), EdgeType.HANDOFF) - ) - if id(handoff) not in self._visited: + handoff_node = Node(handoff_id, handoff.name, NodeType.AGENT) + graph.add_node(handoff_node) + graph.add_edge(Edge(agent_node, handoff_node, EdgeType.HANDOFF)) + if handoff_id not in self._visited: self._add_agent_nodes_and_edges(handoff, agent, graph) # Connect to end node if no handoffs - if not has_handoffs and not isinstance(agent, Tool): - graph.add_edge( - Edge(graph.get_node(agent.name), graph.get_node("__end__"), EdgeType.HANDOFF) - ) + if not agent.handoffs: + graph.add_edge(Edge(agent_node, end_node, EdgeType.HANDOFF)) T = TypeVar("T") @@ -321,7 +310,10 @@ def _sanitize_id(self, id: str) -> str: class GraphView: def __init__( - self, rendered_graph: str, renderer: GraphRenderer, filename: Optional[str] = None + self, + rendered_graph: str, + renderer: GraphRenderer, + filename: Optional[str] = None, ): self.rendered_graph = rendered_graph self.renderer = renderer diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 544006da..aa00c81c 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -21,32 +21,6 @@ ) from agents.handoffs import Handoff -# Common test graph elements -START_NODE = ( - '"__start__" [label="__start__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" -) -END_NODE = ( - '"__end__" [label="__end__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" -) -AGENT_NODE = ( - '"Agent1" [label="Agent1", shape=box, style=filled, ' - "fillcolor=lightyellow, width=1.5, height=0.8];" -) -TOOL1_NODE = ( - '"Tool1" [label="Tool1", shape=ellipse, style=filled, ' - "fillcolor=lightgreen, width=0.5, height=0.3];" -) -TOOL2_NODE = ( - '"Tool2" [label="Tool2", shape=ellipse, style=filled, ' - "fillcolor=lightgreen, width=0.5, height=0.3];" -) -HANDOFF_NODE = ( - '"Handoff1" [label="Handoff1", shape=box, style=filled, ' - "fillcolor=lightyellow, width=1.5, height=0.8];" -) - @pytest.fixture def mock_agent(): @@ -87,25 +61,23 @@ def test_graph_builder(mock_agent): # Check nodes assert "__start__" in graph.nodes assert "__end__" in graph.nodes - assert "Agent1" in graph.nodes - assert "Tool1" in graph.nodes - assert "Tool2" in graph.nodes - assert "Handoff1" in graph.nodes + + # Find nodes by name + agent_node = next(node for node in graph.nodes.values() if node.label == "Agent1") + tool1_node = next(node for node in graph.nodes.values() if node.label == "Tool1") + tool2_node = next(node for node in graph.nodes.values() if node.label == "Tool2") + handoff_node = next(node for node in graph.nodes.values() if node.label == "Handoff1") # Check node types assert graph.nodes["__start__"].type == NodeType.START assert graph.nodes["__end__"].type == NodeType.END - assert graph.nodes["Agent1"].type == NodeType.AGENT - assert graph.nodes["Tool1"].type == NodeType.TOOL - assert graph.nodes["Tool2"].type == NodeType.TOOL - assert graph.nodes["Handoff1"].type == NodeType.HANDOFF + assert agent_node.type == NodeType.AGENT + assert tool1_node.type == NodeType.TOOL + assert tool2_node.type == NodeType.TOOL + assert handoff_node.type == NodeType.HANDOFF # Check edges start_node = graph.nodes["__start__"] - agent_node = graph.nodes["Agent1"] - tool1_node = graph.nodes["Tool1"] - tool2_node = graph.nodes["Tool2"] - handoff_node = graph.nodes["Handoff1"] start_to_agent = Edge(start_node, agent_node, EdgeType.HANDOFF) agent_to_tool1 = Edge(agent_node, tool1_node, EdgeType.TOOL) @@ -150,28 +122,49 @@ def test_graphviz_renderer(mock_agent): assert "graph [splines=true];" in dot_code assert 'node [fontname="Arial"];' in dot_code assert "edge [penwidth=1.5];" in dot_code - assert START_NODE in dot_code - assert END_NODE in dot_code - assert AGENT_NODE in dot_code - assert TOOL1_NODE in dot_code - assert TOOL2_NODE in dot_code - assert HANDOFF_NODE in dot_code + + # Find nodes by name in rendered output + agent_node = next(node for node in graph.nodes.values() if node.label == "Agent1") + tool1_node = next(node for node in graph.nodes.values() if node.label == "Tool1") + tool2_node = next(node for node in graph.nodes.values() if node.label == "Tool2") + handoff_node = next(node for node in graph.nodes.values() if node.label == "Handoff1") + + # Check node definitions in dot code + agent_style = ( + f'"{agent_node.id}" [label="Agent1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" + ) + assert agent_style in dot_code + tool1_style = ( + f'"{tool1_node.id}" [label="Tool1", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" + ) + assert tool1_style in dot_code + tool2_style = ( + f'"{tool2_node.id}" [label="Tool2", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" + ) + assert tool2_style in dot_code + handoff_style = ( + f'"{handoff_node.id}" [label="Handoff1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" + ) + assert handoff_style in dot_code def test_recursive_graph_builder(mock_recursive_agents): builder = GraphBuilder() graph = builder.build_from_agent(mock_recursive_agents) - # Check nodes - assert "Agent1" in graph.nodes - assert "Agent2" in graph.nodes - assert graph.nodes["Agent1"].type == NodeType.AGENT - assert graph.nodes["Agent2"].type == NodeType.AGENT + # Find nodes by name + agent1_node = next(node for node in graph.nodes.values() if node.label == "Agent1") + agent2_node = next(node for node in graph.nodes.values() if node.label == "Agent2") - # Check edges - agent1_node = graph.nodes["Agent1"] - agent2_node = graph.nodes["Agent2"] + # Check node types + assert agent1_node.type == NodeType.AGENT + assert agent2_node.type == NodeType.AGENT + # Check edges agent1_to_agent2 = Edge(agent1_node, agent2_node, EdgeType.HANDOFF) agent2_to_agent1 = Edge(agent2_node, agent1_node, EdgeType.HANDOFF) @@ -264,22 +257,50 @@ def test_draw_graph_with_graphviz(mock_agent): assert "graph [splines=true];" in result.rendered_graph assert 'node [fontname="Arial"];' in result.rendered_graph assert "edge [penwidth=1.5];" in result.rendered_graph - assert START_NODE in result.rendered_graph - assert END_NODE in result.rendered_graph - assert AGENT_NODE in result.rendered_graph - assert TOOL1_NODE in result.rendered_graph - assert TOOL2_NODE in result.rendered_graph - assert HANDOFF_NODE in result.rendered_graph + + # Get the graph to find node IDs + builder = GraphBuilder() + graph = builder.build_from_agent(mock_agent) + agent_node = next(node for node in graph.nodes.values() if node.label == "Agent1") + tool1_node = next(node for node in graph.nodes.values() if node.label == "Tool1") + tool2_node = next(node for node in graph.nodes.values() if node.label == "Tool2") + handoff_node = next(node for node in graph.nodes.values() if node.label == "Handoff1") + + # Check node definitions in dot code + agent_style = ( + f'"{agent_node.id}" [label="Agent1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" + ) + assert agent_style in result.rendered_graph + tool1_style = ( + f'"{tool1_node.id}" [label="Tool1", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" + ) + assert tool1_style in result.rendered_graph + tool2_style = ( + f'"{tool2_node.id}" [label="Tool2", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" + ) + assert tool2_style in result.rendered_graph + handoff_style = ( + f'"{handoff_node.id}" [label="Handoff1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" + ) + assert handoff_style in result.rendered_graph def test_draw_graph_with_mermaid(mock_agent): result = draw_graph(mock_agent, renderer="mermaid") assert isinstance(result, GraphView) assert "graph TD" in result.rendered_graph - assert "__start__(__start__)" in result.rendered_graph - assert "style __start__ fill:lightblue" in result.rendered_graph - assert "Agent1[Agent1]" in result.rendered_graph - assert "style Agent1 fill:lightyellow" in result.rendered_graph + + # Get the graph to find node IDs + builder = GraphBuilder() + graph = builder.build_from_agent(mock_agent) + agent_node = next(node for node in graph.nodes.values() if node.label == "Agent1") + + assert f"{agent_node.id}[Agent1]" in result.rendered_graph + assert f"style {agent_node.id} fill:lightyellow" in result.rendered_graph def test_draw_graph_with_filename_graphviz(mock_agent, tmp_path): @@ -319,50 +340,118 @@ def test_get_main_graph(mock_agent): assert "graph [splines=true];" in result assert 'node [fontname="Arial"];' in result assert "edge [penwidth=1.5];" in result - assert START_NODE in result - assert END_NODE in result - assert AGENT_NODE in result - assert TOOL1_NODE in result - assert TOOL2_NODE in result - assert HANDOFF_NODE in result + + # Get the graph to find node IDs + builder = GraphBuilder() + graph = builder.build_from_agent(mock_agent) + agent_node = next(node for node in graph.nodes.values() if node.label == "Agent1") + tool1_node = next(node for node in graph.nodes.values() if node.label == "Tool1") + tool2_node = next(node for node in graph.nodes.values() if node.label == "Tool2") + handoff_node = next(node for node in graph.nodes.values() if node.label == "Handoff1") + + # Check node definitions in dot code + agent_style = ( + f'"{agent_node.id}" [label="Agent1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" + ) + assert agent_style in result + tool1_style = ( + f'"{tool1_node.id}" [label="Tool1", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" + ) + assert tool1_style in result + tool2_style = ( + f'"{tool2_node.id}" [label="Tool2", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" + ) + assert tool2_style in result + handoff_style = ( + f'"{handoff_node.id}" [label="Handoff1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" + ) + assert handoff_style in result def test_get_all_nodes(mock_agent): with pytest.warns(DeprecationWarning): result = get_all_nodes(mock_agent) - assert START_NODE in result - assert END_NODE in result - assert AGENT_NODE in result - assert TOOL1_NODE in result - assert TOOL2_NODE in result - assert HANDOFF_NODE in result + + # Get the graph to find node IDs + builder = GraphBuilder() + graph = builder.build_from_agent(mock_agent) + agent_node = next(node for node in graph.nodes.values() if node.label == "Agent1") + tool1_node = next(node for node in graph.nodes.values() if node.label == "Tool1") + tool2_node = next(node for node in graph.nodes.values() if node.label == "Tool2") + handoff_node = next(node for node in graph.nodes.values() if node.label == "Handoff1") + + # Check node definitions in dot code + agent_style = ( + f'"{agent_node.id}" [label="Agent1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" + ) + assert agent_style in result + tool1_style = ( + f'"{tool1_node.id}" [label="Tool1", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" + ) + assert tool1_style in result + tool2_style = ( + f'"{tool2_node.id}" [label="Tool2", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" + ) + assert tool2_style in result + handoff_style = ( + f'"{handoff_node.id}" [label="Handoff1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" + ) + assert handoff_style in result def test_get_all_edges(mock_agent): with pytest.warns(DeprecationWarning): result = get_all_edges(mock_agent) - assert '"__start__" -> "Agent1";' in result - assert '"Agent1" -> "Tool1" [style=dotted, penwidth=1.5];' in result - assert '"Tool1" -> "Agent1" [style=dotted, penwidth=1.5];' in result - assert '"Agent1" -> "Tool2" [style=dotted, penwidth=1.5];' in result - assert '"Tool2" -> "Agent1" [style=dotted, penwidth=1.5];' in result - assert '"Agent1" -> "Handoff1";' in result + + # Get the graph to find node IDs + builder = GraphBuilder() + graph = builder.build_from_agent(mock_agent) + start_node = graph.nodes["__start__"] + agent_node = next(node for node in graph.nodes.values() if node.label == "Agent1") + tool1_node = next(node for node in graph.nodes.values() if node.label == "Tool1") + tool2_node = next(node for node in graph.nodes.values() if node.label == "Tool2") + handoff_node = next(node for node in graph.nodes.values() if node.label == "Handoff1") + + # Check edge definitions + assert f'"{start_node.id}" -> "{agent_node.id}";' in result + assert f'"{agent_node.id}" -> "{tool1_node.id}" [style=dotted, penwidth=1.5];' in result + assert f'"{tool1_node.id}" -> "{agent_node.id}" [style=dotted, penwidth=1.5];' in result + assert f'"{agent_node.id}" -> "{tool2_node.id}" [style=dotted, penwidth=1.5];' in result + assert f'"{tool2_node.id}" -> "{agent_node.id}" [style=dotted, penwidth=1.5];' in result + assert f'"{agent_node.id}" -> "{handoff_node.id}";' in result def test_recursive_handoff_loop(mock_recursive_agents): with pytest.warns(DeprecationWarning): dot = get_main_graph(mock_recursive_agents) - assert ( - '"Agent1" [label="Agent1", shape=box, style=filled, ' - "fillcolor=lightyellow, width=1.5, height=0.8];" in dot + # Get the graph to find node IDs + builder = GraphBuilder() + graph = builder.build_from_agent(mock_recursive_agents) + agent1_node = next(node for node in graph.nodes.values() if node.label == "Agent1") + agent2_node = next(node for node in graph.nodes.values() if node.label == "Agent2") + + # Check node and edge definitions + agent1_style = ( + f'"{agent1_node.id}" [label="Agent1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" ) - assert ( - '"Agent2" [label="Agent2", shape=box, style=filled, ' - "fillcolor=lightyellow, width=1.5, height=0.8];" in dot + assert agent1_style in dot + agent2_style = ( + f'"{agent2_node.id}" [label="Agent2", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" ) - assert '"Agent1" -> "Agent2";' in dot - assert '"Agent2" -> "Agent1";' in dot + assert agent2_style in dot + assert f'"{agent1_node.id}" -> "{agent2_node.id}";' in dot + assert f'"{agent2_node.id}" -> "{agent1_node.id}";' in dot def test_mermaid_renderer(mock_agent): @@ -374,41 +463,18 @@ def test_mermaid_renderer(mock_agent): # Test flowchart header assert "graph TD" in mermaid_code - # Test node rendering - assert "__start__(__start__)" in mermaid_code - assert "style __start__ fill:lightblue" in mermaid_code - assert "__end__(__end__)" in mermaid_code - assert "style __end__ fill:lightblue" in mermaid_code - assert "Agent1[Agent1]" in mermaid_code - assert "style Agent1 fill:lightyellow" in mermaid_code - assert "Tool1((Tool1))" in mermaid_code - assert "style Tool1 fill:lightgreen" in mermaid_code - assert "Tool2((Tool2))" in mermaid_code - assert "style Tool2 fill:lightgreen" in mermaid_code - assert "Handoff1[Handoff1]" in mermaid_code - assert "style Handoff1 fill:lightyellow" in mermaid_code - - # Test edge rendering - assert "__start__ --> Agent1" in mermaid_code - assert "Agent1 -.-> Tool1" in mermaid_code - assert "Tool1 -.-> Agent1" in mermaid_code - assert "Agent1 -.-> Tool2" in mermaid_code - assert "Tool2 -.-> Agent1" in mermaid_code - assert "Agent1 --> Handoff1" in mermaid_code - - -def test_mermaid_renderer_save(mock_agent, tmp_path): - renderer = MermaidRenderer() - graph = GraphBuilder().build_from_agent(mock_agent) - rendered = renderer.render(graph) - filename = tmp_path / "test_graph" + # Find nodes by name + agent_node = next(node for node in graph.nodes.values() if node.label == "Agent1") + tool1_node = next(node for node in graph.nodes.values() if node.label == "Tool1") + tool2_node = next(node for node in graph.nodes.values() if node.label == "Tool2") + handoff_node = next(node for node in graph.nodes.values() if node.label == "Handoff1") - mock_response = Mock() - mock_response.content = b"mock image data" - mock_response.raise_for_status = Mock() - - with patch("requests.get", return_value=mock_response): - renderer.save(rendered, str(filename)) - assert (tmp_path / "test_graph.png").exists() - with open(tmp_path / "test_graph.png", "rb") as f: - assert f.read() == b"mock image data" + # Test node rendering + assert f"{agent_node.id}[Agent1]" in mermaid_code + assert f"style {agent_node.id} fill:lightyellow" in mermaid_code + assert f"{tool1_node.id}((Tool1))" in mermaid_code + assert f"style {tool1_node.id} fill:lightgreen" in mermaid_code + assert f"{tool2_node.id}((Tool2))" in mermaid_code + assert f"style {tool2_node.id} fill:lightgreen" in mermaid_code + assert f"{handoff_node.id}[Handoff1]" in mermaid_code + assert f"style {handoff_node.id} fill:lightyellow" in mermaid_code