Skip to content

Commit

Permalink
Typing: Add overload signatures for open
Browse files Browse the repository at this point in the history
Added for the `FolderData` and `NodeRepository` classes. The signature
of the `SinglefileData` was actually incorrect as it defined:

    t.Iterator[t.BinaryIO | t.TextIO]

as the return type, but which should really be:

    t.Iterator[t.BinaryIO] | t.Iterator[t.TextIO]

The former will cause `mypy` to raise an error.
  • Loading branch information
sphuber committed Nov 27, 2023
1 parent d18eedc commit 0986f6b
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 13 deletions.
12 changes: 11 additions & 1 deletion aiida/orm/nodes/data/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,18 @@ def list_object_names(self, path: str | None = None) -> list[str]:
"""
return self.base.repository.list_object_names(path)

@t.overload
@contextlib.contextmanager
def open(self, path: FilePath, mode: t.Literal['r']) -> t.Iterator[t.TextIO]:
...

@t.overload
@contextlib.contextmanager
def open(self, path: FilePath, mode: t.Literal['rb']) -> t.Iterator[t.BinaryIO]:
...

@contextlib.contextmanager
def open(self, path: str, mode='r') -> t.Iterator[t.BinaryIO | t.TextIO]:
def open(self, path: FilePath, mode: t.Literal['r', 'rb'] = 'r') -> t.Iterator[t.BinaryIO] | t.Iterator[t.TextIO]:
"""Open a file handle to an object stored under the given key.
.. note:: this should only be used to open a handle to read an existing file. To write a new file use the method
Expand Down
22 changes: 15 additions & 7 deletions aiida/orm/nodes/data/singlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

__all__ = ('SinglefileData',)

FilePath = t.Union[str, pathlib.PurePosixPath]


