Skip to content

Commit df39cc4

Browse files
committed
Enforce that the input to huracanpy.save is not modified, but sort by track_id for the saved file to ensure the contiguous ragged data makes sense
1 parent b4df40a commit df39cc4

File tree

3 files changed

+33
-21
lines changed

3 files changed

+33
-21
lines changed

huracanpy/_data/_load.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
lat_track="lat",
1515
# Names for CHAZ netCDF
1616
stormID="track_id",
17+
# Names for TRACK netCDF
18+
TRACK_ID="track_id",
1719
)
1820

1921

huracanpy/_data/_netcdf.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,10 @@ def save(dataset, filename):
4747
f"{trajectory_id.name} spans multiple dimensions, should be 1d"
4848
)
4949

50-
# np.unique returns a sorted array, so return the index so that the trajectory_ids
51-
# can be put back in the same order as they are in the original dataset otherwise
52-
# the ordering of data can be messed up if the trajectories ids aren't monotonic
53-
trajectory_ids, idx = np.unique(trajectory_id, return_index=True)
54-
trajectory_ids = trajectory_id[sorted(idx)].values
50+
# Sort by trajectory_id so each track can be described by the first index and
51+
# number of elements of the unique trajectory id
52+
dataset = dataset.sortby(trajectory_id.name)
53+
trajectory_ids = np.unique(trajectory_id)
5554
rowsize = [np.count_nonzero(trajectory_id == x) for x in trajectory_ids]
5655

5756
dataset[trajectory_id.name] = ("trajectory", trajectory_ids)
@@ -75,9 +74,10 @@ def stretch_trid(dataset):
7574

7675
dataset = dataset.drop_vars([trajectory_id.name, rowsize.name])
7776

78-
dataset["track_id"] = (sample_dimension, trajectory_id_stretched)
79-
# Keep attributes (including cf_role)
80-
dataset["track_id"].attrs = trajectory_id.attrs
77+
dataset[trajectory_id.name] = (sample_dimension, trajectory_id_stretched)
78+
# Keep attributes (add cf_role if not already there)
79+
dataset[trajectory_id.name].attrs = trajectory_id.attrs
80+
dataset[trajectory_id.name].attrs["cf_role"] = "trajectory_id"
8181

8282
return dataset
8383

@@ -118,7 +118,6 @@ def _find_trajectory_id(dataset):
118118
return trajectory_id[0]
119119
else:
120120
if "track_id" in dataset:
121-
dataset["track_id"].attrs["cf_role"] = "trajectory_id"
122121
return dataset["track_id"]
123122
else:
124123
raise ValueError(

tests/test_huracanpy.py

+23-12
Original file line numberDiff line numberDiff line change
@@ -65,29 +65,40 @@ def test_save(filename, source, extension, muddle, tmp_path):
6565
# Caused an issue because they got sorted before
6666
if muddle:
6767
data = data.sortby("track_id", ascending=False)
68+
6869
# Copy the data because save modifies the dataset at the moment
69-
huracanpy.save(data.copy(), str(tmp_path / f"tmp_file.{extension}"))
70+
data_orig = data.copy()
71+
huracanpy.save(data, str(tmp_path / f"tmp_file.{extension}"))
72+
73+
# Check that the original data is not modified by the save function
74+
_assert_dataset_identical(data_orig, data)
7075

7176
# Reload the data and check it is still the same
72-
data_ = huracanpy.load(str(tmp_path / f"tmp_file.{extension}"))
77+
# Saving as netcdf does force sorting by track_id so apply this
78+
if extension == "nc":
79+
data = data.sortby("track_id")
80+
data_reload = huracanpy.load(str(tmp_path / f"tmp_file.{extension}"))
81+
_assert_dataset_identical(data, data_reload)
82+
7383

74-
assert len(data.variables) == len(data_.variables)
75-
assert len(data.coords) == len(data_.coords)
76-
for var in list(data.variables) + list(data.coords):
84+
def _assert_dataset_identical(ds1, ds2):
85+
assert len(ds1.variables) == len(ds2.variables)
86+
assert len(ds1.coords) == len(ds2.coords)
87+
for var in list(ds1.variables) + list(ds1.coords):
7788
# Work around for xarray inconsistent loading the data as float or double
7889
# depending on fill_value and scale_factor
7990
# np.testing.assert_allclose doesn't work for datetime64, object, or string
80-
if np.issubdtype(data[var].dtype, np.number):
81-
if data[var].dtype != data_[var].dtype:
91+
if np.issubdtype(ds1[var].dtype, np.number):
92+
if ds1[var].dtype != ds2[var].dtype:
8293
rtol = 1e-6
8394
else:
8495
rtol = 0
8596
np.testing.assert_allclose(
86-
data[var].data.astype(data_[var].dtype), data_[var].data, rtol=rtol
97+
ds1[var].data.astype(ds2[var].dtype), ds2[var].data, rtol=rtol
8798
)
8899
else:
89-
assert (data[var].data == data_[var].data).all()
100+
assert (ds1[var].data == ds2[var].data).all()
90101

91-
assert len(data.attrs) == len(data_.attrs)
92-
for attr in data.attrs:
93-
assert data.attrs[attr] == data_.attrs[attr]
102+
assert len(ds1.attrs) == len(ds2.attrs)
103+
for attr in ds1.attrs:
104+
assert ds1.attrs[attr] == ds2.attrs[attr]

0 commit comments

Comments
 (0)