Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 21 additions & 18 deletions python/packages/core/agent_framework/_workflows/_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

import logging
import uuid
from collections.abc import Callable, Sequence
from collections.abc import Callable, MutableMapping, Sequence
from dataclasses import dataclass, field
from typing import Any, ClassVar

from .._serialization import SerializationMixin
from ._const import INTERNAL_SOURCE_ID
from ._executor import Executor
from ._model_utils import DictConvertible, encode_value
from ._model_utils import encode_value

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,7 +63,7 @@ def _raise(*_: Any, **__: Any) -> Any:


@dataclass(init=False)
class Edge(DictConvertible):
class Edge(SerializationMixin):
"""Model a directed, optionally-conditional hand-off between two executors.

Each `Edge` captures the minimal metadata required to move a message from
Expand Down Expand Up @@ -164,7 +165,7 @@ def should_route(self, data: Any) -> bool:
return True
return self._condition(data)

def to_dict(self) -> dict[str, Any]:
def to_dict(self, **kwargs: Any) -> dict[str, Any]:
"""Produce a JSON-serialisable view of the edge metadata.

The representation includes the source and target executor identifiers
Expand All @@ -184,7 +185,7 @@ def to_dict(self) -> dict[str, Any]:
return payload

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "Edge":
def from_dict(cls, data: dict[str, Any] | MutableMapping[str, Any], /, **kwargs: Any) -> "Edge":
"""Reconstruct an `Edge` from its serialised dictionary form.

The deserialised edge will lack the executable predicate because we do
Expand Down Expand Up @@ -259,7 +260,7 @@ def __init__(self) -> None:


@dataclass(init=False)
class EdgeGroup(DictConvertible):
class EdgeGroup(SerializationMixin):
"""Bundle edges that share a common routing semantics under a single id.

The workflow runtime manipulates `EdgeGroup` instances rather than raw
Expand Down Expand Up @@ -342,7 +343,7 @@ def target_executor_ids(self) -> list[str]:
"""
return list(dict.fromkeys(edge.target_id for edge in self.edges))

def to_dict(self) -> dict[str, Any]:
def to_dict(self, **kwargs: Any) -> dict[str, Any]:
"""Serialise the group metadata and contained edges into primitives.

The payload captures each edge through its own `to_dict` call, enabling
Expand Down Expand Up @@ -385,7 +386,7 @@ class CustomGroup(EdgeGroup):
return subclass

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "EdgeGroup":
def from_dict(cls, data: dict[str, Any] | MutableMapping[str, Any], /, **kwargs: Any) -> "EdgeGroup":
"""Hydrate the correct `EdgeGroup` subclass from serialised state.

The method inspects the `type` field, allocates the corresponding class
Expand Down Expand Up @@ -556,7 +557,7 @@ def selection_func(self) -> Callable[[Any, list[str]], list[str]] | None:
"""
return self._selection_func

def to_dict(self) -> dict[str, Any]:
def to_dict(self, **kwargs: Any) -> dict[str, Any]:
"""Serialise the fan-out group while preserving selection metadata.

In addition to the base `EdgeGroup` payload we embed the human-friendly
Expand All @@ -569,7 +570,7 @@ def to_dict(self) -> dict[str, Any]:
snapshot = group.to_dict()
assert snapshot["selection_func_name"] == "<lambda>"
"""
payload = super().to_dict()
payload = super().to_dict(**kwargs)
payload["selection_func_name"] = self.selection_func_name
return payload

