Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional calculation of bulk statistics in feature detection on raw input (before smoothing) #449

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions tobac/feature_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,9 +627,9 @@
# find the updated label, and overwrite all of label_ind indices with
# updated label
labels_2_alt = labels_2[label_z, y_val_alt, x_val_alt]
labels_2[
label_locs_v, label_locs_h1, label_locs_h2
] = labels_2_alt
labels_2[label_locs_v, label_locs_h1, label_locs_h2] = (

Check warning on line 630 in tobac/feature_detection.py

View check run for this annotation

Codecov / codecov/patch

tobac/feature_detection.py#L630

Added line #L630 was not covered by tests
labels_2_alt
)
skip_list = np.append(skip_list, label_ind)
break

Expand Down Expand Up @@ -673,9 +673,9 @@
# find the updated label, and overwrite all of label_ind indices with
# updated label
labels_2_alt = labels_2[label_z, y_val_alt, label_x]
labels_2[
label_locs_v, label_locs_h1, label_locs_h2
] = labels_2_alt
labels_2[label_locs_v, label_locs_h1, label_locs_h2] = (
labels_2_alt
)
new_label_ind = labels_2_alt
skip_list = np.append(skip_list, label_ind)

Expand Down Expand Up @@ -717,9 +717,9 @@
# find the updated label, and overwrite all of label_ind indices with
# updated label
labels_2_alt = labels_2[label_z, label_y, x_val_alt]
labels_2[
label_locs_v, label_locs_h1, label_locs_h2
] = labels_2_alt
labels_2[label_locs_v, label_locs_h1, label_locs_h2] = (
labels_2_alt
)
new_label_ind = labels_2_alt
skip_list = np.append(skip_list, label_ind)

Expand Down Expand Up @@ -912,6 +912,7 @@
wavelength_filtering: tuple[float] = None,
strict_thresholding: bool = False,
statistic: Union[dict[str, Union[Callable, tuple[Callable, dict]]], None] = None,
statistics_unsmoothed: bool = False,
) -> pd.DataFrame:
"""Find features in each timestep.

Expand Down Expand Up @@ -984,6 +985,9 @@
Default is None. Optional parameter to calculate bulk statistics within feature detection.
Dictionary with callable function(s) to apply over the region of each detected feature and the name of the statistics to appear in the feature ou tput dataframe. The functions should be the values and the names of the metric the keys (e.g. {'mean': np.mean})

statistics_unsmoothed: bool, optional
Default is False. If True, calculate the statistics on the raw data instead of the smoothed input data.

Returns
-------
features_threshold : pandas DataFrame
Expand All @@ -1005,6 +1009,14 @@
# get actual numpy array and make a copy so as not to change the data in the iris cube
track_data = data_i.core_data().copy()

# keep a copy of the unsmoothed data (that can be used for calculating stats)
JuliaKukulies marked this conversation as resolved.
Show resolved Hide resolved
if statistics_unsmoothed:
if not statistic:
raise ValueError(

Check warning on line 1015 in tobac/feature_detection.py

View check run for this annotation

Codecov / codecov/patch

tobac/feature_detection.py#L1015

Added line #L1015 was not covered by tests
"Please provide the input parameter statistic to determine what statistics to calculate."
)
raw_data = data_i.core_data().copy()

track_data = gaussian_filter(
track_data, sigma=sigma_threshold
) # smooth data slightly to create rounded, continuous field
Expand Down Expand Up @@ -1117,14 +1129,24 @@
labels.ravel()[regions_old[key]] = key
# apply function to get statistics based on labeled regions and functions provided by the user
# the feature dataframe is updated by appending a column for each metric
features_thresholds = get_statistics(
features_thresholds,
labels,
track_data,
statistic=statistic,
index=np.unique(labels[labels > 0]),
id_column="idx",
)
if statistics_unsmoothed:
features_thresholds = get_statistics(
features_thresholds,
labels,
raw_data,
statistic=statistic,
index=np.unique(labels[labels > 0]),
id_column="idx",
)
else:
features_thresholds = get_statistics(
features_thresholds,
labels,
track_data,
statistic=statistic,
index=np.unique(labels[labels > 0]),
id_column="idx",
)

logging.debug(
"Finished feature detection for threshold "
Expand Down Expand Up @@ -1158,6 +1180,7 @@
dz: Union[float, None] = None,
strict_thresholding: bool = False,
statistic: Union[dict[str, Union[Callable, tuple[Callable, dict]]], None] = None,
statistics_unsmoothed: bool = False,
) -> pd.DataFrame:
"""Perform feature detection based on contiguous regions.

