From 24c92396b587b75238a425582f84a99e957ff9a2 Mon Sep 17 00:00:00 2001 From: Xing Wang Date: Wed, 17 Jul 2024 09:32:43 +0200 Subject: [PATCH] use `get_processes_latest` for workgraph.update (#174) * use `get_processes_latest` for workgraph.update * serialize task.process in `to_dict` * handle data node and process node differently - when loading the WorkGraph and reading outputs of the tasks - When setting the results of the task in the context of the Engine --- aiida_workgraph/engine/workgraph.py | 41 +++++------- aiida_workgraph/task.py | 4 +- aiida_workgraph/utils/__init__.py | 5 -- aiida_workgraph/utils/analysis.py | 10 ++- aiida_workgraph/workgraph.py | 85 ++++++++----------------- docs/source/quick_start.ipynb | 37 ++++++----- docs/source/tutorial/zero_to_hero.ipynb | 2 +- 7 files changed, 70 insertions(+), 114 deletions(-) diff --git a/aiida_workgraph/engine/workgraph.py b/aiida_workgraph/engine/workgraph.py index c0fa6d8e..03c8c55f 100644 --- a/aiida_workgraph/engine/workgraph.py +++ b/aiida_workgraph/engine/workgraph.py @@ -17,6 +17,7 @@ from aiida.common.lang import override from aiida import orm from aiida.orm import load_node, Node, ProcessNode, WorkChainNode +from aiida.orm.utils.serialize import deserialize_unsafe, serialize from aiida.engine.processes.exit_code import ExitCode from aiida.engine.processes.process import Process @@ -453,7 +454,6 @@ def setup_ctx_workgraph(self, wgdata: t.Dict[str, t.Any]) -> None: def read_wgdata_from_base(self) -> t.Dict[str, t.Any]: """Read workgraph data from base.extras.""" - from aiida.orm.utils.serialize import deserialize_unsafe wgdata = self.node.base.extras.get("_workgraph") for name, task in wgdata["tasks"].items(): @@ -480,7 +480,6 @@ def update_task(self, task: Task): def get_task_state_info(self, name: str, key: str) -> str: """Get task state info from base.extras.""" - from aiida.orm.utils.serialize import deserialize_unsafe if key == "process": value = deserialize_unsafe( @@ -492,7 +491,6 @@ def get_task_state_info(self, name: str, key: str) -> str: def set_task_state_info(self, name: str, key: str, value: any) -> None: """Set task state info to base.extras.""" - from aiida.orm.utils.serialize import serialize if key == "process": self.node.base.extras.set(f"_task_{key}_{name}", serialize(value)) @@ -518,52 +516,40 @@ def set_task_results(self) -> None: for name, task in self.ctx.tasks.items(): if self.get_task_state_info(name, "action").upper() == "RESET": self.reset_task(task["name"]) - if self.get_task_state_info(name, "process"): - if isinstance(self.get_task_state_info(task["name"], "process"), str): - self.set_task_state_info( - task["name"], - "process", - orm.load_node( - self.get_task_state_info(task["name"], "process") - ), - ) + process = self.get_task_state_info(name, "process") + if process: self.set_task_result(task) self.set_task_result(task) def set_task_result(self, task: t.Dict[str, t.Any]) -> None: name = task["name"] # print(f"set task result: {name}") - if self.get_task_state_info(name, "process"): + node = self.get_task_state_info(name, "process") + if isinstance(node, orm.ProcessNode): # print(f"set task result: {name} process") state = self.get_task_state_info( task["name"], "process" ).process_state.value.upper() - if self.get_task_state_info(task["name"], "process").is_finished_ok: + if node.is_finished_ok: self.set_task_state_info(task["name"], "state", state) if task["metadata"]["node_type"].upper() == "WORKGRAPH": # expose the outputs of all the tasks in the workgraph task["results"] = {} - outgoing = self.get_task_state_info( - task["name"], "process" - ).base.links.get_outgoing() + outgoing = node.base.links.get_outgoing() for link in outgoing.all(): if isinstance(link.node, ProcessNode) and getattr( link.node, "process_state", False ): task["results"][link.link_label] = link.node.outputs else: - task["results"] = self.get_task_state_info( - task["name"], "process" - ).outputs + task["results"] = node.outputs # self.ctx.new_data[name] = task["results"] self.set_task_state_info(task["name"], "state", "FINISHED") self.task_set_context(name) self.report(f"Task: {name} finished.") # all other states are considered as failed else: - task["results"] = self.get_task_state_info( - task["name"], "process" - ).outputs + task["results"] = node.outputs # self.ctx.new_data[name] = task["results"] self.set_task_state_info(task["name"], "state", "FAILED") # set child tasks state to SKIPPED @@ -572,6 +558,11 @@ def set_task_result(self, task: t.Dict[str, t.Any]) -> None: ) self.report(f"Task: {name} failed.") self.run_error_handlers(name) + elif isinstance(node, orm.Data): + task["results"] = {task["outputs"][0]["name"]: node} + self.set_task_state_info(task["name"], "state", "FINISHED") + self.task_set_context(name) + self.report(f"Task: {name} finished.") else: task["results"] = None @@ -801,8 +792,8 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None kwargs[key] = args[i] # update the port namespace kwargs = update_nested_dict_with_special_keys(kwargs) - # print("args: ", args) - # print("kwargs: ", kwargs) + print("args: ", args) + print("kwargs: ", kwargs) # print("var_kwargs: ", var_kwargs) # kwargs["meta.label"] = name # output must be a Data type or a mapping of {string: Data} diff --git a/aiida_workgraph/task.py b/aiida_workgraph/task.py index d60aa267..6dac3160 100644 --- a/aiida_workgraph/task.py +++ b/aiida_workgraph/task.py @@ -57,12 +57,14 @@ def __init__( self.action = "" def to_dict(self) -> Dict[str, Any]: + from aiida.orm.utils.serialize import serialize + tdata = super().to_dict() tdata["context_mapping"] = self.context_mapping tdata["wait"] = [ task if isinstance(task, str) else task.name for task in self.wait ] - tdata["process"] = self.process.uuid if self.process else None + tdata["process"] = serialize(self.process) if self.process else serialize(None) tdata["metadata"]["pk"] = self.process.pk if self.process else None tdata["metadata"]["is_aiida_component"] = self.is_aiida_component diff --git a/aiida_workgraph/utils/__init__.py b/aiida_workgraph/utils/__init__.py index 076f328b..cab8bbb3 100644 --- a/aiida_workgraph/utils/__init__.py +++ b/aiida_workgraph/utils/__init__.py @@ -260,11 +260,9 @@ def get_processes_latest( from aiida.orm.utils.serialize import deserialize_unsafe from aiida.orm import QueryBuilder from aiida_workgraph.engine.workgraph import WorkGraphEngine - import time tasks = {} node_names = [node_name] if node_name else [] - tstart = time.time() if node_name: projections = [ f"extras._task_state_{node_name}", @@ -297,9 +295,6 @@ def get_processes_latest( "ctime": task_process.ctime if task_process else None, "mtime": task_process.mtime if task_process else None, } - # print("tasks: ", tasks) - print(f"Time to deserialize data: {time.time() - tstart}") - return tasks diff --git a/aiida_workgraph/utils/analysis.py b/aiida_workgraph/utils/analysis.py index 24a025a0..8d5c023d 100644 --- a/aiida_workgraph/utils/analysis.py +++ b/aiida_workgraph/utils/analysis.py @@ -2,6 +2,7 @@ # import datetime from aiida.orm import ProcessNode +from aiida.orm.utils.serialize import serialize, deserialize_unsafe class WorkGraphSaver: @@ -99,7 +100,6 @@ def insert_workgraph_to_db(self) -> None: - workgraph - all tasks """ - from aiida.orm.utils.serialize import serialize from aiida_workgraph.utils import workgraph_to_short_json # pprint(self.wgdata) @@ -117,14 +117,13 @@ def insert_workgraph_to_db(self) -> None: def save_task_states(self) -> Dict: """Get task states.""" - from aiida.orm.utils.serialize import serialize task_states = {} task_processes = {} task_actions = {} for name, task in self.wgdata["tasks"].items(): task_states[f"_task_state_{name}"] = task["state"] - task_processes[f"_task_process_{name}"] = serialize(task["process"]) + task_processes[f"_task_process_{name}"] = task["process"] task_actions[f"_task_action_{name}"] = task["action"] self.process.base.extras.set_many(task_states) self.process.base.extras.set_many(task_processes) @@ -147,13 +146,13 @@ def reset_tasks(self, tasks: List[str]) -> None: ): for name in tasks: self.wgdata["tasks"][name]["state"] = "PLANNED" - self.wgdata["tasks"][name]["process"] = None + self.wgdata["tasks"][name]["process"] = serialize(None) self.wgdata["tasks"][name]["result"] = None names = self.wgdata["connectivity"]["child_node"][name] for name in names: self.wgdata["tasks"][name]["state"] = "PLANNED" self.wgdata["tasks"][name]["result"] = None - self.wgdata["tasks"][name]["process"] = None + self.wgdata["tasks"][name]["process"] = serialize(None) else: create_task_action(self.process.pk, tasks=tasks, action="reset") @@ -166,7 +165,6 @@ def set_tasks_action(self, action: str) -> None: def get_wgdata_from_db( self, process: Optional[ProcessNode] = None ) -> Optional[Dict]: - from aiida.orm.utils.serialize import deserialize_unsafe process = self.process if process is None else process wgdata = process.base.extras.get("_workgraph", None) diff --git a/aiida_workgraph/workgraph.py b/aiida_workgraph/workgraph.py index 94ae07d4..3de36ba6 100644 --- a/aiida_workgraph/workgraph.py +++ b/aiida_workgraph/workgraph.py @@ -235,69 +235,40 @@ def update(self) -> None: linked to the current process, and data nodes linked to the current process. """ # from aiida_workgraph.utils import get_executor - from aiida_workgraph.utils import get_nested_dict + from aiida_workgraph.utils import get_nested_dict, get_processes_latest if self.process is None: return self.state = self.process.process_state.value.upper() - outgoing = self.process.base.links.get_outgoing() - for link in outgoing.all(): - node = link.node - # the link is added in order - # so the restarted node will be the last one - # thus the task is correct - if isinstance(node, aiida.orm.ProcessNode) and getattr( - node, "process_state", False - ): - self.tasks[link.link_label].process = node - self.tasks[link.link_label].state = node.process_state.value.upper() - self.tasks[link.link_label].node = node - self.tasks[link.link_label].pk = node.pk - self.tasks[link.link_label].ctime = node.ctime - self.tasks[link.link_label].mtime = node.mtime - if self.tasks[link.link_label].state == "FINISHED": - # update the output sockets - i = 0 - for socket in self.tasks[link.link_label].outputs: - socket.value = get_nested_dict( - node.outputs, socket.name, allow_none=True - ) - i += 1 - elif isinstance(node, aiida.orm.Data): - if link.link_label.startswith("new_data__"): - label = link.link_label.split("__", 1)[1] - if label in self.tasks.keys(): - self.tasks[label].state = "FINISHED" - self.tasks[label].node = node - self.tasks[label].pk = node.pk - elif link.link_label == "execution_count": - self.execution_count = node.value - # read results from the process outputs - for task in self.tasks: - if task.node_type.upper() == "DATA": - if not getattr(self.process.outputs, "new_data", False): - continue - task.outputs[0].value = getattr( - self.process.outputs.new_data, task.name, None - ) - # for normal tasks, we try to read the results from the extras of the task - # this is disabled for now - # if task.node_type.upper() == "NORMAL": - # results = self.process.base.extras.get( - # f"nodes__results__{task.name}", {} - # ) - # for key, value in results.items(): - # # if value is an AiiDA data node, we don't need to deserialize it - # deserializer = node.outputs[key].get_deserialize() - # executor = get_executor(deserializer)[0] - # try: - # value = executor(bytes(value)) - # except Exception: - # pass - # node.outputs[key].value = value + processes_data = get_processes_latest(self.pk) + for name, data in processes_data.items(): + self.tasks[name].state = data["state"] + self.tasks[name].ctime = data["ctime"] + self.tasks[name].mtime = data["mtime"] + self.tasks[name].pk = data["pk"] + if data["pk"] is not None: + node = aiida.orm.load_node(data["pk"]) + self.tasks[name].process = self.tasks[name].node = node + if isinstance(node, aiida.orm.ProcessNode) and getattr( + node, "process_state", False + ): + if self.tasks[name].state == "FINISHED": + # update the output sockets + i = 0 + for socket in self.tasks[name].outputs: + socket.value = get_nested_dict( + node.outputs, socket.name, allow_none=True + ) + i += 1 + # read results from the process outputs + elif isinstance(node, aiida.orm.Data): + self.tasks[name].outputs[0].value = node + execution_count = getattr(self.process.outputs, "execution_count", None) + self.execution_count = execution_count if execution_count else 0 if self._widget is not None: - self._widget.states = {task.name: node.state for node in self.tasks} + states = {name: data["state"] for name, data in processes_data.items()} + self._widget.states = states @property def pk(self) -> Optional[int]: diff --git a/docs/source/quick_start.ipynb b/docs/source/quick_start.ipynb index 4fa7dc01..13da2c96 100644 --- a/docs/source/quick_start.ipynb +++ b/docs/source/quick_start.ipynb @@ -17,7 +17,7 @@ "To run this tutorial, you need to install `aiida-workgraph`. Open a terminal and run:\n", "\n", "```console\n", - "pip install aiida-workgraph\n", + "pip install aiida-workgraph[widget]\n", "```\n", "\n", "Start (or restart) the AiiDA daemon if needed:\n", @@ -122,7 +122,7 @@ " " ], "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -168,20 +168,13 @@ "outputs": [ { "data": { - "text/html": [ - "\n", - " \n", - " " - ], + "application/vnd.jupyter.widget-view+json": { + "model_id": "72c8e6fc42f644b8a1a672a0daec1d80", + "version_major": 2, + "version_minor": 1 + }, "text/plain": [ - "" + "NodeGraphWidget(settings={'minimap': True}, style={'width': '90%', 'height': '600px'}, value={'name': 'add_mul…" ] }, "execution_count": 4, @@ -200,7 +193,7 @@ "# export the workgraph to html file so that it can be visualized in a browser\n", "wg.to_html()\n", "# comment out the following line to visualize the workgraph in jupyter-notebook\n", - "# wg" + "wg" ] }, { @@ -221,10 +214,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "WorkGraph process created, PK: 87594\n", + "WorkGraph process created, PK: 92155\n", + "Time to deserialize data: 0.03729963302612305\n", + "Time to deserialize data: 0.035559654235839844\n", + "Time to deserialize data: 0.032629966735839844\n", + "Time to deserialize data: 0.034082889556884766\n", + "Time to deserialize data: 0.04560279846191406\n", + "Time to deserialize data: 0.04309391975402832\n", "State of WorkGraph: FINISHED\n", - "Result of add : uuid: 81d4b1a9-e7d6-479e-9820-6459bf15efb9 (pk: 87596) value: 5\n", - "Result of multiply : uuid: 27f87728-b771-4d8a-a13f-a2b115a16353 (pk: 87598) value: 20\n" + "Result of add : uuid: d86318af-3390-4ffb-afe4-025d0bf61d62 (pk: 92157) value: 5\n", + "Result of multiply : uuid: 12c16973-c6a3-4769-8db3-5e960c313069 (pk: 92159) value: 20\n" ] } ], diff --git a/docs/source/tutorial/zero_to_hero.ipynb b/docs/source/tutorial/zero_to_hero.ipynb index a79a2d46..0bebdf09 100644 --- a/docs/source/tutorial/zero_to_hero.ipynb +++ b/docs/source/tutorial/zero_to_hero.ipynb @@ -20,7 +20,7 @@ "To run this tutorial, you need to install `aiida-workgraph`, `aiida-quantumespresso`. Open a terminal and run:\n", "\n", "```console\n", - "pip install aiida-workgraph aiida-quantumespresso\n", + "pip install aiida-workgraph[widiget] aiida-quantumespresso\n", "```\n", "\n", "Restart (or start) the AiiDA daemon if needed:\n",