-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
50690ae
commit 0aa62bf
Showing
3 changed files
with
223 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from collections import defaultdict | ||
from dataclasses import dataclass | ||
|
||
import rustworkx as rx | ||
|
||
from armonik.client import ArmoniKTasks | ||
from armonik.common import Filter | ||
from grpc import Channel | ||
|
||
|
||
@dataclass | ||
class ArmoniKGraphAttr: | ||
"""A class for ArmoniKGraph attributes. | ||
Parameters | ||
---------- | ||
name : str | ||
Name of the graph. | ||
description : str | ||
Description of the graph. | ||
task_filter : armonik.common.Filter | ||
Task filter defining the graph. | ||
""" | ||
|
||
name: str | ||
description: str | ||
task_filter: Filter | ||
|
||
|
||
class ArmoniKGraph: | ||
"""A class to represent workloads on ArmoniK. The execution of a program corresponds to a graph, | ||
each node of which is a task. This class provides this representation and makes it possible to | ||
analyse an execution. | ||
The flexibility of ArmoniK means that a workload can share its session with other workloads, or | ||
be dispersed between several sessions or partitions. To identify a workload, and therefore the | ||
corresponding graph, the user must provide a filter on the tasks within a cluster. | ||
In this graph, only the tasks are represented. Results are not included in the graph. | ||
Each node contains an 'armonik.common.Task' object which contains the task metadata up to date | ||
at the time the graph is loaded. | ||
Parameters | ||
---------- | ||
task_filter : armonik.common.Filter | ||
A filter identifying the tasks belonging to the graph within the ArmoniK cluster. | ||
name : str | None | ||
An optional name for the graph. Default is None. | ||
description: str | None | ||
An optional description for the graph. Default is None. | ||
Example | ||
------- | ||
>>> from armonik.client import TaskFieldFilter | ||
>>> from armonik_analytics import ArmoniKGraph | ||
>>> g = ArmoniKGraph(task_filter=(TaskFieldFilter.SESSION_ID == "session_id")) | ||
""" | ||
|
||
def __init__( | ||
self, task_filter: Filter, name: str | None = None, description: str | None = None | ||
) -> None: | ||
self.graph = rx.PyDiGraph(check_cycle=True, multigraph=True) | ||
self.graph.attrs = ArmoniKGraphAttr( | ||
name=name if name else "", | ||
description=description if description else "", | ||
task_filter=task_filter, | ||
) | ||
|
||
def update(self, channel: Channel) -> None: | ||
"""Updates in-place the contents of the graph from the state database of a running cluster. | ||
Note that this operation deletes and then re-downloads the content, even if it has not changed. | ||
This can take a significant amount of time. | ||
Parameters | ||
---------- | ||
channel : grpc.Channel | ||
An open gRPC channel to the running cluster. | ||
""" | ||
# Clear current content. Should be improve in future versions. | ||
self.graph.clear() | ||
|
||
client = ArmoniKTasks(channel) | ||
|
||
# Tasks depend on each other through their input/output data. The following dictionary is | ||
# used to build dependencies between tasks. It stores for each input/output data item which | ||
# unique task produces it and which task(s) consume(s) it. | ||
edges = defaultdict(lambda: [None, []]) | ||
|
||
# Iterates over all tasks corresponding to the filter defining the graph | ||
page = 0 | ||
total, tasks = client.list_tasks(task_filter=self.graph.attrs.task_filter, with_errors=True) | ||
while tasks: | ||
for task in tasks: | ||
# Add task to graph | ||
node_id = self.graph.add_node(task) | ||
# Add task inputs/outputs to 'edges' dictionnary | ||
for in_data_dep in task.data_dependencies: | ||
edges[in_data_dep][1].append(node_id) | ||
for out_data_dep in task.expected_output_ids: | ||
# An output data can only be produced by a single task. However, an ArmoniK | ||
# graph is dynamic and a task can transfer its responsibility for generating | ||
# an output to another task. This operation is not reflected in the task's | ||
# metadata. So care must be taken to select only the task that actually produces | ||
# the result. This is the task that has all the other tasks as parents among the | ||
# tasks claiming to produce this output data. | ||
old_tail_id = edges[out_data_dep][0] | ||
if old_tail_id: | ||
if task.id in self.graph.get_node_data(old_tail_id).parent_task_ids: | ||
node_id = old_tail_id | ||
edges[out_data_dep][0] = node_id | ||
|
||
page += 1 | ||
_, tasks = client.list_tasks(task_filter=self.graph.attrs.task_filter) | ||
|
||
# Once built, the 'edges' dictionary is used to construct dependencies between tasks in | ||
# the graph. | ||
for tail, heads in edges.values(): | ||
# Root input data have no tails (not produced by any task) and don't correspond to any | ||
# dependency between two tasks. Such data are ignored. | ||
if tail is not None: | ||
self.graph.add_edges_from_no_data([(tail, head) for head in heads]) | ||
|
||
assert self.graph.num_nodes() == total |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import grpc | ||
import rustworkx as rx | ||
import pytest | ||
|
||
from armonik.client import ArmoniKTasks, TaskFieldFilter | ||
from armonik.common import Task | ||
|
||
from armonik_analytics.graph import ArmoniKGraph | ||
|
||
|
||
class ListTasks: | ||
def __init__(self, tasks): | ||
self.tasks = tasks | ||
self.call_count = 0 | ||
|
||
def __call__(self, *args, **kwds): | ||
self.call_count += 1 | ||
if self.call_count == 1: | ||
return len(self.tasks), self.tasks | ||
else: | ||
return len(self.tasks), [] | ||
|
||
|
||
def single_node(): | ||
tasks = [ | ||
Task(id="t0", data_dependencies=["i0"], expected_output_ids=["o0"]), | ||
] | ||
graph = rx.PyDiGraph(check_cycle=True) | ||
graph.add_nodes_from(tasks) | ||
|
||
return tasks, graph | ||
|
||
|
||
def three_parallel_nodes(): | ||
tasks = [ | ||
Task(id="t0", data_dependencies=["i0"], expected_output_ids=["o0"]), | ||
Task(id="t1", data_dependencies=["i1"], expected_output_ids=["o1"]), | ||
Task(id="t2", data_dependencies=["i2"], expected_output_ids=["o2"]), | ||
] | ||
graph = rx.PyDiGraph(check_cycle=True) | ||
graph.add_nodes_from(tasks) | ||
|
||
return tasks, graph | ||
|
||
|
||
def three_dependant_nodes(): | ||
tasks = [ | ||
Task(id="t0", data_dependencies=["i0"], expected_output_ids=["o0"]), | ||
Task(id="t1", data_dependencies=["o0"], expected_output_ids=["o1"]), | ||
Task(id="t2", data_dependencies=["o0"], expected_output_ids=["o2"]), | ||
] | ||
graph = rx.PyDiGraph(check_cycle=True) | ||
graph.add_nodes_from(tasks) | ||
graph.add_edges_from_no_data([(0, 1), (0, 2)]) | ||
|
||
return tasks, graph | ||
|
||
|
||
def seven_dependant_nodes(): | ||
tasks = [ | ||
Task(id="t0", data_dependencies=["i0"], expected_output_ids=["o0"]), | ||
Task(id="t1", data_dependencies=["o0"], expected_output_ids=["o1"]), | ||
Task(id="t2", data_dependencies=["o0"], expected_output_ids=["o2"]), | ||
Task(id="t3", data_dependencies=["o1", "o2"], expected_output_ids=["o3"]), | ||
Task(id="t4", data_dependencies=["o1"], expected_output_ids=["o4"]), | ||
Task(id="t5", data_dependencies=["o4"], expected_output_ids=["o5"]), | ||
Task(id="t6", data_dependencies=["o3"], expected_output_ids=["o6", "o7"]), | ||
] | ||
graph = rx.PyDiGraph(check_cycle=True) | ||
graph.add_nodes_from(tasks) | ||
graph.add_edges_from_no_data([(0, 1), (0, 2), (1, 3), (1, 4), (2, 3), (4, 5), (3, 6)]) | ||
|
||
return tasks, graph | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("tasks", "expected_graph"), | ||
[ | ||
generator.__call__() | ||
for generator in [ | ||
single_node, | ||
three_parallel_nodes, | ||
three_dependant_nodes, | ||
seven_dependant_nodes, | ||
] | ||
], | ||
) | ||
def test_graph_update(mocker, tasks, expected_graph): | ||
mocker.patch.object(ArmoniKTasks, "list_tasks", new=ListTasks(tasks)) | ||
g = ArmoniKGraph(task_filter=(TaskFieldFilter.SESSION_ID == "session_id")) | ||
with grpc.insecure_channel("host") as channel: | ||
g.update(channel) | ||
assert g.graph.nodes() == expected_graph.nodes() | ||
assert g.graph.edges() == expected_graph.edges() |