Skip to content

Commit

Permalink
Typing: Add overload signatures for get_object_content
Browse files Browse the repository at this point in the history
Added for the `FolderData` and `NodeRepository` classes.
  • Loading branch information
sphuber committed Nov 27, 2023
1 parent 33dffb0 commit d18eedc
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 20 deletions.
4 changes: 2 additions & 2 deletions aiida/engine/processes/calcjobs/calcjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,8 +757,8 @@ def parse_scheduler_output(self, retrieved: orm.Node) -> Optional[ExitCode]:
try:
exit_code = scheduler.parse_output(
detailed_job_info,
scheduler_stdout or '', # type: ignore[arg-type]
scheduler_stderr or '', # type: ignore[arg-type]
scheduler_stdout or '',
scheduler_stderr or '',
)
except exceptions.FeatureNotAvailable:
self.logger.info(f'`{scheduler.__class__.__name__}` does not implement scheduler output parsing')
Expand Down
26 changes: 17 additions & 9 deletions aiida/orm/nodes/data/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
import contextlib
import io
import pathlib
from typing import BinaryIO, Iterable, Iterator, Optional, TextIO, Union
import typing as t

from aiida.repository import File

from .data import Data

__all__ = ('FolderData',)

FilePath = Union[str, pathlib.PurePosixPath]
FilePath = t.Union[str, pathlib.PurePosixPath]


class FolderData(Data):
Expand Down Expand Up @@ -72,7 +72,7 @@ def list_object_names(self, path: str | None = None) -> list[str]:
return self.base.repository.list_object_names(path)

