diff --git a/.github/workflows/fabm.yml b/.github/workflows/fabm.yml index f7381c2f..1ac15134 100644 --- a/.github/workflows/fabm.yml +++ b/.github/workflows/fabm.yml @@ -303,6 +303,11 @@ jobs: conda install ipython jupyterlab ipympl matplotlib scipy cd testcases/python for f in *.ipynb; do ipython --gui qt -c "%run $f"; done + - name: Type checking with mypy + if: matrix.python-version != '3.7' && matrix.python-version != '3.8' + run: | + conda install mypy pyyaml types-pyyaml netcdf4 pyqt-stubs + mypy src/pyfabm - name: Install with customization via command line arguments run: | rm -rf build diff --git a/src/fabm_coupling.F90 b/src/fabm_coupling.F90 index 853af911..1211c581 100644 --- a/src/fabm_coupling.F90 +++ b/src/fabm_coupling.F90 @@ -339,7 +339,8 @@ recursive subroutine process_coupling_tasks(self, final, log_unit) standard_variable=standard_variable, presence=presence_external_optional, link=link) class is (type_horizontal_standard_variable) call root%add_horizontal_variable(standard_variable%name, standard_variable%units, standard_variable%name, & - standard_variable=standard_variable, presence=presence_external_optional, link=link) + standard_variable=standard_variable, presence=presence_external_optional, link=link, & + domain=standard_variable2domain(standard_variable)) class is (type_global_standard_variable) call root%add_scalar_variable(standard_variable%name, standard_variable%units, standard_variable%name, & standard_variable=standard_variable, presence=presence_external_optional, link=link) diff --git a/src/pyfabm/__init__.py b/src/pyfabm/__init__.py index a998c526..2cd61886 100644 --- a/src/pyfabm/__init__.py +++ b/src/pyfabm/__init__.py @@ -16,8 +16,19 @@ TypeVar, List, Dict, + Type, + overload, + cast, ) +# typing.Final not available in Python 3.7 +try: + from typing import Final, SupportsIndex +except ImportError: + from typing import Any + + Final = SupportsIndex = Any # type: ignore + try: import importlib.metadata @@ -35,12 +46,24 @@ LOG_CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_char_p) -name2lib: MutableMapping[str, ctypes.CDLL] = {} + +class FABMDLL(ctypes.CDLL): + dtype: Union[Type[ctypes.c_float], Type[ctypes.c_double]] + numpy_dtype: np.dtype + ndim_int: int + ndim_hz: int + idepthdim: int + mask_type: int + variable_bottom_index: bool + + +name2lib: MutableMapping[str, FABMDLL] = {} def _find_library(name: str) -> str: # Determine potential names of dynamic library. libdir, name = os.path.split(name) + names: Tuple[str, ...] if os.name == "nt": names = (f"{name}.dll", f"lib{name}.dll") elif os.name == "posix" and sys.platform == "darwin": @@ -68,7 +91,7 @@ def _find_library(name: str) -> str: ) -def get_lib(name: str) -> ctypes.CDLL: +def get_lib(name: str) -> FABMDLL: if name in name2lib: return name2lib[name] @@ -78,7 +101,7 @@ def get_lib(name: str) -> ctypes.CDLL: path = _find_library(name) # Load FABM library. - lib = ctypes.CDLL(path) + lib = FABMDLL(path) lib.dtype = ctypes.c_double lib.numpy_dtype = np.dtype(lib.dtype).newbyteorder("=") @@ -109,7 +132,7 @@ def get_lib(name: str) -> ctypes.CDLL: lib.mask_type = mask_type.value lib.variable_bottom_index = variable_bottom_index.value != 0 - CONTIGUOUS = "CONTIGUOUS" + CONTIGUOUS: Final = "CONTIGUOUS" arrtype0D = np.ctypeslib.ndpointer(dtype=lib.dtype, ndim=0, flags=CONTIGUOUS) arrtype1D = np.ctypeslib.ndpointer(dtype=lib.dtype, ndim=1, flags=CONTIGUOUS) arrtypeInterior = np.ctypeslib.ndpointer( @@ -545,10 +568,12 @@ def getError() -> Optional[str]: strmessage = ctypes.create_string_buffer(1024) lib.get_error(1024, strmessage) return strmessage.value.decode("ascii") + return None NodeValue = TypeVar("NodeValue") NodeType = Mapping[str, Union["NodeType", NodeValue]] +EditableNodeType = Dict[str, Union["EditableNodeType", NodeValue]] def printTree( @@ -571,24 +596,24 @@ def __init__(self, model: "Model", variable_pointer: ctypes.c_void_p): self._pvariable = variable_pointer def __getitem__(self, key: str) -> Union[float, int, bool]: - typecode = self.model.fabm.variable_get_property_type( + typecode: int = self.model.fabm.variable_get_property_type( self._pvariable, key.encode("ascii") ) if typecode == DataType.REAL: - return self.model.fabm.variable_get_real_property( + value = self.model.fabm.variable_get_real_property( self._pvariable, key.encode("ascii"), -1.0 ) + return cast(float, value) elif typecode == DataType.INTEGER: - return self.model.fabm.variable_get_integer_property( + value = self.model.fabm.variable_get_integer_property( self._pvariable, key.encode("ascii"), 0 ) + return cast(int, value) elif typecode == DataType.LOGICAL: - return ( - self.model.fabm.variable_get_logical_property( - self._pvariable, key.encode("ascii"), 0 - ) - != 0 + value = self.model.fabm.variable_get_logical_property( + self._pvariable, key.encode("ascii"), 0 ) + return cast(int, value) != 0 raise KeyError @@ -647,7 +672,8 @@ def __init__(self, model: "Model", variable_pointer: ctypes.c_void_p): @property def long_path(self) -> str: - """Long model instance name, followed by a slash, followed by long variable name.""" + """Long model instance name, followed by a slash, followed by long + variable name.""" strlong_name = ctypes.create_string_buffer(ATTRIBUTE_LENGTH) self.model.fabm.variable_get_long_path( self._pvariable, ATTRIBUTE_LENGTH, strlong_name @@ -657,12 +683,14 @@ def long_path(self) -> str: @property def missing_value(self) -> float: """Value that indicates missing data, for instance, on land.""" - return self.model.fabm.variable_get_missing_value(self._pvariable) + value: float = self.model.fabm.variable_get_missing_value(self._pvariable) + return value - def getRealProperty(self, name, default=-1.0) -> float: - return self.model.fabm.variable_get_real_property( + def getRealProperty(self, name: str, default: float = -1.0) -> float: + value: float = self.model.fabm.variable_get_real_property( self._pvariable, name.encode("ascii"), default ) + return value class Dependency(VariableFromPointer): @@ -670,26 +698,27 @@ def __init__( self, model: "Model", variable_pointer: ctypes.c_void_p, - shape: Tuple[int], + shape: Tuple[int, ...], link_function: Callable[[ctypes.c_void_p, ctypes.c_void_p, np.ndarray], None], ): super().__init__(model, variable_pointer) - self._is_set = False self._link_function = link_function self._shape = shape + self._data: Optional[np.ndarray] = None @property def value(self) -> Optional[np.ndarray]: - return None if not self._is_set else self._data + return self._data @value.setter def value(self, value: npt.ArrayLike): - if not self._is_set: + if self._data is None: self.link(np.empty(self._shape, dtype=self.model.fabm.numpy_dtype)) + assert self._data is not None self._data[...] = value @property - def shape(self) -> Tuple[int]: + def shape(self) -> Tuple[int, ...]: return self._shape def link(self, data: np.ndarray): @@ -699,11 +728,11 @@ def link(self, data: np.ndarray): ) self._data = data self._link_function(self.model.pmodel, self._pvariable, self._data) - self._is_set = True @property def required(self) -> bool: - return self.model.fabm.variable_is_required(self._pvariable) != 0 + value: int = self.model.fabm.variable_is_required(self._pvariable) + return value != 0 class StateVariable(VariableFromPointer): @@ -723,21 +752,25 @@ def value(self, value: npt.ArrayLike): @property def background_value(self) -> float: - return self.model.fabm.variable_get_background_value(self._pvariable) + value: float = self.model.fabm.variable_get_background_value(self._pvariable) + return value @property def output(self) -> bool: - return self.model.fabm.variable_get_output(self._pvariable) != 0 + value: int = self.model.fabm.variable_get_output(self._pvariable) + return value != 0 @property def no_river_dilution(self) -> bool: - return self.model.fabm.variable_get_no_river_dilution(self._pvariable) != 0 + value: int = self.model.fabm.variable_get_no_river_dilution(self._pvariable) + return value != 0 @property def no_precipitation_dilution(self) -> bool: - return ( - self.model.fabm.variable_get_no_precipitation_dilution(self._pvariable) != 0 + value: int = self.model.fabm.variable_get_no_precipitation_dilution( + self._pvariable ) + return value != 0 class DiagnosticVariable(VariableFromPointer): @@ -749,7 +782,7 @@ def __init__( horizontal: bool, ): super().__init__(model, variable_pointer) - self._data = None + self._data: Optional[np.ndarray] = None self._horizontal = horizontal self._index = index + 1 @@ -760,9 +793,17 @@ def value(self) -> Optional[np.ndarray]: @property def output(self) -> bool: """Whether this diagnostic is meant to be included in output by default""" - return self.model.fabm.variable_get_output(self._pvariable) != 0 + value: int = self.model.fabm.variable_get_output(self._pvariable) + return value != 0 + + @property + def save(self) -> bool: + """Whether the value of this diagnostic must be calculated, + for instance, because it will be included in model output""" + raise AttributeError("save attribute is write-only") - def _set_save(self, value: bool): + @save.setter + def save(self, value: bool): vartype = ( HORIZONTAL_DIAGNOSTIC_VARIABLE if self._horizontal @@ -772,9 +813,6 @@ def _set_save(self, value: bool): self.model.pmodel, vartype, self._index, 1 if value else 0 ) - #: Whether the value of this diagnostic must be calculated, for instance, for output - save: bool = property(fset=_set_save) - class Parameter(Variable): def __init__( @@ -782,37 +820,37 @@ def __init__( model: "Model", name: str, index: int, - units: Optional[str] = None, - long_name: Optional[str] = None, - type: Optional[DataType] = None, - has_default: bool = False, + units: str, + long_name: str, + type: DataType, + has_default: bool, ): super().__init__(model, name, units, long_name) self._type = type self._index = index + 1 self._has_default = has_default - def _get_value(self, *, default: bool = False): - default = 1 if default else 0 + def _get_value(self, *, default: bool = False) -> Union[float, int, bool, str]: + idefault = 1 if default else 0 if self._type == DataType.REAL: - return self.model.fabm.get_real_parameter( - self.model.pmodel, self._index, default + value = self.model.fabm.get_real_parameter( + self.model.pmodel, self._index, idefault ) + return cast(float, value) elif self._type == DataType.INTEGER: - return self.model.fabm.get_integer_parameter( - self.model.pmodel, self._index, default + value = self.model.fabm.get_integer_parameter( + self.model.pmodel, self._index, idefault ) + return cast(int, value) elif self._type == DataType.LOGICAL: - return ( - self.model.fabm.get_logical_parameter( - self.model.pmodel, self._index, default - ) - != 0 + value = self.model.fabm.get_logical_parameter( + self.model.pmodel, self._index, idefault ) + return cast(int, value) != 0 elif self._type == DataType.STRING: result = ctypes.create_string_buffer(ATTRIBUTE_LENGTH) self.model.fabm.get_string_parameter( - self.model.pmodel, self._index, default, ATTRIBUTE_LENGTH, result + self.model.pmodel, self._index, idefault, ATTRIBUTE_LENGTH, result ) return result.value.decode("ascii") @@ -838,7 +876,7 @@ def value(self, value: Union[float, int, bool, str]): ) elif self._type == DataType.STRING: self.model.fabm.set_string_parameter( - self.model.pmodel, self.name.encode("ascii"), value.encode("ascii") + self.model.pmodel, self.name.encode("ascii"), str(value).encode("ascii") ) # Update the model configuration @@ -852,7 +890,7 @@ def default(self) -> Union[float, int, bool, str, None]: return None return self._get_value(default=True) - def reset(self): + def reset(self) -> None: """Reset this parameter to its default value""" settings = self.model._save_state() self.model.fabm.reset_parameter(self.model.pmodel, self._index) @@ -869,7 +907,8 @@ def __init__( @property def missing_value(self) -> float: """Value that indicates missing data, for instance, on land.""" - return self.model.fabm.variable_get_missing_value(self._pvariable) + value: float = self.model.fabm.variable_get_missing_value(self._pvariable) + return value class StandardVariable: @@ -886,12 +925,12 @@ def value(self) -> np.ndarray: if horizontal.value == 0: shape = self.model.interior_domain_shape else: - shape = self.model.horizontal_domain.shape + shape = self.model.horizontal_domain_shape arr = np.ctypeslib.as_array(pdata, shape) return arr.view(dtype=self.model.fabm.numpy_dtype) -T = TypeVar("T") +T = TypeVar("T", bound=Variable) class NamedObjectList(Sequence[T]): @@ -905,12 +944,18 @@ def __init__(self, *data: Iterable[T]): def __len__(self) -> int: return len(self._data) - def __getitem__(self, key: Union[int, str]) -> T: + @overload + def __getitem__(self, key: Union[int, str]) -> T: ... + + @overload + def __getitem__(self, key: slice) -> Sequence[T]: ... + + def __getitem__(self, key: Union[int, slice, str]) -> Union[T, Sequence[T]]: if isinstance(key, str): return self.find(key) return self._data[key] - def __contains__(self, key: Union[T, str]) -> bool: + def __contains__(self, key: object) -> bool: if isinstance(key, str): try: self.find(key) @@ -919,7 +964,7 @@ def __contains__(self, key: Union[T, str]) -> bool: return False return key in self._data - def index(self, key: Union[T, str], *args) -> int: + def index(self, key: Union[T, str], *args: SupportsIndex) -> int: if isinstance(key, str): try: key = self.find(key) @@ -943,7 +988,7 @@ def find(self, name: str, case_insensitive: bool = False) -> T: self._lookup = {obj.name: obj for obj in self._data} return self._lookup[name] - def clear(self): + def clear(self) -> None: self._data.clear() self._lookup = None self._lookup_ci = None @@ -960,12 +1005,12 @@ def __init__(self, model: "Model", index: int): ctypes.byref(self._ptarget), ) super().__init__(model, self._psource) - self._options = None + self._options: Optional[List[str]] = None @property def value(self) -> Optional[str]: if self._psource.value == self._ptarget.value: - return + return None strlong_name = ctypes.create_string_buffer(ATTRIBUTE_LENGTH) self.model.fabm.variable_get_long_path( self._ptarget, ATTRIBUTE_LENGTH, strlong_name @@ -980,7 +1025,7 @@ def value(self, value: str): @property def options(self) -> Sequence[str]: if self._options is None: - self._options: List[str] = [] + self._options = [] plist = self.model.fabm.variable_get_suitable_masters( self.model.pmodel, self._psource ) @@ -1002,7 +1047,7 @@ def __init__(self, model: "Model", name: str): model.fabm.get_model_metadata( model.pmodel, name.encode("ascii"), ATTRIBUTE_LENGTH, strlong_name, iuser ) - self.long_name = strlong_name.value.decode("ascii") + self.long_name: str = strlong_name.value.decode("ascii") self.user_created = iuser.value != 0 @@ -1010,10 +1055,10 @@ class Model(object): def __init__( self, path: Union[str, dict] = "fabm.yaml", - shape: Tuple[int] = (), + shape: Tuple[int, ...] = (), libname: Optional[str] = None, - start: Optional[Tuple[int]] = None, - stop: Optional[Tuple[int]] = None, + start: Optional[Tuple[int, ...]] = None, + stop: Optional[Tuple[int, ...]] = None, ): delete = False if isinstance(path, dict): @@ -1049,7 +1094,7 @@ def __init__( ) self.fabm.reset_error_state() - self._cell_thickness = None + self._cell_thickness: Optional[np.ndarray] = None self.pmodel = self.fabm.create_model(path.encode("ascii"), *shape[::-1]) if hasError(): raise FABMException( @@ -1080,10 +1125,10 @@ def __init__( # fmt: on self._update_configuration() - self._mask = None - self._bottom_index = None + self._mask: Optional[Tuple[npt.NDArray[np.intc], ...]] = None + self._bottom_index: Optional[npt.NDArray[np.intc]] = None - def link_mask(self, *masks: np.ndarray): + def link_mask(self, *masks: npt.NDArray[np.intc]): if self.fabm.mask_type == 0: raise FABMException( "the underlying FABM library has been compiled without support for masks" @@ -1107,28 +1152,32 @@ def link_mask(self, *masks: np.ndarray): self.fabm.set_mask(self.pmodel, *self._mask) @property - def mask(self) -> Union[np.ndarray, Sequence[np.ndarray], None]: - mask = self._mask - if mask is not None and len(mask) == 1: - mask = mask[0] - return mask + def mask( + self, + ) -> Union[npt.NDArray[np.intc], Tuple[npt.NDArray[np.intc], ...], None]: + if self._mask is not None and len(self._mask) == 1: + return self._mask[0] + return self._mask @mask.setter def mask(self, values: Union[npt.ArrayLike, Sequence[npt.ArrayLike]]): if self.fabm.mask_type == 1: - values = (values,) + values = cast(Sequence[npt.ArrayLike], (values,)) + assert isinstance(values, Sequence) if len(values) != self.fabm.mask_type: raise FABMException(f"mask must be set to {self.fabm.mask_type} values") if self._mask is None: + masks: Tuple[npt.NDArray[np.intc], ...] masks = (np.ones(self.horizontal_domain_shape, dtype=np.intc),) if self.fabm.mask_type > 1: masks = (np.ones(self.interior_domain_shape, dtype=np.intc),) + masks self.link_mask(*masks) + assert self._mask is not None for value, mask in zip(values, self._mask): if value is not mask: mask[...] = value - def link_bottom_index(self, indices: np.ndarray): + def link_bottom_index(self, indices: npt.NDArray[np.intc]): if not self.fabm.variable_bottom_index: raise FABMException( "the underlying FABM library has been compiled without support for variable bottom indices" @@ -1142,13 +1191,14 @@ def link_bottom_index(self, indices: np.ndarray): self.fabm.set_bottom_index(self.pmodel, self._bottom_index) @property - def bottom_index(self) -> Optional[np.ndarray]: + def bottom_index(self) -> Optional[npt.NDArray[np.intc]]: return self._bottom_index @bottom_index.setter def bottom_index(self, indices: npt.ArrayLike): if self._bottom_index is None: self.link_bottom_index(np.ones(self.horizontal_domain_shape, dtype=np.intc)) + assert self._bottom_index is not None if indices is not self._bottom_index: self._bottom_index[...] = indices @@ -1211,6 +1261,7 @@ def link_cell_thickness(self, data: np.ndarray): def setCellThickness(self, value: npt.ArrayLike): if self._cell_thickness is None: self.link_cell_thickness(np.empty(self.interior_domain_shape)) + assert self._cell_thickness is not None self._cell_thickness[...] = value cell_thickness = property(fset=setCellThickness) @@ -1222,7 +1273,7 @@ def save_settings(self, path: str, display: int = DISPLAY_NORMAL): """Write model configuration to yaml file""" self.fabm.save_settings(self.pmodel, path.encode("ascii"), display) - def _save_state(self) -> Tuple: + def _save_state(self) -> Tuple[Mapping[str, np.ndarray], Mapping[str, np.ndarray]]: environment = {} for dependency in self.dependencies: if dependency.value is not None: @@ -1230,7 +1281,9 @@ def _save_state(self) -> Tuple: state = {variable.name: variable.value for variable in self.state_variables} return environment, state - def _restore_state(self, data: Tuple): + def _restore_state( + self, data: Tuple[Mapping[str, np.ndarray], Mapping[str, np.ndarray]] + ): environment, state = data for dependency in self.dependencies: if dependency.name in environment: @@ -1239,7 +1292,12 @@ def _restore_state(self, data: Tuple): if variable.name in state: variable.value = state[variable.name] - def _update_configuration(self, settings: Optional[Tuple] = None): + def _update_configuration( + self, + settings: Optional[ + Tuple[Mapping[str, np.ndarray], Mapping[str, np.ndarray]] + ] = None, + ): # Get number of model variables per category nstate_interior = ctypes.c_int() nstate_surface = ctypes.c_int() @@ -1269,6 +1327,7 @@ def _update_configuration(self, settings: Optional[Tuple] = None): # Allocate memory for state variable values, and send ctypes.pointer to # this memory to FABM. + self._state: Optional[np.ndarray] if self.fabm.idepthdim == -1: # No depth dimension, so interior and surface/bottom variables have # the same shape. Store values for all together in one contiguous array @@ -1406,7 +1465,7 @@ def _update_configuration(self, settings: Optional[Tuple] = None): self, strname.value.decode("ascii"), i, - type=typecode.value, + type=DataType(typecode.value), units=strunits.value.decode("ascii"), long_name=strlong_name.value.decode("ascii"), has_default=has_default.value != 0, @@ -1431,8 +1490,8 @@ def _update_configuration(self, settings: Optional[Tuple] = None): + self.horizontal_dependencies + self.scalar_dependencies ) - self.variables: NamedObjectList[VariableFromPointer] = ( - self.state_variables + self.diagnostic_variables + self.dependencies + self.variables = NamedObjectList( + self.state_variables, self.diagnostic_variables, self.dependencies ) if settings is not None: @@ -1537,10 +1596,10 @@ def get_conserved_quantities(self, out: Optional[np.ndarray] = None) -> np.ndarr return out def check_state(self, repair: bool = False) -> bool: - valid = self.fabm.check_state(self.pmodel, repair) != 0 + valid: int = self.fabm.check_state(self.pmodel, repair) if hasError(): raise FABMException(getError()) - return valid + return valid != 0 checkState = check_state @@ -1590,14 +1649,15 @@ def findCoupling(self, name: str, case_insensitive: bool = False): def find_standard_variable(self, name: str) -> Optional[StandardVariable]: pointer = self.fabm.find_standard_variable(name.encode("ascii")) - if pointer: - return StandardVariable(self, pointer) + if not pointer: + return None + return StandardVariable(self, pointer) def require_data(self, standard_variable: StandardVariable): return self.fabm.require_data(self.pmodel, standard_variable._pvariable) - def _get_parameter_tree(self) -> Mapping: - root = {} + def _get_parameter_tree(self) -> NodeType: + root: EditableNodeType = {} for parameter in self.parameters: pathcomps = parameter.name.split("/") parent = root @@ -1655,7 +1715,10 @@ def updateTime(self, nsec: float): def printInformation(self): """Show information about the model.""" - def printArray(classname: str, array: Sequence[Variable]): + def print_array( + classname: str, + array: Sequence[Union[StateVariable, DiagnosticVariable, Dependency]], + ): if not array: return log(f" {len(array)} {classname}:") @@ -1666,14 +1729,14 @@ def parameter2str(p: Parameter) -> str: return f"{p.value} {p.units}" log("FABM model contains the following:") - printArray("interior state variables", self.interior_state_variables) - printArray("bottom state variables", self.bottom_state_variables) - printArray("surface state variables", self.surface_state_variables) - printArray("interior diagnostic variables", self.interior_diagnostic_variables) - printArray( + print_array("interior state variables", self.interior_state_variables) + print_array("bottom state variables", self.bottom_state_variables) + print_array("surface state variables", self.surface_state_variables) + print_array("interior diagnostic variables", self.interior_diagnostic_variables) + print_array( "horizontal diagnostic variables", self.horizontal_diagnostic_variables ) - printArray("external variables", self.dependencies) + print_array("external variables", self.dependencies) log(f" {len(self.parameters)} parameters:") printTree(self._get_parameter_tree(), parameter2str, " ") @@ -1704,19 +1767,19 @@ def integrate( dt, surface, bottom, - ctypes.byref(self.model._cell_thickness), + self.model._cell_thickness, ) if hasError(): raise FABMException(getError()) return y -def unload(): +def unload() -> None: global ctypes for lib in name2lib.values(): handle = lib._handle - if os.name == "nt": + if sys.platform == "win32": import ctypes.wintypes ctypes.windll.kernel32.FreeLibrary.argtypes = [ctypes.wintypes.HMODULE] @@ -1735,3 +1798,4 @@ def get_version() -> Optional[str]: strversion = ctypes.create_string_buffer(version_length) lib.get_version(version_length, strversion) return strversion.value.decode("ascii") + return None diff --git a/src/pyfabm/complete_yaml.py b/src/pyfabm/complete_yaml.py index 9c135d9c..b9e142a0 100644 --- a/src/pyfabm/complete_yaml.py +++ b/src/pyfabm/complete_yaml.py @@ -1,7 +1,12 @@ import pyfabm -def processFile(infile, outfile, subtract_background=False, add_missing=False): +def processFile( + infile: str, + outfile: str, + subtract_background: bool = False, + add_missing: bool = False, +): # Create model object from YAML file. model = pyfabm.Model(infile) model.save_settings( diff --git a/src/pyfabm/gui_qt.py b/src/pyfabm/gui_qt.py index df4502f6..a00b55d4 100644 --- a/src/pyfabm/gui_qt.py +++ b/src/pyfabm/gui_qt.py @@ -1,5 +1,5 @@ import sys -from typing import Iterable, Union, List, Optional +from typing import Iterable, Union, List, Optional, Tuple import numpy as np import pyfabm @@ -7,9 +7,9 @@ from PyQt5 import QtCore, QtGui, QtWidgets except ImportError as e1: try: - from PySide import QtCore, QtGui + from PySide import QtCore, QtGui # type: ignore - QtWidgets = QtGui + QtWidgets = QtGui # type: ignore except ImportError as e2: print(e1) print(e2) @@ -18,45 +18,57 @@ class Delegate(QtWidgets.QStyledItemDelegate): - def __init__(self, parent=None): + def __init__(self, parent: Optional[QtCore.QObject] = None): QtWidgets.QStyledItemDelegate.__init__(self, parent) - def createEditor(self, parent, option, index): + def createEditor( + self, + parent: QtWidgets.QWidget, + option: QtWidgets.QStyleOptionViewItem, + index: QtCore.QModelIndex, + ): assert index.isValid() - data: Union[str, pyfabm.Variable] = index.internalPointer().object - if not isinstance(data, str): + entry: Entry = index.internalPointer() + data = entry.object + if isinstance(data, pyfabm.Variable): if data.options is not None: - widget = QtWidgets.QComboBox(parent) - widget.addItems(data.options) - return widget + combobox = QtWidgets.QComboBox(parent) + combobox.addItems(data.options) + return combobox elif data.value is None or isinstance(data.value, (float, np.ndarray)): - widget = ScientificDoubleEditor(parent) + editor = ScientificDoubleEditor(parent) if data.units: - widget.setSuffix(f" {data.units_unicode}") - return widget + editor.setSuffix(f" {data.units_unicode}") + return editor return QtWidgets.QStyledItemDelegate.createEditor(self, parent, option, index) - def setEditorData(self, editor, index): + def setEditorData(self, editor: QtWidgets.QWidget, index: QtCore.QModelIndex): + entry: Entry = index.internalPointer() + data = entry.object if isinstance(editor, QtWidgets.QComboBox): - data: Union[str, pyfabm.Variable] = index.internalPointer().object - if not isinstance(data, str): - if data.options is not None and data.value is not None: - editor.setCurrentIndex(data.options.index(data.value)) - return + assert isinstance(data, pyfabm.Variable) and data.options is not None + if data.value is not None: + editor.setCurrentIndex(data.options.index(data.value)) + return elif isinstance(editor, ScientificDoubleEditor): - data = index.internalPointer().object + assert isinstance(data, pyfabm.Variable) editor.setValue(data.value if data.value is None else float(data.value)) return return QtWidgets.QStyledItemDelegate.setEditorData(self, editor, index) - def setModelData(self, editor, model, index): + def setModelData( + self, + editor: QtWidgets.QWidget, + model: QtCore.QAbstractItemModel, + index: QtCore.QModelIndex, + ): if isinstance(editor, QtWidgets.QComboBox): - data = index.internalPointer().object - if not isinstance(data, str): - if data.options is not None: - i = editor.currentIndex() - model.setData(index, data.options[i], QtCore.Qt.EditRole) - return + entry: Entry = index.internalPointer() + data = entry.object + assert isinstance(data, pyfabm.Variable) and data.options is not None + i = editor.currentIndex() + model.setData(index, data.options[i], QtCore.Qt.EditRole) + return elif isinstance(editor, ScientificDoubleEditor): model.setData(index, editor.value(), QtCore.Qt.EditRole) return @@ -64,11 +76,21 @@ def setModelData(self, editor, model, index): class Entry: - def __init__(self, object: Union[None, str, "Entry"] = None, name: str = ""): - if name == "" and object is not None: - name = object - self.object = object + def __init__( + self, + name: str = "", + object: Union[ + None, + "Entry", + pyfabm.Parameter, + pyfabm.StateVariable, + pyfabm.Dependency, + pyfabm.Coupling, + "Submodel", + ] = None, + ): self.name = name + self.object = object self.parent: Optional["Entry"] = None self.children: List["Entry"] = [] assert isinstance(self.name, str) @@ -86,13 +108,24 @@ def removeChild(self, index: int): def findChild(self, name: str): for child in self.children: - if isinstance(child.object, str) and child.object == name: + if child.object is None and child.name == name: return child child = Entry(name) self.addChild(child) return child - def addTree(self, arr: Iterable[pyfabm.Variable], category: Optional[str] = None): + def addTree( + self, + arr: Iterable[ + Union[ + pyfabm.Parameter, + pyfabm.StateVariable, + pyfabm.Dependency, + pyfabm.Coupling, + ] + ], + category: Optional[str] = None, + ): for variable in arr: pathcomps = variable.path.split("/") if len(pathcomps) < 2: @@ -103,12 +136,12 @@ def addTree(self, arr: Iterable[pyfabm.Variable], category: Optional[str] = None parent = self for component in pathcomps[:-1]: parent = parent.findChild(component) - entry = Entry(variable, pathcomps[-1]) + entry = Entry(pathcomps[-1], variable) parent.addChild(entry) class Submodel: - def __init__(self, long_name): + def __init__(self, long_name: str): self.units = None self.units_unicode = None self.value = None @@ -116,28 +149,28 @@ def __init__(self, long_name): class ItemModel(QtCore.QAbstractItemModel): - def __init__(self, model: pyfabm.Model, parent): + def __init__(self, model: pyfabm.Model, parent: Optional[QtCore.QObject]): QtCore.QAbstractItemModel.__init__(self, parent) - self.root = None self.model = model - self.rebuild() + self.root = self._build_tree() - def rebuild(self): + def _build_tree(self) -> Entry: root = Entry() env = Entry("environment") for d in self.model.dependencies: - env.addChild(Entry(d, d.name)) + env.addChild(Entry(d.name, d)) root.addTree(self.model.parameters, "parameters") root.addTree(self.model.state_variables, "initialization") root.addTree(self.model.couplings, "coupling") root.addChild(env) # For all models, create an object that returns appropriate metadata. - def processNode(n, path=()): + def processNode(n: Entry, path: Tuple[str, ...] = ()): for i in range(len(n.children) - 1, -1, -1): child = n.children[i] childpath = path + (child.name,) - if isinstance(child.object, str): + if child.object is None: + # Child without underlying object: remove redundant subheaders if childpath[-1] in ( "parameters", "initialization", @@ -151,97 +184,100 @@ def processNode(n, path=()): child.object = Submodel(submodel.long_name) else: n.removeChild(i) - child = None - if child is not None: - processNode(child, childpath) + continue + processNode(child, childpath) processNode(root) - - if self.root is not None: - # We already have an old tree - compare and amend model. - def processChange(newnode, oldnode, parent): - oldnode.object = newnode.object - if not newnode.children: - return - ioldstart = 0 - for node in newnode.children: - iold = -1 - for i in range(ioldstart, len(oldnode.children)): - if node.name == oldnode.children[i].name: - iold = i - break - if iold != -1: - # New node was found among old nodes; - # remove any unused old nodes that precede it. - if iold > ioldstart: - self.beginRemoveRows(parent, ioldstart, iold - 1) - for i in range(iold - 1, ioldstart - 1, -1): - oldnode.removeChild(i) - self.endRemoveRows() - # Process changes to children of node. - processChange( - node, - oldnode.children[ioldstart], - self.createIndex(ioldstart, 0, oldnode.children[ioldstart]), - ) - else: - # New node not found; insert it. - self.beginInsertRows(parent, ioldstart, ioldstart) - oldnode.insertChild(ioldstart, node) - self.endInsertRows() - ioldstart = ioldstart + 1 - if ioldstart < len(oldnode.children): - # Remove any trailing unused old nodes. - self.beginRemoveRows(parent, ioldstart, len(oldnode.children) - 1) - for i in range(len(oldnode.children) - 1, ioldstart - 1, -1): - oldnode.removeChild(i) - self.endRemoveRows() - - processChange(root, self.root, QtCore.QModelIndex()) - else: - # First time a tree was created - store it and move on. - self.root = root - # self.reset() - - def rowCount(self, index): + return root + + def rebuild(self) -> None: + # We already have an old tree - compare and amend model. + def processChange(newnode: Entry, oldnode: Entry, parent: QtCore.QModelIndex): + oldnode.object = newnode.object + if not newnode.children: + return + ioldstart = 0 + for node in newnode.children: + iold = -1 + for i in range(ioldstart, len(oldnode.children)): + if node.name == oldnode.children[i].name: + iold = i + break + if iold != -1: + # New node was found among old nodes; + # remove any unused old nodes that precede it. + if iold > ioldstart: + self.beginRemoveRows(parent, ioldstart, iold - 1) + for i in range(iold - 1, ioldstart - 1, -1): + oldnode.removeChild(i) + self.endRemoveRows() + # Process changes to children of node. + processChange( + node, + oldnode.children[ioldstart], + self.createIndex(ioldstart, 0, oldnode.children[ioldstart]), + ) + else: + # New node not found; insert it. + self.beginInsertRows(parent, ioldstart, ioldstart) + oldnode.insertChild(ioldstart, node) + self.endInsertRows() + ioldstart = ioldstart + 1 + if ioldstart < len(oldnode.children): + # Remove any trailing unused old nodes. + self.beginRemoveRows(parent, ioldstart, len(oldnode.children) - 1) + for i in range(len(oldnode.children) - 1, ioldstart - 1, -1): + oldnode.removeChild(i) + self.endRemoveRows() + + processChange(self._build_tree(), self.root, QtCore.QModelIndex()) + + def rowCount(self, index: QtCore.QModelIndex = QtCore.QModelIndex()) -> int: if not index.isValid(): return len(self.root.children) elif index.column() == 0: - return len(index.internalPointer().children) + entry: Entry = index.internalPointer() + return len(entry.children) return 0 - def columnCount(self, index): + def columnCount(self, index: QtCore.QModelIndex = QtCore.QModelIndex()) -> int: return 4 - def index(self, row, column, parent): + def index( + self, row: int, column: int, parent: QtCore.QModelIndex = QtCore.QModelIndex() + ) -> QtCore.QModelIndex: if not parent.isValid(): # top-level children = self.root.children else: # not top-level - children = parent.internalPointer().children + entry: Entry = parent.internalPointer() + children = entry.children if row < 0 or row >= len(children) or column < 0 or column >= 4: return QtCore.QModelIndex() return self.createIndex(row, column, children[row]) - def parent(self, index): + def parent(self, index: QtCore.QModelIndex = QtCore.QModelIndex()): if not index.isValid(): return QtCore.QModelIndex() - parent = index.internalPointer().parent - if parent is self.root: + entry: Entry = index.internalPointer() + parent = entry.parent + assert parent is not None + if parent.parent is None: return QtCore.QModelIndex() irow = parent.parent.children.index(parent) return self.createIndex(irow, 0, parent) - def data(self, index, role): + def data(self, index: QtCore.QModelIndex, role: int = QtCore.Qt.DisplayRole): if not index.isValid(): return - entry = index.internalPointer() + entry: Entry = index.internalPointer() data = entry.object + assert not isinstance(data, Entry) if role == QtCore.Qt.DisplayRole: if index.column() == 0: - return entry.name if isinstance(data, str) else data.long_name - if not isinstance(data, (str, Submodel)): + return entry.name if data is None else data.long_name + if isinstance(data, pyfabm.Variable): if index.column() == 1: value = data.value if value is None: @@ -256,12 +292,11 @@ def data(self, index, role): elif index.column() == 3: return entry.name elif role == QtCore.Qt.ToolTipRole and index.parent().isValid(): - if not isinstance(data, str): + if isinstance(data, pyfabm.Variable): return data.long_path elif role == QtCore.Qt.EditRole: - if not isinstance(data, (str, Submodel)): - # print data.options - return data.value + assert isinstance(data, pyfabm.Variable) + return data.value elif role == QtCore.Qt.FontRole and index.column() == 1: if isinstance(data, pyfabm.Parameter) and data.value != data.default: font = QtGui.QFont() @@ -270,37 +305,51 @@ def data(self, index, role): elif ( role == QtCore.Qt.CheckStateRole and index.column() == 1 - and not isinstance(data, str) + and isinstance(data, pyfabm.Variable) and isinstance(data.value, bool) ): return QtCore.Qt.Checked if data.value else QtCore.Qt.Unchecked return None - def setData(self, index, value, role): + def setData( + self, + index: QtCore.QModelIndex, + value: object, + role: int = QtCore.Qt.EditRole, + ) -> bool: if role == QtCore.Qt.CheckStateRole: value = value == QtCore.Qt.Checked if role in (QtCore.Qt.EditRole, QtCore.Qt.CheckStateRole): - data = index.internalPointer().object + assert isinstance(value, (float, int, bool, str)) + entry: Entry = index.internalPointer() + data = entry.object + assert isinstance(data, pyfabm.Variable) data.value = value if isinstance(data, pyfabm.Parameter): self.rebuild() return True return False - def flags(self, index): - flags = QtCore.Qt.NoItemFlags + def flags(self, index: QtCore.QModelIndex): + flags: QtCore.Qt.ItemFlags = 0 | QtCore.Qt.NoItemFlags if not index.isValid(): return flags if index.column() == 1: - entry = index.internalPointer().object - if not isinstance(entry, (str, Submodel)): - if isinstance(entry.value, bool): + entry: Entry = index.internalPointer() + data = entry.object + if isinstance(data, pyfabm.Variable): + if isinstance(data.value, bool): flags |= QtCore.Qt.ItemIsUserCheckable else: flags |= QtCore.Qt.ItemIsEditable return flags | QtCore.Qt.ItemIsEnabled | QtCore.Qt.ItemIsSelectable - def headerData(self, section, orientation, role): + def headerData( + self, + section: int, + orientation: QtCore.Qt.Orientation, + role: int = QtCore.Qt.EditRole, + ): if ( orientation == QtCore.Qt.Horizontal and role == QtCore.Qt.DisplayRole @@ -311,7 +360,7 @@ def headerData(self, section, orientation, role): class TreeView(QtWidgets.QTreeView): - def __init__(self, model, parent): + def __init__(self, model: pyfabm.Model, parent: Optional[QtWidgets.QWidget]): QtWidgets.QTreeView.__init__(self, parent) itemmodel = pyfabm.gui_qt.ItemModel(model, parent) self.setItemDelegate(Delegate(parent)) @@ -319,17 +368,18 @@ def __init__(self, model, parent): self.setUniformRowHeights(True) self.expandAll() - def onTreeViewContextMenu(pos): + def onTreeViewContextMenu(pos: QtCore.QPoint): index = self.indexAt(pos) if index.isValid() and index.column() == 1: - data = index.internalPointer().object + entry: Entry = index.internalPointer() + data = entry.object if ( isinstance(data, pyfabm.Parameter) and data.value != data.default and data.default is not None ): - def reset(): + def reset() -> None: data.reset() itemmodel.rebuild() @@ -357,13 +407,13 @@ class ScientificDoubleValidator(QtGui.QValidator): fix-up. """ - def __init__(self, parent=None): + def __init__(self, parent: Optional[QtCore.QObject] = None): QtGui.QValidator.__init__(self, parent) - self.minimum = None - self.maximum = None + self.minimum: Optional[float] = None + self.maximum: Optional[float] = None self.suffix = "" - def validate(self, input, pos): + def validate(self, input: str, pos: int) -> Tuple[QtGui.QValidator.State, str, int]: assert isinstance(input, str), "input argument is not a string (old PyQt4 API?)" # Check for suffix (if ok, cut it off for further value checking) @@ -390,7 +440,7 @@ def validate(self, input, pos): return (QtGui.QValidator.Acceptable, input, pos) - def fixup(self, input): + def fixup(self, input: str): assert isinstance(input, str), "input argument is not a string (old PyQt4 API?)" if not input.endswith(self.suffix): return input @@ -407,14 +457,14 @@ def fixup(self, input): print(repr(input)) return input - def setSuffix(self, suffix): + def setSuffix(self, suffix: str): self.suffix = suffix class ScientificDoubleEditor(QtWidgets.QLineEdit): """Editor for a floating point value.""" - def __init__(self, parent): + def __init__(self, parent: Optional[QtWidgets.QWidget]): QtWidgets.QLineEdit.__init__(self, parent) self.curvalidator = ScientificDoubleValidator(self) @@ -422,20 +472,20 @@ def __init__(self, parent): self.suffix = "" # self.editingFinished.connect(self.onPropertyEditingFinished) - def setSuffix(self, suffix): + def setSuffix(self, suffix: str): value = self.value() self.suffix = suffix self.curvalidator.setSuffix(suffix) self.setValue(value) - def value(self): + def value(self) -> float: text = self.text() text = text[: len(text) - len(self.suffix)] if text == "": return 0 return float(text) - def setValue(self, value, format=None): + def setValue(self, value: Optional[float], format: Optional[str] = None): if value is None: strvalue = "" else: @@ -445,20 +495,20 @@ def setValue(self, value, format=None): strvalue = format % value self.setText(f"{strvalue}{self.suffix}") - def focusInEvent(self, e): + def focusInEvent(self, e: QtGui.QFocusEvent): QtWidgets.QLineEdit.focusInEvent(self, e) self.selectAll() - def selectAll(self): + def selectAll(self) -> None: QtWidgets.QLineEdit.setSelection(self, 0, len(self.text()) - len(self.suffix)) - def setMinimum(self, minimum): + def setMinimum(self, minimum: Optional[float]): self.curvalidator.minimum = minimum - def setMaximum(self, maximum): + def setMaximum(self, maximum: Optional[float]): self.curvalidator.maximum = maximum - def interpretText(self): + def interpretText(self) -> None: if not self.hasAcceptableInput(): text = self.text() textnew = self.curvalidator.fixup(text) diff --git a/src/pyfabm/utils/fabm_complete_yaml.py b/src/pyfabm/utils/fabm_complete_yaml.py index 9cbd5de1..266e1a36 100755 --- a/src/pyfabm/utils/fabm_complete_yaml.py +++ b/src/pyfabm/utils/fabm_complete_yaml.py @@ -15,7 +15,7 @@ sys.exit(1) -def main(): +def main() -> None: import argparse parser = argparse.ArgumentParser(description=__doc__) diff --git a/src/pyfabm/utils/fabm_configuration_gui.py b/src/pyfabm/utils/fabm_configuration_gui.py index f5fb7dbb..fd99d4a3 100755 --- a/src/pyfabm/utils/fabm_configuration_gui.py +++ b/src/pyfabm/utils/fabm_configuration_gui.py @@ -17,7 +17,7 @@ QtWidgets = pyfabm.gui_qt.QtWidgets -def main(): +def main() -> None: import argparse # Parse command line arguments. diff --git a/src/pyfabm/utils/fabm_describe_model.py b/src/pyfabm/utils/fabm_describe_model.py index df929987..b42e5320 100755 --- a/src/pyfabm/utils/fabm_describe_model.py +++ b/src/pyfabm/utils/fabm_describe_model.py @@ -14,7 +14,7 @@ sys.exit(1) -def main(): +def main() -> None: import argparse parser = argparse.ArgumentParser(description=__doc__) @@ -35,6 +35,8 @@ def main(): # Create model object from YAML file. model = pyfabm.Model(args.path) + variable: pyfabm.Variable + print("Interior state variables:") for variable in model.interior_state_variables: print(f" {variable.name} = {variable.long_name} ({variable.units})") diff --git a/src/pyfabm/utils/fabm_evaluate.py b/src/pyfabm/utils/fabm_evaluate.py index 93a5d18f..72e9b0ff 100755 --- a/src/pyfabm/utils/fabm_evaluate.py +++ b/src/pyfabm/utils/fabm_evaluate.py @@ -12,13 +12,9 @@ import sys import os +from typing import Union, MutableMapping, Mapping, Iterable, cast -try: - input = raw_input -except NameError: - pass - -import numpy +import numpy as np import netCDF4 import yaml @@ -30,20 +26,24 @@ def evaluate( - yaml_path, - sources=(), - location={}, - assignments={}, - verbose=True, - ignore_missing=False, - surface=True, - bottom=True, + yaml_path: str, + sources: Iterable[str] = (), + location: Mapping[str, int] = {}, + assignments: Mapping[str, float] = {}, + verbose: bool = True, + ignore_missing: bool = False, + surface: bool = True, + bottom: bool = True, ): # Create model object from YAML file. model = pyfabm.Model(yaml_path) - allvariables = list(model.state_variables) + list(model.dependencies) - name2variable = {} + allvariables: pyfabm.NamedObjectList[ + Union[pyfabm.StateVariable, pyfabm.Dependency] + ] = pyfabm.NamedObjectList(model.state_variables, model.dependencies) + name2variable: MutableMapping[ + str, Union[pyfabm.StateVariable, pyfabm.Dependency] + ] = {} for variable in allvariables: name2variable[variable.name] = variable if hasattr(variable, "output_name"): @@ -52,11 +52,17 @@ def evaluate( [(name.lower(), variable) for (name, variable) in name2variable.items()] ) - def set_state(**dim2index): + def set_state(**dim2index: int): missing = set(allvariables) - variable2source = {} + variable2source: MutableMapping[ + Union[pyfabm.StateVariable, pyfabm.Dependency], str + ] = {} - def set_variable(variable, value, source): + def set_variable( + variable: Union[pyfabm.StateVariable, pyfabm.Dependency], + value: float, + source: str, + ): missing.discard(variable) if variable in variable2source: print( @@ -65,12 +71,12 @@ def set_variable(variable, value, source): f" set by {source}" ) variable2source[variable] = source - variable.value = value + variable.value = cast(np.ndarray, value) for path in sources: if path.endswith("yaml"): with open(path) as f: - data = yaml.load(f) + data = yaml.safe_load(f) for name, value in data.items(): variable = name2variable.get(name) if variable is None: @@ -110,22 +116,16 @@ def set_variable(variable, value, source): variable = name2variable[name] missing.discard(variable) variable2source[variable] = "command line" - variable.value = float(value) + variable.value = cast(np.ndarray, float(value)) if verbose: print() print("State:") - for variable in sorted(model.state_variables, key=lambda x: x.name.lower()): - print( - f" {variable.name}: {variable.value}" - f" [{variable2source.get(variable)}]" - ) + for sv in sorted(model.state_variables, key=lambda x: x.name.lower()): + print(f" {sv.name}: {sv.value}" f" [{variable2source.get(sv)}]") print("Environment:") - for variable in sorted(model.dependencies, key=lambda x: x.name.lower()): - print( - f" {variable.name}: {variable.value}" - f" [{variable2source.get(variable)}]" - ) + for d in sorted(model.dependencies, key=lambda x: x.name.lower()): + print(f" {d.name}: {d.value} [{variable2source.get(d)}]") if missing: print("The following variables are still missing:") @@ -142,10 +142,10 @@ def set_variable(variable, value, source): sys.exit(1) print("State variables with largest value:") - for variable in sorted( - model.state_variables, key=lambda x: abs(x.value), reverse=True + for sv in sorted( + model.state_variables, key=lambda x: abs(float(x.value)), reverse=True )[:3]: - print(f" {variable.name}: {variable.value} {variable.units}") + print(f" {sv.name}: {sv.value} {sv.units}") # Get model rates rates = model.getRates(surface=surface, bottom=bottom) @@ -155,22 +155,20 @@ def set_variable(variable, value, source): if verbose: print("Diagnostics:") - for variable in sorted( - model.diagnostic_variables, key=lambda x: x.name.lower() - ): - if variable.output: - print(f" {variable.name}: {variable.value} {variable.units}") + for dv in sorted(model.diagnostic_variables, key=lambda x: x.name.lower()): + if dv.output: + print(f" {dv.name}: {dv.value} {dv.units}") # Check whether rates of change are valid numbers - valids = numpy.isfinite(rates) + valids = np.isfinite(rates) if not valids.all(): print("The following state variables have an invalid rate of change:") - for variable, rate, valid in zip(model.state_variables, rates, valids): + for sv, rate, valid in zip(model.state_variables, rates, valids): if not valid: - print(f" {variable.name}: {rate}") + print(f" {sv.name}: {rate}") eps = 1e-30 - relative_rates = numpy.array( + relative_rates = np.array( [ rate / (variable.value + eps) for variable, rate in zip(model.state_variables, rates) @@ -194,14 +192,14 @@ def set_variable(variable, value, source): )[:3]: print(f" {variable.name}: {86400 * relative_rate} d-1") - i = relative_rates.argmin() + i = int(relative_rates.argmin()) print( f"Minimum time step = {-1.0 / relative_rates[i]:%.3f} s due to decrease" f" in {model.state_variables[i].name}" ) -def main(): +def main() -> None: import argparse parser = argparse.ArgumentParser( diff --git a/src/pyfabm/utils/fabm_stress_test.py b/src/pyfabm/utils/fabm_stress_test.py index cc73e80f..fd8564b9 100755 --- a/src/pyfabm/utils/fabm_stress_test.py +++ b/src/pyfabm/utils/fabm_stress_test.py @@ -7,10 +7,23 @@ and running randomized or exhaustive tests. """ +from typing import ( + MutableSequence, + MutableSet, + Tuple, + Callable, + Mapping, + Sequence, + Iterable, + Union, + Dict, + cast, + Optional, +) import sys import yaml -import numpy +import numpy as np try: import pyfabm @@ -19,14 +32,24 @@ sys.exit(1) -def testRangePresence(ranges, variables, vary, found_variables): +VAR_TYPE = Union[pyfabm.StateVariable, pyfabm.Dependency] +VARY_TYPE = Sequence[Tuple[VAR_TYPE, Sequence[float]]] +MUTABLE_VARY_TYPE = MutableSequence[Tuple[VAR_TYPE, MutableSequence[float]]] + + +def testRangePresence( + ranges: Mapping[str, Union[float, int, MutableSequence[Union[int, float]]]], + variables: Iterable[VAR_TYPE], + vary: MUTABLE_VARY_TYPE, + found_variables: MutableSet[str], +): for variable in variables: if variable.name not in ranges: print(f"No range specified for variable {variable.name}.") sys.exit(1) variable_range = ranges[variable.name] if isinstance(variable_range, (int, float)): - variable.value = float(variable_range) + variable.value = cast(np.ndarray, float(variable_range)) else: if len(variable_range) != 2: print( @@ -58,22 +81,23 @@ def testRangePresence(ranges, variables, vary, found_variables): f"Raising default value of {variable.name} to prescribed" f" minimum {variable_range[0]} (previous default was {value})." ) - variable.value = variable_range[0] + variable.value = cast(np.ndarray, variable_range[0]) elif value > variable_range[1]: print( f"Lowering default value of {variable.name} to prescribed" f" maximum {variable_range[1]} (previous default was {value})." ) - variable.value = variable_range[1] + variable.value = cast(np.ndarray, variable_range[1]) else: - variable.value = 0.5 * (variable_range[0] + variable_range[1]) + center = 0.5 * (variable_range[0] + variable_range[1]) + variable.value = cast(np.ndarray, center) print( - f"Setting default value of {variable.name} to mean {variable.value}" + f"Setting default value of {variable.name} to center {center}" f" of prescribed range {variable_range[0]} - {variable_range[1]}" f" (no previous default was set)." ) if variable_range[0] == variable_range[1]: - variable.value = variable_range[0] + variable.value = cast(np.ndarray, variable_range[0]) else: vary.append((variable, variable_range)) found_variables.add(variable.name) @@ -82,26 +106,27 @@ def testRangePresence(ranges, variables, vary, found_variables): ndone = 0 -def check(model): +def check(model: pyfabm.Model): rates = model.getRates() assert len(rates) == len(model.state_variables) - valid = numpy.isfinite(rates) + valid = np.isfinite(rates) global ndone ndone += 1 if not valid.all(): print(f"Test {ndone} FAILED!") + variable: VAR_TYPE for variable, value in zip(model.state_variables, rates): - if not numpy.isfinite(value): + if not np.isfinite(value): print(f"Change in {variable.name} has invalid value {value}") - values = {} + values: Dict[str, np.ndarray] = {} print("MODEL STATE:") for variable in model.state_variables: print(f"- {variable.name} = {variable.value}") - values[variable.name] = float(variable.value) + values[variable.name] = np.array(variable.value) print("ENVIRONMENT:") for variable in model.dependencies: print(f"- {variable.name} = {variable.value}") - values[variable.name] = float(variable.value) + values[variable.name] = np.array(variable.value) with open("last_error.yaml", "w") as f: yaml.safe_dump(values, f, default_flow_style=False) print("This model state and environment has been saved in last_error.yaml.") @@ -112,46 +137,52 @@ def check(model): sys.exit(1) -def testRandomized(model, vary): +def testRandomized(model: pyfabm.Model, vary: VARY_TYPE): # Perpetual random test: # for each model input, pick a value from its valid range [minimum,maximum] while 1: - random_values = numpy.random.rand(len(vary)) + random_values = np.random.rand(len(vary)) for (variable, (minimum, maximum)), random_value in zip(vary, random_values): - variable.value = minimum + (maximum - minimum) * random_value + value = minimum + (maximum - minimum) * random_value + variable.value = cast(np.ndarray, value) check(model) if ndone % 1000 == 0: print(f"Test {ndone} completed.") -def testRandomizedExtremes(model, vary): +def testRandomizedExtremes(model: pyfabm.Model, vary: VARY_TYPE): # Perpetual random test: # for each model input, pick either its minimum or its maximum value. while 1: - pick_maxs = numpy.random.rand(len(vary)) > 0.5 + pick_maxs = np.random.rand(len(vary)) > 0.5 for (variable, (minimum, maximum)), pick_max in zip(vary, pick_maxs): - variable.value = maximum if pick_max else minimum + value = maximum if pick_max else minimum + variable.value = cast(np.ndarray, value) check(model) if ndone % 1000 == 0: print(f"Test {ndone} completed.") -def testExtremes(model, vary): +def testExtremes(model: pyfabm.Model, vary: VARY_TYPE): # Finite deterministic test: # for each model input, test minimum and maximum, leaving all other inputs # at their default value. for variable, (minimum, maximum) in vary: print(f"Testing {variable.name} = {minimum}...") - oldvalue = float(variable.value) - variable.value = minimum + oldvalue = np.array(variable.value) + variable.value = cast(np.ndarray, minimum) check(model) print(f"Testing {variable.name} = {maximum}...") - variable.value = maximum + variable.value = cast(np.ndarray, maximum) check(model) variable.value = oldvalue -def testExtremesRecursive(model, vary, ntot=None): +def testExtremesRecursive( + model: pyfabm.Model, + vary: VARY_TYPE, + ntot: Optional[int] = None, +): # Finite deterministic test: # test all possible combinations of minimum and maximm for each model input. if ntot is None: @@ -162,18 +193,18 @@ def testExtremesRecursive(model, vary, ntot=None): print(f"Completed {ndone} of {ntot} tests") return variable, (minimum, maximum) = vary[0] - oldvalue = float(variable.value) - variable.value = minimum + oldvalue = np.array(variable.value) + variable.value = cast(np.ndarray, minimum) testExtremesRecursive(model, vary[1:], ntot) - variable.value = maximum + variable.value = cast(np.ndarray, maximum) testExtremesRecursive(model, vary[1:], ntot) variable.value = oldvalue -def main(): +def main() -> None: import argparse - tests = { + tests: Mapping[str, Callable[[pyfabm.Model, VARY_TYPE], None]] = { "randomized": testRandomized, "extremes_per_variable": testExtremes, "extremes_randomized": testRandomizedExtremes, @@ -222,7 +253,7 @@ def main(): if args.write_ranges: with open(args.ranges_path, "w") as f: - def writeRanges(variables): + def writeRanges(variables: Iterable[VAR_TYPE]): for variable in variables: strmax = "?" if variable.value is None else 10 * variable.value f.write(f"{variable.name}: [0,{strmax}]\n") @@ -241,8 +272,8 @@ def writeRanges(variables): ) sys.exit(1) - vary = [] - found_variables = set() + vary: MUTABLE_VARY_TYPE = [] + found_variables: MutableSet[str] = set() testRangePresence(ranges, model.state_variables, vary, found_variables) testRangePresence(ranges, model.dependencies, vary, found_variables)