Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ async def save_state(self) -> Mapping[str, Any]:
"remaining": {target: dict(counter) for target, counter in self._remaining.items()},
"enqueued_any": dict(self._enqueued_any),
"ready": list(self._ready),
"triggered_activation_groups": {target: list(groups) for target, groups in self._triggered_activation_groups.items()},
}
return state

Expand All @@ -527,6 +528,7 @@ async def load_state(self, state: Mapping[str, Any]) -> None:
self._remaining = {target: Counter(groups) for target, groups in state["remaining"].items()}
self._enqueued_any = state["enqueued_any"]
self._ready = deque(state["ready"])
self._triggered_activation_groups = {target: set(groups) for target, groups in state["triggered_activation_groups"].items()}

async def reset(self) -> None:
"""Reset execution state to the start of the graph."""
Expand Down
61 changes: 60 additions & 1 deletion python/packages/autogen-agentchat/tests/test_group_chat_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
PerSourceFilter,
)
from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.conditions import MaxMessageTermination, SourceMatchTermination
from autogen_agentchat.conditions import MaxMessageTermination, SourceMatchTermination, StopMessageTermination
from autogen_agentchat.messages import BaseChatMessage, ChatMessage, MessageFactory, StopMessage, TextMessage
from autogen_agentchat.teams import (
DiGraphBuilder,
Expand Down Expand Up @@ -1766,3 +1766,62 @@ async def test_digraph_group_chat_resume_with_termination_condition(runtime: Age
assert agent_a.total_messages == 1 # Still only ran once
assert agent_b.total_messages == 1 # Still only ran once
assert agent_c.total_messages == 1 # Now ran once

@pytest.mark.asyncio
async def test_digraph_group_chat_resumes_from_state_with_triggered_activation_groups(runtime: AgentRuntime | None) -> None:
class _NoOpAgent(BaseChatAgent):
def __init__(self, name: str, description: str = "", target: str = "", stop_when: str = "") -> None:
super().__init__(name, description)
self._target = target
self._stop_when = stop_when

@property
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (TextMessage,)

async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:

# Force the stop message when specific string is matched
if self.name == self._stop_when:
return Response(chat_message=StopMessage(content=self.name, source=self.name))

return Response(chat_message=TextMessage(content=self._target, source=self.name))

async def on_reset(self, cancellation_token: CancellationToken) -> None:
pass

def create_flow() -> GraphFlow:
agent_a = _NoOpAgent(name="A", target="B", stop_when="A")
agent_b = _NoOpAgent(name="B", target="A")

builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b)

# Agent A will loop forever
builder.add_edge(agent_a, agent_a, "A", "loopback")
builder.add_edge(agent_a, agent_b, "B")

builder.set_entry_point(agent_a)

return GraphFlow(
participants=builder.get_participants(),
graph=builder.build(),
runtime=runtime,
termination_condition=StopMessageTermination(),
)

# Run the graph flow until termination condition is reached
team_one = create_flow()
result_one: TaskResult = await team_one.run(task="Start")
assert result_one.stop_reason == "Stop message received"

# Export state.
state = await team_one.save_state()

# Load team 1's state into team 2.
team_two = create_flow()
await team_two.load_state(state)

# Team 2 should resume and immediately hit the stop condition again
result_two: TaskResult = await team_two.run(task="Continue")
assert result_two.stop_reason == "Stop message received"