Skip to content

Commit e77d5d0

Browse files
committed
Make gen_vals and apex_vals call style consistent with other huracanpy functions. Pass arrays when called as a function, pass variable names when called as a method
1 parent 8686488 commit e77d5d0

File tree

4 files changed

+92
-44
lines changed

4 files changed

+92
-44
lines changed

huracanpy/_accessor.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def get_pace(
170170
wind_units=wind_units,
171171
**kwargs,
172172
)
173-
return pace_values
173+
return pace_values, model
174174

175175
def add_pace(
176176
self,
@@ -556,11 +556,16 @@ def get_track_duration(self, time_name="time", track_id_name="track_id"):
556556
)
557557

558558
def get_gen_vals(self, time_name="time", track_id_name="track_id"):
559-
return diags.get_gen_vals(self._dataset, time_name, track_id_name)
559+
return diags.get_gen_vals(
560+
self._dataset, self._dataset[time_name], self._dataset[track_id_name]
561+
)
560562

561563
def get_apex_vals(self, varname, track_id_name="track_id", stat="max"):
562564
return diags.get_apex_vals(
563-
self._dataset, varname, track_id_name=track_id_name, stat=stat
565+
self._dataset,
566+
variable=self._dataset[varname],
567+
track_id=self._dataset[track_id_name],
568+
stat=stat,
564569
)
565570

566571
# ---- climato

huracanpy/diags/_track_stats.py

+37-17
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
Module containing functions to compute track statistics
33
"""
44

5+
import numpy as np
6+
import pandas as pd
7+
58

69
def get_track_duration(time, track_ids):
710
"""
@@ -25,39 +28,50 @@ def get_track_duration(time, track_ids):
2528
return duration
2629

2730

28-
def get_gen_vals(tracks, time_name="time", track_id_name="track_id"):
31+
def get_gen_vals(tracks, time, track_id):
2932
"""
3033
Shows the attributes for the genesis point of each track
3134
3235
Parameters
3336
----------
3437
tracks : xarray.DataSet
38+
time : array_like
39+
track_id : xarray.Dataset
3540
3641
Returns
3742
-------
3843
xarray.Dataset
3944
Dataset containing only genesis points, with track_id as index.
4045
4146
"""
47+
# It is 470 times much faster to switch to a dataframe...
48+
# Use the sortby/groupby with pandas to find the relevant indices in the original
49+
# Dataset by passing an index (named idx to not clash with "index")
50+
df = pd.DataFrame(
51+
data=dict(
52+
idx=np.arange(len(track_id)),
53+
time=np.array(time),
54+
track_id=np.array(track_id),
55+
)
56+
)
57+
idx = np.array(df.sort_values("time").groupby("track_id").first().idx)
4258

43-
return (
44-
tracks.to_dataframe()
45-
.sort_values(time_name)
46-
.groupby(track_id_name)
47-
.first()
48-
.to_xarray()
49-
) # It is 470 times much faster to switch to a dataframe...
59+
# Could check that track_id is 1d, but the function would already have failed by now
60+
# if not
61+
dim = track_id.dims[0]
62+
return tracks.isel(**{dim: idx})
5063

5164

