Skip to content

Commit

Permalink
Fix mypy daemon/execmanager.py
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Nov 25, 2024
1 parent bef39ec commit 5d95d38
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 33 deletions.
2 changes: 1 addition & 1 deletion src/aiida/common/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class CalcInfo(DefaultFieldsAttributeDict):
email: None | str
email_on_started: bool
email_on_terminated: bool
uuid: None | str
uuid: str
prepend_text: None | str
append_text: None | str
num_machines: None | int
Expand Down
73 changes: 46 additions & 27 deletions src/aiida/engine/daemon/execmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from logging import LoggerAdapter
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Optional, Sequence, cast
from typing import Mapping as MappingType

from aiida.common import AIIDA_LOGGER, exceptions
Expand All @@ -29,7 +29,7 @@
from aiida.common.links import LinkType
from aiida.engine.processes.exit_code import ExitCode
from aiida.manage.configuration import get_config_option
from aiida.orm import CalcJobNode, Code, FolderData, Node, PortableCode, RemoteData, load_node
from aiida.orm import CalcJobNode, Code, Computer, FolderData, Node, PortableCode, RemoteData, load_node
from aiida.orm.utils.log import get_dblogger_extra
from aiida.repository.common import FileType
from aiida.schedulers.datastructures import JobState
Expand Down Expand Up @@ -84,12 +84,13 @@ def upload_calculation(
link_label = 'remote_folder'
if node.base.links.get_outgoing(RemoteData, link_label_filter=link_label).first():
EXEC_LOGGER.warning(f'CalcJobNode<{node.pk}> already has a `{link_label}` output: skipping upload')
return calc_info
return None

computer = node.computer
# cast since certain the CalcJobNode must have a computer attach to
computer = cast(Computer, node.computer)

codes_info = calc_info.codes_info
input_codes = [load_node(_.code_uuid, sub_classes=(Code,)) for _ in codes_info]
input_codes = [load_node(_.code_uuid, sub_classes=(Code,)) for _ in codes_info] if codes_info else []

logger_extra = get_dblogger_extra(node)
transport.set_logger_extra(logger_extra)
Expand Down Expand Up @@ -182,7 +183,7 @@ def upload_calculation(
# Since the content of the node could potentially be binary, we read the raw bytes and pass them on
for filename in filenames:
with NamedTemporaryFile(mode='wb+') as handle:
content = code.base.repository.get_object_content(Path(root) / filename, mode='rb')
content = code.base.repository.get_object_content(root / filename, mode='rb')
handle.write(content)
handle.flush()
transport.put(handle.name, str(workdir.joinpath(root, filename)))
Expand Down Expand Up @@ -222,7 +223,7 @@ def upload_calculation(
if dry_run:
if remote_copy_list:
filepath = os.path.join(str(workdir), '_aiida_remote_copy_list.txt')
with open(filepath, 'w', encoding='utf-8') as handle: # type: ignore[assignment]
with open(filepath, 'w', encoding='utf-8') as handle:
for _, remote_abs_path, dest_rel_path in remote_copy_list:
handle.write(
f'would have copied {remote_abs_path} to {dest_rel_path} in working '
Expand All @@ -231,7 +232,7 @@ def upload_calculation(

if remote_symlink_list:
filepath = os.path.join(str(workdir), '_aiida_remote_symlink_list.txt')
with open(filepath, 'w', encoding='utf-8') as handle: # type: ignore[assignment]
with open(filepath, 'w', encoding='utf-8') as handle:
for _, remote_abs_path, dest_rel_path in remote_symlink_list:
handle.write(
f'would have created symlinks from {remote_abs_path} to {dest_rel_path} in working'
Expand Down Expand Up @@ -265,7 +266,7 @@ def upload_calculation(
if relpath not in provenance_exclude_list and all(
dirname not in provenance_exclude_list for dirname in dirnames
):
with open(filepath, 'rb') as handle: # type: ignore[assignment]
with open(filepath, 'rb') as handle:
node.base.repository._repository.put_object_from_filelike(handle, relpath)

# Since the node is already stored, we cannot use the normal repository interface since it will raise a
Expand Down Expand Up @@ -333,14 +334,15 @@ def _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir:
for uuid, filename, target in local_copy_list:
logger.debug(f'[submission of calculation {node.uuid}] copying local file/folder to {target}')

data_node = None
try:
data_node = load_node(uuid=uuid)
except exceptions.NotExistent:
data_node = _find_data_node(inputs, uuid) if inputs else None

if data_node is None:
logger.warning(f'failed to load Node<{uuid}> specified in the `local_copy_list`')
continue
finally:
if data_node is None:
logger.warning(f'failed to load Node<{uuid}> specified in the `local_copy_list`')
continue

# The transport class can only copy files directly from the file system, so the files in the source node's repo
# have to first be copied to a temporary directory on disk.
Expand Down Expand Up @@ -410,20 +412,27 @@ def submit_calculation(calculation: CalcJobNode, transport: Transport) -> str |
if job_id is not None:
return job_id

scheduler = calculation.computer.get_scheduler()
computer = cast(Computer, calculation.computer)
scheduler = computer.get_scheduler()
scheduler.set_transport(transport)

submit_script_filename = calculation.get_option('submit_script_filename')
# metadata.options.submit_script_filename of CalcJob inputs
submit_script_filename: str = cast(str, calculation.get_option('submit_script_filename'))
workdir = calculation.get_remote_workdir()
result = scheduler.submit_job(workdir, submit_script_filename)
if workdir is not None:
result = scheduler.submit_job(workdir, submit_script_filename)
else:
# FIXME: Require inner exit_code for remote_workdir of calculation is not set
# Return ExitCode since it is what user can fix
return ExitCode(-1)

if isinstance(result, str):
calculation.set_job_id(result)

return result


def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None:
def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None | ExitCode:
"""Stash files from the working directory of a completed calculation to a permanent remote folder.
After a calculation has been completed, optionally stash files from the work directory to a storage location on the
Expand All @@ -439,23 +448,29 @@ def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None:

logger_extra = get_dblogger_extra(calculation)

stash_options = calculation.get_option('stash')
stash_options = cast(dict[str, Any], calculation.get_option('stash'))
stash_mode = stash_options.get('mode', StashMode.COPY.value)
source_list = stash_options.get('source_list', [])

if not source_list:
return
return None

if stash_mode != StashMode.COPY.value:
EXEC_LOGGER.warning(f'stashing mode {stash_mode} is not implemented yet.')
return
return None

cls = RemoteStashFolderData

EXEC_LOGGER.debug(f'stashing files for calculation<{calculation.pk}>: {source_list}', extra=logger_extra)

uuid = calculation.uuid
source_basepath = Path(calculation.get_remote_workdir())
workdir = calculation.get_remote_workdir()
if workdir is not None:
source_basepath = Path(workdir)
else:
# FIXME: Require inner exit_code for remote_workdir of calculation is not set
# Return ExitCode since it is what user can fix
return ExitCode(-1)
target_basepath = Path(stash_options['target_base']) / uuid[:2] / uuid[2:4] / uuid[4:]

for source_filename in source_list:
Expand Down Expand Up @@ -487,6 +502,8 @@ def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None:
).store()
remote_stash.base.links.add_incoming(calculation, link_type=LinkType.CREATE, link_label='remote_stash')

return None


def retrieve_calculation(
calculation: CalcJobNode, transport: Transport, retrieved_temporary_folder: str
Expand Down Expand Up @@ -518,7 +535,7 @@ def retrieve_calculation(
EXEC_LOGGER.warning(
f'CalcJobNode<{calculation.pk}> already has a `{link_label}` output folder: skipping retrieval'
)
return
return None

# Create the FolderData node into which to store the files that are to be retrieved
retrieved_files = FolderData()
Expand Down Expand Up @@ -567,7 +584,8 @@ def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None:
return

# Get the scheduler plugin class and initialize it with the correct transport
scheduler = calculation.computer.get_scheduler()
computer = cast(Computer, calculation.computer)
scheduler = computer.get_scheduler()
scheduler.set_transport(transport)

# Call the proper kill method for the job ID of this calculation
Expand All @@ -576,7 +594,7 @@ def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None:
if result is not True:
# Failed to kill because the job might have already been completed
running_jobs = scheduler.get_jobs(jobs=[job_id], as_dict=True)
job = running_jobs.get(job_id, None)
job = running_jobs.get(job_id, None) # type: ignore[union-attr]

# If the job is returned it is still running and the kill really failed, so we raise
if job is not None and job.job_state != JobState.DONE:
Expand All @@ -591,7 +609,7 @@ def retrieve_files_from_list(
calculation: CalcJobNode,
transport: Transport,
folder: str,
retrieve_list: List[Union[str, Tuple[str, str, int], list]],
retrieve_list: Sequence[str | tuple[str, str, int | None]] | None,
) -> None:
"""Retrieve all the files in the retrieve_list from the remote into the
local folder instance through the transport. The entries in the retrieve_list
Expand Down Expand Up @@ -621,7 +639,7 @@ def retrieve_files_from_list(
tmp_rname, tmp_lname, depth = item
# if there are more than one file I do something differently
if transport.has_magic(tmp_rname):
remote_names = transport.glob(str(workdir.joinpath(tmp_rname)))
remote_names = transport.glob(str(workdir / tmp_rname))
local_names = []
for rem in remote_names:
# get the relative path so to make local_names relative
Expand All @@ -633,6 +651,7 @@ def retrieve_files_from_list(
local_names.append(os.path.sep.join([tmp_lname] + to_append))
else:
remote_names = [tmp_rname]
# FIXME: will except if depth is none
to_append = tmp_rname.split(os.path.sep)[-depth:] if depth > 0 else []
local_names = [os.path.sep.join([tmp_lname] + to_append)]
if depth is None or depth > 1: # create directories in the folder, if needed
Expand All @@ -641,7 +660,7 @@ def retrieve_files_from_list(
if not os.path.exists(new_folder):
os.makedirs(new_folder)
else:
abs_item = item if item.startswith('/') else str(workdir.joinpath(item))
abs_item = item if item.startswith('/') else str(workdir / item)

if transport.has_magic(abs_item):
remote_names = transport.glob(abs_item)
Expand Down
4 changes: 2 additions & 2 deletions src/aiida/orm/nodes/process/calculation/calcjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def set_retrieve_list(self, retrieve_list: Sequence[Union[str, Tuple[str, str, s
self._validate_retrieval_directive(retrieve_list)
self.base.attributes.set(self.RETRIEVE_LIST_KEY, retrieve_list)

def get_retrieve_list(self) -> Optional[Sequence[Union[str, Tuple[str, str, str]]]]:
def get_retrieve_list(self) -> Sequence[str | tuple[str, str, int | None]] | None:
"""Return the list of files/directories to be retrieved on the cluster after the calculation has completed.
:return: a list of file directives
Expand All @@ -330,7 +330,7 @@ def set_retrieve_temporary_list(self, retrieve_temporary_list: Sequence[Union[st
self._validate_retrieval_directive(retrieve_temporary_list)
self.base.attributes.set(self.RETRIEVE_TEMPORARY_LIST_KEY, retrieve_temporary_list)

def get_retrieve_temporary_list(self) -> Optional[Sequence[Union[str, Tuple[str, str, str]]]]:
def get_retrieve_temporary_list(self) -> Sequence[str | tuple[str, str, int | None]] | None:
"""Return list of files to be retrieved from the cluster which will be available during parsing.
:return: a list of file directives
Expand Down
6 changes: 3 additions & 3 deletions src/aiida/orm/nodes/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,12 @@ def get_object(self, path: FilePath | None = None) -> File:
return self._repository.get_object(path)

@t.overload
def get_object_content(self, path: str, mode: t.Literal['r']) -> str: ...
def get_object_content(self, path: FilePath, mode: t.Literal['r']) -> str: ...

@t.overload
def get_object_content(self, path: str, mode: t.Literal['rb']) -> bytes: ...
def get_object_content(self, path: FilePath, mode: t.Literal['rb']) -> bytes: ...

def get_object_content(self, path: str, mode: t.Literal['r', 'rb'] = 'r') -> str | bytes:
def get_object_content(self, path: FilePath, mode: t.Literal['r', 'rb'] = 'r') -> str | bytes:
"""Return the content of a object identified by key.
:param path: the relative path of the object within the repository.
Expand Down

0 comments on commit 5d95d38

Please sign in to comment.