Skip to content

Commit

Permalink
Change .data to .values call in get_statistics_from_mask to prevent p…
Browse files Browse the repository at this point in the history
…assing dask array to get_statistics, and added test for dask array input
  • Loading branch information
w-k-jones committed Dec 13, 2024
1 parent 706b3ee commit 0b75b46
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 2 deletions.
87 changes: 86 additions & 1 deletion tobac/tests/test_utils_bulk_statistics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from datetime import datetime

import numpy as np
import dask.array as da
import pandas as pd
import pytest
import xarray as xr
import pytest

import tobac
import tobac.utils as tb_utils
import tobac.testing as tb_test
Expand Down Expand Up @@ -702,3 +705,85 @@ def test_get_statistics_from_mask_collapse_dim():
statistic=statistics_sum,
collapse_dim="not_a_dim",
)


def test_bulk_statistics_dask():
"""
Test dask input for labels and fields is handled correctly
"""

test_labels = da.array(
[
[
[0, 0, 0, 0, 0],
[0, 1, 0, 2, 0],
[0, 1, 0, 2, 0],
[0, 1, 0, 0, 0],
[0, 0, 0, 0, 0],
],
[
[0, 0, 0, 0, 0],
[0, 3, 0, 0, 0],
[0, 3, 0, 4, 0],
[0, 3, 0, 4, 0],
[0, 0, 0, 0, 0],
],
],
dtype=int,
)

test_labels = xr.DataArray(
test_labels,
dims=("time", "y", "x"),
coords={
"time": [datetime(2000, 1, 1), datetime(2000, 1, 1, 0, 5)],
"y": np.arange(5),
"x": np.arange(5),
},
)

test_values = da.array(
[
[
[0, 0, 0, 0, 0],
[0, 1, 0, 2, 0],
[0, 2, 0, 2, 0],
[0, 3, 0, 0, 0],
[0, 0, 0, 0, 0],
],
[
[0, 0, 0, 0, 0],
[0, 2, 0, 0, 0],
[0, 3, 0, 3, 0],
[0, 4, 0, 2, 0],
[0, 0, 0, 0, 0],
],
]
)

test_values = xr.DataArray(
test_values, dims=test_labels.dims, coords=test_labels.coords
)

test_features = pd.DataFrame(
{
"feature": [1, 2, 3, 4],
"frame": [0, 0, 1, 1],
"time": [
datetime(2000, 1, 1),
datetime(2000, 1, 1),
datetime(2000, 1, 1, 0, 5),
datetime(2000, 1, 1, 0, 5),
],
}
)

statistics_size = {"size": np.size}

expected_size_result = np.array([3, 2, 3, 2])

bulk_statistics_output = tb_utils.get_statistics_from_mask(
test_features, test_labels, test_values, statistic=statistics_size
)

assert np.all(bulk_statistics_output["size"] == expected_size_result)
2 changes: 1 addition & 1 deletion tobac/utils/bulk_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def get_statistics_from_mask(

for tt in pd.to_datetime(segmentation_mask.time):
# select specific timestep
segmentation_mask_t = segmentation_mask.sel(time=tt, method="nearest").data
segmentation_mask_t = segmentation_mask.sel(time=tt, method="nearest").values
fields_t = (
(
field.sel(
Expand Down

0 comments on commit 0b75b46

Please sign in to comment.