52-
def get_apex_vals(tracks, varname, stat="max", track_id_name="track_id"):
65+
def get_apex_vals(tracks, variable, track_id, stat="max"):
5366
"""
5467
Shows the attribute for the extremum point of each track
5568
5669
Parameters
5770
----------
5871
tracks : xarray.DataSet
59-
var : str
72+
variable : array_like
6073
The extremum variable
74+
track_id : xarray.DataArray
6175
stat : str, optional
6276
Type of extremum. Can be "min" or "max". The default is "max".
6377
@@ -82,10 +96,16 @@ def get_apex_vals(tracks, varname, stat="max", track_id_name="track_id"):
8296
else:
8397
raise NotImplementedError("stat not recognized. Please use one of {min, max}")
8498

85-
return (
86-
tracks.to_dataframe()
87-
.sort_values(varname, ascending=asc)
88-
.groupby(track_id_name)
89-
.first()
90-
.to_xarray()
91-
) # It is 350 times much faster to switch to a dataframe..
99+
# It is 350 times much faster to switch to a dataframe.
100+
# Use the same trick as with gen_vals
101+
df = pd.DataFrame(
102+
data=dict(
103+
idx=np.arange(len(variable)),
104+
var=np.array(variable),
105+
track_id=np.array(track_id),
106+
)
107+
)
108+
idx = np.array(df.sort_values("var", ascending=asc).groupby("track_id").first().idx)
109+
110+
dim = track_id.dims[0]
111+
return tracks.isel(**{dim: idx})

tests/test_accessor.py

+42-19
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import pytest
22

3-
import huracanpy
4-
53
import numpy as np
4+
import xarray as xr
5+
6+
import huracanpy
67

78

89
# %% DataArrayAccessor
@@ -77,6 +78,13 @@ def test_nunique():
7778
),
7879
(huracanpy.diags.get_freq, ["track_id"], "freq", {}),
7980
(huracanpy.diags.get_tc_days, ["time", "track_id"], "tc_days", {}),
81+
# (huracanpy.diags.get_gen_vals, ["all", "time", "track_id"], "gen_vals", {}),
82+
# (
83+
# huracanpy.diags.get_apex_vals,
84+
# ["all", "wind10", "track_id"],
85+
# "apex_vals",
86+
# {"varname": "wind10"},
87+
# ),
8088
],
8189
)
8290
def test_accessor_methods_match_functions(
@@ -93,13 +101,20 @@ def test_accessor_methods_match_functions(
93101
"track_duration",
94102
"freq",
95103
"tc_days",
104+
"gen_vals",
105+
"apex_vals",
96106
]:
97107
pytest.skip(f"Accessor function add_{accessor_name} does not exist")
98108
elif accessor_name in ["ace"] and "sum_by" in accessor_function_kwargs:
99109
pytest.skip(f"sum_by not a valid argument for add_{accessor_name}")
100110

101111
# Call the huracanpy function
102-
result = function(*[tracks_csv[var] for var in function_args])
112+
# Get the function arguments as arrays. Use "all" as a wildcard for the full dataset
113+
function_args = [
114+
tracks_csv[var] if not var == "all" else tracks_csv for var in function_args
115+
]
116+
result = function(*function_args)
117+
103118
# Call the accessor method
104119
result_accessor = getattr(tracks_csv.hrcn, f"{call_type}_{accessor_name}")(
105120
**accessor_function_kwargs
@@ -127,14 +142,18 @@ def test_accessor_methods_match_functions(
127142

128143

129144
# %% DatasetAccessor
145+
# Currently keeping tests here that return more than just a DataArray as the testing is
146+
# less generic
130147
def test_get_methods(tracks_csv):
131148
"""Test get_ accessors output is same as function"""
132149
data = tracks_csv
133150

134151
## - pace
135-
pace_acc = data.hrcn.get_pace(pressure_name="slp", wind_name="wind10")
152+
pace_acc, _ = data.hrcn.get_pace(pressure_name="slp", wind_name="wind10")
136153
pace_fct, model_fct = huracanpy.tc.pace(data.slp, data.wind10)
137-
assert not any(pace_acc != pace_fct), "accessor output differs from function output"
154+
np.testing.assert_array_equal(
155+
pace_acc, pace_fct, err_msg="accessor output differs from function output"
156+
)
138157

139158
## - time components
140159
year_acc, month_acc, day_acc, hour_acc = data.hrcn.get_time_components(
@@ -143,10 +162,18 @@ def test_get_methods(tracks_csv):
143162
year_fct, month_fct, day_fct, hour_fct = huracanpy.utils.get_time_components(
144163
data.time
145164
)
146-
assert all(year_acc == year_fct), "Year component does not match"
147-
assert all(month_acc == month_fct), "Month component does not match"
148-
assert all(day_acc == day_fct), "Day component does not match"
149-
assert all(hour_acc == hour_fct), "Hour component does not match"
165+
np.testing.assert_array_equal(
166+
year_acc, year_fct, err_msg="Year component does not match"
167+
)
168+
np.testing.assert_array_equal(
169+
month_acc, month_fct, err_msg="Month component does not match"
170+
)
171+
np.testing.assert_array_equal(
172+
day_acc, day_fct, err_msg="Day component does not match"
173+
)
174+
np.testing.assert_array_equal(
175+
hour_acc, hour_fct, err_msg="Hour component does not match"
176+
)
150177

151178
## - track pace
152179
pace_acc, _ = data.hrcn.get_pace(wind_name="wind10", sum_by="track_id")
@@ -162,21 +189,17 @@ def test_get_methods(tracks_csv):
162189
time_name="time",
163190
track_id_name="track_id",
164191
)
165-
gen_vals_fct = huracanpy.diags.get_gen_vals(
166-
data,
167-
)
168-
assert gen_vals_acc.equals(
169-
gen_vals_fct
170-
), "Genesis Values accessor output differs from function output"
192+
gen_vals_fct = huracanpy.diags.get_gen_vals(data, data.time, data.track_id)
193+
xr.testing.assert_equal(gen_vals_acc, gen_vals_fct)
171194

172195
## - Apex Values
173196
apex_vals_acc = data.hrcn.get_apex_vals(
174197
track_id_name="track_id", varname="wind10", stat="max"
175198
)
176-
apex_vals_fct = huracanpy.diags.get_apex_vals(data, varname="wind10", stat="max")
177-
assert apex_vals_acc.equals(
178-
apex_vals_fct
179-
), "Genesis Values accessor output differs from function output"
199+
apex_vals_fct = huracanpy.diags.get_apex_vals(
200+
data, data.wind10, data.track_id, stat="max"
201+
)
202+
xr.testing.assert_equal(apex_vals_acc, apex_vals_fct)
180203

181204

182205
def test_interp_methods():

tests/test_diags/test_track_stats.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ def test_duration():
1111

1212
def test_gen_vals():
1313
data = huracanpy.load(huracanpy.example_csv_file, source="csv")
14-
G = huracanpy.diags.get_gen_vals(data)
14+
G = huracanpy.diags.get_gen_vals(data, data.time, data.track_id)
1515
assert G.time.dt.day.mean() == 10
1616

1717

18-
def test_extremum_vals():
19-
data = huracanpy.load(huracanpy.example_csv_file, source="csv")
20-
M = huracanpy.diags.get_apex_vals(data, "wind10", "max")
21-
m = huracanpy.diags.get_apex_vals(data, "slp", "min")
18+
def test_apex_vals():
19+
data = huracanpy.load(huracanpy.example_csv_file)
20+
M = huracanpy.diags.get_apex_vals(data, data.wind10, data.track_id, "max")
21+
m = huracanpy.diags.get_apex_vals(data, data.slp, data.track_id, "min")
2222
assert M.time.dt.day.mean() == 15
2323
assert m.lat.mean() == -27

0 commit comments

Comments
 (0)