@contextlib.contextmanager
def open(self, path: str, mode='r') -> Iterator[BinaryIO | TextIO]:
def open(self, path: str, mode='r') -> t.Iterator[t.BinaryIO | t.TextIO]:
"""Open a file handle to an object stored under the given key.
.. note:: this should only be used to open a handle to read an existing file. To write a new file use the method
Expand All @@ -89,7 +89,7 @@ def open(self, path: str, mode='r') -> Iterator[BinaryIO | TextIO]:
yield handle

@contextlib.contextmanager
def as_path(self, path: FilePath | None = None) -> Iterator[pathlib.Path]:
def as_path(self, path: FilePath | None = None) -> t.Iterator[pathlib.Path]:
"""Make the contents of the repository available as a normal filepath on the local file system.
:param path: optional relative path of the object within the repository.
Expand All @@ -110,7 +110,15 @@ def get_object(self, path: FilePath | None = None) -> File:
"""
return self.base.repository.get_object(path)

def get_object_content(self, path: str, mode='r') -> str | bytes:
@t.overload
def get_object_content(self, path: str, mode: t.Literal['r']) -> str:
...

@t.overload
def get_object_content(self, path: str, mode: t.Literal['rb']) -> bytes:
...

def get_object_content(self, path: str, mode: t.Literal['r', 'rb'] = 'r') -> str | bytes:
"""Return the content of a object identified by key.
:param path: the relative path of the object within the repository.
Expand Down Expand Up @@ -151,7 +159,7 @@ def put_object_from_file(self, filepath: str, path: str) -> None:
"""
return self.base.repository.put_object_from_file(filepath, path)

def put_object_from_tree(self, filepath: str, path: Optional[str] = None) -> None:
def put_object_from_tree(self, filepath: str, path: str | None = None) -> None:
"""Store the entire contents of `filepath` on the local file system in the repository with under given `path`.
:param filepath: absolute path of the directory whose contents to copy to the repository.
Expand All @@ -161,7 +169,7 @@ def put_object_from_tree(self, filepath: str, path: Optional[str] = None) -> Non
"""
return self.base.repository.put_object_from_tree(filepath, path)

def walk(self, path: Optional[FilePath] = None) -> Iterable[tuple[pathlib.PurePosixPath, list[str], list[str]]]:
def walk(self, path: FilePath | None = None) -> t.Iterable[tuple[pathlib.PurePosixPath, list[str], list[str]]]:
"""Walk over the directories and files contained within this repository.
.. note:: the order of the dirname and filename lists that are returned is not necessarily sorted. This is in
Expand All @@ -174,11 +182,11 @@ def walk(self, path: Optional[FilePath] = None) -> Iterable[tuple[pathlib.PurePo
"""
yield from self.base.repository.walk(path)

def glob(self) -> Iterable[pathlib.PurePosixPath]:
def glob(self) -> t.Iterable[pathlib.PurePosixPath]:
"""Yield a recursive list of all paths (files and directories)."""
yield from self.base.repository.glob()

def copy_tree(self, target: str | pathlib.Path, path: Optional[FilePath] = None) -> None:
def copy_tree(self, target: str | pathlib.Path, path: FilePath | None = None) -> None:
"""Copy the contents of the entire node repository to another location on the local file system.
:param target: absolute path of the directory where to copy the contents to.
Expand Down
26 changes: 17 additions & 9 deletions aiida/orm/nodes/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@
import pathlib
import shutil
import tempfile
from typing import TYPE_CHECKING, Any, BinaryIO, Iterable, Iterator, TextIO, Union
import typing as t

from aiida.common import exceptions
from aiida.manage import get_config_option
from aiida.repository import File, Repository
from aiida.repository.backend import SandboxRepositoryBackend

if TYPE_CHECKING:
if t.TYPE_CHECKING:
from .node import Node

__all__ = ('NodeRepository',)

FilePath = Union[str, pathlib.PurePosixPath]
FilePath = t.Union[str, pathlib.PurePosixPath]


class NodeRepository:
Expand All @@ -47,7 +47,7 @@ def __init__(self, node: 'Node') -> None:
self._repository_instance: Repository | None = None

@property
def metadata(self) -> dict[str, Any]:
def metadata(self) -> dict[str, t.Any]:
"""Return the repository metadata, representing the virtual file hierarchy.
Note, this is only accurate if the node is stored.
Expand Down Expand Up @@ -165,7 +165,7 @@ def list_object_names(self, path: str | None = None) -> list[str]:
return self._repository.list_object_names(path)

@contextlib.contextmanager
def open(self, path: FilePath, mode='r') -> Iterator[BinaryIO | TextIO]:
def open(self, path: FilePath, mode='r') -> t.Iterator[t.BinaryIO | t.TextIO]:
"""Open a file handle to an object stored under the given key.
.. note:: this should only be used to open a handle to read an existing file. To write a new file use the method
Expand All @@ -188,7 +188,7 @@ def open(self, path: FilePath, mode='r') -> Iterator[BinaryIO | TextIO]:
yield handle

@contextlib.contextmanager
def as_path(self, path: FilePath | None = None) -> Iterator[pathlib.Path]:
def as_path(self, path: FilePath | None = None) -> t.Iterator[pathlib.Path]:
"""Make the contents of the repository available as a normal filepath on the local file system.
:param path: optional relative path of the object within the repository.
Expand Down Expand Up @@ -223,7 +223,15 @@ def get_object(self, path: FilePath | None = None) -> File:
"""
return self._repository.get_object(path)

def get_object_content(self, path: str, mode='r') -> str | bytes:
@t.overload
def get_object_content(self, path: str, mode: t.Literal['r']) -> str:
...

@t.overload
def get_object_content(self, path: str, mode: t.Literal['rb']) -> bytes:
...

def get_object_content(self, path: str, mode: t.Literal['r', 'rb'] = 'r') -> str | bytes:
"""Return the content of a object identified by key.
:param path: the relative path of the object within the repository.
Expand Down Expand Up @@ -298,7 +306,7 @@ def put_object_from_tree(self, filepath: str, path: str | None = None):
self._repository.put_object_from_tree(filepath, path)
self._update_repository_metadata()

def walk(self, path: FilePath | None = None) -> Iterable[tuple[pathlib.PurePosixPath, list[str], list[str]]]:
def walk(self, path: FilePath | None = None) -> t.Iterable[tuple[pathlib.PurePosixPath, list[str], list[str]]]:
"""Walk over the directories and files contained within this repository.
.. note:: the order of the dirname and filename lists that are returned is not necessarily sorted. This is in
Expand All @@ -311,7 +319,7 @@ def walk(self, path: FilePath | None = None) -> Iterable[tuple[pathlib.PurePosix
"""
yield from self._repository.walk(path)

def glob(self) -> Iterable[pathlib.PurePosixPath]:
def glob(self) -> t.Iterable[pathlib.PurePosixPath]:
"""Yield a recursive list of all paths (files and directories)."""
for dirpath, dirnames, filenames in self.walk():
for dirname in dirnames:
Expand Down

0 comments on commit d18eedc

Please sign in to comment.