diff --git a/CODING.md b/CODING.md index 5d1f764..7bd6c55 100644 --- a/CODING.md +++ b/CODING.md @@ -36,6 +36,7 @@ pip install --editable . ```bash # use these linters, formatters, fixers: pip install -r requirements.code.txt +mypy --install-types # check configurations at [tool.*] in pyproject.toml # start with the template cp src/fire2template/template.py src/fire2a/.py diff --git a/requirements.code.txt b/requirements.code.txt index a29f7ee..f834e4a 100644 --- a/requirements.code.txt +++ b/requirements.code.txt @@ -8,4 +8,4 @@ python-lsp-black pytest pdoc -qgis-stubs @ git+https://github.com/leonhard-s/qgis-stubs.git +# qgis-stubs @ git+https://github.com/leonhard-s/qgis-stubs.git diff --git a/src/fire2a/__init__.py b/src/fire2a/__init__.py index bb74682..65ed87e 100644 --- a/src/fire2a/__init__.py +++ b/src/fire2a/__init__.py @@ -25,6 +25,7 @@ import logging from importlib.metadata import PackageNotFoundError, distribution from pathlib import Path +from typing import Any, Dict, List, Tuple, Union logger = logging.getLogger(__name__) @@ -43,7 +44,7 @@ logger.warning("%s Package version: %s, from %s", __name__, __version__, version_from) -def setup_logger(name: str = None, verbosity: str | int = "INFO", logfile: Path | None = None): +def setup_logger(name: Union[str, None] = None, verbosity: Union[str, int] = "INFO", logfile: Union[Path, None] = None): """ Users or developers not implementing their own logger should use this function to get enhanced program execution information. Capture the logger, setup its __name__ or root logger, verbosity, stream handler & rotating logfile. diff --git a/src/fire2a/clustering.py b/src/fire2a/clustering.py index abb78f2..e348028 100644 --- a/src/fire2a/clustering.py +++ b/src/fire2a/clustering.py @@ -8,6 +8,7 @@ import numpy as np from scipy.sparse import dok_matrix, lil_matrix from sklearn.cluster import AgglomerativeClustering +from typing import Union from .adjacency import adjacent_cells @@ -18,8 +19,8 @@ def raster_clusters( min_surface: float, max_surface: float, distance_threshold: float = 50.0, - total_clusters: int = None, - connectivity: int = None, + total_clusters: Union[int, None] = None, + connectivity: Union[int, None] = None, ) -> np.ndarray: """ This function receives as arguments: diff --git a/src/fire2a/downstream_protection_value.py b/src/fire2a/downstream_protection_value.py index 3e622ce..b9618f3 100644 --- a/src/fire2a/downstream_protection_value.py +++ b/src/fire2a/downstream_protection_value.py @@ -189,7 +189,7 @@ def recursion(G, i, pv, mdpv, i2n): return mdpv -def recursion2(G: DiGraph, i: np.int32, mdpv: ndarray, i2n: list[int]) -> ndarray: +def recursion2(G: DiGraph, i: int, mdpv: ndarray, i2n: list[int]) -> ndarray: for j in G.successors(i): mdpv[i2n.index(i)] += recursion2(G, j, mdpv, i2n) return mdpv[i2n.index(i)] @@ -326,10 +326,10 @@ def worker(data, pv, sid): def load_msg(afile: Path): try: - sim_id = search("\\d+", afile.stem).group(0) + sim_id = search(r"\d+", afile.stem).group(0) except: sim_id = "-1" - data = loadtxt( + data = np.loadtxt( afile, delimiter=",", dtype=[("i", np.int32), ("j", np.int32), ("t", np.int32)], usecols=(0, 1, 2), ndmin=1 ) return data, sim_id @@ -340,7 +340,7 @@ def get_data(files, callback=None): for count, afile in enumerate(files): sim_id = search("\\d+", afile.stem).group(0) data += [ - loadtxt( + np.loadtxt( afile, delimiter=",", dtype=[("i", np.int32), ("j", np.int32), ("t", np.int32)], diff --git a/src/fire2a/managedata.py b/src/fire2a/managedata.py index 4ee6693..fea059e 100644 --- a/src/fire2a/managedata.py +++ b/src/fire2a/managedata.py @@ -12,10 +12,12 @@ from numpy import max as npmax from numpy import nan as npnan from numpy import zeros as npzeros +from numpy import ndarray, dtype from pandas import DataFrame +from typing import Any, Dict, List, Optional, Tuple -def Lookupdict(filename: str) -> tuple[dict, dict]: +def Lookupdict(filename: Optional[Path,str]) -> Tuple[dict, dict]: """Reads lookup_table.csv and creates dictionaries for the fuel types and cells' colors Args: @@ -56,8 +58,9 @@ def Lookupdict(filename: str) -> tuple[dict, dict]: return row, colors - -def ForestGrid(filename: str, Lookupdict: dict) -> tuple[(list, list, int, int, list, list, int)]: +# Tuple[(list, list, int, int, list, list, int)] +# Tuple[list[Any], list[Any], int, int, list[Any], list[Any], int] +def ForestGrid(filename: str, Lookupdict: dict) -> Tuple[list[int], list[str], int, int, list[dict[str, Optional[list[int]]]], ndarray[Any, dtype[Any]], float]: """Reads fuels.asc file and returns an array with all the cells, and grid dimension nxm Args: @@ -324,8 +327,8 @@ def ForestGrid(filename: str, Lookupdict: dict) -> tuple[(list, list, int, int, return gridcell3, gridcell4, len(grid), tcols - 1, AdjCells, CoordCells, cellsize - -def DataGrids(InFolder: str, NCells: int) -> tuple[(list, list, list, list, list, list, list, list, list)]: +# Tuple[(list, list, list, list, list, list, list, list, list)] +def DataGrids(InFolder: str, NCells: int) -> Tuple[ndarray[Any, dtype[Any]], ndarray[Any, dtype[Any]], ndarray[Any, dtype[Any]], ndarray[Any, dtype[Any]], ndarray[Any, dtype[Any]], ndarray[Any, dtype[Any]], ndarray[Any, dtype[Any]], ndarray[Any, dtype[Any]], ndarray[Any, dtype[Any]]]: """ Reads *.asc files and returns an array per each ASCII file with the correspondant information per each cell. Currently supports elevation, ascpect, slope, curing, canopy bulk density, crown base height, conifer percent dead fir, probability of ignition and foliar moisture content. diff --git a/src/fire2a/raster.py b/src/fire2a/raster.py index 90fde96..0cf676f 100644 --- a/src/fire2a/raster.py +++ b/src/fire2a/raster.py @@ -14,6 +14,7 @@ import numpy as np from osgeo import gdal, ogr from qgis.core import QgsRasterLayer +from typing import Optional, Tuple, Union, Any, Dict, List from .utils import qgis2numpy_dtype @@ -73,7 +74,7 @@ def read_raster_band(filename: str, band: int = 1) -> tuple[np.ndarray, int, int return dataset.GetRasterBand(band).ReadAsArray(), dataset.RasterXSize, dataset.RasterYSize -def read_raster(filename: str, band: int = 1, data: bool = True, info: bool = True) -> tuple[np.ndarray, dict]: +def read_raster(filename: str, band: int = 1, data: bool = True, info: bool = True) -> tuple[Union[np.ndarray,None], Union[dict,None]]: """Read a raster file and return the data as a numpy array. Along raster info: transform, projection, raster count, raster width, raster height. @@ -238,24 +239,26 @@ def get_rlayer_data(layer: QgsRasterLayer): nodata = block.noDataValue() np_dtype = qgis2numpy_dtype(provider.dataType(1)) data = np.frombuffer(block.data(), dtype=np_dtype).reshape(layer.height(), layer.width()) + # return data, nodata, np_dtype else: data = [] nodata = [] - np_dtype = [] + np_dtypel = [] for i in range(layer.bandCount()): block = provider.block(i + 1, layer.extent(), layer.width(), layer.height()) nodata += [None] if block.hasNoDataValue(): nodata[-1] = block.noDataValue() - np_dtype += [qgis2numpy_dtype(provider.dataType(i + 1))] - data += [np.frombuffer(block.data(), dtype=np_dtype[-1]).reshape(layer.height(), layer.width())] + np_dtypel += [qgis2numpy_dtype(provider.dataType(i + 1))] + data += [np.frombuffer(block.data(), dtype=np_dtypel[-1]).reshape(layer.height(), layer.width())] # would different data types bug this next line? data = np.array(data) - # return data, nodata, np_dtype + # return data, nodata, np_dtypl return data def get_cell_sizeV2(filename: str, band: int = 1) -> tuple[float, float]: + # TODO: deprecate this function _, info = read_raster(filename, band=band, data=False, info=True) return info["RasterXSize"], info["RasterYSize"] @@ -289,7 +292,7 @@ def get_cell_size(raster: gdal.Dataset) -> tuple[float, float]: return cell_size -def mask_raster(raster_ds: gdal.Dataset, band: int, polygons: list[ogr.Geometry]) -> np.array: +def mask_raster(raster_ds: gdal.Dataset, band: int, polygons: list[ogr.Geometry]) -> np.ndarray: """Mask a raster with polygons using GDAL. Args: @@ -313,7 +316,7 @@ def mask_raster(raster_ds: gdal.Dataset, band: int, polygons: list[ogr.Geometry] return masked_data -def rasterize_polygons(polygons: list[ogr.Geometry], width: int, height: int) -> np.array: +def rasterize_polygons(polygons: list[ogr.Geometry], width: int, height: int) -> np.ndarray: """Rasterize polygons to a boolean array. Args: @@ -346,7 +349,7 @@ def rasterize_polygons(polygons: list[ogr.Geometry], width: int, height: int) -> return mask_array -def stack_rasters(file_list: list[Path], mask_polygon: list[ogr.Geometry] = None) -> tuple[np.ndarray, list[str]]: +def stack_rasters(file_list: list[Path], mask_polygon: Union[list[ogr.Geometry],None] = None) -> tuple[np.ndarray, list[str]]: """Stack raster files from a list into a 3D NumPy array. Args: diff --git a/src/fire2a/utils.py b/src/fire2a/utils.py index b72f650..230a556 100644 --- a/src/fire2a/utils.py +++ b/src/fire2a/utils.py @@ -8,11 +8,12 @@ import numpy as np from qgis.core import Qgis, QgsProcessingFeedback +from typing import Union, Any logger = logging.getLogger(__name__) -def loadtxt_nodata(fname, no_data=-9999, dtype=np.float32, **kwargs) -> np.ndarray: +def loadtxt_nodata(fname : str, no_data : int = -9999, dtype=np.float32, **kwargs) -> np.ndarray: """Load a text file into an array, casting safely to a specified data type, and replacing ValueError with a no_data value. Other arguments are passed to numpy.loadtxt. (delimiter=',' for example) @@ -49,7 +50,7 @@ def conv(no_data, dtype, val): return np.loadtxt(fname, converters=conv, dtype=dtype, **kwargs) -def qgis2numpy_dtype(qgis_dtype: Qgis.DataType) -> np.dtype: +def qgis2numpy_dtype(qgis_dtype: Qgis.DataType) -> Union[np.dtype[Any], None]: """Conver QGIS data type to corresponding numpy data type https://raw.githubusercontent.com/PUTvision/qgis-plugin-deepness/fbc99f02f7f065b2f6157da485bef589f611ea60/src/deepness/processing/processing_utils.py This is modified and extended copy of GDALDataType. @@ -80,6 +81,8 @@ def qgis2numpy_dtype(qgis_dtype: Qgis.DataType) -> np.dtype: return np.float32 if qgis_dtype == Qgis.DataType.Float64 or qgis_dtype == "Float64": return np.float64 + logger.error(f"QGIS data type {qgis_dtype} not matched to numpy data type.") + return None def getGDALdrivers(): diff --git a/src/fire2a/weathers.py b/src/fire2a/weathers.py index ddd1411..cd51ca9 100644 --- a/src/fire2a/weathers.py +++ b/src/fire2a/weathers.py @@ -71,8 +71,8 @@ def cut_weather_scenarios( weather_records: DataFrame, scenario_lengths: List[int], output_folder: Union[Path, str] = None, - n_output_files: int = None, -) -> None: + n_output_files: Union[int,None] = None, +) -> DataFrame: """Split weather records into smaller scenarios following specified scenario lengths. The number of output weather scenarios can be customized using the 'n_output_files' parameter. @@ -120,7 +120,7 @@ def cut_weather_scenarios( if any(length > total_data_length for length in sample): raise ValueError("Scenario length cannot be greater than the total length of weather records") - scenarios = [] # List to store weather scenarios + scenarios : DataFrame = [] # List to store weather scenarios # Generate scenarios based on specified lengths for index, length in enumerate(sample, start=1): @@ -150,7 +150,7 @@ def random_weather_scenario_generator( hr_limit: Optional[int] = None, lambda_ws: Optional[float] = None, lambda_wd: Optional[float] = None, - output_folder: Optional[str] = None, + output_folder: Optional[Union[Path,str]] = None, ): """Generates random weather scenarios and saves them as CSV files.