Skip to content

Commit 40e140c

Browse files
committed
Make accelerated functions equivalent to naive slow approaches and add tests to verify
1 parent 13d646f commit 40e140c

File tree

7 files changed

+187
-24
lines changed

7 files changed

+187
-24
lines changed

huracanpy/_subset.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,36 @@
1+
import numpy as np
12
import xarray as xr
23

34
__all__ = ["trackswhere", "sel_id"]
45

56

6-
def sel_id(data, track_id, track_ids):
7+
def sel_id(tracks, track_ids, track_id):
78
"""Select an individual track from a set of tracks by ID
89
910
Parameters
1011
----------
11-
data : xarray.Dataset
12-
track_id : scalar
12+
tracks : xarray.Dataset
1313
track_ids : xarray.DataArray
14+
The track_ids corresponding to the tracks Dataset
15+
track_id : Any
16+
The track ID to match in track_ids. Must be the same type as the track_ids.
17+
Usually `int` or `str`
1418
1519
Returns
1620
-------
1721
xarray.Dataset
1822
1923
"""
20-
df = data.to_dataframe()
21-
track = df[track_ids == track_id]
22-
return track.to_xarray()
24+
if track_ids.ndim != 1:
25+
raise ValueError("track_ids must be 1d")
26+
27+
dim = track_ids.dims[0]
28+
idx = np.where(track_ids == track_id)[0]
29+
30+
return tracks.isel(**{dim: idx})
2331

2432

25-
def trackswhere(tracks, condition):
33+
def trackswhere(tracks, track_ids, condition):
2634
"""Subset tracks from the input
2735
2836
e.g select all tracks that are solely in the Northern hemisphere
@@ -31,7 +39,8 @@ def trackswhere(tracks, condition):
3139
Parameters
3240
----------
3341
tracks : xarray.Dataset
34-
condition : function
42+
track_ids : xarray.DataArray
43+
condition : callable
3544
A function that takes an `xarray.Dataset` of an individual track and returns
3645
True or False
3746
@@ -41,6 +50,9 @@ def trackswhere(tracks, condition):
4150
A dataset with the subset of tracks from the input that match the given criteria
4251
4352
"""
53+
if track_ids.ndim != 1:
54+
raise ValueError("track_ids must be 1d")
55+
4456
track_groups = tracks.groupby("track_id")
4557

4658
if callable(condition):
@@ -50,7 +62,4 @@ def trackswhere(tracks, condition):
5062
track for n, (track_id, track) in enumerate(track_groups) if is_match[n]
5163
]
5264

53-
if len(tracks.time.dims) == 1:
54-
raise ValueError("trackswhere input must have exactly 1 time dimension")
55-
56-
return xr.concat(track_groups, dim=tracks.time.dims[0])
65+
return xr.concat(track_groups, dim=track_ids.dims[0])

huracanpy/assess/_match.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def match(tracksets, names=["1", "2"], max_dist=300, min_overlap=0):
1212
1313
Parameters
1414
----------
15-
tracksets : list
15+
tracksets : list[xarray.Dataset]
1616
list of track datasets to match together. Must be of length two or more.
1717
names : list, optional
1818
list of track datasets names. Must be the same size as tracksets. The default is ['1','2'].
@@ -23,7 +23,7 @@ def match(tracksets, names=["1", "2"], max_dist=300, min_overlap=0):
2323
2424
Returns
2525
-------
26-
pd.DataFrame
26+
pandas.DataFrame
2727
Dataframe containing the matching tracks with
2828
the id from corresponding datasets
2929
the number of matching time steps (if only two datasets provided)