class SinglefileData(Data):
"""Data class that can be used to store a single file in its repository."""
Expand All @@ -37,7 +39,9 @@ def from_string(cls, content: str, filename: str | pathlib.Path | None = None, *
"""
return cls(io.StringIO(content), filename, **kwargs)

def __init__(self, file: str | t.IO, filename: str | pathlib.Path | None = None, **kwargs: t.Any) -> None:
def __init__(
self, file: str | pathlib.Path | t.IO, filename: str | pathlib.Path | None = None, **kwargs: t.Any
) -> None:
"""Construct a new instance and set the contents to that of the file.
:param file: an absolute filepath or filelike object whose contents to copy.
Expand All @@ -60,26 +64,30 @@ def filename(self) -> str:

@t.overload
@contextlib.contextmanager
def open(self, path: str, mode: t.Literal['r']) -> t.Iterator[t.TextIO]:
def open(self, path: FilePath, mode: t.Literal['r'] = ...) -> t.Iterator[t.TextIO]:
...

@t.overload
@contextlib.contextmanager
def open(self, path: None, mode: t.Literal['r']) -> t.Iterator[t.TextIO]:
def open(self, path: FilePath, mode: t.Literal['rb']) -> t.Iterator[t.BinaryIO]:
...

@t.overload
@contextlib.contextmanager
def open(self, path: str, mode: t.Literal['rb']) -> t.Iterator[t.BinaryIO]:
def open( # type: ignore[overload-overlap]
self, path: None = None, mode: t.Literal['r'] = ...
) -> t.Iterator[t.TextIO]:
...

@t.overload
@contextlib.contextmanager
def open(self, path: None, mode: t.Literal['rb']) -> t.Iterator[t.BinaryIO]:
def open(self, path: None = None, mode: t.Literal['rb'] = ...) -> t.Iterator[t.BinaryIO]:
...

@contextlib.contextmanager
def open(self, path: str | None = None, mode: t.Literal['r', 'rb'] = 'r') -> t.Iterator[t.BinaryIO | t.TextIO]:
def open(self,
path: FilePath | None = None,
mode: t.Literal['r', 'rb'] = 'r') -> t.Iterator[t.BinaryIO] | t.Iterator[t.TextIO]:
"""Return an open file handle to the content of this data node.
:param path: the relative path of the object within the repository.
Expand Down Expand Up @@ -113,7 +121,7 @@ def get_content(self, mode: str = 'r') -> str | bytes:
with self.open(mode=mode) as handle: # type: ignore[call-overload]
return handle.read()

def set_file(self, file: str | t.IO, filename: str | pathlib.Path | None = None) -> None:
def set_file(self, file: str | pathlib.Path | t.IO, filename: str | pathlib.Path | None = None) -> None:
"""Store the content of the file in the node's repository, deleting any other existing objects.
:param file: an absolute filepath or filelike object whose contents to copy
Expand Down
14 changes: 12 additions & 2 deletions aiida/orm/nodes/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,18 @@ def list_object_names(self, path: str | None = None) -> list[str]:
"""
return self._repository.list_object_names(path)

@t.overload
@contextlib.contextmanager
def open(self, path: FilePath, mode: t.Literal['r']) -> t.Iterator[t.TextIO]:
...

@t.overload
@contextlib.contextmanager
def open(self, path: FilePath, mode: t.Literal['rb']) -> t.Iterator[t.BinaryIO]:
...

@contextlib.contextmanager
def open(self, path: FilePath, mode='r') -> t.Iterator[t.BinaryIO | t.TextIO]:
def open(self, path: FilePath, mode: t.Literal['r', 'rb'] = 'r') -> t.Iterator[t.BinaryIO] | t.Iterator[t.TextIO]:
"""Open a file handle to an object stored under the given key.
.. note:: this should only be used to open a handle to read an existing file. To write a new file use the method
Expand Down Expand Up @@ -210,7 +220,7 @@ def as_path(self, path: FilePath | None = None) -> t.Iterator[pathlib.Path]:
assert path is not None
with self.open(path, mode='rb') as source:
with filepath.open('wb') as target:
shutil.copyfileobj(source, target) # type: ignore[misc]
shutil.copyfileobj(source, target)
yield filepath

def get_object(self, path: FilePath | None = None) -> File:
Expand Down
2 changes: 1 addition & 1 deletion aiida/parsers/plugins/arithmetic/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
###########################################################################
# Warning: this implementation is used directly in the documentation as a literal-include, which means that if any part
# of this code is changed, the snippets in the file `docs/source/howto/codes.rst` have to be checked for consistency.
# mypy: disable_error_code=arg-type
# mypy: disable_error_code=call-overload
"""Parser for an `ArithmeticAddCalculation` job."""
from aiida.parsers.parser import Parser

Expand Down
5 changes: 3 additions & 2 deletions aiida/parsers/plugins/diff_tutorial/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Register parsers via the "aiida.parsers" entry point in the pyproject.toml file.
"""
# mypy: disable_error_code=call-overload
# START PARSER HEAD
from aiida.engine import ExitCode
from aiida.orm import SinglefileData
Expand Down Expand Up @@ -38,7 +39,7 @@ def parse(self, **kwargs):

# add output file
self.logger.info(f"Parsing '{output_filename}'")
with self.retrieved.open(output_filename, 'rb') as handle: # type: ignore[arg-type]
with self.retrieved.open(output_filename, 'rb') as handle:
output_node = SinglefileData(file=handle)
self.out('diff', output_node)

Expand All @@ -59,7 +60,7 @@ def parse(self, **kwargs):

# add output file
self.logger.info(f"Parsing '{output_filename}'")
with self.retrieved.open(output_filename, 'rb') as handle: # type: ignore[arg-type]
with self.retrieved.open(output_filename, 'rb') as handle:
output_node = SinglefileData(file=handle)
self.out('diff', output_node)

Expand Down
1 change: 1 addition & 0 deletions docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ py:class BinaryIO
py:class EntryPoint
py:class EntryPoints
py:class IO
py:class FilePath
py:class Path
py:class str | list[str]
py:class str | Path
Expand Down

0 comments on commit 0986f6b

Please sign in to comment.