diff --git a/tobac/feature_detection.py b/tobac/feature_detection.py index fc491ff0..2555c45a 100644 --- a/tobac/feature_detection.py +++ b/tobac/feature_detection.py @@ -627,9 +627,9 @@ def feature_detection_threshold( # find the updated label, and overwrite all of label_ind indices with # updated label labels_2_alt = labels_2[label_z, y_val_alt, x_val_alt] - labels_2[ - label_locs_v, label_locs_h1, label_locs_h2 - ] = labels_2_alt + labels_2[label_locs_v, label_locs_h1, label_locs_h2] = ( + labels_2_alt + ) skip_list = np.append(skip_list, label_ind) break @@ -673,9 +673,9 @@ def feature_detection_threshold( # find the updated label, and overwrite all of label_ind indices with # updated label labels_2_alt = labels_2[label_z, y_val_alt, label_x] - labels_2[ - label_locs_v, label_locs_h1, label_locs_h2 - ] = labels_2_alt + labels_2[label_locs_v, label_locs_h1, label_locs_h2] = ( + labels_2_alt + ) new_label_ind = labels_2_alt skip_list = np.append(skip_list, label_ind) @@ -717,9 +717,9 @@ def feature_detection_threshold( # find the updated label, and overwrite all of label_ind indices with # updated label labels_2_alt = labels_2[label_z, label_y, x_val_alt] - labels_2[ - label_locs_v, label_locs_h1, label_locs_h2 - ] = labels_2_alt + labels_2[label_locs_v, label_locs_h1, label_locs_h2] = ( + labels_2_alt + ) new_label_ind = labels_2_alt skip_list = np.append(skip_list, label_ind) @@ -912,6 +912,7 @@ def feature_detection_multithreshold_timestep( wavelength_filtering: tuple[float] = None, strict_thresholding: bool = False, statistic: Union[dict[str, Union[Callable, tuple[Callable, dict]]], None] = None, + statistics_unsmoothed: bool = False, ) -> pd.DataFrame: """Find features in each timestep. @@ -984,6 +985,9 @@ def feature_detection_multithreshold_timestep( Default is None. Optional parameter to calculate bulk statistics within feature detection. Dictionary with callable function(s) to apply over the region of each detected feature and the name of the statistics to appear in the feature ou tput dataframe. The functions should be the values and the names of the metric the keys (e.g. {'mean': np.mean}) + statistics_unsmoothed: bool, optional + Default is False. If True, calculate the statistics on the raw data instead of the smoothed input data. + Returns ------- features_threshold : pandas DataFrame @@ -1005,6 +1009,14 @@ def feature_detection_multithreshold_timestep( # get actual numpy array and make a copy so as not to change the data in the iris cube track_data = data_i.core_data().copy() + # keep a copy of the unsmoothed data (that can be used for calculating stats) + if statistics_unsmoothed: + if not statistic: + raise ValueError( + "Please provide the input parameter statistic to determine what statistics to calculate." + ) + + track_data = gaussian_filter( track_data, sigma=sigma_threshold ) # smooth data slightly to create rounded, continuous field @@ -1117,14 +1129,24 @@ def feature_detection_multithreshold_timestep( labels.ravel()[regions_old[key]] = key # apply function to get statistics based on labeled regions and functions provided by the user # the feature dataframe is updated by appending a column for each metric - features_thresholds = get_statistics( - features_thresholds, - labels, - track_data, - statistic=statistic, - index=np.unique(labels[labels > 0]), - id_column="idx", - ) + if statistics_unsmoothed: + features_thresholds = get_statistics( + features_thresholds, + labels, + data_i.core_data(), + statistic=statistic, + index=np.unique(labels[labels > 0]), + id_column="idx", + ) + else: + features_thresholds = get_statistics( + features_thresholds, + labels, + track_data, + statistic=statistic, + index=np.unique(labels[labels > 0]), + id_column="idx", + ) logging.debug( "Finished feature detection for threshold " @@ -1158,6 +1180,7 @@ def feature_detection_multithreshold( dz: Union[float, None] = None, strict_thresholding: bool = False, statistic: Union[dict[str, Union[Callable, tuple[Callable, dict]]], None] = None, + statistics_unsmoothed: bool = False, ) -> pd.DataFrame: """Perform feature detection based on contiguous regions. @@ -1370,6 +1393,7 @@ def feature_detection_multithreshold( wavelength_filtering=wavelength_filtering, strict_thresholding=strict_thresholding, statistic=statistic, + statistics_unsmoothed=statistics_unsmoothed, ) # check if list of features is not empty, then merge features from different threshold # values into one DataFrame and append to list for individual timesteps: diff --git a/tobac/segmentation.py b/tobac/segmentation.py index fe2eda2e..4697a25d 100644 --- a/tobac/segmentation.py +++ b/tobac/segmentation.py @@ -824,15 +824,15 @@ def segmentation_timestep( ) # edit value in buddy_features dataframe - buddy_features.hdim_1.values[ - buddy_looper - ] = pbc_utils.transfm_pbc_point( - float(buddy_feat.hdim_1), hdim1_min, hdim1_max + buddy_features.hdim_1.values[buddy_looper] = ( + pbc_utils.transfm_pbc_point( + float(buddy_feat.hdim_1), hdim1_min, hdim1_max + ) ) - buddy_features.hdim_2.values[ - buddy_looper - ] = pbc_utils.transfm_pbc_point( - float(buddy_feat.hdim_2), hdim2_min, hdim2_max + buddy_features.hdim_2.values[buddy_looper] = ( + pbc_utils.transfm_pbc_point( + float(buddy_feat.hdim_2), hdim2_min, hdim2_max + ) ) buddy_looper = buddy_looper + 1 @@ -1010,9 +1010,9 @@ def segmentation_timestep( segmentation_mask_3[z_val_o, y_val_o, x_val_o] != segmentation_mask_4.data[z_seg, y_seg, x_seg] ): - segmentation_mask_3[ - z_val_o, y_val_o, x_val_o - ] = segmentation_mask_4.data[z_seg, y_seg, x_seg] + segmentation_mask_3[z_val_o, y_val_o, x_val_o] = ( + segmentation_mask_4.data[z_seg, y_seg, x_seg] + ) if not is_3D_seg: segmentation_mask_3 = segmentation_mask_3[0] diff --git a/tobac/tests/test_utils_bulk_statistics.py b/tobac/tests/test_utils_bulk_statistics.py index 1db036b0..c62ee821 100644 --- a/tobac/tests/test_utils_bulk_statistics.py +++ b/tobac/tests/test_utils_bulk_statistics.py @@ -8,6 +8,37 @@ import tobac.testing as tb_test +@pytest.mark.parametrize("statistics_unsmoothed", [(False), (True)]) +def test_bulk_statistics_fd(statistics_unsmoothed): + """ + Assure that bulk statistics in feature detection work, both on smoothed and raw data + """ + ### Test 2D data with time dimension + test_data = tb_test.make_simple_sample_data_2D().core_data() + common_dset_opts = { + "in_arr": test_data, + "data_type": "iris", + } + test_data_iris = tb_test.make_dataset_from_arr( + time_dim_num=0, y_dim_num=1, x_dim_num=2, **common_dset_opts + ) + stats = {"feature_max": np.max} + + # detect features + threshold = 7 + fd_output = tobac.feature_detection.feature_detection_multithreshold( + test_data_iris, + dxy=1000, + threshold=[threshold], + n_min_threshold=100, + target="maximum", + statistic=stats, + statistics_unsmoothed=statistics_unsmoothed, + ) + + assert "feature_max" in fd_output.columns + + @pytest.mark.parametrize( "id_column, index", [("feature", [1]), ("feature_id", [1]), ("cell", [1])] ) diff --git a/tobac/utils/decorators.py b/tobac/utils/decorators.py index 90e600b5..8d304e71 100644 --- a/tobac/utils/decorators.py +++ b/tobac/utils/decorators.py @@ -75,9 +75,7 @@ def _conv_kwargs_irispandas_to_xarray(conv_kwargs: dict): key: ( convert_cube_to_dataarray(arg) if isinstance(arg, iris.cube.Cube) - else arg.to_xarray() - if isinstance(arg, pd.DataFrame) - else arg + else arg.to_xarray() if isinstance(arg, pd.DataFrame) else arg ) for key, arg in zip(conv_kwargs.keys(), conv_kwargs.values()) } @@ -123,9 +121,7 @@ def _conv_kwargs_xarray_to_irispandas(conv_kwargs: dict): key: ( xr.DataArray.to_iris(arg) if isinstance(arg, xr.DataArray) - else arg.to_dataframe() - if isinstance(arg, xr.Dataset) - else arg + else arg.to_dataframe() if isinstance(arg, xr.Dataset) else arg ) for key, arg in zip(conv_kwargs.keys(), conv_kwargs.values()) } @@ -340,9 +336,7 @@ def wrapper(*args, **kwargs): ( convert_cube_to_dataarray(arg) if type(arg) == iris.cube.Cube - else arg.to_xarray() - if type(arg) == pd.DataFrame - else arg + else arg.to_xarray() if type(arg) == pd.DataFrame else arg ) for arg in args ]