Skip to content

Commit

Permalink
Merge pull request #449 from JuliaKukulies/bulk_stats_on_raw_input
Browse files Browse the repository at this point in the history
Optional calculation of bulk statistics in feature detection on raw input (before smoothing)
  • Loading branch information
JuliaKukulies authored Oct 14, 2024
2 parents c38ae1c + 132ea24 commit d9f2df8
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 37 deletions.
58 changes: 41 additions & 17 deletions tobac/feature_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,9 +627,9 @@ def feature_detection_threshold(
# find the updated label, and overwrite all of label_ind indices with
# updated label
labels_2_alt = labels_2[label_z, y_val_alt, x_val_alt]
labels_2[
label_locs_v, label_locs_h1, label_locs_h2
] = labels_2_alt
labels_2[label_locs_v, label_locs_h1, label_locs_h2] = (
labels_2_alt
)
skip_list = np.append(skip_list, label_ind)
break

Expand Down Expand Up @@ -673,9 +673,9 @@ def feature_detection_threshold(
# find the updated label, and overwrite all of label_ind indices with
# updated label
labels_2_alt = labels_2[label_z, y_val_alt, label_x]
labels_2[
label_locs_v, label_locs_h1, label_locs_h2
] = labels_2_alt
labels_2[label_locs_v, label_locs_h1, label_locs_h2] = (
labels_2_alt
)
new_label_ind = labels_2_alt
skip_list = np.append(skip_list, label_ind)

Expand Down Expand Up @@ -717,9 +717,9 @@ def feature_detection_threshold(
# find the updated label, and overwrite all of label_ind indices with
# updated label
labels_2_alt = labels_2[label_z, label_y, x_val_alt]
labels_2[
label_locs_v, label_locs_h1, label_locs_h2
] = labels_2_alt
labels_2[label_locs_v, label_locs_h1, label_locs_h2] = (
labels_2_alt
)
new_label_ind = labels_2_alt
skip_list = np.append(skip_list, label_ind)

Expand Down Expand Up @@ -912,6 +912,7 @@ def feature_detection_multithreshold_timestep(
wavelength_filtering: tuple[float] = None,
strict_thresholding: bool = False,
statistic: Union[dict[str, Union[Callable, tuple[Callable, dict]]], None] = None,
statistics_unsmoothed: bool = False,
) -> pd.DataFrame:
"""Find features in each timestep.
Expand Down Expand Up @@ -984,6 +985,9 @@ def feature_detection_multithreshold_timestep(
Default is None. Optional parameter to calculate bulk statistics within feature detection.
Dictionary with callable function(s) to apply over the region of each detected feature and the name of the statistics to appear in the feature ou tput dataframe. The functions should be the values and the names of the metric the keys (e.g. {'mean': np.mean})
statistics_unsmoothed: bool, optional
Default is False. If True, calculate the statistics on the raw data instead of the smoothed input data.
Returns
-------
features_threshold : pandas DataFrame
Expand All @@ -1005,6 +1009,14 @@ def feature_detection_multithreshold_timestep(
# get actual numpy array and make a copy so as not to change the data in the iris cube
track_data = data_i.core_data().copy()

# keep a copy of the unsmoothed data (that can be used for calculating stats)
if statistics_unsmoothed:
if not statistic:
raise ValueError(
"Please provide the input parameter statistic to determine what statistics to calculate."
)


track_data = gaussian_filter(
track_data, sigma=sigma_threshold
) # smooth data slightly to create rounded, continuous field
Expand Down Expand Up @@ -1117,14 +1129,24 @@ def feature_detection_multithreshold_timestep(
labels.ravel()[regions_old[key]] = key
# apply function to get statistics based on labeled regions and functions provided by the user
# the feature dataframe is updated by appending a column for each metric
features_thresholds = get_statistics(
features_thresholds,
labels,
track_data,
statistic=statistic,
index=np.unique(labels[labels > 0]),
id_column="idx",
)
if statistics_unsmoothed:
features_thresholds = get_statistics(
features_thresholds,
labels,
data_i.core_data(),
statistic=statistic,
index=np.unique(labels[labels > 0]),
id_column="idx",
)
else:
features_thresholds = get_statistics(
features_thresholds,
labels,
track_data,
statistic=statistic,
index=np.unique(labels[labels > 0]),
id_column="idx",
)

logging.debug(
"Finished feature detection for threshold "
Expand Down Expand Up @@ -1158,6 +1180,7 @@ def feature_detection_multithreshold(
dz: Union[float, None] = None,
strict_thresholding: bool = False,
statistic: Union[dict[str, Union[Callable, tuple[Callable, dict]]], None] = None,
statistics_unsmoothed: bool = False,
) -> pd.DataFrame:
"""Perform feature detection based on contiguous regions.
Expand Down Expand Up @@ -1370,6 +1393,7 @@ def feature_detection_multithreshold(
wavelength_filtering=wavelength_filtering,
strict_thresholding=strict_thresholding,
statistic=statistic,
statistics_unsmoothed=statistics_unsmoothed,
)
# check if list of features is not empty, then merge features from different threshold
# values into one DataFrame and append to list for individual timesteps:
Expand Down
22 changes: 11 additions & 11 deletions tobac/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,15 +824,15 @@ def segmentation_timestep(
)

# edit value in buddy_features dataframe
buddy_features.hdim_1.values[
buddy_looper
] = pbc_utils.transfm_pbc_point(
float(buddy_feat.hdim_1), hdim1_min, hdim1_max
buddy_features.hdim_1.values[buddy_looper] = (
pbc_utils.transfm_pbc_point(
float(buddy_feat.hdim_1), hdim1_min, hdim1_max
)
)
buddy_features.hdim_2.values[
buddy_looper
] = pbc_utils.transfm_pbc_point(
float(buddy_feat.hdim_2), hdim2_min, hdim2_max
buddy_features.hdim_2.values[buddy_looper] = (
pbc_utils.transfm_pbc_point(
float(buddy_feat.hdim_2), hdim2_min, hdim2_max
)
)

buddy_looper = buddy_looper + 1
Expand Down Expand Up @@ -1010,9 +1010,9 @@ def segmentation_timestep(
segmentation_mask_3[z_val_o, y_val_o, x_val_o]
!= segmentation_mask_4.data[z_seg, y_seg, x_seg]
):
segmentation_mask_3[
z_val_o, y_val_o, x_val_o
] = segmentation_mask_4.data[z_seg, y_seg, x_seg]
segmentation_mask_3[z_val_o, y_val_o, x_val_o] = (
segmentation_mask_4.data[z_seg, y_seg, x_seg]
)
if not is_3D_seg:
segmentation_mask_3 = segmentation_mask_3[0]

Expand Down
31 changes: 31 additions & 0 deletions tobac/tests/test_utils_bulk_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,37 @@
import tobac.testing as tb_test


@pytest.mark.parametrize("statistics_unsmoothed", [(False), (True)])
def test_bulk_statistics_fd(statistics_unsmoothed):
"""
Assure that bulk statistics in feature detection work, both on smoothed and raw data
"""
### Test 2D data with time dimension
test_data = tb_test.make_simple_sample_data_2D().core_data()
common_dset_opts = {
"in_arr": test_data,
"data_type": "iris",
}
test_data_iris = tb_test.make_dataset_from_arr(
time_dim_num=0, y_dim_num=1, x_dim_num=2, **common_dset_opts
)
stats = {"feature_max": np.max}

# detect features
threshold = 7
fd_output = tobac.feature_detection.feature_detection_multithreshold(
test_data_iris,
dxy=1000,
threshold=[threshold],
n_min_threshold=100,
target="maximum",
statistic=stats,
statistics_unsmoothed=statistics_unsmoothed,
)

assert "feature_max" in fd_output.columns


@pytest.mark.parametrize(
"id_column, index", [("feature", [1]), ("feature_id", [1]), ("cell", [1])]
)
Expand Down
12 changes: 3 additions & 9 deletions tobac/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ def _conv_kwargs_irispandas_to_xarray(conv_kwargs: dict):
key: (
convert_cube_to_dataarray(arg)
if isinstance(arg, iris.cube.Cube)
else arg.to_xarray()
if isinstance(arg, pd.DataFrame)
else arg
else arg.to_xarray() if isinstance(arg, pd.DataFrame) else arg
)
for key, arg in zip(conv_kwargs.keys(), conv_kwargs.values())
}
Expand Down Expand Up @@ -123,9 +121,7 @@ def _conv_kwargs_xarray_to_irispandas(conv_kwargs: dict):
key: (
xr.DataArray.to_iris(arg)
if isinstance(arg, xr.DataArray)
else arg.to_dataframe()
if isinstance(arg, xr.Dataset)
else arg
else arg.to_dataframe() if isinstance(arg, xr.Dataset) else arg
)
for key, arg in zip(conv_kwargs.keys(), conv_kwargs.values())
}
Expand Down Expand Up @@ -340,9 +336,7 @@ def wrapper(*args, **kwargs):
(
convert_cube_to_dataarray(arg)
if type(arg) == iris.cube.Cube
else arg.to_xarray()
if type(arg) == pd.DataFrame
else arg
else arg.to_xarray() if type(arg) == pd.DataFrame else arg
)
for arg in args
]
Expand Down

0 comments on commit d9f2df8

Please sign in to comment.