Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): add capability to define persistent functionality to a dialogue edge #310

Merged
merged 4 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
the messages that are expected to be exchanged.
"""

from typing import Type
from warnings import warn

from uagents import Model
Expand Down Expand Up @@ -95,7 +96,7 @@ class RejectChitChatDialogue(Model):
async def start_chitchat(
ctx: Context,
sender: str,
_msg: type[Model],
_msg: Type[Model],
):
ctx.logger.info(f"Received init message from {sender}. Accepting Dialogue.")
await ctx.send(sender, AcceptChitChatDialogue())
Expand All @@ -104,7 +105,7 @@ async def start_chitchat(
async def accept_chitchat(
ctx: Context,
sender: str,
_msg: type[Model],
_msg: Type[Model],
):
ctx.logger.info(
f"Dialogue session with {sender} was accepted. "
Expand All @@ -116,7 +117,7 @@ async def accept_chitchat(
async def conclude_chitchat(
ctx: Context,
sender: str,
_msg: type[Model],
_msg: Type[Model],
):
ctx.logger.info(f"Received conclude message from: {sender}; accessing history:")
ctx.logger.info(ctx.dialogue)
Expand All @@ -125,7 +126,7 @@ async def conclude_chitchat(
async def default(
_ctx: Context,
_sender: str,
_msg: type[Model],
_msg: Type[Model],
):
warn(
"There is no handler for this message, please add your own logic by "
Expand All @@ -135,9 +136,20 @@ async def default(
)


async def persisting_function(
ctx: Context,
_sender: str,
_msg: Type[Model],
):
ctx.logger.info("I was not overwritten, hehe.")


init_session.set_default_behaviour(InitiateChitChatDialogue, start_chitchat)
start_dialogue.set_default_behaviour(AcceptChitChatDialogue, accept_chitchat)
cont_dialogue.set_default_behaviour(ChitChatDialogueMessage, default)
# cont_dialogue.set_default_behaviour(ChitChatDialogueMessage, default, persist=False)
cont_dialogue.set_default_behaviour(
ChitChatDialogueMessage, persisting_function, persist=True
)
end_session.set_default_behaviour(ConcludeChitChatDialogue, conclude_chitchat)


Expand Down
36 changes: 24 additions & 12 deletions python/src/uagents/experimental/dialogues/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from uagents import Context, Model, Protocol
from uagents.storage import KeyValueStore

DEFAULT_SESSION_TIMEOUT_IN_SECONDS = 100
DEFAULT_SESSION_TIMEOUT_IN_SECONDS = 60
TARGET_UUID_VERSION = 4

JsonStr = str
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(
self.starter = False
self.ender = False
self._model = None
self._func = None
self._func = Optional[tuple[MessageCallback, bool]]
jrriehl marked this conversation as resolved.
Show resolved Hide resolved

@property
def model(self) -> Optional[Type[Model]]:
Expand All @@ -67,10 +67,17 @@ def func(self) -> MessageCallback:
"""The message handler that is associated with the edge."""
return self._func

def set_default_behaviour(self, model: Type[Model], func: MessageCallback):
"""Set the default behaviour for the edge."""
def set_default_behaviour(
self, model: Type[Model], func: MessageCallback, persist: bool = False
):
"""
Set the default behaviour for the edge that will be overwritten if
a decorator defines a new function to be called.
"""
if self._model:
raise ValueError("Functionality already set for edge!")
self._model = model
self._func = func
self._func = func, persist


class Dialogue(Protocol):
Expand Down Expand Up @@ -326,7 +333,7 @@ def is_starter(self, digest: str) -> bool:

def is_ender(self, digest: str) -> bool:
"""
Return True if the digest is the last message of the dialogue.
Return True if the digest is one of the last messages of the dialogue.
False otherwise.
"""
return digest in [self._digest_by_edge[edge] for edge in self._ender]
Expand All @@ -346,7 +353,7 @@ def _auto_add_message_handler(self) -> None:
"""Automatically add message handlers for edges with models."""
for edge in self._edges:
if edge.model and edge.func:
self._add_message_handler(edge.model, edge.func, None, False)
self._add_message_handler(edge.model, edge.func[0], None, False)

def update_state(self, digest: str, session_id: UUID) -> None:
"""
Expand Down Expand Up @@ -504,17 +511,22 @@ def _on_state_transition(self, edge_name: str, model: Type[Model]):
if edge_name not in self._digest_by_edge:
raise ValueError("Edge does not exist in the dialogue!")

persisting_function = None
edge = self.get_edge(edge_name)
if edge.func[1]:
persisting_function = edge.func[0]

def decorator_on_state_transition(func: MessageCallback):
@functools.wraps(func)
def handler(*args, **kwargs):
return func(*args, **kwargs)
async def handler(*args, **kwargs):
if persisting_function:
await persisting_function(*args, **kwargs)
return await func(*args, **kwargs)

edge = self.get_edge(edge_name)
self._update_transition_model(edge, model)
self._add_message_handler(model, func, None, False)
self._add_message_handler(model, handler, None, False)
return handler

# NOTE: recalculate manifest after each update and re-register /w agent
return decorator_on_state_transition

@property
Expand Down
Loading