Skip to content

Commit

Permalink
use get_processes_latest for workgraph.update (#174)
Browse files Browse the repository at this point in the history
* use `get_processes_latest` for workgraph.update
* serialize task.process in `to_dict`
* handle data node and process node differently
  - when loading the WorkGraph and reading outputs of the tasks
  - When setting the results of the task in the context of the Engine
  • Loading branch information
superstar54 authored Jul 17, 2024
1 parent 01007f4 commit 24c9239
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 114 deletions.
41 changes: 16 additions & 25 deletions aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from aiida.common.lang import override
from aiida import orm
from aiida.orm import load_node, Node, ProcessNode, WorkChainNode
from aiida.orm.utils.serialize import deserialize_unsafe, serialize

from aiida.engine.processes.exit_code import ExitCode
from aiida.engine.processes.process import Process
Expand Down Expand Up @@ -453,7 +454,6 @@ def setup_ctx_workgraph(self, wgdata: t.Dict[str, t.Any]) -> None:

def read_wgdata_from_base(self) -> t.Dict[str, t.Any]:
"""Read workgraph data from base.extras."""
from aiida.orm.utils.serialize import deserialize_unsafe

wgdata = self.node.base.extras.get("_workgraph")
for name, task in wgdata["tasks"].items():
Expand All @@ -480,7 +480,6 @@ def update_task(self, task: Task):

def get_task_state_info(self, name: str, key: str) -> str:
"""Get task state info from base.extras."""
from aiida.orm.utils.serialize import deserialize_unsafe

if key == "process":
value = deserialize_unsafe(
Expand All @@ -492,7 +491,6 @@ def get_task_state_info(self, name: str, key: str) -> str:

def set_task_state_info(self, name: str, key: str, value: any) -> None:
"""Set task state info to base.extras."""
from aiida.orm.utils.serialize import serialize

if key == "process":
self.node.base.extras.set(f"_task_{key}_{name}", serialize(value))
Expand All @@ -518,52 +516,40 @@ def set_task_results(self) -> None:
for name, task in self.ctx.tasks.items():
if self.get_task_state_info(name, "action").upper() == "RESET":
self.reset_task(task["name"])
if self.get_task_state_info(name, "process"):
if isinstance(self.get_task_state_info(task["name"], "process"), str):
self.set_task_state_info(
task["name"],
"process",
orm.load_node(
self.get_task_state_info(task["name"], "process")
),
)
process = self.get_task_state_info(name, "process")
if process:
self.set_task_result(task)
self.set_task_result(task)

def set_task_result(self, task: t.Dict[str, t.Any]) -> None:
name = task["name"]
# print(f"set task result: {name}")
if self.get_task_state_info(name, "process"):
node = self.get_task_state_info(name, "process")
if isinstance(node, orm.ProcessNode):
# print(f"set task result: {name} process")
state = self.get_task_state_info(
task["name"], "process"
).process_state.value.upper()
if self.get_task_state_info(task["name"], "process").is_finished_ok:
if node.is_finished_ok:
self.set_task_state_info(task["name"], "state", state)
if task["metadata"]["node_type"].upper() == "WORKGRAPH":
# expose the outputs of all the tasks in the workgraph
task["results"] = {}
outgoing = self.get_task_state_info(
task["name"], "process"
).base.links.get_outgoing()
outgoing = node.base.links.get_outgoing()
for link in outgoing.all():
if isinstance(link.node, ProcessNode) and getattr(
link.node, "process_state", False
):
task["results"][link.link_label] = link.node.outputs
else:
task["results"] = self.get_task_state_info(
task["name"], "process"
).outputs
task["results"] = node.outputs
# self.ctx.new_data[name] = task["results"]
self.set_task_state_info(task["name"], "state", "FINISHED")
self.task_set_context(name)
self.report(f"Task: {name} finished.")
# all other states are considered as failed
else:
task["results"] = self.get_task_state_info(
task["name"], "process"
).outputs
task["results"] = node.outputs
# self.ctx.new_data[name] = task["results"]
self.set_task_state_info(task["name"], "state", "FAILED")
# set child tasks state to SKIPPED
Expand All @@ -572,6 +558,11 @@ def set_task_result(self, task: t.Dict[str, t.Any]) -> None:
)
self.report(f"Task: {name} failed.")
self.run_error_handlers(name)
elif isinstance(node, orm.Data):
task["results"] = {task["outputs"][0]["name"]: node}
self.set_task_state_info(task["name"], "state", "FINISHED")
self.task_set_context(name)
self.report(f"Task: {name} finished.")
else:
task["results"] = None

Expand Down Expand Up @@ -801,8 +792,8 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None
kwargs[key] = args[i]
# update the port namespace
kwargs = update_nested_dict_with_special_keys(kwargs)
# print("args: ", args)
# print("kwargs: ", kwargs)
print("args: ", args)
print("kwargs: ", kwargs)
# print("var_kwargs: ", var_kwargs)
# kwargs["meta.label"] = name
# output must be a Data type or a mapping of {string: Data}
Expand Down
4 changes: 3 additions & 1 deletion aiida_workgraph/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@ def __init__(
self.action = ""

def to_dict(self) -> Dict[str, Any]:
from aiida.orm.utils.serialize import serialize

tdata = super().to_dict()
tdata["context_mapping"] = self.context_mapping
tdata["wait"] = [
task if isinstance(task, str) else task.name for task in self.wait
]
tdata["process"] = self.process.uuid if self.process else None
tdata["process"] = serialize(self.process) if self.process else serialize(None)
tdata["metadata"]["pk"] = self.process.pk if self.process else None
tdata["metadata"]["is_aiida_component"] = self.is_aiida_component

Expand Down
5 changes: 0 additions & 5 deletions aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,9 @@ def get_processes_latest(
from aiida.orm.utils.serialize import deserialize_unsafe
from aiida.orm import QueryBuilder
from aiida_workgraph.engine.workgraph import WorkGraphEngine
import time

tasks = {}
node_names = [node_name] if node_name else []
tstart = time.time()
if node_name:
projections = [
f"extras._task_state_{node_name}",
Expand Down Expand Up @@ -297,9 +295,6 @@ def get_processes_latest(
"ctime": task_process.ctime if task_process else None,
"mtime": task_process.mtime if task_process else None,
}
# print("tasks: ", tasks)
print(f"Time to deserialize data: {time.time() - tstart}")

return tasks


Expand Down
10 changes: 4 additions & 6 deletions aiida_workgraph/utils/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# import datetime
from aiida.orm import ProcessNode
from aiida.orm.utils.serialize import serialize, deserialize_unsafe


class WorkGraphSaver:
Expand Down Expand Up @@ -99,7 +100,6 @@ def insert_workgraph_to_db(self) -> None:
- workgraph
- all tasks
"""
from aiida.orm.utils.serialize import serialize
from aiida_workgraph.utils import workgraph_to_short_json

# pprint(self.wgdata)
Expand All @@ -117,14 +117,13 @@ def insert_workgraph_to_db(self) -> None:

def save_task_states(self) -> Dict:
"""Get task states."""
from aiida.orm.utils.serialize import serialize

task_states = {}
task_processes = {}
task_actions = {}
for name, task in self.wgdata["tasks"].items():
task_states[f"_task_state_{name}"] = task["state"]
task_processes[f"_task_process_{name}"] = serialize(task["process"])
task_processes[f"_task_process_{name}"] = task["process"]
task_actions[f"_task_action_{name}"] = task["action"]
self.process.base.extras.set_many(task_states)
self.process.base.extras.set_many(task_processes)
Expand All @@ -147,13 +146,13 @@ def reset_tasks(self, tasks: List[str]) -> None:
):
for name in tasks:
self.wgdata["tasks"][name]["state"] = "PLANNED"
self.wgdata["tasks"][name]["process"] = None
self.wgdata["tasks"][name]["process"] = serialize(None)
self.wgdata["tasks"][name]["result"] = None
names = self.wgdata["connectivity"]["child_node"][name]
for name in names:
self.wgdata["tasks"][name]["state"] = "PLANNED"
self.wgdata["tasks"][name]["result"] = None
self.wgdata["tasks"][name]["process"] = None
self.wgdata["tasks"][name]["process"] = serialize(None)
else:
create_task_action(self.process.pk, tasks=tasks, action="reset")

Expand All @@ -166,7 +165,6 @@ def set_tasks_action(self, action: str) -> None:
def get_wgdata_from_db(
self, process: Optional[ProcessNode] = None
) -> Optional[Dict]:
from aiida.orm.utils.serialize import deserialize_unsafe

process = self.process if process is None else process
wgdata = process.base.extras.get("_workgraph", None)
Expand Down
85 changes: 28 additions & 57 deletions aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,69 +235,40 @@ def update(self) -> None:
linked to the current process, and data nodes linked to the current process.
"""
# from aiida_workgraph.utils import get_executor
from aiida_workgraph.utils import get_nested_dict
from aiida_workgraph.utils import get_nested_dict, get_processes_latest

if self.process is None:
return

self.state = self.process.process_state.value.upper()
outgoing = self.process.base.links.get_outgoing()
for link in outgoing.all():
node = link.node
# the link is added in order
# so the restarted node will be the last one
# thus the task is correct
if isinstance(node, aiida.orm.ProcessNode) and getattr(
node, "process_state", False
):
self.tasks[link.link_label].process = node
self.tasks[link.link_label].state = node.process_state.value.upper()
self.tasks[link.link_label].node = node
self.tasks[link.link_label].pk = node.pk
self.tasks[link.link_label].ctime = node.ctime
self.tasks[link.link_label].mtime = node.mtime
if self.tasks[link.link_label].state == "FINISHED":
# update the output sockets
i = 0
for socket in self.tasks[link.link_label].outputs:
socket.value = get_nested_dict(
node.outputs, socket.name, allow_none=True
)
i += 1
elif isinstance(node, aiida.orm.Data):
if link.link_label.startswith("new_data__"):
label = link.link_label.split("__", 1)[1]
if label in self.tasks.keys():
self.tasks[label].state = "FINISHED"
self.tasks[label].node = node
self.tasks[label].pk = node.pk
elif link.link_label == "execution_count":
self.execution_count = node.value
# read results from the process outputs
for task in self.tasks:
if task.node_type.upper() == "DATA":
if not getattr(self.process.outputs, "new_data", False):
continue
task.outputs[0].value = getattr(
self.process.outputs.new_data, task.name, None
)
# for normal tasks, we try to read the results from the extras of the task
# this is disabled for now
# if task.node_type.upper() == "NORMAL":
# results = self.process.base.extras.get(
# f"nodes__results__{task.name}", {}
# )
# for key, value in results.items():
# # if value is an AiiDA data node, we don't need to deserialize it
# deserializer = node.outputs[key].get_deserialize()
# executor = get_executor(deserializer)[0]
# try:
# value = executor(bytes(value))
# except Exception:
# pass
# node.outputs[key].value = value
processes_data = get_processes_latest(self.pk)
for name, data in processes_data.items():
self.tasks[name].state = data["state"]
self.tasks[name].ctime = data["ctime"]
self.tasks[name].mtime = data["mtime"]
self.tasks[name].pk = data["pk"]
if data["pk"] is not None:
node = aiida.orm.load_node(data["pk"])
self.tasks[name].process = self.tasks[name].node = node
if isinstance(node, aiida.orm.ProcessNode) and getattr(
node, "process_state", False
):
if self.tasks[name].state == "FINISHED":
# update the output sockets
i = 0
for socket in self.tasks[name].outputs:
socket.value = get_nested_dict(
node.outputs, socket.name, allow_none=True
)
i += 1
# read results from the process outputs
elif isinstance(node, aiida.orm.Data):
self.tasks[name].outputs[0].value = node
execution_count = getattr(self.process.outputs, "execution_count", None)
self.execution_count = execution_count if execution_count else 0
if self._widget is not None:
self._widget.states = {task.name: node.state for node in self.tasks}
states = {name: data["state"] for name, data in processes_data.items()}
self._widget.states = states

@property
def pk(self) -> Optional[int]:
Expand Down
37 changes: 18 additions & 19 deletions docs/source/quick_start.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"To run this tutorial, you need to install `aiida-workgraph`. Open a terminal and run:\n",
"\n",
"```console\n",
"pip install aiida-workgraph\n",
"pip install aiida-workgraph[widget]\n",
"```\n",
"\n",
"Start (or restart) the AiiDA daemon if needed:\n",
Expand Down Expand Up @@ -122,7 +122,7 @@
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x7e51a6031710>"
"<IPython.lib.display.IFrame at 0x759f13f08290>"
]
},
"execution_count": 3,
Expand Down Expand Up @@ -168,20 +168,13 @@
"outputs": [
{
"data": {
"text/html": [
"\n",
" <iframe\n",
" width=\"100%\"\n",
" height=\"600px\"\n",
" src=\"html/add_multiply_workflow.html\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" \n",
" ></iframe>\n",
" "
],
"application/vnd.jupyter.widget-view+json": {
"model_id": "72c8e6fc42f644b8a1a672a0daec1d80",
"version_major": 2,
"version_minor": 1
},
"text/plain": [
"<IPython.lib.display.IFrame at 0x7e51a6043f10>"
"NodeGraphWidget(settings={'minimap': True}, style={'width': '90%', 'height': '600px'}, value={'name': 'add_mul…"
]
},
"execution_count": 4,
Expand All @@ -200,7 +193,7 @@
"# export the workgraph to html file so that it can be visualized in a browser\n",
"wg.to_html()\n",
"# comment out the following line to visualize the workgraph in jupyter-notebook\n",
"# wg"
"wg"
]
},
{
Expand All @@ -221,10 +214,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
"WorkGraph process created, PK: 87594\n",
"WorkGraph process created, PK: 92155\n",
"Time to deserialize data: 0.03729963302612305\n",
"Time to deserialize data: 0.035559654235839844\n",
"Time to deserialize data: 0.032629966735839844\n",
"Time to deserialize data: 0.034082889556884766\n",
"Time to deserialize data: 0.04560279846191406\n",
"Time to deserialize data: 0.04309391975402832\n",
"State of WorkGraph: FINISHED\n",
"Result of add : uuid: 81d4b1a9-e7d6-479e-9820-6459bf15efb9 (pk: 87596) value: 5\n",
"Result of multiply : uuid: 27f87728-b771-4d8a-a13f-a2b115a16353 (pk: 87598) value: 20\n"
"Result of add : uuid: d86318af-3390-4ffb-afe4-025d0bf61d62 (pk: 92157) value: 5\n",
"Result of multiply : uuid: 12c16973-c6a3-4769-8db3-5e960c313069 (pk: 92159) value: 20\n"
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion docs/source/tutorial/zero_to_hero.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"To run this tutorial, you need to install `aiida-workgraph`, `aiida-quantumespresso`. Open a terminal and run:\n",
"\n",
"```console\n",
"pip install aiida-workgraph aiida-quantumespresso\n",
"pip install aiida-workgraph[widiget] aiida-quantumespresso\n",
"```\n",
"\n",
"Restart (or start) the AiiDA daemon if needed:\n",
Expand Down

0 comments on commit 24c9239

Please sign in to comment.