From ed8d4c2c9008bdcdf674865170e867b7373cc282 Mon Sep 17 00:00:00 2001 From: Hofer-Julian <30049909+Hofer-Julian@users.noreply.github.com> Date: Mon, 28 Mar 2022 14:34:07 +0200 Subject: [PATCH] Add typing annotations (#72) * Add typing annotations * Update CI * Pymake issue is not yet fixed in the latest release * Sort imports * Use typing syntax that is compatible with <3.10 * Extend typing * CI: Ignore missing imports * Fix types * Add type annotation * Ignore last type errors * Install types before running mypy * Finish up CI * Export types * Add more annotations * Add more types --- .github/workflows/CI.yml | 17 +++--- setup.py | 4 +- xmipy/__init__.py | 6 +- xmipy/py.typed | 0 xmipy/timers/timer.py | 4 +- xmipy/utils.py | 7 ++- xmipy/xmi.py | 8 +-- xmipy/xmiwrapper.py | 119 ++++++++++++++++++--------------------- 8 files changed, 79 insertions(+), 86 deletions(-) create mode 100644 xmipy/py.typed diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a8161e6..676c644 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -30,6 +30,10 @@ jobs: run: | sudo ln -fs /usr/bin/gfortran-10 /usr/local/bin/gfortran gfortran --version + # Remove as soon as https://github.com/modflowpy/pymake/issues/111 is fixed + - name: Install numpy + run: | + pip install numpy - name: Install and print system dependencies (macOS) shell: bash if: runner.os == 'macOS' @@ -40,10 +44,6 @@ jobs: if: runner.os == 'Windows' run: | gfortran --version - # Remove as soon as https://github.com/modflowpy/pymake/issues/111 is fixed - - name: Install numpy - run: | - pip install numpy - name: Install test dependencies run: | pip install -e ".[tests]" @@ -59,10 +59,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.8 + - name: Set up Python uses: actions/setup-python@v1 - with: - python-version: 3.8 - name: Install lint dependencies run: | pip install -e ".[lint]" @@ -70,5 +68,6 @@ jobs: run: black --check . - name: Run isort run: isort --check . - - name: Run flake8 - run: flake8 + - name: Run mypy + run: | + mypy --install-types --non-interactive --ignore-missing-imports . diff --git a/setup.py b/setup.py index a536516..a154b17 100755 --- a/setup.py +++ b/setup.py @@ -46,13 +46,15 @@ def get_version(rel_path): extras_require={ "tests": ["pytest", "pytest-cov", "requests", "mfpymake", "flopy"], "lint": [ - "flake8", + "mypy", "black", "isort", ], }, python_requires=">=3.7", packages=find_namespace_packages(exclude=("tests", "examples")), + package_data={"xmipy": ["py.typed"]}, version=get_version("xmipy/__init__.py"), classifiers=["Topic :: Scientific/Engineering :: Hydrology"], + zip_safe=False, ) diff --git a/xmipy/__init__.py b/xmipy/__init__.py index 17bd166..be190c0 100644 --- a/xmipy/__init__.py +++ b/xmipy/__init__.py @@ -1,5 +1,5 @@ # imports -from xmipy.xmi import Xmi -from xmipy.xmiwrapper import XmiWrapper +from xmipy.xmi import Xmi as Xmi +from xmipy.xmiwrapper import XmiWrapper as XmiWrapper -__version__ = "1.0.0" +__version__ = "1.1" diff --git a/xmipy/py.typed b/xmipy/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/xmipy/timers/timer.py b/xmipy/timers/timer.py index 2f5c75f..c5e1a15 100644 --- a/xmipy/timers/timer.py +++ b/xmipy/timers/timer.py @@ -3,8 +3,6 @@ Adapted from https://pypi.org/project/codetiming/. """ -# Standard library imports -import functools import logging import math import time @@ -21,7 +19,7 @@ class Timer: def __init__(self, name: str, text: str): self.name = name self.timers = Timers() - self._start_time = {} + self._start_time: dict[str, float] = {} self.text = text self.last = math.nan diff --git a/xmipy/utils.py b/xmipy/utils.py index 5575477..bebc7e4 100644 --- a/xmipy/utils.py +++ b/xmipy/utils.py @@ -1,11 +1,12 @@ import os from contextlib import contextmanager +from pathlib import Path @contextmanager -def cd(newdir): - prevdir = os.getcwd() - os.chdir(os.path.expanduser(newdir)) +def cd(newdir: Path): + prevdir = Path().cwd() + os.chdir(newdir) try: yield finally: diff --git a/xmipy/xmi.py b/xmipy/xmi.py index 4f72111..ac8b00c 100644 --- a/xmipy/xmi.py +++ b/xmipy/xmi.py @@ -21,7 +21,7 @@ class Xmi(Bmi): """ @abstractmethod - def prepare_time_step(self, dt) -> None: + def prepare_time_step(self, dt: float) -> None: """ """ ... @@ -41,16 +41,16 @@ def get_subcomponent_count(self) -> int: ... @abstractmethod - def prepare_solve(self, component_id) -> None: + def prepare_solve(self, component_id: int) -> None: """ """ ... @abstractmethod - def solve(self, component_id) -> bool: + def solve(self, component_id: int) -> bool: """ """ ... @abstractmethod - def finalize_solve(self, component_id) -> None: + def finalize_solve(self, component_id: int) -> None: """ """ ... diff --git a/xmipy/xmiwrapper.py b/xmipy/xmiwrapper.py index 4ec9c75..3bd160d 100644 --- a/xmipy/xmiwrapper.py +++ b/xmipy/xmiwrapper.py @@ -13,16 +13,12 @@ cdll, create_string_buffer, ) - -FreeLibrary = None -try: - import _ctypes.FreeLibrary -except Exception: - pass from enum import Enum, IntEnum, unique -from typing import Iterable, Tuple +from pathlib import Path +from typing import Any, Callable, Union import numpy as np +from numpy.typing import NDArray from xmipy.errors import InputError, TimerError, XMIError from xmipy.timers.timer import Timer @@ -52,25 +48,29 @@ class XmiWrapper(Xmi): def __init__( self, - lib_path: str, - lib_dependency: str = None, - working_directory: str = ".", + lib_path: Union[str, Path], + lib_dependency: Union[str, Path, None] = None, + working_directory: Union[str, Path, None] = None, timing: bool = False, ): - self._add_lib_dependency(lib_dependency) + if lib_dependency: + self._add_lib_dependency(lib_dependency) if sys.version_info[0:2] < (3, 8): # Python version < 3.8 - self.lib = CDLL(lib_path) + self.lib = CDLL(str(lib_path)) else: # LoadLibraryEx flag: LOAD_WITH_ALTERED_SEARCH_PATH 0x08 - # -> uses the altered search path for resolving ddl dependencies + # -> uses the altered search path for resolving dll dependencies # `winmode` has no effect while running on Linux or macOS # Note: this could make xmipy less secure (dll-injection) # Can we get it to work without this flag? - self.lib = CDLL(lib_path, winmode=0x08) + self.lib = CDLL(str(lib_path), winmode=0x08) - self.working_directory = working_directory + if working_directory: + self.working_directory = Path(working_directory) + else: + self.working_directory = Path().cwd() self._state = State.UNINITIALIZED self.timing = timing self.libname = os.path.basename(lib_path) @@ -82,20 +82,20 @@ def __init__( ) @staticmethod - def _add_lib_dependency(lib_dependency): - if lib_dependency: - if platform.system() == "Windows": - os.environ["PATH"] = lib_dependency + os.pathsep + os.environ["PATH"] + def _add_lib_dependency(lib_dependency: Union[str, Path]) -> None: + lib_dependency = str(lib_dependency) + if platform.system() == "Windows": + os.environ["PATH"] = lib_dependency + os.pathsep + os.environ["PATH"] + else: + # Assume a Unix-like system + if "LD_LIBRARY_PATH" in os.environ: + os.environ["LD_LIBRARY_PATH"] = ( + lib_dependency + os.pathsep + os.environ["LD_LIBRARY_PATH"] + ) else: - # Assume a Unix-like system - if "LD_LIBRARY_PATH" in os.environ: - os.environ["LD_LIBRARY_PATH"] = ( - lib_dependency + os.pathsep + os.environ["LD_LIBRARY_PATH"] - ) - else: - os.environ["LD_LIBRARY_PATH"] = lib_dependency + os.environ["LD_LIBRARY_PATH"] = lib_dependency - def report_timing_totals(self): + def report_timing_totals(self) -> float: if self.timing: total = self.timer.report_totals() logger.info(f"Total elapsed time for {self.libname}: {total:0.4f} seconds") @@ -131,10 +131,6 @@ def finalize(self) -> None: with cd(self.working_directory): self.execute_function(self.lib.finalize) self._state = State.UNINITIALIZED - try: - FreeLibrary(self.lib._handle) - except Exception: - pass else: raise InputError("The library is not initialized yet") @@ -174,7 +170,7 @@ def get_output_item_count(self) -> int: self.execute_function(self.lib.get_output_item_count, byref(count)) return count.value - def get_input_var_names(self) -> Tuple[str]: + def get_input_var_names(self): len_address = self.get_constant_int("BMI_LENVARADDRESS") nr_input_vars = self.get_input_item_count() len_names = nr_input_vars * len_address @@ -185,15 +181,15 @@ def get_input_var_names(self) -> Tuple[str]: self.execute_function(self.lib.get_input_var_names, byref(names)) # decode - input_vars = [ - names[i * len_address : (i + 1) * len_address] + input_vars = ( + names[i * len_address : (i + 1) * len_address] # type: ignore .split(b"\0", 1)[0] .decode("ascii") for i in range(nr_input_vars) - ] + ) return tuple(input_vars) - def get_output_var_names(self) -> Tuple[str]: + def get_output_var_names(self): len_address = self.get_constant_int("BMI_LENVARADDRESS") nr_output_vars = self.get_output_item_count() len_names = nr_output_vars * len_address @@ -205,7 +201,7 @@ def get_output_var_names(self) -> Tuple[str]: # decode output_vars = [ - names[i * len_address : (i + 1) * len_address] + names[i * len_address : (i + 1) * len_address] # type: ignore .split(b"\0", 1)[0] .decode("ascii") for i in range(nr_output_vars) @@ -234,7 +230,7 @@ def get_var_type(self, name: str) -> str: return var_type.value.decode() # strictly speaking not BMI... - def get_var_shape(self, name: str) -> np.ndarray: + def get_var_shape(self, name: str) -> NDArray: rank = self.get_var_rank(name) array = np.zeros(rank, dtype=np.int32) self.execute_function( @@ -284,7 +280,7 @@ def get_var_location(self, name: str) -> str: def get_time_units(self) -> str: raise NotImplementedError - def get_value(self, name: str, dest: np.ndarray = None) -> np.ndarray: + def get_value(self, name: str, dest: Union[NDArray, None] = None) -> NDArray: # make sure that optional array is of correct layout: if dest is not None: if not dest.flags["C"]: @@ -326,8 +322,7 @@ def get_value(self, name: str, dest: np.ndarray = None) -> np.ndarray: return dest - def get_value_ptr(self, name: str) -> np.ndarray: - + def get_value_ptr(self, name: str) -> NDArray: # first scalars rank = self.get_var_rank(name) if rank == 0: @@ -376,8 +371,10 @@ def get_value_ptr(self, name: str) -> np.ndarray: detail="for variable " + name, ) return values.contents + else: + raise InputError(f"Given {vartype=} is invalid.") - def get_value_ptr_scalar(self, name: str) -> np.ndarray: + def get_value_ptr_scalar(self, name: str) -> NDArray: vartype = self.get_var_type(name) if vartype.lower().startswith("double"): arraytype = np.ctypeslib.ndpointer( @@ -392,7 +389,7 @@ def get_value_ptr_scalar(self, name: str) -> np.ndarray: ) elif vartype.lower().startswith("float"): arraytype = np.ctypeslib.ndpointer( - dtype=np.float, ndim=1, shape=(1,), flags="C" + dtype=float, ndim=1, shape=(1,), flags="C" ) values = arraytype() self.execute_function( @@ -417,12 +414,10 @@ def get_value_ptr_scalar(self, name: str) -> np.ndarray: return values.contents - def get_value_at_indices( - self, name: str, dest: np.ndarray, inds: np.ndarray - ) -> np.ndarray: + def get_value_at_indices(self, name: str, dest: NDArray, inds: NDArray) -> NDArray: raise NotImplementedError - def set_value(self, name: str, values: np.ndarray) -> None: + def set_value(self, name: str, values: NDArray) -> None: if not values.flags["C"]: raise InputError("Array should have C layout") vartype = self.get_var_type(name) @@ -447,9 +442,7 @@ def set_value(self, name: str, values: np.ndarray) -> None: else: raise InputError("Unsupported value type") - def set_value_at_indices( - self, name: str, inds: np.ndarray, src: np.ndarray - ) -> None: + def set_value_at_indices(self, name: str, inds: NDArray, src: NDArray) -> None: raise NotImplementedError def get_grid_rank(self, grid: int) -> int: @@ -486,7 +479,7 @@ def get_grid_type(self, grid: int) -> str: ) return grid_type.value.decode() - def get_grid_shape(self, grid: int, shape: np.ndarray) -> np.ndarray: + def get_grid_shape(self, grid: int, shape: NDArray) -> NDArray: c_grid = c_int(grid) self.execute_function( self.lib.get_grid_shape, @@ -496,13 +489,13 @@ def get_grid_shape(self, grid: int, shape: np.ndarray) -> np.ndarray: ) return shape - def get_grid_spacing(self, grid: int, spacing: np.ndarray) -> np.ndarray: + def get_grid_spacing(self, grid: int, spacing: NDArray) -> NDArray: raise NotImplementedError - def get_grid_origin(self, grid: int, origin: np.ndarray) -> np.ndarray: + def get_grid_origin(self, grid: int, origin: NDArray) -> NDArray: raise NotImplementedError - def get_grid_x(self, grid: int, x: np.ndarray) -> np.ndarray: + def get_grid_x(self, grid: int, x: NDArray) -> NDArray: c_grid = c_int(grid) self.execute_function( self.lib.get_grid_x, @@ -512,7 +505,7 @@ def get_grid_x(self, grid: int, x: np.ndarray) -> np.ndarray: ) return x - def get_grid_y(self, grid: int, y: np.ndarray) -> np.ndarray: + def get_grid_y(self, grid: int, y: NDArray) -> NDArray: c_grid = c_int(grid) self.execute_function( self.lib.get_grid_y, @@ -522,7 +515,7 @@ def get_grid_y(self, grid: int, y: np.ndarray) -> np.ndarray: ) return y - def get_grid_z(self, grid: int, z: np.ndarray) -> np.ndarray: + def get_grid_z(self, grid: int, z: NDArray) -> NDArray: c_grid = c_int(grid) self.execute_function( self.lib.get_grid_z, @@ -557,13 +550,13 @@ def get_grid_face_count(self, grid: int) -> int: ) return grid_face_count.value - def get_grid_edge_nodes(self, grid: int, edge_nodes: np.ndarray) -> np.ndarray: + def get_grid_edge_nodes(self, grid: int, edge_nodes: NDArray) -> NDArray: raise NotImplementedError - def get_grid_face_edges(self, grid: int, face_edges: np.ndarray) -> np.ndarray: + def get_grid_face_edges(self, grid: int, face_edges: NDArray) -> NDArray: raise NotImplementedError - def get_grid_face_nodes(self, grid: int, face_nodes: np.ndarray) -> np.ndarray: + def get_grid_face_nodes(self, grid: int, face_nodes: NDArray) -> NDArray: c_grid = c_int(grid) self.execute_function( self.lib.get_grid_face_nodes, @@ -573,9 +566,7 @@ def get_grid_face_nodes(self, grid: int, face_nodes: np.ndarray) -> np.ndarray: ) return face_nodes - def get_grid_nodes_per_face( - self, grid: int, nodes_per_face: np.ndarray - ) -> np.ndarray: + def get_grid_nodes_per_face(self, grid: int, nodes_per_face: NDArray) -> NDArray: c_grid = c_int(grid) self.execute_function( self.lib.get_grid_nodes_per_face, @@ -643,7 +634,9 @@ def get_var_address( return var_address.value.decode() - def execute_function(self, function, *args, detail=""): + def execute_function( + self, function: Callable[[Any], int], *args, detail="" + ) -> None: """ Utility function to execute a BMI function in the kernel and checks its status """