huracanpy/assess/_overlap.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ def overlap(tracks1, tracks2, matches=None):
99
1010
Parameters
1111
----------
12-
tracks1 (pd.Dataframe)
13-
tracks2 (pd.Dataframe)
14-
matches (pd.Dataframe): The output from match_tracks on tracks1 and tracks2.
12+
tracks1 (xarray.Dataset)
13+
tracks2 (xarray.Dataset)
14+
matches (pandas.Dataframe): The output from match_tracks on tracks1 and tracks2.
1515
If None, match_tracks is run on tracks1 and tracks2.
1616
1717
Returns
@@ -20,7 +20,7 @@ def overlap(tracks1, tracks2, matches=None):
2020
Match dataset with added deltas in days
2121
"""
2222
if matches is None:
23-
matches = match(tracks1, tracks2)
23+
matches = match([tracks1, tracks2])
2424
c1, c2 = matches.columns[:2].str.slice(3)
2525
tracks1, tracks2 = tracks1.to_dataframe(), tracks2.to_dataframe()
2626
matches = (

huracanpy/calc/_lifecycle.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ def get_time_from_genesis(time, track_ids):
2929
)
3030
time_from_start = data_df.time_actual - data_df.time_gen
3131
return (
32-
time_from_start.to_xarray().rename({"index": "obs"}).rename("time_from_genesis")
32+
time_from_start.to_xarray()
33+
.rename({"index": track_ids.dims[0]})
34+
.drop(track_ids.dims[0])
35+
.rename("time_from_genesis")
3336
)
3437

3538

@@ -40,7 +43,7 @@ def get_time_from_apex(time, track_ids, intensity_var, stat="max"):
4043
Parameters
4144
----------
4245
time : array_like
43-
track_ids : array_like
46+
track_ids : xarray.DataArray
4447
intensity_var : array_like
4548
stat : str, optional
4649
Take either the maxima ("max") or minima ("min") of `intensity_var`. Default is
@@ -65,6 +68,7 @@ def get_time_from_apex(time, track_ids, intensity_var, stat="max"):
6568
time_from_extr = data_df.time_actual - data_df.time_extr
6669
return (
6770
time_from_extr.to_xarray()
68-
.rename({"index": time.dims[0]})
71+
.rename({"index": track_ids.dims[0]})
72+
.drop(track_ids.dims[0])
6973
.rename("time_from_extremum")
7074
)

huracanpy/calc/_track_stats.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,12 @@ def get_gen_vals(tracks, time, track_id):
6060
# Could check that track_id is 1d, but the function would already have failed by now
6161
# if not
6262
dim = track_id.dims[0]
63-
return tracks.isel(**{dim: idx})
63+
tracks = tracks.isel(**{dim: idx})
64+
65+
# Promote track_id to a coordinate and remove record
66+
return tracks.assign_coords(**{track_id.name: tracks[track_id.name]}).swap_dims(
67+
**{dim: track_id.name}
68+
)
6469

6570

6671
def get_apex_vals(tracks, variable, track_id, stat="max"):
@@ -109,4 +114,9 @@ def get_apex_vals(tracks, variable, track_id, stat="max"):
109114
idx = np.array(df.sort_values("var", ascending=asc).groupby("track_id").first().idx)
110115

111116
dim = track_id.dims[0]
112-
return tracks.isel(**{dim: idx})
117+
tracks = tracks.isel(**{dim: idx})
118+
119+
# Promote track_id to a coordinate and remove record
120+
return tracks.assign_coords(**{track_id.name: tracks[track_id.name]}).swap_dims(
121+
**{dim: track_id.name}
122+
)

tests/test_accel.py

+140
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""
2+
Test functions that use tricks to speed up their code produce the same result as the
3+
slower method
4+
"""
5+
6+
from haversine import haversine_vector, Unit
7+
import numpy as np
8+
import xarray as xr
9+
10+
import huracanpy
11+
12+
13+
def test_accel_sel_id(tracks_csv):
14+
result = huracanpy.sel_id(tracks_csv, tracks_csv.track_id, 0)
15+
16+
expected = tracks_csv.groupby("track_id")[0]
17+
18+
xr.testing.assert_identical(result, expected)
19+
20+
21+
def test_accel_trackswhere():
22+
# TODO accelerate trackswhere
23+
pass
24+
25+
26+
def test_accel_get_gen_vals(tracks_csv):
27+
result = huracanpy.calc.get_gen_vals(
28+
tracks_csv, tracks_csv.time, tracks_csv.track_id
29+
)
30+
31+
expected = tracks_csv.groupby("track_id").first()
32+
33+
xr.testing.assert_identical(result, expected)
34+
35+
36+
def test_accel_get_apex_vals(tracks_csv):
37+
result = huracanpy.calc.get_apex_vals(
38+
tracks_csv, tracks_csv.wind10, tracks_csv.track_id
39+
)
40+
41+
expected = tracks_csv.sortby("wind10", ascending=False).groupby("track_id").first()
42+
43+
xr.testing.assert_identical(result, expected)
44+
45+
46+
def test_accel_get_time_from_genesis(tracks_csv):
47+
result = huracanpy.calc.get_time_from_genesis(tracks_csv.time, tracks_csv.track_id)
48+
49+
track_groups = tracks_csv.groupby("track_id")
50+
expected = []
51+
for track_id, track in track_groups:
52+
expected.append(track.time - track.time[0])
53+
54+
expected = xr.concat(expected, dim="record")
55+
expected = expected.rename("time_from_genesis")
56+
57+
xr.testing.assert_identical(result, expected)
58+
59+
60+
def test_accel_get_time_from_apex(tracks_csv):
61+
result = huracanpy.calc.get_time_from_apex(
62+
tracks_csv.time, tracks_csv.track_id, tracks_csv.wind10
63+
)
64+
65+
track_groups = tracks_csv.groupby("track_id")
66+
expected = []
67+
for track_id, track in track_groups:
68+
idx = track.wind10.argmax()
69+
expected.append(track.time - track.time[idx])
70+
71+
expected = xr.concat(expected, dim="record")
72+
expected = expected.rename("time_from_extremum")
73+
74+
xr.testing.assert_identical(result, expected)
75+
76+
77+
def test_accel_match():
78+
ref = huracanpy.load(huracanpy.example_csv_file)
79+
tracks = ref.where(ref.track_id < 2, drop=True)
80+
tracks = tracks.where(tracks.time.dt.hour == 0, drop=True)
81+
tracks["lon"] = tracks.lon + 0.5
82+
tracks["lat"] = tracks.lat + 0.5
83+
84+
result = huracanpy.assess.match([tracks, ref])
85+
86+
max_dist = 300
87+
track_id1 = []
88+
track_id2 = []
89+
npoints = []
90+
dist = []
91+
92+
for track_id, track in tracks.groupby("track_id"):
93+
for track_id_ref, track_ref in ref.groupby("track_id"):
94+
# Match times
95+
track_ = track.where(track.time.isin(track_ref.time), drop=True)
96+
97+
if len(track_.time) > 0:
98+
track_ref_ = track_ref.where(track_ref.time.isin(track.time), drop=True)
99+
100+
yx_track = np.array([track_.lat, track_.lon]).T
101+
yx_ref = np.array([track_ref_.lat, track_ref_.lon]).T
102+
103+
dists = haversine_vector(yx_track, yx_ref, Unit.KILOMETERS)
104+
105+
matches = dists < max_dist
106+
if matches.any():
107+
track_id1.append(track_id)
108+
track_id2.append(track_id_ref)
109+
110+
dists_track = dists[matches]
111+
npoints.append(len(dists_track))
112+
dist.append(np.mean(dists_track))
113+
114+
np.testing.assert_equal(result.id_1, np.array(track_id1))
115+
np.testing.assert_equal(result.id_2, np.array(track_id2))
116+
np.testing.assert_equal(result.temp, np.array(npoints))
117+
np.testing.assert_allclose(result.dist, np.array(dist), rtol=1e-12)
118+
119+
120+
def test_accel_overlap():
121+
ref = huracanpy.load(huracanpy.example_csv_file)
122+
tracks = ref.where(ref.track_id < 2, drop=True)
123+
tracks = tracks.where(tracks.time.dt.hour == 0, drop=True)
124+
tracks["lon"] = tracks.lon + 0.5
125+
tracks["lat"] = tracks.lat + 0.5
126+
127+
result = huracanpy.assess.overlap(tracks, ref)
128+
129+
delta_start = []
130+
delta_end = []
131+
132+
for n, row in result.iterrows():
133+
track = tracks.where(tracks.track_id == row.id_1, drop=True)
134+
track_ref = ref.where(ref.track_id == row.id_2, drop=True)
135+
136+
delta_start.append((track_ref.time[0] - track.time[0]) / np.timedelta64(1, "D"))
137+
delta_end.append((track_ref.time[-1] - track.time[-1]) / np.timedelta64(1, "D"))
138+
139+
np.testing.assert_equal(result.delta_start, np.array(delta_start))
140+
np.testing.assert_equal(result.delta_end, np.array(delta_end))

tests/test_subset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def test_trackswhere():
77
tracks["category"] = huracanpy.tc.get_pressure_cat(tracks.slp, slp_units="Pa")
88

99
tracks_subset = huracanpy.trackswhere(
10-
tracks, lambda track: track.category.max() >= 2
10+
tracks, tracks.track_id, lambda track: track.category.max() >= 2
1111
)
1212

1313
assert set(tracks_subset.track_id.data) == {0, 2}

0 commit comments

Comments
 (0)