Skip to content

Commit d051127

Browse files
committed
Check accelerated functions work when the dataset has an extra dimension for some variables
1 parent 5c26a8d commit d051127

File tree

2 files changed

+71
-23
lines changed

2 files changed

+71
-23
lines changed

tests/conftest.py

+18
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,24 @@ def tracks_csv():
1414
return huracanpy.load(huracanpy.example_csv_file)
1515

1616

17+
@pytest.fixture()
18+
def tracks_with_extra_coord(tracks_csv):
19+
# Test that the same results apply if a variable has an additional dimension to the
20+
# time/track_id dimension (e.g. if each point had a profile on pressure levels)
21+
# Most functions should work fine but using pandas can cause the data to be
22+
# broadcast across the dimensions to be able to represent it as 1d
23+
return tracks_csv.assign(
24+
thing=(
25+
(
26+
("record", "level"),
27+
np.array(
28+
[np.ones_like(tracks_csv.lon), np.ones_like(tracks_csv.lon) * 2],
29+
).T,
30+
)
31+
)
32+
)
33+
34+
1735
@pytest.fixture()
1836
def tracks_minus180_plus180():
1937
return xr.Dataset(

tests/test_accel.py

+53-23
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,23 @@
33
slower method
44
"""
55

6+
import pytest
67
from haversine import haversine_vector, Unit
78
import numpy as np
89
import xarray as xr
910

1011
import huracanpy
1112

1213

13-
def test_accel_sel_id(tracks_csv):
14-
result = huracanpy.sel_id(tracks_csv, tracks_csv.track_id, 0)
14+
@pytest.mark.parametrize(
15+
("tracks",),
16+
(["tracks_csv"], ["tracks_with_extra_coord"]),
17+
)
18+
def test_accel_sel_id(tracks, request):
19+
tracks = request.getfixturevalue(tracks)
20+
result = huracanpy.sel_id(tracks, tracks.track_id, 0)
1521

16-
expected = tracks_csv.groupby("track_id")[0]
22+
expected = tracks.groupby("track_id")[0]
1723

1824
xr.testing.assert_identical(result, expected)
1925

@@ -23,30 +29,41 @@ def test_accel_trackswhere():
2329
pass
2430

2531

26-
def test_accel_get_gen_vals(tracks_csv):
27-
result = huracanpy.calc.get_gen_vals(
28-
tracks_csv, tracks_csv.time, tracks_csv.track_id
29-
)
32+
@pytest.mark.parametrize(
33+
("tracks",),
34+
(["tracks_csv"], ["tracks_with_extra_coord"]),
35+
)
36+
def test_accel_get_gen_vals(tracks, request):
37+
tracks = request.getfixturevalue(tracks)
38+
result = huracanpy.calc.get_gen_vals(tracks, tracks.time, tracks.track_id)
3039

31-
expected = tracks_csv.groupby("track_id").first()
40+
expected = tracks.groupby("track_id").first()
3241

3342
xr.testing.assert_identical(result, expected)
3443

3544

36-
def test_accel_get_apex_vals(tracks_csv):
37-
result = huracanpy.calc.get_apex_vals(
38-
tracks_csv, tracks_csv.wind10, tracks_csv.track_id
39-
)
45+
@pytest.mark.parametrize(
46+
("tracks",),
47+
(["tracks_csv"], ["tracks_with_extra_coord"]),
48+
)
49+
def test_accel_get_apex_vals(tracks, request):
50+
tracks = request.getfixturevalue(tracks)
51+
result = huracanpy.calc.get_apex_vals(tracks, tracks.wind10, tracks.track_id)
4052

41-
expected = tracks_csv.sortby("wind10", ascending=False).groupby("track_id").first()
53+
expected = tracks.sortby("wind10", ascending=False).groupby("track_id").first()
4254

4355
xr.testing.assert_identical(result, expected)
4456

4557

46-
def test_accel_get_time_from_genesis(tracks_csv):
47-
result = huracanpy.calc.get_time_from_genesis(tracks_csv.time, tracks_csv.track_id)
58+
@pytest.mark.parametrize(
59+
("tracks",),
60+
(["tracks_csv"], ["tracks_with_extra_coord"]),
61+
)
62+
def test_accel_get_time_from_genesis(tracks, request):
63+
tracks = request.getfixturevalue(tracks)
64+
result = huracanpy.calc.get_time_from_genesis(tracks.time, tracks.track_id)
4865

49-
track_groups = tracks_csv.groupby("track_id")
66+
track_groups = tracks.groupby("track_id")
5067
expected = []
5168
for track_id, track in track_groups:
5269
expected.append(track.time - track.time[0])
@@ -57,12 +74,17 @@ def test_accel_get_time_from_genesis(tracks_csv):
5774
xr.testing.assert_identical(result, expected)
5875

5976

60-
def test_accel_get_time_from_apex(tracks_csv):
77+
@pytest.mark.parametrize(
78+
("tracks",),
79+
(["tracks_csv"], ["tracks_with_extra_coord"]),
80+
)
81+
def test_accel_get_time_from_apex(tracks, request):
82+
tracks = request.getfixturevalue(tracks)
6183
result = huracanpy.calc.get_time_from_apex(
62-
tracks_csv.time, tracks_csv.track_id, tracks_csv.wind10
84+
tracks.time, tracks.track_id, tracks.wind10
6385
)
6486

65-
track_groups = tracks_csv.groupby("track_id")
87+
track_groups = tracks.groupby("track_id")
6688
expected = []
6789
for track_id, track in track_groups:
6890
idx = track.wind10.argmax()
@@ -74,8 +96,12 @@ def test_accel_get_time_from_apex(tracks_csv):
7496
xr.testing.assert_identical(result, expected)
7597

7698

77-
def test_accel_match():
78-
ref = huracanpy.load(huracanpy.example_csv_file)
99+
@pytest.mark.parametrize(
100+
("tracks",),
101+
(["tracks_csv"], ["tracks_with_extra_coord"]),
102+
)
103+
def test_accel_match(tracks, request):
104+
ref = request.getfixturevalue(tracks)
79105
tracks = ref.where(ref.track_id < 2, drop=True)
80106
tracks = tracks.where(tracks.time.dt.hour == 0, drop=True)
81107
tracks["lon"] = tracks.lon + 0.5
@@ -117,8 +143,12 @@ def test_accel_match():
117143
np.testing.assert_allclose(result.dist, np.array(dist), rtol=1e-12)
118144

119145

120-
def test_accel_overlap():
121-
ref = huracanpy.load(huracanpy.example_csv_file)
146+
@pytest.mark.parametrize(
147+
("tracks",),
148+
(["tracks_csv"], ["tracks_with_extra_coord"]),
149+
)
150+
def test_accel_overlap(tracks, request):
151+
ref = request.getfixturevalue(tracks)
122152
tracks = ref.where(ref.track_id < 2, drop=True)
123153
tracks = tracks.where(tracks.time.dt.hour == 0, drop=True)
124154
tracks["lon"] = tracks.lon + 0.5

0 commit comments

Comments
 (0)