1
1
import pytest
2
2
3
- import huracanpy
4
-
5
3
import numpy as np
4
+ import xarray as xr
5
+
6
+ import huracanpy
6
7
7
8
8
9
# %% DataArrayAccessor
@@ -77,6 +78,13 @@ def test_nunique():
77
78
),
78
79
(huracanpy .diags .get_freq , ["track_id" ], "freq" , {}),
79
80
(huracanpy .diags .get_tc_days , ["time" , "track_id" ], "tc_days" , {}),
81
+ # (huracanpy.diags.get_gen_vals, ["all", "time", "track_id"], "gen_vals", {}),
82
+ # (
83
+ # huracanpy.diags.get_apex_vals,
84
+ # ["all", "wind10", "track_id"],
85
+ # "apex_vals",
86
+ # {"varname": "wind10"},
87
+ # ),
80
88
],
81
89
)
82
90
def test_accessor_methods_match_functions (
@@ -93,13 +101,20 @@ def test_accessor_methods_match_functions(
93
101
"track_duration" ,
94
102
"freq" ,
95
103
"tc_days" ,
104
+ "gen_vals" ,
105
+ "apex_vals" ,
96
106
]:
97
107
pytest .skip (f"Accessor function add_{ accessor_name } does not exist" )
98
108
elif accessor_name in ["ace" ] and "sum_by" in accessor_function_kwargs :
99
109
pytest .skip (f"sum_by not a valid argument for add_{ accessor_name } " )
100
110
101
111
# Call the huracanpy function
102
- result = function (* [tracks_csv [var ] for var in function_args ])
112
+ # Get the function arguments as arrays. Use "all" as a wildcard for the full dataset
113
+ function_args = [
114
+ tracks_csv [var ] if not var == "all" else tracks_csv for var in function_args
115
+ ]
116
+ result = function (* function_args )
117
+
103
118
# Call the accessor method
104
119
result_accessor = getattr (tracks_csv .hrcn , f"{ call_type } _{ accessor_name } " )(
105
120
** accessor_function_kwargs
@@ -127,14 +142,18 @@ def test_accessor_methods_match_functions(
127
142
128
143
129
144
# %% DatasetAccessor
145
+ # Currently keeping tests here that return more than just a DataArray as the testing is
146
+ # less generic
130
147
def test_get_methods (tracks_csv ):
131
148
"""Test get_ accessors output is same as function"""
132
149
data = tracks_csv
133
150
134
151
## - pace
135
- pace_acc = data .hrcn .get_pace (pressure_name = "slp" , wind_name = "wind10" )
152
+ pace_acc , _ = data .hrcn .get_pace (pressure_name = "slp" , wind_name = "wind10" )
136
153
pace_fct , model_fct = huracanpy .tc .pace (data .slp , data .wind10 )
137
- assert not any (pace_acc != pace_fct ), "accessor output differs from function output"
154
+ np .testing .assert_array_equal (
155
+ pace_acc , pace_fct , err_msg = "accessor output differs from function output"
156
+ )
138
157
139
158
## - time components
140
159
year_acc , month_acc , day_acc , hour_acc = data .hrcn .get_time_components (
@@ -143,10 +162,18 @@ def test_get_methods(tracks_csv):
143
162
year_fct , month_fct , day_fct , hour_fct = huracanpy .utils .get_time_components (
144
163
data .time
145
164
)
146
- assert all (year_acc == year_fct ), "Year component does not match"
147
- assert all (month_acc == month_fct ), "Month component does not match"
148
- assert all (day_acc == day_fct ), "Day component does not match"
149
- assert all (hour_acc == hour_fct ), "Hour component does not match"
165
+ np .testing .assert_array_equal (
166
+ year_acc , year_fct , err_msg = "Year component does not match"
167
+ )
168
+ np .testing .assert_array_equal (
169
+ month_acc , month_fct , err_msg = "Month component does not match"
170
+ )
171
+ np .testing .assert_array_equal (
172
+ day_acc , day_fct , err_msg = "Day component does not match"
173
+ )
174
+ np .testing .assert_array_equal (
175
+ hour_acc , hour_fct , err_msg = "Hour component does not match"
176
+ )
150
177
151
178
## - track pace
152
179
pace_acc , _ = data .hrcn .get_pace (wind_name = "wind10" , sum_by = "track_id" )
@@ -162,21 +189,17 @@ def test_get_methods(tracks_csv):
162
189
time_name = "time" ,
163
190
track_id_name = "track_id" ,
164
191
)
165
- gen_vals_fct = huracanpy .diags .get_gen_vals (
166
- data ,
167
- )
168
- assert gen_vals_acc .equals (
169
- gen_vals_fct
170
- ), "Genesis Values accessor output differs from function output"
192
+ gen_vals_fct = huracanpy .diags .get_gen_vals (data , data .time , data .track_id )
193
+ xr .testing .assert_equal (gen_vals_acc , gen_vals_fct )
171
194
172
195
## - Apex Values
173
196
apex_vals_acc = data .hrcn .get_apex_vals (
174
197
track_id_name = "track_id" , varname = "wind10" , stat = "max"
175
198
)
176
- apex_vals_fct = huracanpy .diags .get_apex_vals (data , varname = "wind10" , stat = "max" )
177
- assert apex_vals_acc . equals (
178
- apex_vals_fct
179
- ), "Genesis Values accessor output differs from function output"
199
+ apex_vals_fct = huracanpy .diags .get_apex_vals (
200
+ data , data . wind10 , data . track_id , stat = "max"
201
+ )
202
+ xr . testing . assert_equal ( apex_vals_acc , apex_vals_fct )
180
203
181
204
182
205
def test_interp_methods ():
0 commit comments