diff --git a/requirements.txt b/requirements.txt index 2f79a67c..22b4a2b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ iris xarray cartopy trackpy +typing_extensions \ No newline at end of file diff --git a/tobac/feature_detection.py b/tobac/feature_detection.py index 969a0150..59e38a1f 100644 --- a/tobac/feature_detection.py +++ b/tobac/feature_detection.py @@ -16,36 +16,46 @@ diverse datasets. Geoscientific Model Development, 12(11), 4551-4570. """ - +from __future__ import annotations +from typing import Union, Callable +import warnings import logging + +from typing_extensions import Literal + import numpy as np import pandas as pd from scipy.spatial import KDTree from sklearn.neighbors import BallTree +import iris +import xarray as xr + from tobac.tracking import build_distance_function from tobac.utils import internal as internal_utils from tobac.utils import periodic_boundaries as pbc_utils -from tobac.utils import spectral_filtering +from tobac.utils.general import spectral_filtering from tobac.utils import get_statistics import warnings def feature_position( - hdim1_indices, - hdim2_indices, - vdim_indices=None, - region_small=None, - region_bbox=None, - track_data=None, - threshold_i=None, - position_threshold="center", - target=None, - PBC_flag="none", - hdim1_min=0, - hdim1_max=0, - hdim2_min=0, - hdim2_max=0, -): + hdim1_indices: list[int], + hdim2_indices: list[int], + vdim_indices: Union[list[int], None] = None, + region_small: np.ndarray = None, + region_bbox: Union[list[int], tuple[int]] = None, + track_data: np.ndarray = None, + threshold_i: float = None, + position_threshold: Literal[ + "center", "extreme", "weighted_diff", "weighted abs" + ] = "center", + target: Literal["maximum", "minimum"] = None, + PBC_flag: Literal["none", "hdim_1", "hdim_2", "both"] = "none", + hdim1_min: int = 0, + hdim1_max: int = 0, + hdim2_min: int = 0, + hdim2_max: int = 0, +) -> tuple[float]: """Determine feature position with regard to the horizontal dimensions in pixels from the identified region above threshold values @@ -142,8 +152,8 @@ def feature_position( # First, if necessary, run PBC processing. # processing of PBC indices # checks to see if minimum and maximum values are present in dimensional array - # then if true, adds max value to any indices past the halfway point of their respective dimension. - # this, in essence, shifts the set of points to the high side. + # then if true, adds max value to any indices past the halfway point of their + # respective dimension. this, in essence, shifts the set of points to the high side. pbc_options = ["hdim_1", "hdim_2", "both"] if len(region_bbox) == 4: @@ -214,13 +224,13 @@ def feature_position( ) if run_mean: - if PBC_flag == "hdim_1" or PBC_flag == "both": + if PBC_flag in ("hdim_1", "both"): hdim1_index = pbc_utils.weighted_circmean( hdim1_indices, weights=hdim1_weights, high=hdim1_max + 1, low=hdim1_min ) else: hdim1_index = np.average(hdim1_indices, weights=hdim1_weights) - if PBC_flag == "hdim_2" or PBC_flag == "both": + if PBC_flag in ("hdim_2", "both"): hdim2_index = pbc_utils.weighted_circmean( hdim2_indices, weights=hdim2_weights, high=hdim2_max + 1, low=hdim2_min ) @@ -235,7 +245,9 @@ def feature_position( return hdim1_index, hdim2_index -def test_overlap(region_inner, region_outer): +def test_overlap( + region_inner: list[tuple[int]], region_outer: list[tuple[int]] +) -> bool: """Test for overlap between two regions Parameters @@ -260,8 +272,11 @@ def test_overlap(region_inner, region_outer): def remove_parents( - features_thresholds, regions_i, regions_old, strict_thresholding=False -): + features_thresholds: pd.DataFrame, + regions_i: dict, + regions_old: dict, + strict_thresholding: bool = False, +) -> pd.DataFrame: """Remove parents of newly detected feature regions. Remove features where its regions surround newly @@ -363,26 +378,28 @@ def remove_parents( def feature_detection_threshold( - data_i, - i_time, - threshold=None, - min_num=0, - target="maximum", - position_threshold="center", - sigma_threshold=0.5, - n_erosion_threshold=0, - n_min_threshold=0, - min_distance=0, - idx_start=0, - PBC_flag="none", - vertical_axis=0, -): + data_i: np.array, + i_time: int, + threshold: float = None, + min_num: int = 0, + target: Literal["maximum", "minimum"] = "maximum", + position_threshold: Literal[ + "center", "extreme", "weighted_diff", "weighted_abs" + ] = "center", + sigma_threshold: float = 0.5, + n_erosion_threshold: int = 0, + n_min_threshold: int = 0, + min_distance: float = 0, + idx_start: int = 0, + PBC_flag: Literal["none", "hdim_1", "hdim_2", "both"] = "none", + vertical_axis: int = 0, +) -> tuple[pd.DataFrame, dict]: """Find features based on individual threshold value. Parameters ---------- - data_i : iris.cube.Cube - 2D field to perform the feature detection (single timestep) on. + data_i : np.array + 2D or 3D field to perform the feature detection (single timestep) on. i_time : int Number of the current timestep. @@ -442,7 +459,8 @@ def feature_detection_threshold( if min_num != 0: warnings.warn( - "min_num parameter has no effect and will be deprecated in a future version of tobac. Please use n_min_threshold instead", + "min_num parameter has no effect and will be deprecated in a future version of tobac. " + "Please use n_min_threshold instead", FutureWarning, ) @@ -568,8 +586,8 @@ def feature_detection_threshold( ~np.any(label_on_corner == skip_list) ): # alt_inds = np.where(labels==alt_label_3) - # get a list of indices where the label on the corner is so we can switch them - # in the new list. + # get a list of indices where the label on the corner is so we can switch + # them in the new list. labels_2[ all_label_locs_v[label_on_corner], @@ -596,7 +614,8 @@ def feature_detection_threshold( and (np.any(label_on_corner == skip_list)) and (~np.any(label_on_corner == skip_list_thisind)) ): - # find the updated label, and overwrite all of label_ind indices with updated label + # find the updated label, and overwrite all of label_ind indices with + # updated label labels_2_alt = labels_2[label_z, y_val_alt, x_val_alt] labels_2[ label_locs_v, label_locs_h1, label_locs_h2 @@ -615,7 +634,8 @@ def feature_detection_threshold( # if it's labeled and not already been dealt with if (label_alt != 0) and (~np.any(label_alt == skip_list)): - # find the indices where it has the label value on opposite side and change their value to original side + # find the indices where it has the label value on opposite side and change + # their value to original side # print(all_label_locs_v[label_alt], alt_inds[0]) labels_2[ all_label_locs_v[label_alt], @@ -640,7 +660,8 @@ def feature_detection_threshold( and (np.any(label_alt == skip_list)) and (~np.any(label_alt == skip_list_thisind)) ): - # find the updated label, and overwrite all of label_ind indices with updated label + # find the updated label, and overwrite all of label_ind indices with + # updated label labels_2_alt = labels_2[label_z, y_val_alt, label_x] labels_2[ label_locs_v, label_locs_h1, label_locs_h2 @@ -658,7 +679,8 @@ def feature_detection_threshold( # if it's labeled and not already been dealt with if (label_alt != 0) and (~np.any(label_alt == skip_list)): - # find the indices where it has the label value on opposite side and change their value to original side + # find the indices where it has the label value on opposite side and change + # their value to original side labels_2[ all_label_locs_v[label_alt], all_label_locs_h1[label_alt], @@ -682,7 +704,8 @@ def feature_detection_threshold( and (np.any(label_alt == skip_list)) and (~np.any(label_alt == skip_list_thisind)) ): - # find the updated label, and overwrite all of label_ind indices with updated label + # find the updated label, and overwrite all of label_ind indices with + # updated label labels_2_alt = labels_2[label_z, label_y, x_val_alt] labels_2[ label_locs_v, label_locs_h1, label_locs_h2 @@ -860,24 +883,26 @@ def feature_detection_threshold( def feature_detection_multithreshold_timestep( - data_i, - i_time, - threshold=None, - min_num=0, - target="maximum", - position_threshold="center", - sigma_threshold=0.5, - n_erosion_threshold=0, - n_min_threshold=0, - min_distance=0, - feature_number_start=1, - PBC_flag="none", - vertical_axis=None, - dxy=-1, - wavelength_filtering=None, - strict_thresholding=False, - statistics=None, -): + data_i: np.array, + i_time: int, + threshold: list[float] = None, + min_num: int = 0, + target: Literal["maximum", "minimum"] = "maximum", + position_threshold: Literal[ + "center", "extreme", "weighted_diff", "weighted abs" + ] = "center", + sigma_threshold: float = 0.5, + n_erosion_threshold: int = 0, + n_min_threshold: int = 0, + min_distance: float = 0, + feature_number_start: int = 1, + PBC_flag: Literal["none", "hdim_1", "hdim_2", "both"] = "none", + vertical_axis: int = None, + dxy: float = -1, + wavelength_filtering: tuple[float] = None, + strict_thresholding: bool = False, + statistics: Union[dict[str, Union[Callable, tuple[Callable, dict]]], None] = None, +) -> pd.DataFrame: """Find features in each timestep. Based on iteratively finding regions above/below a set of @@ -888,9 +913,12 @@ def feature_detection_multithreshold_timestep( ---------- data_i : iris.cube.Cube - 2D field to perform the feature detection (single timestep) on. + 3D field to perform the feature detection (single timestep) on. - threshold : float, optional + i_time : int + Number of the current timestep. + + threshold : list of floats, optional Threshold value used to select target regions to track. Default is None. @@ -977,13 +1005,15 @@ def feature_detection_multithreshold_timestep( ) # sort thresholds from least extreme to most extreme - threshold_sorted = sorted(threshold, reverse=(target == "minimum")) + threshold_sorted = sorted(threshold, reverse=target == "minimum") - # check if each threshold has a n_min_threshold (minimum nr. of grid cells associated with thresholds), if multiple n_min_threshold are given + # check if each threshold has a n_min_threshold (minimum nr. of grid cells associated with + # thresholds), if multiple n_min_threshold are given if isinstance(n_min_threshold, list) or isinstance(n_min_threshold, dict): if len(n_min_threshold) is not len(threshold): raise ValueError( - "Number of elements in n_min_threshold needs to be the same as thresholds, if n_min_threshold is given as dict or list." + "Number of elements in n_min_threshold needs to be the same as thresholds, if " + "n_min_threshold is given as dict or list." ) # check if thresholds in dict correspond to given thresholds @@ -992,15 +1022,18 @@ def feature_detection_multithreshold_timestep( n_min_threshold.keys(), reverse=(target == "minimum") ): raise ValueError( - "Ambiguous input for threshold values. If n_min_threshold is given as a dict, the keys not to correspond to the values in threshold." + "Ambiguous input for threshold values. If n_min_threshold is given as a dict," + " the keys not to correspond to the values in threshold." ) - # sort dictionary by keys (threshold values) so that they match sorted thresholds and get values for n_min_threshold + # sort dictionary by keys (threshold values) so that they match sorted thresholds and + # get values for n_min_threshold n_min_threshold = [ n_min_threshold[threshold] for threshold in threshold_sorted ] elif isinstance(n_min_threshold, list): - # if n_min_threshold is a list, sort it such that it still matches with the sorted threshold values + # if n_min_threshold is a list, sort it such that it still matches with the sorted + # threshold values n_min_threshold = [ x for _, x in sorted( @@ -1013,7 +1046,8 @@ def feature_detection_multithreshold_timestep( and not isinstance(n_min_threshold, int) ): raise ValueError( - "N_min_threshold must be an integer. If multiple values for n_min_threshold are given, please provide a dictionary or list." + "N_min_threshold must be an integer. If multiple values for n_min_threshold are given," + " please provide a dictionary or list." ) # create empty lists to store regions and features for individual timestep @@ -1050,9 +1084,11 @@ def feature_detection_multithreshold_timestep( [features_thresholds, features_threshold_i], ignore_index=True ) - # For multiple threshold, and features found both in the current and previous step, remove "parent" features from Dataframe + # For multiple threshold, and features found both in the current and previous step, remove + # "parent" features from Dataframe if i_threshold > 0 and not features_thresholds.empty: - # for each threshold value: check if newly found features are surrounded by feature based on less restrictive threshold + # For multiple threshold, and features found both in the current and previous step, remove + # "parent" features from Dataframe features_thresholds, regions_old = remove_parents( features_thresholds, regions_i, @@ -1089,26 +1125,28 @@ def feature_detection_multithreshold_timestep( def feature_detection_multithreshold( - field_in, - dxy=None, - threshold=None, - min_num=0, - target="maximum", - position_threshold="center", - sigma_threshold=0.5, - n_erosion_threshold=0, - n_min_threshold=0, - min_distance=0, - feature_number_start=1, - PBC_flag="none", - vertical_coord=None, - vertical_axis=None, - detect_subset=None, - wavelength_filtering=None, - dz=None, - strict_thresholding=False, - statistics=None, -): + field_in: iris.cube.Cube, + dxy: float = None, + threshold: list[float] = None, + min_num: int = 0, + target: Literal["maximum", "minimum"] = "maximum", + position_threshold: Literal[ + "center", "extreme", "weighted_diff", "weighted abs" + ] = "center", + sigma_threshold: float = 0.5, + n_erosion_threshold: int = 0, + n_min_threshold: int = 0, + min_distance: float = 0, + feature_number_start: int = 1, + PBC_flag: Literal["none", "hdim_1", "hdim_2", "both"] = "none", + vertical_coord: str = None, + vertical_axis: int = None, + detect_subset: dict = None, + wavelength_filtering: tuple = None, + dz: Union[float, None] = None, + strict_thresholding: bool = False, + statistics: Union[dict[str, Union[Callable, tuple[Callable, dict]]], None] = None, +) -> pd.DataFrame: """Perform feature detection based on contiguous regions. The regions are above/below a threshold. @@ -1270,8 +1308,8 @@ def feature_detection_multithreshold( if type(threshold) in [int, float]: threshold = [threshold] - # if wavelength_filtering is given, check that value cannot be larger than distances along x and y, - # that the value cannot be smaller or equal to the grid spacing + # if wavelength_filtering is given, check that value cannot be larger than distances along + # x and y, that the value cannot be smaller or equal to the grid spacing # and throw a warning if dxy and wavelengths have about the same order of magnitude if wavelength_filtering is not None: if is_3D: @@ -1287,17 +1325,21 @@ def feature_detection_multithreshold( if lambda_min > distance or lambda_max > distance: raise ValueError( - "The given wavelengths cannot be larger than the total distance in m along the axes of the domain." + "The given wavelengths cannot be larger than the total distance in m along the axes" + " of the domain." ) elif lambda_min <= dxy: raise ValueError( - "The given minimum wavelength cannot be smaller than gridspacing dxy. Please note that both dxy and the values for wavelength_filtering should be given in meter." + "The given minimum wavelength cannot be smaller than gridspacing dxy. Please note " + "that both dxy and the values for wavelength_filtering should be given in meter." ) elif np.floor(np.log10(lambda_min)) - np.floor(np.log10(dxy)) > 1: warnings.warn( - "Warning: The values for dxy and the minimum wavelength are close in order of magnitude. Please note that both dxy and for wavelength_filtering should be given in meter." + "Warning: The values for dxy and the minimum wavelength are close in order of " + "magnitude. Please note that both dxy and for wavelength_filtering should be " + "given in meter." ) for i_time, data_i in enumerate(data_time): @@ -1322,15 +1364,16 @@ def feature_detection_multithreshold( strict_thresholding=strict_thresholding, statistics=statistics, ) - # check if list of features is not empty, then merge features from different threshold values - # into one DataFrame and append to list for individual timesteps: + # check if list of features is not empty, then merge features from different threshold + # values into one DataFrame and append to list for individual timesteps: if not features_thresholds.empty: hdim1_ax, hdim2_ax = internal_utils.find_hdim_axes_3D( field_in, vertical_coord=vertical_coord ) hdim1_max = field_in.shape[hdim1_ax] - 1 hdim2_max = field_in.shape[hdim2_ax] - 1 - # Loop over DataFrame to remove features that are closer than distance_min to each other: + # Loop over DataFrame to remove features that are closer than distance_min to each + # other: if min_distance > 0: features_thresholds = filter_min_distance( features_thresholds, @@ -1352,7 +1395,8 @@ def feature_detection_multithreshold( ) logging.debug("feature detection: merging DataFrames") - # Check if features are detected and then concatenate features from different timesteps into one pandas DataFrame + # Check if features are detected and then concatenate features from different timesteps into + # one pandas DataFrame # If no features are detected raise error if any([not x.empty for x in list_features_timesteps]): features = pd.concat(list_features_timesteps, ignore_index=True) @@ -1373,20 +1417,20 @@ def feature_detection_multithreshold( def filter_min_distance( - features, - dxy=None, - dz=None, - min_distance=None, - x_coordinate_name=None, - y_coordinate_name=None, - z_coordinate_name=None, - target="maximum", - PBC_flag="none", - min_h1=0, - max_h1=0, - min_h2=0, - max_h2=0, -): + features: pd.DataFrame, + dxy: float = None, + dz: float = None, + min_distance: float = None, + x_coordinate_name: str = None, + y_coordinate_name: str = None, + z_coordinate_name: str = None, + target: Literal["maximum", "minimum"] = "maximum", + PBC_flag: Literal["none", "hdim_1", "hdim_2", "both"] = "none", + min_h1: int = 0, + max_h1: int = 0, + min_h2: int = 0, + max_h2: int = 0, +) -> pd.DataFrame: """Function to remove features that are too close together. If two features are closer than `min_distance`, it keeps the larger feature. @@ -1471,7 +1515,8 @@ def filter_min_distance( "Both " + z_coordinate_name + " and dz available to filter_min_distance; using constant dz. " - "Set dz to none if you want to use altitude or set `z_coordinate_name` to None to use constant dz." + "Set dz to none if you want to use altitude or set `z_coordinate_name` to None to use " + "constant dz." ) # As optional coordinate names are not yet implemented, set to defaults here: