Skip to content

Commit

Permalink
160 148 defences (#177)
Browse files Browse the repository at this point in the history
* refactor: changed type defences to module

* tests: updated match error messages

* feat: replaced list functions to accept any iterable

* chore: replace function names in code

* refactor: changed tuple defence in uc

* fix: replaced import to prevent deprecation warning

* fix: typo in imported function

* fix: fixed function argument names and gtfs tests

* tests: changes to defence tests

* tests: finished changes to defence tests

* feat: added check iterable lenght func

* refactor: change centre tuple length check in uc

* tests: add tests for iterable length

* fix: minor corrections to match error messages

* Reorganise imports; Update comment specifying PosixPath to Path

* add space for pre-commit;

* Best practice changes;Update type hinting and associated docstring

* chore: run pre-commit

* style: changes to docs, style and defences as suggested in review

* fix: changed defences in cleaners.py

* fix: fixed check for type None

---------

Co-authored-by: Charlie Browning <[email protected]>
Co-authored-by: Ethan Moss <[email protected]>
  • Loading branch information
3 people authored Oct 17, 2023
1 parent 21a58ac commit e7c6e7f
Show file tree
Hide file tree
Showing 9 changed files with 376 additions and 200 deletions.
12 changes: 8 additions & 4 deletions src/transport_performance/gtfs/cleaners.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from transport_performance.utils.defence import _gtfs_defence, _check_list
from transport_performance.utils.defence import _gtfs_defence, _check_iterable


def drop_trips(gtfs, trip_id: Union[str, list, np.ndarray]) -> None:
Expand Down Expand Up @@ -32,13 +32,17 @@ def drop_trips(gtfs, trip_id: Union[str, list, np.ndarray]) -> None:
if isinstance(trip_id, str):
trip_id = [trip_id]

# _check_list only takes lists, therefore convert numpy arrays
# _check_iterable only takes lists, therefore convert numpy arrays
if isinstance(trip_id, np.ndarray):
trip_id = list(trip_id)

# ensure trip ids are string
_check_list(
ls=trip_id, param_nm="trip_id", check_elements=True, exp_type=str
_check_iterable(
iterable=trip_id,
param_nm="trip_id",
iterable_type=list,
check_elements=True,
exp_type=str,
)

# drop relevant records from tables
Expand Down
9 changes: 7 additions & 2 deletions src/transport_performance/gtfs/gtfs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from transport_performance.utils.defence import (
_is_expected_filetype,
_check_list,
_check_iterable,
_type_defence,
)
from transport_performance.utils.constants import PKG_PATH
Expand Down Expand Up @@ -89,7 +89,12 @@ def bbox_filter_gtfs(
)

if isinstance(bbox, list):
_check_list(ls=bbox, param_nm="bbox", exp_type=float)
_check_iterable(
iterable=bbox,
param_nm="bbox_list",
iterable_type=list,
exp_type=float,
)
# create box polygon around provided coords, need to splat
bbox = box(*bbox)
# gtfs_kit expects gdf
Expand Down
12 changes: 6 additions & 6 deletions src/transport_performance/gtfs/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
_check_parent_dir_exists,
_check_column_in_df,
_type_defence,
_check_item_in_list,
_check_item_in_iter,
_check_attribute,
)

Expand Down Expand Up @@ -571,7 +571,7 @@ def viz_stops(
# geoms defence
geoms = geoms.lower().strip()
ACCEPT_VALS = ["point", "hull"]
_check_item_in_list(geoms, ACCEPT_VALS, "geoms")
_check_item_in_iter(geoms, ACCEPT_VALS, "geoms")

try:
m = self._produce_stops_map(
Expand Down Expand Up @@ -1077,8 +1077,8 @@ def _plot_summary(
which = which.lower()

# ensure 'which' is valid
_check_item_in_list(
item=which, _list=["trip", "route"], param_nm="which"
_check_item_in_iter(
item=which, iterable=["trip", "route"], param_nm="which"
)

raw_pth = os.path.join(
Expand All @@ -1088,8 +1088,8 @@ def _plot_summary(
_check_parent_dir_exists(raw_pth, "save_pth", create=True)

# orientation input defences
_check_item_in_list(
item=orientation, _list=["v", "h"], param_nm="orientation"
_check_item_in_iter(
item=orientation, iterable=["v", "h"], param_nm="orientation"
)

# assign the correct values depending on which breakdown has been
Expand Down
10 changes: 8 additions & 2 deletions src/transport_performance/osm/osm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from transport_performance.utils.defence import (
_type_defence,
_check_list,
_check_iterable,
_check_parent_dir_exists,
_is_expected_filetype,
)
Expand Down Expand Up @@ -77,7 +77,13 @@ def filter_osm(
)
)

_check_list(bbox, param_nm="bbox", check_elements=True, exp_type=float)
_check_iterable(
bbox,
param_nm="bbox",
iterable_type=list,
check_elements=True,
exp_type=float,
)
_check_parent_dir_exists(out_pth, param_nm="out_pth", create=True)
# Compile the osmosis command
cmd = [
Expand Down
151 changes: 49 additions & 102 deletions src/transport_performance/urban_centres/raster_uc.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
"""Functions to calculate urban centres following Eurostat definition."""
from collections import Counter

from typing import Union
import pathlib

import affine
import geopandas as gpd
import numpy as np
import numpy.ma as ma
import pandas as pd
import pathlib
import rasterio
import xarray as xr

from geocube.vector import vectorize
from pyproj import Transformer
from rasterio.mask import raster_geometry_mask
from rasterio.transform import rowcol
from scipy.ndimage import generic_filter, label
from transport_performance.utils.defence import _is_expected_filetype
from typing import Union

import transport_performance.utils.defence as d


class UrbanCentre:
Expand Down Expand Up @@ -71,8 +72,9 @@ def __init__(
exp_ext: list = [".tif", ".tiff", ".tff"],
):

# check that path is str or PosixPath
_is_expected_filetype(path, "file", exp_ext=exp_ext)
# check that path is str or pathlib.Path
d._is_expected_filetype(path, "file", exp_ext=exp_ext)
d._check_iterable(exp_ext, "exp_ext", list, True, str)
self.file = path

def get_urban_centre(
Expand Down Expand Up @@ -175,11 +177,8 @@ def get_urban_centre(
)

# buffer
if not isinstance(buffer_size, int):
raise TypeError(
"`buffer_size` expected int, "
f"got {type(buffer_size).__name__}."
)
d._type_defence(buffer_size, "buffer_size", int)

if buffer_size <= 0:
raise ValueError(
"`buffer_size` expected positive non-zero integer"
Expand All @@ -204,13 +203,16 @@ def get_urban_centre(
return self.output

def _window_raster(
self, file: str, bbox: gpd.GeoDataFrame, band_n: int = 1
self,
file: Union[str, pathlib.Path],
bbox: gpd.GeoDataFrame,
band_n: int = 1,
) -> tuple:
"""Open file, load band and apply mask.
Parameters
----------
file : str
file : Union[str, pathlib.Path]
Path to geoTIFF file.
bbox : gpd.GeoDataFrame
A GeoPandas GeoDataFrame containing boundaries to filter the
Expand All @@ -229,14 +231,8 @@ def _window_raster(
crs string from the raster.
"""
if not isinstance(bbox, gpd.GeoDataFrame):
raise TypeError(
"`bbox` expected GeoDataFrame, " f"got {type(bbox).__name__}."
)
if not isinstance(band_n, int):
raise TypeError(
"`band_n` expected integer, " f"got {type(band_n).__name__}"
)
d._type_defence(bbox, "bbox", gpd.GeoDataFrame)
d._type_defence(band_n, "band_n", int)

with rasterio.open(file) as src:
if src.crs != bbox.crs:
Expand Down Expand Up @@ -276,16 +272,8 @@ def _flag_cells(
If cell_pop_threshold is too high and all cells are filtered out.
"""
if not isinstance(masked_rst, np.ndarray):
raise TypeError(
"`masked_rst` expected numpy array, "
f"got {type(masked_rst).__name__}."
)
if not isinstance(cell_pop_threshold, int):
raise TypeError(
"`cell_pop_threshold` expected integer, "
f"got {type(cell_pop_threshold).__name__}."
)
d._type_defence(masked_rst, "masked_rst", np.ndarray)
d._type_defence(cell_pop_threshold, "cell_pop_threshold", int)

flag_array = masked_rst >= cell_pop_threshold

Expand Down Expand Up @@ -317,18 +305,13 @@ def _cluster_cells(
Number of clusters identified.
"""
if not isinstance(flag_array, np.ndarray):
raise TypeError(
"`flag_array` expected numpy array, "
f"got {type(flag_array).__name__}."
)
if not isinstance(diag, bool):
raise TypeError("`diag` must be a boolean.")
d._type_defence(flag_array, "flag_array", np.ndarray)
d._type_defence(diag, "diag", bool)

if diag is False:
s = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])
elif diag is True:
if diag:
s = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
else:
s = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])

labelled_array, num_clusters = label(flag_array, s)

Expand Down Expand Up @@ -364,25 +347,10 @@ def _check_cluster_pop(
Array including only clusters with population over the threshold.
"""
if not isinstance(band, np.ndarray):
raise TypeError(
"`band` expected numpy array, " f"got {type(band).__name__}."
)
if not isinstance(labelled_array, np.ndarray):
raise TypeError(
"`labelled_array` expected numpy array, "
f"got {type(labelled_array).__name__}."
)
if not isinstance(num_clusters, int):
raise TypeError(
"`num_clusters` expected integer, "
f"got {type(num_clusters).__name__}"
)
if not isinstance(cluster_pop_threshold, int):
raise TypeError(
"`cluster_pop_threshold` expected integer, "
f"got {type(cluster_pop_threshold).__name__}"
)
d._type_defence(band, "band", np.ndarray)
d._type_defence(labelled_array, "labelled_array", np.ndarray)
d._type_defence(num_clusters, "num_clusters", int)
d._type_defence(cluster_pop_threshold, "cluster_pop_threshold", int)

urban_centres = labelled_array.copy()
for n in range(1, num_clusters + 1):
Expand Down Expand Up @@ -424,10 +392,8 @@ def _custom_filter(self, win: np.ndarray, threshold: int) -> int:
counter = Counter(win)
mode_count = counter.most_common(1)[0]
if (mode_count[1] >= threshold) & (win[len(win) // 2] == 0):
r = mode_count[0]
else:
r = win[len(win) // 2]
return r
return mode_count[0]
return win[len(win) // 2]

def _fill_gaps(
self, urban_centres: np.ndarray, cell_fill_threshold: int = 5
Expand All @@ -453,16 +419,9 @@ def _fill_gaps(
Array including urban centres with gaps filled.
"""
if not isinstance(urban_centres, np.ndarray):
raise TypeError(
"`urban_centres` expected numpy array, "
f"got {type(urban_centres).__name__}."
)
if not isinstance(cell_fill_threshold, int):
raise TypeError(
"`cell_fill_threshold` expected integer, "
f"got {type(cell_fill_threshold).__name__}"
)
d._type_defence(urban_centres, "urban_centres", np.ndarray)
d._type_defence(cell_fill_threshold, "cell_fill_threshold", int)

if not (5 <= cell_fill_threshold <= 8):
raise ValueError(
"Wrong value for `cell_fill_threshold`, "
Expand All @@ -482,8 +441,7 @@ def _fill_gaps(
extra_keywords={"threshold": cell_fill_threshold},
)
if np.array_equal(filled, check):
break
return filled
return filled

def _get_x_y(
self,
Expand Down Expand Up @@ -511,13 +469,15 @@ def _get_x_y(
(row, col) position for provided parameters.
"""
if len(coords) != 2:
raise ValueError("`coords` expected a tuple of lenght 2.")

if (not isinstance(coords[0], float)) and (
not isinstance(coords[1], float)
):
raise TypeError("Elements of `coords` need to be float.")
d._check_iterable(
iterable=coords,
param_nm="coords",
iterable_type=tuple,
check_elements=True,
exp_type=float,
check_length=True,
length=2,
)

transformer = Transformer.from_crs(coords_crs, raster_crs)
x, y = transformer.transform(*coords)
Expand Down Expand Up @@ -567,25 +527,12 @@ def _vectorize_uc(
If centre coordinates are not included within any cluster.
"""
if not isinstance(uc_array, np.ndarray):
raise TypeError(
"`uc_array` expected numpy array, "
f"got {type(uc_array).__name__}."
)
if not isinstance(centre, tuple):
raise TypeError(
"`centre` expected tuple, " f"got {type(centre).__name__}"
)
if not isinstance(aff, affine.Affine):
raise TypeError("`aff` must be a valid Affine object")
if not isinstance(raster_crs, rasterio.crs.CRS):
raise TypeError(
"`raster_crs` must be a valid rasterio.crs.CRS " "object"
)
if not isinstance(nodata, int):
raise TypeError(
"`nodata` expected integer, " f"got {type(nodata).__name__}"
)
d._type_defence(uc_array, "uc_array", np.ndarray)
d._type_defence(centre, "centre", tuple)
d._type_defence(aff, "aff", affine.Affine)
d._type_defence(raster_crs, "raster_crs", rasterio.crs.CRS)
d._type_defence(nodata, "nodata", int)
d._type_defence(centre_crs, "centre_crs", (type(None), str))

if centre_crs is None:
centre_crs = raster_crs
Expand Down
Loading

0 comments on commit e7c6e7f

Please sign in to comment.