Expand Down Expand Up @@ -610,7 +611,7 @@ def __init__(self, source_ids: Sequence[str], target_id: str, *, id: str | None


@dataclass(init=False)
class SwitchCaseEdgeGroupCase(DictConvertible):
class SwitchCaseEdgeGroupCase(SerializationMixin):
"""Persistable description of a single conditional branch in a switch-case.

Unlike the runtime `Case` object this serialisable variant stores only the
Expand Down Expand Up @@ -684,7 +685,7 @@ def condition(self) -> Callable[[Any], bool]:
"""
return self._condition

def to_dict(self) -> dict[str, Any]:
def to_dict(self, **kwargs: Any) -> dict[str, Any]:
"""Serialise the case metadata without the executable predicate.

Examples:
Expand All @@ -699,7 +700,7 @@ def to_dict(self) -> dict[str, Any]:
return payload

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "SwitchCaseEdgeGroupCase":
def from_dict(cls, data: dict[str, Any] | MutableMapping[str, Any], /, **kwargs: Any) -> "SwitchCaseEdgeGroupCase":
"""Instantiate a case from its serialised dictionary payload.

Examples:
Expand All @@ -717,7 +718,7 @@ def from_dict(cls, data: dict[str, Any]) -> "SwitchCaseEdgeGroupCase":


@dataclass(init=False)
class SwitchCaseEdgeGroupDefault(DictConvertible):
class SwitchCaseEdgeGroupDefault(SerializationMixin):
"""Persistable descriptor for the fallback branch of a switch-case group.

The default branch is guaranteed to exist and is invoked when every other
Expand All @@ -741,7 +742,7 @@ def __init__(self, target_id: str) -> None:
self.target_id = target_id
self.type = "Default"

def to_dict(self) -> dict[str, Any]:
def to_dict(self, **kwargs: Any) -> dict[str, Any]:
"""Serialise the default branch metadata for persistence or logging.

Examples:
Expand All @@ -753,7 +754,9 @@ def to_dict(self) -> dict[str, Any]:
return {"target_id": self.target_id, "type": self.type}

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "SwitchCaseEdgeGroupDefault":
def from_dict(
cls, data: dict[str, Any] | MutableMapping[str, Any], /, **kwargs: Any
) -> "SwitchCaseEdgeGroupDefault":
"""Recreate the default branch from its persisted form.

Examples:
Expand Down Expand Up @@ -844,7 +847,7 @@ def selection_func(message: Any, targets: list[str]) -> list[str]:
self.selection_func_name = None # type: ignore[attr-defined]
self.cases = list(cases)

def to_dict(self) -> dict[str, Any]:
def to_dict(self, **kwargs: Any) -> dict[str, Any]:
"""Serialise the switch-case group, capturing all case descriptors.

Each case is converted using `encode_value` to respect dataclass
Expand All @@ -863,7 +866,7 @@ def to_dict(self) -> dict[str, Any]:
snapshot = group.to_dict()
assert len(snapshot["cases"]) == 2
"""
payload = super().to_dict()
payload = super().to_dict(**kwargs)
payload["cases"] = [encode_value(case) for case in self.cases]
return payload

Expand Down
6 changes: 3 additions & 3 deletions python/packages/core/agent_framework/_workflows/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar

from .._serialization import SerializationMixin
from ..observability import create_processing_span
from ._events import (
ExecutorCompletedEvent,
Expand All @@ -15,7 +16,6 @@
WorkflowErrorDetails,
_framework_event_origin, # type: ignore[reportPrivateUsage]
)
from ._model_utils import DictConvertible
from ._request_info_mixin import RequestInfoMixin
from ._runner_context import Message, MessageType, RunnerContext
from ._shared_state import SharedState
Expand All @@ -26,7 +26,7 @@


# region Executor
class Executor(RequestInfoMixin, DictConvertible):
class Executor(RequestInfoMixin, SerializationMixin):
"""Base class for all workflow executors that process messages and perform computations.

## Overview
Expand Down Expand Up @@ -422,7 +422,7 @@ def workflow_output_types(self) -> list[type[Any]]:

return list(output_types)

def to_dict(self) -> dict[str, Any]:
def to_dict(self, **kwargs: Any) -> dict[str, Any]:
"""Serialize executor definition for workflow topology export."""
return {"id": self.id, "type": self.type}

Expand Down
46 changes: 28 additions & 18 deletions python/packages/core/agent_framework/_workflows/_magentic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
import sys
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, Sequence
from collections.abc import AsyncIterable, MutableMapping, Sequence
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Protocol, TypeVar, Union, cast
Expand All @@ -23,6 +23,7 @@
FunctionResultContent,
Role,
)
from agent_framework._serialization import SerializationMixin

from ._base_group_chat_orchestrator import BaseGroupChatOrchestrator
from ._checkpoint import CheckpointStorage, WorkflowCheckpoint
Expand All @@ -38,7 +39,7 @@
group_chat_orchestrator,
)
from ._message_utils import normalize_messages_input
from ._model_utils import DictConvertible, encode_value
from ._model_utils import encode_value
from ._participant_utils import GroupChatParticipantSpec, participant_description
from ._request_info_mixin import response_handler
from ._workflow import Workflow, WorkflowRunResult
Expand Down Expand Up @@ -328,7 +329,7 @@ def _new_chat_message_list() -> list[ChatMessage]:


@dataclass
class _MagenticStartMessage(DictConvertible):
class _MagenticStartMessage(SerializationMixin):
"""Internal: A message to start a magentic workflow."""

messages: list[ChatMessage] = field(default_factory=_new_chat_message_list)
Expand Down Expand Up @@ -356,15 +357,15 @@ def from_string(cls, task_text: str) -> "_MagenticStartMessage":
"""Create a MagenticStartMessage from a simple string."""
return cls(task_text)

def to_dict(self) -> dict[str, Any]:
def to_dict(self, **kwargs: Any) -> dict[str, Any]:
"""Create a dict representation of the message."""
return {
"messages": [message.to_dict() for message in self.messages],
"task": self.task.to_dict(),
}

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "_MagenticStartMessage":
def from_dict(cls, data: dict[str, Any] | MutableMapping[str, Any], /, **kwargs: Any) -> "_MagenticStartMessage":
"""Create from a dict."""
if "messages" in data:
raw_messages = data["messages"]
Expand Down Expand Up @@ -446,43 +447,45 @@ class _MagenticPlanReviewReply:


