Skip to content

Commit

Permalink
refactor: clear distinction between message and edge handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
Archento committed Apr 10, 2024
1 parent c9c38a7 commit 01267a6
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,13 @@ async def persisting_function(
ctx.logger.info("I was not overwritten, hehe.")


init_session.set_default_behaviour(InitiateChitChatDialogue, start_chitchat)
start_dialogue.set_default_behaviour(AcceptChitChatDialogue, accept_chitchat)
# cont_dialogue.set_default_behaviour(ChitChatDialogueMessage, default, persist=False)
cont_dialogue.set_default_behaviour(
ChitChatDialogueMessage, persisting_function, persist=True
)
end_session.set_default_behaviour(ConcludeChitChatDialogue, conclude_chitchat)
init_session.set_message_handler(InitiateChitChatDialogue, start_chitchat)
start_dialogue.set_message_handler(AcceptChitChatDialogue, accept_chitchat)

cont_dialogue.set_message_handler(ChitChatDialogueMessage, default)
cont_dialogue.set_edge_handler(ChitChatDialogueMessage, persisting_function)

end_session.set_message_handler(ConcludeChitChatDialogue, conclude_chitchat)


class ChitChatDialogue(Dialogue):
Expand Down
75 changes: 51 additions & 24 deletions python/src/uagents/experimental/dialogues/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ def __init__(
self.description = description
self.parent = parent
self.child = child
self.starter = False
self.ender = False
self._model = None
self._func: Optional[tuple[MessageCallback, bool]] = None
self.starter: bool = False
self.ender: bool = False
self._model: Type[Model] = None
self._func: Optional[MessageCallback] = None
self._efunc: Optional[MessageCallback] = None

@property
def model(self) -> Optional[Type[Model]]:
Expand All @@ -63,21 +64,39 @@ def model(self, model: Type[Model]) -> None:
self._model = model

@property
def func(self) -> MessageCallback:
def func(self) -> Optional[MessageCallback]:
"""The message handler that is associated with the edge."""
return self._func

def set_default_behaviour(
self, model: Type[Model], func: MessageCallback, persist: bool = False
):
@func.setter
def func(self, func: MessageCallback) -> None:
"""Set the message handler that will be called when a message is received."""
self._func = func

@property
def efunc(self) -> MessageCallback:
"""The edge handler that is associated with the edge."""
return self._efunc

def set_edge_handler(self, model: Type[Model], func: MessageCallback):
"""
Set the edge handler that will be called when a message is received
This handler can not be overwritten by a decorator.
"""
if self._model and self._model is not model:
raise ValueError("Functionality already set with a different model!")
self._model = model
self._efunc = func

def set_message_handler(self, model: Type[Model], func: MessageCallback):
"""
Set the default behaviour for the edge that will be overwritten if
Set the default message handler for the edge that will be overwritten if
a decorator defines a new function to be called.
"""
if self._model:
raise ValueError("Functionality already set for edge!")
if self._model and self._model is not model:
raise ValueError("Functionality already set with a different model!")
self._model = model
self._func = func, persist
self._func = func


class Dialogue(Protocol):
Expand Down Expand Up @@ -349,11 +368,27 @@ def is_finished(self, session_id: UUID) -> bool:
"""
return self.is_ender(self.get_current_state(session_id))

def _build_function_handler(self, edge: Edge) -> MessageCallback:
"""Build the function handler for a message."""

@functools.wraps(edge.func)
async def handler(ctx: Context, sender: str, message: Any):
if edge.efunc:
await edge.efunc(ctx, sender, message)
return await edge.func(ctx, sender, message)

return handler

def _auto_add_message_handler(self) -> None:
"""Automatically add message handlers for edges with models."""
for edge in self._edges:
if edge.model and edge.func:
self._add_message_handler(edge.model, edge.func[0], None, False)
self._add_message_handler(
edge.model,
self._build_function_handler(edge),
None, # no replies
False, # only verified
)

def update_state(self, digest: str, session_id: UUID) -> None:
"""
Expand Down Expand Up @@ -511,18 +546,10 @@ def _on_state_transition(self, edge_name: str, model: Type[Model]):
if edge_name not in self._digest_by_edge:
raise ValueError("Edge does not exist in the dialogue!")

persisting_function = None
edge = self.get_edge(edge_name)
if edge.func[1]:
persisting_function = edge.func[0]

def decorator_on_state_transition(func: MessageCallback):
@functools.wraps(func)
async def handler(*args, **kwargs):
if persisting_function:
await persisting_function(*args, **kwargs)
return await func(*args, **kwargs)

edge = self.get_edge(edge_name)
edge.func = func
handler = self._build_function_handler(edge)
self._update_transition_model(edge, model)
self._add_message_handler(model, handler, None, False)
return handler
Expand Down

0 comments on commit 01267a6

Please sign in to comment.