Expand Down Expand Up @@ -1370,6 +1393,7 @@
wavelength_filtering=wavelength_filtering,
strict_thresholding=strict_thresholding,
statistic=statistic,
statistics_unsmoothed=statistics_unsmoothed,
)
# check if list of features is not empty, then merge features from different threshold
# values into one DataFrame and append to list for individual timesteps:
Expand Down
22 changes: 11 additions & 11 deletions tobac/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,15 +824,15 @@ def segmentation_timestep(
)

# edit value in buddy_features dataframe
buddy_features.hdim_1.values[
buddy_looper
] = pbc_utils.transfm_pbc_point(
float(buddy_feat.hdim_1), hdim1_min, hdim1_max
buddy_features.hdim_1.values[buddy_looper] = (
pbc_utils.transfm_pbc_point(
float(buddy_feat.hdim_1), hdim1_min, hdim1_max
)
)
buddy_features.hdim_2.values[
buddy_looper
] = pbc_utils.transfm_pbc_point(
float(buddy_feat.hdim_2), hdim2_min, hdim2_max
buddy_features.hdim_2.values[buddy_looper] = (
pbc_utils.transfm_pbc_point(
float(buddy_feat.hdim_2), hdim2_min, hdim2_max
)
)

buddy_looper = buddy_looper + 1
Expand Down Expand Up @@ -1010,9 +1010,9 @@ def segmentation_timestep(
segmentation_mask_3[z_val_o, y_val_o, x_val_o]
!= segmentation_mask_4.data[z_seg, y_seg, x_seg]
):
segmentation_mask_3[
z_val_o, y_val_o, x_val_o
] = segmentation_mask_4.data[z_seg, y_seg, x_seg]
segmentation_mask_3[z_val_o, y_val_o, x_val_o] = (
segmentation_mask_4.data[z_seg, y_seg, x_seg]
)
if not is_3D_seg:
segmentation_mask_3 = segmentation_mask_3[0]

Expand Down
31 changes: 31 additions & 0 deletions tobac/tests/test_utils_bulk_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,37 @@
import tobac.testing as tb_test


@pytest.mark.parametrize("statistics_unsmoothed", [(False), (True)])
def test_bulk_statistics_fd(statistics_unsmoothed):
"""
Assure that bulk statistics in feature detection work, both on smoothed and raw data
"""
### Test 2D data with time dimension
test_data = tb_test.make_simple_sample_data_2D().core_data()
common_dset_opts = {
"in_arr": test_data,
"data_type": "iris",
}
test_data_iris = tb_test.make_dataset_from_arr(
time_dim_num=0, y_dim_num=1, x_dim_num=2, **common_dset_opts
)
stats = {"feature_max": np.max}

# detect features
threshold = 7
fd_output = tobac.feature_detection.feature_detection_multithreshold(
test_data_iris,
dxy=1000,
threshold=[threshold],
n_min_threshold=100,
target="maximum",
statistic=stats,
statistics_unsmoothed=statistics_unsmoothed,
)

assert "feature_max" in fd_output.columns


@pytest.mark.parametrize(
"id_column, index", [("feature", [1]), ("feature_id", [1]), ("cell", [1])]
)
Expand Down
12 changes: 3 additions & 9 deletions tobac/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ def _conv_kwargs_irispandas_to_xarray(conv_kwargs: dict):
key: (
convert_cube_to_dataarray(arg)
if isinstance(arg, iris.cube.Cube)
else arg.to_xarray()
if isinstance(arg, pd.DataFrame)
else arg
else arg.to_xarray() if isinstance(arg, pd.DataFrame) else arg
)
for key, arg in zip(conv_kwargs.keys(), conv_kwargs.values())
}
Expand Down Expand Up @@ -123,9 +121,7 @@ def _conv_kwargs_xarray_to_irispandas(conv_kwargs: dict):
key: (
xr.DataArray.to_iris(arg)
if isinstance(arg, xr.DataArray)
else arg.to_dataframe()
if isinstance(arg, xr.Dataset)
else arg
else arg.to_dataframe() if isinstance(arg, xr.Dataset) else arg
)
for key, arg in zip(conv_kwargs.keys(), conv_kwargs.values())
}
Expand Down Expand Up @@ -340,9 +336,7 @@ def wrapper(*args, **kwargs):
(
convert_cube_to_dataarray(arg)
if type(arg) == iris.cube.Cube
else arg.to_xarray()
if type(arg) == pd.DataFrame
else arg
else arg.to_xarray() if type(arg) == pd.DataFrame else arg
)
for arg in args
]
Expand Down
Loading