diff --git a/src/transport_performance/gtfs/cleaners.py b/src/transport_performance/gtfs/cleaners.py index 6f0ff73f..3d4d7157 100644 --- a/src/transport_performance/gtfs/cleaners.py +++ b/src/transport_performance/gtfs/cleaners.py @@ -3,7 +3,7 @@ import numpy as np -from transport_performance.utils.defence import _gtfs_defence, _check_list +from transport_performance.utils.defence import _gtfs_defence, _check_iterable def drop_trips(gtfs, trip_id: Union[str, list, np.ndarray]) -> None: @@ -32,13 +32,17 @@ def drop_trips(gtfs, trip_id: Union[str, list, np.ndarray]) -> None: if isinstance(trip_id, str): trip_id = [trip_id] - # _check_list only takes lists, therefore convert numpy arrays + # _check_iterable only takes lists, therefore convert numpy arrays if isinstance(trip_id, np.ndarray): trip_id = list(trip_id) # ensure trip ids are string - _check_list( - ls=trip_id, param_nm="trip_id", check_elements=True, exp_type=str + _check_iterable( + iterable=trip_id, + param_nm="trip_id", + iterable_type=list, + check_elements=True, + exp_type=str, ) # drop relevant records from tables diff --git a/src/transport_performance/gtfs/gtfs_utils.py b/src/transport_performance/gtfs/gtfs_utils.py index f8330889..85f9a059 100644 --- a/src/transport_performance/gtfs/gtfs_utils.py +++ b/src/transport_performance/gtfs/gtfs_utils.py @@ -13,7 +13,7 @@ from transport_performance.utils.defence import ( _is_expected_filetype, - _check_list, + _check_iterable, _type_defence, ) from transport_performance.utils.constants import PKG_PATH @@ -89,7 +89,12 @@ def bbox_filter_gtfs( ) if isinstance(bbox, list): - _check_list(ls=bbox, param_nm="bbox", exp_type=float) + _check_iterable( + iterable=bbox, + param_nm="bbox_list", + iterable_type=list, + exp_type=float, + ) # create box polygon around provided coords, need to splat bbox = box(*bbox) # gtfs_kit expects gdf diff --git a/src/transport_performance/gtfs/validation.py b/src/transport_performance/gtfs/validation.py index c8828a13..fbe13186 100644 --- a/src/transport_performance/gtfs/validation.py +++ b/src/transport_performance/gtfs/validation.py @@ -34,7 +34,7 @@ _check_parent_dir_exists, _check_column_in_df, _type_defence, - _check_item_in_list, + _check_item_in_iter, _check_attribute, ) @@ -571,7 +571,7 @@ def viz_stops( # geoms defence geoms = geoms.lower().strip() ACCEPT_VALS = ["point", "hull"] - _check_item_in_list(geoms, ACCEPT_VALS, "geoms") + _check_item_in_iter(geoms, ACCEPT_VALS, "geoms") try: m = self._produce_stops_map( @@ -1077,8 +1077,8 @@ def _plot_summary( which = which.lower() # ensure 'which' is valid - _check_item_in_list( - item=which, _list=["trip", "route"], param_nm="which" + _check_item_in_iter( + item=which, iterable=["trip", "route"], param_nm="which" ) raw_pth = os.path.join( @@ -1088,8 +1088,8 @@ def _plot_summary( _check_parent_dir_exists(raw_pth, "save_pth", create=True) # orientation input defences - _check_item_in_list( - item=orientation, _list=["v", "h"], param_nm="orientation" + _check_item_in_iter( + item=orientation, iterable=["v", "h"], param_nm="orientation" ) # assign the correct values depending on which breakdown has been diff --git a/src/transport_performance/osm/osm_utils.py b/src/transport_performance/osm/osm_utils.py index bd444847..f523438f 100644 --- a/src/transport_performance/osm/osm_utils.py +++ b/src/transport_performance/osm/osm_utils.py @@ -6,7 +6,7 @@ from transport_performance.utils.defence import ( _type_defence, - _check_list, + _check_iterable, _check_parent_dir_exists, _is_expected_filetype, ) @@ -77,7 +77,13 @@ def filter_osm( ) ) - _check_list(bbox, param_nm="bbox", check_elements=True, exp_type=float) + _check_iterable( + bbox, + param_nm="bbox", + iterable_type=list, + check_elements=True, + exp_type=float, + ) _check_parent_dir_exists(out_pth, param_nm="out_pth", create=True) # Compile the osmosis command cmd = [ diff --git a/src/transport_performance/urban_centres/raster_uc.py b/src/transport_performance/urban_centres/raster_uc.py index 5a7dd3e9..50e9b35b 100644 --- a/src/transport_performance/urban_centres/raster_uc.py +++ b/src/transport_performance/urban_centres/raster_uc.py @@ -1,22 +1,23 @@ """Functions to calculate urban centres following Eurostat definition.""" from collections import Counter +from typing import Union +import pathlib + import affine import geopandas as gpd import numpy as np import numpy.ma as ma import pandas as pd -import pathlib import rasterio import xarray as xr - from geocube.vector import vectorize from pyproj import Transformer from rasterio.mask import raster_geometry_mask from rasterio.transform import rowcol from scipy.ndimage import generic_filter, label -from transport_performance.utils.defence import _is_expected_filetype -from typing import Union + +import transport_performance.utils.defence as d class UrbanCentre: @@ -71,8 +72,9 @@ def __init__( exp_ext: list = [".tif", ".tiff", ".tff"], ): - # check that path is str or PosixPath - _is_expected_filetype(path, "file", exp_ext=exp_ext) + # check that path is str or pathlib.Path + d._is_expected_filetype(path, "file", exp_ext=exp_ext) + d._check_iterable(exp_ext, "exp_ext", list, True, str) self.file = path def get_urban_centre( @@ -175,11 +177,8 @@ def get_urban_centre( ) # buffer - if not isinstance(buffer_size, int): - raise TypeError( - "`buffer_size` expected int, " - f"got {type(buffer_size).__name__}." - ) + d._type_defence(buffer_size, "buffer_size", int) + if buffer_size <= 0: raise ValueError( "`buffer_size` expected positive non-zero integer" @@ -204,13 +203,16 @@ def get_urban_centre( return self.output def _window_raster( - self, file: str, bbox: gpd.GeoDataFrame, band_n: int = 1 + self, + file: Union[str, pathlib.Path], + bbox: gpd.GeoDataFrame, + band_n: int = 1, ) -> tuple: """Open file, load band and apply mask. Parameters ---------- - file : str + file : Union[str, pathlib.Path] Path to geoTIFF file. bbox : gpd.GeoDataFrame A GeoPandas GeoDataFrame containing boundaries to filter the @@ -229,14 +231,8 @@ def _window_raster( crs string from the raster. """ - if not isinstance(bbox, gpd.GeoDataFrame): - raise TypeError( - "`bbox` expected GeoDataFrame, " f"got {type(bbox).__name__}." - ) - if not isinstance(band_n, int): - raise TypeError( - "`band_n` expected integer, " f"got {type(band_n).__name__}" - ) + d._type_defence(bbox, "bbox", gpd.GeoDataFrame) + d._type_defence(band_n, "band_n", int) with rasterio.open(file) as src: if src.crs != bbox.crs: @@ -276,16 +272,8 @@ def _flag_cells( If cell_pop_threshold is too high and all cells are filtered out. """ - if not isinstance(masked_rst, np.ndarray): - raise TypeError( - "`masked_rst` expected numpy array, " - f"got {type(masked_rst).__name__}." - ) - if not isinstance(cell_pop_threshold, int): - raise TypeError( - "`cell_pop_threshold` expected integer, " - f"got {type(cell_pop_threshold).__name__}." - ) + d._type_defence(masked_rst, "masked_rst", np.ndarray) + d._type_defence(cell_pop_threshold, "cell_pop_threshold", int) flag_array = masked_rst >= cell_pop_threshold @@ -317,18 +305,13 @@ def _cluster_cells( Number of clusters identified. """ - if not isinstance(flag_array, np.ndarray): - raise TypeError( - "`flag_array` expected numpy array, " - f"got {type(flag_array).__name__}." - ) - if not isinstance(diag, bool): - raise TypeError("`diag` must be a boolean.") + d._type_defence(flag_array, "flag_array", np.ndarray) + d._type_defence(diag, "diag", bool) - if diag is False: - s = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]) - elif diag is True: + if diag: s = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]]) + else: + s = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]) labelled_array, num_clusters = label(flag_array, s) @@ -364,25 +347,10 @@ def _check_cluster_pop( Array including only clusters with population over the threshold. """ - if not isinstance(band, np.ndarray): - raise TypeError( - "`band` expected numpy array, " f"got {type(band).__name__}." - ) - if not isinstance(labelled_array, np.ndarray): - raise TypeError( - "`labelled_array` expected numpy array, " - f"got {type(labelled_array).__name__}." - ) - if not isinstance(num_clusters, int): - raise TypeError( - "`num_clusters` expected integer, " - f"got {type(num_clusters).__name__}" - ) - if not isinstance(cluster_pop_threshold, int): - raise TypeError( - "`cluster_pop_threshold` expected integer, " - f"got {type(cluster_pop_threshold).__name__}" - ) + d._type_defence(band, "band", np.ndarray) + d._type_defence(labelled_array, "labelled_array", np.ndarray) + d._type_defence(num_clusters, "num_clusters", int) + d._type_defence(cluster_pop_threshold, "cluster_pop_threshold", int) urban_centres = labelled_array.copy() for n in range(1, num_clusters + 1): @@ -424,10 +392,8 @@ def _custom_filter(self, win: np.ndarray, threshold: int) -> int: counter = Counter(win) mode_count = counter.most_common(1)[0] if (mode_count[1] >= threshold) & (win[len(win) // 2] == 0): - r = mode_count[0] - else: - r = win[len(win) // 2] - return r + return mode_count[0] + return win[len(win) // 2] def _fill_gaps( self, urban_centres: np.ndarray, cell_fill_threshold: int = 5 @@ -453,16 +419,9 @@ def _fill_gaps( Array including urban centres with gaps filled. """ - if not isinstance(urban_centres, np.ndarray): - raise TypeError( - "`urban_centres` expected numpy array, " - f"got {type(urban_centres).__name__}." - ) - if not isinstance(cell_fill_threshold, int): - raise TypeError( - "`cell_fill_threshold` expected integer, " - f"got {type(cell_fill_threshold).__name__}" - ) + d._type_defence(urban_centres, "urban_centres", np.ndarray) + d._type_defence(cell_fill_threshold, "cell_fill_threshold", int) + if not (5 <= cell_fill_threshold <= 8): raise ValueError( "Wrong value for `cell_fill_threshold`, " @@ -482,8 +441,7 @@ def _fill_gaps( extra_keywords={"threshold": cell_fill_threshold}, ) if np.array_equal(filled, check): - break - return filled + return filled def _get_x_y( self, @@ -511,13 +469,15 @@ def _get_x_y( (row, col) position for provided parameters. """ - if len(coords) != 2: - raise ValueError("`coords` expected a tuple of lenght 2.") - - if (not isinstance(coords[0], float)) and ( - not isinstance(coords[1], float) - ): - raise TypeError("Elements of `coords` need to be float.") + d._check_iterable( + iterable=coords, + param_nm="coords", + iterable_type=tuple, + check_elements=True, + exp_type=float, + check_length=True, + length=2, + ) transformer = Transformer.from_crs(coords_crs, raster_crs) x, y = transformer.transform(*coords) @@ -567,25 +527,12 @@ def _vectorize_uc( If centre coordinates are not included within any cluster. """ - if not isinstance(uc_array, np.ndarray): - raise TypeError( - "`uc_array` expected numpy array, " - f"got {type(uc_array).__name__}." - ) - if not isinstance(centre, tuple): - raise TypeError( - "`centre` expected tuple, " f"got {type(centre).__name__}" - ) - if not isinstance(aff, affine.Affine): - raise TypeError("`aff` must be a valid Affine object") - if not isinstance(raster_crs, rasterio.crs.CRS): - raise TypeError( - "`raster_crs` must be a valid rasterio.crs.CRS " "object" - ) - if not isinstance(nodata, int): - raise TypeError( - "`nodata` expected integer, " f"got {type(nodata).__name__}" - ) + d._type_defence(uc_array, "uc_array", np.ndarray) + d._type_defence(centre, "centre", tuple) + d._type_defence(aff, "aff", affine.Affine) + d._type_defence(raster_crs, "raster_crs", rasterio.crs.CRS) + d._type_defence(nodata, "nodata", int) + d._type_defence(centre_crs, "centre_crs", (type(None), str)) if centre_crs is None: centre_crs = raster_crs diff --git a/src/transport_performance/utils/defence.py b/src/transport_performance/utils/defence.py index d426c758..b495231d 100644 --- a/src/transport_performance/utils/defence.py +++ b/src/transport_performance/utils/defence.py @@ -1,5 +1,6 @@ """Defensive check utility funcs. Internals only.""" from typing import Union +from collections.abc import Iterable import pathlib import numpy as np @@ -232,36 +233,111 @@ def _type_defence(some_object, param_nm, types) -> None: return None -def _check_list(ls, param_nm, check_elements=True, exp_type=str): - """Check a list and its elements for type. +def _check_iter_length(iterable: Iterable, param_nm: str, length: int) -> None: + """Check the length of an iterable. Parameters ---------- - ls : list - List to check. + iterable : Iterable + Iterable to check. param_nm : str Name of the parameter being checked. - check_elements : (bool, optional) - Whether to check the list element types. Defaults to True. - exp_type : (_type_, optional): - The expected type of the elements. Defaults to str. + length: int + Expected length of the iterable to check. Raises ------ - TypeError: `ls` is not a list. - TypeError: Elements of `ls` are not of the expected type. + ValueError: length of iterable does not match `length`. Returns ------- None """ - if not isinstance(ls, list): - raise TypeError( - f"`{param_nm}` should be a list. Instead found {type(ls)}" + # check if iterable + _type_defence(iterable, param_nm, Iterable) + # check if length is int + _type_defence(length, "length", int) + + if len(iterable) != length: + raise ValueError( + f"`{param_nm}` is of length {len(iterable)}. " + f"Expected length {length}." ) + + return None + + +def _check_iterable( + iterable: Iterable, + param_nm: str, + iterable_type: type, + check_elements: bool = True, + exp_type: Union[tuple, type] = str, + check_length: bool = False, + length: int = 0, +) -> None: + """Check an iterable and its elements for type. + + Parameters + ---------- + iterable : Iterable + Iterable to check. + param_nm : str + Name of the parameter being checked. + iterable_type : type + Expected iterable type. + check_elements : bool, optional + Whether to check the list element types. Defaults to True. + exp_type : Union[tuple, type], optional: + The expected type of the elements. If using a tuple, it should be a + tuple of types. Defaults to str. + check_length: bool, optional + Wether to check the length of the iterable. Defaults to False. + length: int, optional + Expected length of the iterable. Defaults to 0. + + Raises + ------ + TypeError + `iterable` is not iterable. + TypeError + `exp_type` contains elements that are not type. + TypeError + Elements of `iterable` are not of the expected type(s). + + Returns + ------- + None + + """ + # check if iterable + _type_defence(iterable, param_nm, Iterable) + # check if iterable_type is type + _type_defence(iterable_type, "iterable_type", type) + # check if iterable type matches expected + _type_defence(iterable, param_nm, iterable_type) + + # check if expected types tuple includes only types + if isinstance(exp_type, tuple): + for i in exp_type: + if not isinstance(i, type): + raise TypeError( + ( + f"`exp_type` must contain types only. " + f"Found {type(i)} : {i}" + ) + ) + else: + _type_defence(exp_type, "exp_type", (type, tuple)) + + # check length + if check_length: + _check_iter_length(iterable, param_nm, length) + + # check if elements are of the expected types if check_elements: - for i in ls: + for i in iterable: if not isinstance(i, exp_type): raise TypeError( ( @@ -309,15 +385,15 @@ def _check_column_in_df(df: pd.DataFrame, column_name: str) -> None: return None -def _check_item_in_list(item: str, _list: list, param_nm: str) -> None: - """Defence to check if an item is present in a list. +def _check_item_in_iter(item: str, iterable: Iterable, param_nm: str) -> None: + """Defence to check if an item is present in an iterable. Parameters ---------- item : str The item to check the list for - _list : list - The list to check that the item is in + iterable : Iterable + The iterable to check that the item is in param_nm : str The name of the param that the item has been passed to @@ -331,10 +407,13 @@ def _check_item_in_list(item: str, _list: list, param_nm: str) -> None: Error raised when item not in the list. """ - if item not in _list: + # check if iterable + _type_defence(iterable, param_nm, Iterable) + + if item not in iterable: raise ValueError( - f"'{param_nm}' expected one of the following:" - f"{_list} Got {item}" + f"'{param_nm}' expected one of the following: " + f"{iterable}. Got {item}: {type(item)}" ) return None diff --git a/tests/gtfs/test_validation.py b/tests/gtfs/test_validation.py index 44fcea1b..b0fbda34 100644 --- a/tests/gtfs/test_validation.py +++ b/tests/gtfs/test_validation.py @@ -354,8 +354,8 @@ def test_viz_stops_defence(self, tmpdir, gtfs_fixture): with pytest.raises( ValueError, match=re.escape( - "'geoms' expected one of the following:" - "['point', 'hull'] Got foobar" + "'geoms' expected one of the following: " + "['point', 'hull']. Got foobar: " ), ): gtfs_fixture.viz_stops(out_pth=tmp, geoms="foobar") @@ -863,8 +863,8 @@ def test__plot_summary_defences(self, tmp_path, gtfs_fixture): with pytest.raises( ValueError, match=re.escape( - "'orientation' expected one of the following:" - f"{options} Got i" + "'orientation' expected one of the following: " + f"{options}. Got i: " ), ): gtfs_fixture._plot_summary( @@ -894,8 +894,8 @@ def test__plot_summary_defences(self, tmp_path, gtfs_fixture): with pytest.raises( ValueError, match=re.escape( - "'which' expected one of the following:" - "['trip', 'route'] Got tester" + "'which' expected one of the following: " + "['trip', 'route']. Got tester: " ), ): gtfs_fixture._plot_summary(which="tester", target_column="tester") diff --git a/tests/urban_centres/test_urban_centres.py b/tests/urban_centres/test_urban_centres.py index 9d23bfd5..e82293a4 100644 --- a/tests/urban_centres/test_urban_centres.py +++ b/tests/urban_centres/test_urban_centres.py @@ -215,7 +215,7 @@ def cluster_centre(): "num", pytest.raises( TypeError, - match=(r"`pth` expected .*'str'.*Path'.* Got .*'int'.*"), + match=(r"`pth` expected .*str.*Path.* Got .*int.*"), ), ), ], @@ -264,7 +264,8 @@ def test_file( ( "string", pytest.raises( - TypeError, match=(r"`bbox` expected GeoDataFrame, got str") + TypeError, + match=(r"`bbox` expected .*GeoDataFrame.* Got " r".*str.*"), ), ), # wrong format bbox @@ -272,7 +273,7 @@ def test_file( pd.DataFrame(), pytest.raises( TypeError, - match=(r"`bbox` expected GeoDataFrame, got DataFrame"), + match=(r"`bbox` expected .*GeoDataFrame.* Got .*DataFrame.*"), ), ), # badly defined bbox @@ -370,28 +371,33 @@ def test_bbox( (50, 3), None, pytest.raises( - TypeError, match=(r"Elements of `coords` need to be float") + TypeError, + match=( + r"`coords` must contain .*float.* only" + r".*Found .*int.*: 50" + ), ), ), ( (50, 3, 3), None, pytest.raises( - ValueError, match=(r"`coords` expected a tuple of lenght 2") + ValueError, + match=(r"`coords` is of length 3.* Expected " r"length 2"), ), ), ( 50, None, pytest.raises( - TypeError, match=(r"`centre` expected tuple, got int") + TypeError, match=(r"`centre` expected .*tuple.* Got .*int.*") ), ), ( "(50, 3)", None, pytest.raises( - TypeError, match=(r"`centre` expected tuple, got str") + TypeError, match=(r"`centre` expected .*tuple.* Got .*str.*") ), ), ], @@ -436,14 +442,14 @@ def test_centre( ( 1.5, pytest.raises( - TypeError, match=(r"`band_n` expected integer, got float") + TypeError, match=(r"`band_n` expected .*int.* Got .*float.*") ), ), (2, pytest.raises(IndexError, match=(r"band index 2 out of range"))), ( "2", pytest.raises( - TypeError, match=(r"`band_n` expected integer, got str") + TypeError, match=(r"`band_n` expected .*int.* Got .*str.*") ), ), ], @@ -490,7 +496,7 @@ def test_band_n( 1500.5, pytest.raises( TypeError, - match=(r"`cell_pop_threshold` expected integer, got float"), + match=(r"`cell_pop_threshold` expected .*int.* Got .*float.*"), ), [], ), @@ -498,7 +504,7 @@ def test_band_n( "1500", pytest.raises( TypeError, - match=(r"`cell_pop_threshold` expected integer, got str"), + match=(r"`cell_pop_threshold` expected .*int.* Got .*str.*"), ), [], ), @@ -612,13 +618,17 @@ def test_cell_pop_t_output( (False, does_not_raise(), 3, 4), ( 1, - pytest.raises(TypeError, match=(r"`diag` must be a boolean")), + pytest.raises( + TypeError, match=(r"`diag` expected .*bool.* Got " r".*int.*") + ), 0, 0, ), ( "True", - pytest.raises(TypeError, match=(r"`diag` must be a boolean")), + pytest.raises( + TypeError, match=(r"`diag` expected .*bool.* Got " r".*str.*") + ), 0, 0, ), @@ -725,7 +735,10 @@ def test_diag_output( 50000.5, pytest.raises( TypeError, - match=(r"`cluster_pop_threshold` expected integer, got float"), + match=( + r"`cluster_pop_threshold` expected .*int.* Got " + r".*float.*" + ), ), [], ), @@ -733,7 +746,9 @@ def test_diag_output( "50000", pytest.raises( TypeError, - match=(r"`cluster_pop_threshold` expected integer, got str"), + match=( + r"`cluster_pop_threshold` expected .*int.* Got " r".*str.*" + ), ), [], ), @@ -849,7 +864,9 @@ def test_cluster_pop_t_output( 5.5, pytest.raises( TypeError, - match=(r"`cell_fill_threshold` expected integer, got float"), + match=( + r"`cell_fill_threshold` expected .*int.* Got " r".*float.*" + ), ), [], ), @@ -857,7 +874,9 @@ def test_cluster_pop_t_output( "5", pytest.raises( TypeError, - match=(r"`cell_fill_threshold` expected integer, got str"), + match=( + r"`cell_fill_threshold` expected .*int.* Got " r".*str.*" + ), ), [], ), @@ -988,13 +1007,13 @@ def test_cell_fill_output( ( -200.5, pytest.raises( - TypeError, match=(r"`nodata` expected integer, got float") + TypeError, match=(r"`nodata` expected .*int.* Got .*float.*") ), ), ( "str", pytest.raises( - TypeError, match=(r"`nodata` expected integer, got str") + TypeError, match=(r"`nodata` expected .*int.* Got .*str.*") ), ), ], @@ -1046,13 +1065,15 @@ def test_v_nodata( ( 10000.5, pytest.raises( - TypeError, match=(r"`buffer_size` expected int, got float") + TypeError, + match=(r"`buffer_size` expected .*int.* Got " r".*float.*"), ), ), ( "str", pytest.raises( - TypeError, match=(r"`buffer_size` expected int, got str") + TypeError, + match=(r"`buffer_size` expected .*int.* Got " r".*str.*"), ), ), ], @@ -1180,7 +1201,7 @@ def test_final_output( assert out.loc[2][1] == Polygon(bbox_coords) # type of output - assert type(out) == gpd.GeoDataFrame + assert type(out) is gpd.GeoDataFrame # shape of output assert out.shape == (3, 2) @@ -1200,7 +1221,8 @@ def test__flag_cells_raises(dummy_pop_array): """Test _flag_cells raises expected exception.""" uc = ucc.UrbanCentre(dummy_pop_array) with pytest.raises( - TypeError, match="`masked_rst` expected numpy array, got str." + TypeError, + match=(r"`masked_rst` expected .*numpy.ndarray.* Got " r".*str.*"), ): uc._flag_cells("not an array") @@ -1209,7 +1231,8 @@ def test__cluster_cells_raises(dummy_pop_array): """Test _cluster_cells raises.""" uc = ucc.UrbanCentre(dummy_pop_array) with pytest.raises( - TypeError, match="`flag_array` expected numpy array, got str." + TypeError, + match=(r"`flag_array` expected .*numpy.ndarray.* Got " r".*str.*"), ): uc._cluster_cells("not an array") @@ -1218,13 +1241,14 @@ def test__check_cluster_pop_raises(dummy_pop_array): """Test _check_cluster_pop raises.""" uc = ucc.UrbanCentre(dummy_pop_array) with pytest.raises( - TypeError, match="`band` expected numpy array, got str." + TypeError, match=(r"`band` expected .*numpy.ndarray.* Got .*str.*") ): uc._check_cluster_pop( band="not an array", labelled_array=1, num_clusters=2 ) with pytest.raises( - TypeError, match="`labelled_array` expected numpy array, got str." + TypeError, + match=(r"`labelled_array` expected .*numpy.ndarray.* Got " r".*str.*"), ): uc._check_cluster_pop( band=np.array([0, 1, 2]), @@ -1232,7 +1256,7 @@ def test__check_cluster_pop_raises(dummy_pop_array): num_clusters=2, ) with pytest.raises( - TypeError, match="`num_clusters` expected integer, got float" + TypeError, match=(r"`num_clusters` expected .*int.* Got .*float.*") ): uc._check_cluster_pop( band=np.array([0, 1, 2]), @@ -1245,7 +1269,8 @@ def test__fill_gaps_raises(dummy_pop_array): """Test _fill_gaps raises.""" uc = ucc.UrbanCentre(dummy_pop_array) with pytest.raises( - TypeError, match="`urban_centres` expected numpy array, got str." + TypeError, + match=(r"`urban_centres` expected .*numpy.ndarray.* Got " r".*str.*"), ): uc._fill_gaps(urban_centres="not an array") @@ -1255,12 +1280,14 @@ def test__vectorize_uc_raises(dummy_pop_array): uc = ucc.UrbanCentre(dummy_pop_array) # crs = rio.crs.CRS.from_epsg(3005) with pytest.raises( - TypeError, match="`uc_array` expected numpy array, got str." + TypeError, match=(r"`uc_array` expected .*numpy.ndarray.* Got .*str.*") ): uc._vectorize_uc( uc_array="not an array", aff=1, raster_crs=2, centre=3 ) - with pytest.raises(TypeError, match="`aff` must be a valid Affine object"): + with pytest.raises( + TypeError, match=(r"`aff` expected .*affine.Affine.* Got .*str.*") + ): uc._vectorize_uc( uc_array=np.array([0, 1, 2]), aff="not Affine", @@ -1268,7 +1295,8 @@ def test__vectorize_uc_raises(dummy_pop_array): centre=(1, 2), ) with pytest.raises( - TypeError, match="`raster_crs` must be a valid rasterio.crs.CRS object" + TypeError, + match=(r"`raster_crs` expected .*rasterio.crs.CRS.* Got .*str.*"), ): uc._vectorize_uc( uc_array=np.array([0, 1, 2]), diff --git a/tests/utils/test_defence.py b/tests/utils/test_defence.py index d5618e74..daf0dd9b 100644 --- a/tests/utils/test_defence.py +++ b/tests/utils/test_defence.py @@ -10,12 +10,12 @@ from pyprojroot import here from transport_performance.utils.defence import ( - _check_list, + _check_iterable, _check_parent_dir_exists, _gtfs_defence, _type_defence, _check_column_in_df, - _check_item_in_list, + _check_item_in_iter, _check_attribute, _handle_path_like, _is_expected_filetype, @@ -23,19 +23,50 @@ ) -class Test_CheckList(object): - """Test internal _check_list.""" +class Test_CheckIter(object): + """Test internal _check_iterable.""" - def test__check_list_only(self): - """Func raises as expected when not checking list elements.""" + def test__check_iter_only(self): + """Func raises as expected when not checking iterable elements.""" + # not iterable with pytest.raises( TypeError, - match="`some_bool` should be a list. Instead found ", + match="`some_bool` expected .*Iterable.* Got .*bool.*", ): - _check_list(ls=True, param_nm="some_bool", check_elements=False) + _check_iterable( + iterable=True, + param_nm="some_bool", + iterable_type=list, + check_elements=False, + ) + + # iterable does not match provided type + with pytest.raises( + TypeError, + match="`some_tuple` expected .*list.* Got .*tuple.*", + ): + _check_iterable( + iterable=(1, 2, 3), + param_nm="some_tuple", + iterable_type=list, + check_elements=False, + ) + + # iterable_type is not type + with pytest.raises( + TypeError, + match="`iterable_type` expected .*type.* Got .*str.*", + ): + _check_iterable( + iterable=(1, 2, 3), + param_nm="some_tuple", + iterable_type="tuple", + check_elements=False, + ) - def test__check_list_elements(self): + def test__check_iter_elements(self): """Func raises as expected when checking list elements.""" + # mixed types with pytest.raises( TypeError, match=( @@ -43,26 +74,102 @@ def test__check_list_elements(self): " : 2" ), ): - _check_list( - ls=[1, "2", 3], + _check_iterable( + iterable=[1, "2", 3], param_nm="mixed_list", + iterable_type=list, check_elements=True, exp_type=int, ) - def test__check_list_passes(self): + # wrong expected types + with pytest.raises( + TypeError, + match=("`exp_type` expected .*type.*tuple.*" "Got .*str.*"), + ): + _check_iterable( + iterable=["1", "2", "3"], + param_nm="param", + iterable_type=list, + check_elements=True, + exp_type="str", + ) + + # wrong types in exp_type tuple + with pytest.raises( + TypeError, + match=("`exp_type` must contain types only.* Found .*str.*: str"), + ): + _check_iterable( + iterable=[1, "2", 3], + param_nm="param", + iterable_type=list, + check_elements=True, + exp_type=(int, "str"), + ) + + def test__check_iter_passes(self): """Test returns None when pass conditions met.""" + # check list and element type assert ( - _check_list(ls=[1, 2, 3], param_nm="int_list", exp_type=int) + _check_iterable( + iterable=[1, 2, 3], + param_nm="int_list", + iterable_type=list, + exp_type=int, + ) is None ) + + # check list and multiple element types assert ( - _check_list( - ls=[False, True], param_nm="bool_list", check_elements=False + _check_iterable( + iterable=[1, "2", 3], + param_nm="int_list", + iterable_type=list, + exp_type=(int, str), ) is None ) + # check tuple + assert ( + _check_iterable( + iterable=(False, True), + param_nm="bool_list", + iterable_type=tuple, + check_elements=False, + ) + is None + ) + + def test__check_iter_length(self): + """Func raises as expected when length of iterable does not match.""" + # wrong length + with pytest.raises( + ValueError, + match=("`list_3` is of length 3. Expected length 2."), + ): + _check_iterable( + iterable=[1, 2, 3], + param_nm="list_3", + iterable_type=list, + check_elements=False, + check_length=True, + length=2, + ) + + def test__check_iter_length_pass(self): + """Test returns None when pass conditions met.""" + _check_iterable( + iterable=[1, 2, 3], + param_nm="list_3", + iterable_type=list, + check_elements=False, + check_length=True, + length=3, + ) + class Test_CheckParentDirExists(object): """Assertions for check_parent_dir_exists.""" @@ -299,24 +406,24 @@ def test_list(): class TestCheckItemInList(object): - """Tests for _check_item_in_list().""" + """Tests for _check_item_in_iter().""" - def test_check_item_in_list_defence(self, test_list): - """Defensive tests for check_item_in_list().""" + def test_check_item_in_iter_defence(self, test_list): + """Defensive tests for check_item_in_iter().""" with pytest.raises( ValueError, match=re.escape( - "'test' expected one of the following:" - f"{test_list} Got not_in_test" + "'test' expected one of the following: " + f"{test_list}. Got not_in_test: " ), ): - _check_item_in_list( - item="not_in_test", _list=test_list, param_nm="test" + _check_item_in_iter( + item="not_in_test", iterable=test_list, param_nm="test" ) - def test_check_item_in_list_on_pass(self, test_list): - """General tests for check_item_in_list().""" - _check_item_in_list(item="test", _list=test_list, param_nm="test") + def test_check_item_in_iter_on_pass(self, test_list): + """General tests for check_item_in_iter().""" + _check_item_in_iter(item="test", iterable=test_list, param_nm="test") @pytest.fixture(scope="function") @@ -336,7 +443,7 @@ def __init__(self) -> None: class TestCheckAttribute(object): - """Tests for _check_item_in_list().""" + """Tests for _check_item_in_iter().""" def test_check_attribute_defence(self, dummy_obj): """Defensive tests for check_attribute."""