@dataclass
class _MagenticTaskLedger(DictConvertible):
class _MagenticTaskLedger(SerializationMixin):
"""Internal: Task ledger for the Standard Magentic manager."""

facts: ChatMessage
plan: ChatMessage

def to_dict(self) -> dict[str, Any]:
def to_dict(self, **kwargs: Any) -> dict[str, Any]:
return {"facts": _message_to_payload(self.facts), "plan": _message_to_payload(self.plan)}

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "_MagenticTaskLedger":
def from_dict(cls, data: dict[str, Any] | MutableMapping[str, Any], /, **kwargs: Any) -> "_MagenticTaskLedger":
return cls(
facts=_message_from_payload(data.get("facts")),
plan=_message_from_payload(data.get("plan")),
)


@dataclass
class _MagenticProgressLedgerItem(DictConvertible):
class _MagenticProgressLedgerItem(SerializationMixin):
"""Internal: A progress ledger item."""

reason: str
answer: str | bool

def to_dict(self) -> dict[str, Any]:
def to_dict(self, **kwargs: Any) -> dict[str, Any]:
return {"reason": self.reason, "answer": self.answer}

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "_MagenticProgressLedgerItem":
def from_dict(
cls, data: dict[str, Any] | MutableMapping[str, Any], /, **kwargs: Any
) -> "_MagenticProgressLedgerItem":
answer_value = data.get("answer")
if not isinstance(answer_value, (str, bool)):
answer_value = "" # Default to empty string if not str or bool
return cls(reason=data.get("reason", ""), answer=answer_value)


@dataclass
class _MagenticProgressLedger(DictConvertible):
class _MagenticProgressLedger(SerializationMixin):
"""Internal: A progress ledger for tracking workflow progress."""

is_request_satisfied: _MagenticProgressLedgerItem
Expand All @@ -491,7 +494,7 @@ class _MagenticProgressLedger(DictConvertible):
next_speaker: _MagenticProgressLedgerItem
instruction_or_question: _MagenticProgressLedgerItem

def to_dict(self) -> dict[str, Any]:
def to_dict(self, **kwargs: Any) -> dict[str, Any]:
return {
"is_request_satisfied": self.is_request_satisfied.to_dict(),
"is_in_loop": self.is_in_loop.to_dict(),
Expand All @@ -501,7 +504,7 @@ def to_dict(self) -> dict[str, Any]:
}

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "_MagenticProgressLedger":
def from_dict(cls, data: dict[str, Any] | MutableMapping[str, Any], /, **kwargs: Any) -> "_MagenticProgressLedger":
return cls(
is_request_satisfied=_MagenticProgressLedgerItem.from_dict(data.get("is_request_satisfied", {})),
is_in_loop=_MagenticProgressLedgerItem.from_dict(data.get("is_in_loop", {})),
Expand All @@ -512,7 +515,7 @@ def from_dict(cls, data: dict[str, Any]) -> "_MagenticProgressLedger":


@dataclass
class MagenticContext(DictConvertible):
class MagenticContext(SerializationMixin):
"""Context for the Magentic manager."""

task: ChatMessage
Expand All @@ -522,7 +525,7 @@ class MagenticContext(DictConvertible):
stall_count: int = 0
reset_count: int = 0

def to_dict(self) -> dict[str, Any]:
def to_dict(self, **kwargs: Any) -> dict[str, Any]:
return {
"task": _message_to_payload(self.task),
"chat_history": [_message_to_payload(msg) for msg in self.chat_history],
Expand All @@ -533,7 +536,7 @@ def to_dict(self) -> dict[str, Any]:
}

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "MagenticContext":
def from_dict(cls, data: dict[str, Any] | MutableMapping[str, Any], /, **kwargs: Any) -> "MagenticContext":
chat_history_payload = data.get("chat_history", [])
history: list[ChatMessage] = []
for item in chat_history_payload:
Expand All @@ -557,6 +560,12 @@ def reset(self) -> None:
self.stall_count = 0
self.reset_count += 1

def clone(self, *, deep: bool = True) -> Self:
"""Create a copy of this context."""
import copy

return copy.deepcopy(self) if deep else copy.copy(self) # type: ignore[return-value]


# endregion Messages and Types

Expand Down Expand Up @@ -2418,7 +2427,8 @@ async def _validate_checkpoint_participants(
if not isinstance(orchestrator_state, dict):
return

context_payload = orchestrator_state.get("magentic_context")
orchestrator_state_dict = cast(dict[str, Any], orchestrator_state)
context_payload = orchestrator_state_dict.get("magentic_context")
if not isinstance(context_payload, dict):
return

Expand Down
Loading
Loading