Skip to content

Commit

Permalink
programmatically identify starting and ending edges and final state
Browse files Browse the repository at this point in the history
  • Loading branch information
Archento committed Apr 2, 2024
1 parent 58e3a5a commit e5d4e40
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"This is the default state of the dialogue. Every session starts in "
"this state and is automatically updated once the dialogue starts."
),
starter=True,
initial=True,
)
init_state = Node(
name="Initiated",
Expand Down
29 changes: 24 additions & 5 deletions python/src/uagents/experimental/dialogues/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ def __init__(
self,
name: str,
description: str,
starter: Optional[bool] = None,
initial: bool = False,
) -> None:
self.name = name
self.description = description
self.starter = starter
self.initial = initial
self.final = False


class Edge:
Expand All @@ -50,6 +51,8 @@ def __init__(
self.child = child
self._model = model
self._func = func
self.starter = False
self.ender = False

@property
def model(self) -> Optional[Type[Model]]:
Expand Down Expand Up @@ -227,6 +230,8 @@ def get_overview(self) -> Dict:
"parent": edge.parent.name if edge.parent else None,
"child": edge.child.name,
"model": edge.model.__name__ if edge.model else None,
"starter": edge.starter,
"ender": edge.ender,
}
for edge in self._edges
],
Expand Down Expand Up @@ -271,7 +276,7 @@ def _build_rules(self) -> Dict[str, List[str]]:

def _build_starter(self) -> str:
"""Build the starting message of the dialogue."""
starter_nodes = list(filter(lambda n: n.starter, self._nodes))
starter_nodes = list(filter(lambda n: n.initial, self._nodes))
# check if starter property has been set and if there is only one
if len(starter_nodes) > 1:
raise ValueError("Dialogue has more than one entry point!")
Expand All @@ -282,6 +287,7 @@ def _build_starter(self) -> str:
# validate if the graph is correct
starters = list(filter(lambda e: e.parent is starter_nodes[0], self._edges))
if starters:
self._edges[self._edges.index(starters[0])].starter = True
return starters[0].name
if starter_nodes and edges_without_entry:
warnings.warn(
Expand All @@ -293,12 +299,23 @@ def _build_starter(self) -> str:
if len(edges_without_entry) > 1:
raise ValueError("Dialogue has more than one entry point!")
if edges_without_entry:
self._edges[self._edges.index(edges_without_entry[0])].starter = True
return edges_without_entry[0].name
raise ValueError("Dialogue has no entry point!")

def _build_ender(self) -> set[str]:
"""Build the last message(s) of the dialogue."""
return set(edge for edge in self._rules if not self._rules[edge])
"""Build the last message(s) of the dialogue and set final state."""
for node, edges in self._graph.items():
if not edges:
self._nodes[
self._nodes.index(next(n for n in self._nodes if n.name == node))
].final = True
enders = set()
for edge in self._edges:
if edge.child.final:
enders.add(edge.name)
edge.ender = True
return enders

def is_starter(self, digest: str) -> bool:
"""
Expand Down Expand Up @@ -541,6 +558,8 @@ def manifest(self) -> Dict[str, Any]:
"parent": edge.parent.name if edge.parent else None,
"child": edge.child.name,
"model": edge.model.__name__ if edge.model else None,
"starter": edge.starter,
"ender": edge.ender,
}
for edge in self._edges
],
Expand Down

0 comments on commit e5d4e40

Please sign in to comment.