From 5d95d38ca41881e3bf3ceceab9c618a5b1ad4af9 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Tue, 26 Nov 2024 00:47:44 +0100 Subject: [PATCH] Fix mypy daemon/execmanager.py --- src/aiida/common/datastructures.py | 2 +- src/aiida/engine/daemon/execmanager.py | 73 ++++++++++++------- .../orm/nodes/process/calculation/calcjob.py | 4 +- src/aiida/orm/nodes/repository.py | 6 +- 4 files changed, 52 insertions(+), 33 deletions(-) diff --git a/src/aiida/common/datastructures.py b/src/aiida/common/datastructures.py index dc09712a7c..731b1a82b1 100644 --- a/src/aiida/common/datastructures.py +++ b/src/aiida/common/datastructures.py @@ -145,7 +145,7 @@ class CalcInfo(DefaultFieldsAttributeDict): email: None | str email_on_started: bool email_on_terminated: bool - uuid: None | str + uuid: str prepend_text: None | str append_text: None | str num_machines: None | int diff --git a/src/aiida/engine/daemon/execmanager.py b/src/aiida/engine/daemon/execmanager.py index 73c30cab61..ed141bb892 100644 --- a/src/aiida/engine/daemon/execmanager.py +++ b/src/aiida/engine/daemon/execmanager.py @@ -20,7 +20,7 @@ from logging import LoggerAdapter from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory -from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Optional, Sequence, cast from typing import Mapping as MappingType from aiida.common import AIIDA_LOGGER, exceptions @@ -29,7 +29,7 @@ from aiida.common.links import LinkType from aiida.engine.processes.exit_code import ExitCode from aiida.manage.configuration import get_config_option -from aiida.orm import CalcJobNode, Code, FolderData, Node, PortableCode, RemoteData, load_node +from aiida.orm import CalcJobNode, Code, Computer, FolderData, Node, PortableCode, RemoteData, load_node from aiida.orm.utils.log import get_dblogger_extra from aiida.repository.common import FileType from aiida.schedulers.datastructures import JobState @@ -84,12 +84,13 @@ def upload_calculation( link_label = 'remote_folder' if node.base.links.get_outgoing(RemoteData, link_label_filter=link_label).first(): EXEC_LOGGER.warning(f'CalcJobNode<{node.pk}> already has a `{link_label}` output: skipping upload') - return calc_info + return None - computer = node.computer + # cast since certain the CalcJobNode must have a computer attach to + computer = cast(Computer, node.computer) codes_info = calc_info.codes_info - input_codes = [load_node(_.code_uuid, sub_classes=(Code,)) for _ in codes_info] + input_codes = [load_node(_.code_uuid, sub_classes=(Code,)) for _ in codes_info] if codes_info else [] logger_extra = get_dblogger_extra(node) transport.set_logger_extra(logger_extra) @@ -182,7 +183,7 @@ def upload_calculation( # Since the content of the node could potentially be binary, we read the raw bytes and pass them on for filename in filenames: with NamedTemporaryFile(mode='wb+') as handle: - content = code.base.repository.get_object_content(Path(root) / filename, mode='rb') + content = code.base.repository.get_object_content(root / filename, mode='rb') handle.write(content) handle.flush() transport.put(handle.name, str(workdir.joinpath(root, filename))) @@ -222,7 +223,7 @@ def upload_calculation( if dry_run: if remote_copy_list: filepath = os.path.join(str(workdir), '_aiida_remote_copy_list.txt') - with open(filepath, 'w', encoding='utf-8') as handle: # type: ignore[assignment] + with open(filepath, 'w', encoding='utf-8') as handle: for _, remote_abs_path, dest_rel_path in remote_copy_list: handle.write( f'would have copied {remote_abs_path} to {dest_rel_path} in working ' @@ -231,7 +232,7 @@ def upload_calculation( if remote_symlink_list: filepath = os.path.join(str(workdir), '_aiida_remote_symlink_list.txt') - with open(filepath, 'w', encoding='utf-8') as handle: # type: ignore[assignment] + with open(filepath, 'w', encoding='utf-8') as handle: for _, remote_abs_path, dest_rel_path in remote_symlink_list: handle.write( f'would have created symlinks from {remote_abs_path} to {dest_rel_path} in working' @@ -265,7 +266,7 @@ def upload_calculation( if relpath not in provenance_exclude_list and all( dirname not in provenance_exclude_list for dirname in dirnames ): - with open(filepath, 'rb') as handle: # type: ignore[assignment] + with open(filepath, 'rb') as handle: node.base.repository._repository.put_object_from_filelike(handle, relpath) # Since the node is already stored, we cannot use the normal repository interface since it will raise a @@ -333,14 +334,15 @@ def _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir: for uuid, filename, target in local_copy_list: logger.debug(f'[submission of calculation {node.uuid}] copying local file/folder to {target}') + data_node = None try: data_node = load_node(uuid=uuid) except exceptions.NotExistent: data_node = _find_data_node(inputs, uuid) if inputs else None - - if data_node is None: - logger.warning(f'failed to load Node<{uuid}> specified in the `local_copy_list`') - continue + finally: + if data_node is None: + logger.warning(f'failed to load Node<{uuid}> specified in the `local_copy_list`') + continue # The transport class can only copy files directly from the file system, so the files in the source node's repo # have to first be copied to a temporary directory on disk. @@ -410,12 +412,19 @@ def submit_calculation(calculation: CalcJobNode, transport: Transport) -> str | if job_id is not None: return job_id - scheduler = calculation.computer.get_scheduler() + computer = cast(Computer, calculation.computer) + scheduler = computer.get_scheduler() scheduler.set_transport(transport) - submit_script_filename = calculation.get_option('submit_script_filename') + # metadata.options.submit_script_filename of CalcJob inputs + submit_script_filename: str = cast(str, calculation.get_option('submit_script_filename')) workdir = calculation.get_remote_workdir() - result = scheduler.submit_job(workdir, submit_script_filename) + if workdir is not None: + result = scheduler.submit_job(workdir, submit_script_filename) + else: + # FIXME: Require inner exit_code for remote_workdir of calculation is not set + # Return ExitCode since it is what user can fix + return ExitCode(-1) if isinstance(result, str): calculation.set_job_id(result) @@ -423,7 +432,7 @@ def submit_calculation(calculation: CalcJobNode, transport: Transport) -> str | return result -def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None: +def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None | ExitCode: """Stash files from the working directory of a completed calculation to a permanent remote folder. After a calculation has been completed, optionally stash files from the work directory to a storage location on the @@ -439,23 +448,29 @@ def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None: logger_extra = get_dblogger_extra(calculation) - stash_options = calculation.get_option('stash') + stash_options = cast(dict[str, Any], calculation.get_option('stash')) stash_mode = stash_options.get('mode', StashMode.COPY.value) source_list = stash_options.get('source_list', []) if not source_list: - return + return None if stash_mode != StashMode.COPY.value: EXEC_LOGGER.warning(f'stashing mode {stash_mode} is not implemented yet.') - return + return None cls = RemoteStashFolderData EXEC_LOGGER.debug(f'stashing files for calculation<{calculation.pk}>: {source_list}', extra=logger_extra) uuid = calculation.uuid - source_basepath = Path(calculation.get_remote_workdir()) + workdir = calculation.get_remote_workdir() + if workdir is not None: + source_basepath = Path(workdir) + else: + # FIXME: Require inner exit_code for remote_workdir of calculation is not set + # Return ExitCode since it is what user can fix + return ExitCode(-1) target_basepath = Path(stash_options['target_base']) / uuid[:2] / uuid[2:4] / uuid[4:] for source_filename in source_list: @@ -487,6 +502,8 @@ def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None: ).store() remote_stash.base.links.add_incoming(calculation, link_type=LinkType.CREATE, link_label='remote_stash') + return None + def retrieve_calculation( calculation: CalcJobNode, transport: Transport, retrieved_temporary_folder: str @@ -518,7 +535,7 @@ def retrieve_calculation( EXEC_LOGGER.warning( f'CalcJobNode<{calculation.pk}> already has a `{link_label}` output folder: skipping retrieval' ) - return + return None # Create the FolderData node into which to store the files that are to be retrieved retrieved_files = FolderData() @@ -567,7 +584,8 @@ def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None: return # Get the scheduler plugin class and initialize it with the correct transport - scheduler = calculation.computer.get_scheduler() + computer = cast(Computer, calculation.computer) + scheduler = computer.get_scheduler() scheduler.set_transport(transport) # Call the proper kill method for the job ID of this calculation @@ -576,7 +594,7 @@ def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None: if result is not True: # Failed to kill because the job might have already been completed running_jobs = scheduler.get_jobs(jobs=[job_id], as_dict=True) - job = running_jobs.get(job_id, None) + job = running_jobs.get(job_id, None) # type: ignore[union-attr] # If the job is returned it is still running and the kill really failed, so we raise if job is not None and job.job_state != JobState.DONE: @@ -591,7 +609,7 @@ def retrieve_files_from_list( calculation: CalcJobNode, transport: Transport, folder: str, - retrieve_list: List[Union[str, Tuple[str, str, int], list]], + retrieve_list: Sequence[str | tuple[str, str, int | None]] | None, ) -> None: """Retrieve all the files in the retrieve_list from the remote into the local folder instance through the transport. The entries in the retrieve_list @@ -621,7 +639,7 @@ def retrieve_files_from_list( tmp_rname, tmp_lname, depth = item # if there are more than one file I do something differently if transport.has_magic(tmp_rname): - remote_names = transport.glob(str(workdir.joinpath(tmp_rname))) + remote_names = transport.glob(str(workdir / tmp_rname)) local_names = [] for rem in remote_names: # get the relative path so to make local_names relative @@ -633,6 +651,7 @@ def retrieve_files_from_list( local_names.append(os.path.sep.join([tmp_lname] + to_append)) else: remote_names = [tmp_rname] + # FIXME: will except if depth is none to_append = tmp_rname.split(os.path.sep)[-depth:] if depth > 0 else [] local_names = [os.path.sep.join([tmp_lname] + to_append)] if depth is None or depth > 1: # create directories in the folder, if needed @@ -641,7 +660,7 @@ def retrieve_files_from_list( if not os.path.exists(new_folder): os.makedirs(new_folder) else: - abs_item = item if item.startswith('/') else str(workdir.joinpath(item)) + abs_item = item if item.startswith('/') else str(workdir / item) if transport.has_magic(abs_item): remote_names = transport.glob(abs_item) diff --git a/src/aiida/orm/nodes/process/calculation/calcjob.py b/src/aiida/orm/nodes/process/calculation/calcjob.py index a7cd20c88e..e1f7ff7352 100644 --- a/src/aiida/orm/nodes/process/calculation/calcjob.py +++ b/src/aiida/orm/nodes/process/calculation/calcjob.py @@ -312,7 +312,7 @@ def set_retrieve_list(self, retrieve_list: Sequence[Union[str, Tuple[str, str, s self._validate_retrieval_directive(retrieve_list) self.base.attributes.set(self.RETRIEVE_LIST_KEY, retrieve_list) - def get_retrieve_list(self) -> Optional[Sequence[Union[str, Tuple[str, str, str]]]]: + def get_retrieve_list(self) -> Sequence[str | tuple[str, str, int | None]] | None: """Return the list of files/directories to be retrieved on the cluster after the calculation has completed. :return: a list of file directives @@ -330,7 +330,7 @@ def set_retrieve_temporary_list(self, retrieve_temporary_list: Sequence[Union[st self._validate_retrieval_directive(retrieve_temporary_list) self.base.attributes.set(self.RETRIEVE_TEMPORARY_LIST_KEY, retrieve_temporary_list) - def get_retrieve_temporary_list(self) -> Optional[Sequence[Union[str, Tuple[str, str, str]]]]: + def get_retrieve_temporary_list(self) -> Sequence[str | tuple[str, str, int | None]] | None: """Return list of files to be retrieved from the cluster which will be available during parsing. :return: a list of file directives diff --git a/src/aiida/orm/nodes/repository.py b/src/aiida/orm/nodes/repository.py index bc24fe1377..f6f5e2378a 100644 --- a/src/aiida/orm/nodes/repository.py +++ b/src/aiida/orm/nodes/repository.py @@ -237,12 +237,12 @@ def get_object(self, path: FilePath | None = None) -> File: return self._repository.get_object(path) @t.overload - def get_object_content(self, path: str, mode: t.Literal['r']) -> str: ... + def get_object_content(self, path: FilePath, mode: t.Literal['r']) -> str: ... @t.overload - def get_object_content(self, path: str, mode: t.Literal['rb']) -> bytes: ... + def get_object_content(self, path: FilePath, mode: t.Literal['rb']) -> bytes: ... - def get_object_content(self, path: str, mode: t.Literal['r', 'rb'] = 'r') -> str | bytes: + def get_object_content(self, path: FilePath, mode: t.Literal['r', 'rb'] = 'r') -> str | bytes: """Return the content of a object identified by key. :param path: the relative path of the object within the repository.