3
3
slower method
4
4
"""
5
5
6
+ import pytest
6
7
from haversine import haversine_vector , Unit
7
8
import numpy as np
8
9
import xarray as xr
9
10
10
11
import huracanpy
11
12
12
13
13
- def test_accel_sel_id (tracks_csv ):
14
- result = huracanpy .sel_id (tracks_csv , tracks_csv .track_id , 0 )
14
+ @pytest .mark .parametrize (
15
+ ("tracks" ,),
16
+ (["tracks_csv" ], ["tracks_with_extra_coord" ]),
17
+ )
18
+ def test_accel_sel_id (tracks , request ):
19
+ tracks = request .getfixturevalue (tracks )
20
+ result = huracanpy .sel_id (tracks , tracks .track_id , 0 )
15
21
16
- expected = tracks_csv .groupby ("track_id" )[0 ]
22
+ expected = tracks .groupby ("track_id" )[0 ]
17
23
18
24
xr .testing .assert_identical (result , expected )
19
25
@@ -23,30 +29,41 @@ def test_accel_trackswhere():
23
29
pass
24
30
25
31
26
- def test_accel_get_gen_vals (tracks_csv ):
27
- result = huracanpy .calc .get_gen_vals (
28
- tracks_csv , tracks_csv .time , tracks_csv .track_id
29
- )
32
+ @pytest .mark .parametrize (
33
+ ("tracks" ,),
34
+ (["tracks_csv" ], ["tracks_with_extra_coord" ]),
35
+ )
36
+ def test_accel_get_gen_vals (tracks , request ):
37
+ tracks = request .getfixturevalue (tracks )
38
+ result = huracanpy .calc .get_gen_vals (tracks , tracks .time , tracks .track_id )
30
39
31
- expected = tracks_csv .groupby ("track_id" ).first ()
40
+ expected = tracks .groupby ("track_id" ).first ()
32
41
33
42
xr .testing .assert_identical (result , expected )
34
43
35
44
36
- def test_accel_get_apex_vals (tracks_csv ):
37
- result = huracanpy .calc .get_apex_vals (
38
- tracks_csv , tracks_csv .wind10 , tracks_csv .track_id
39
- )
45
+ @pytest .mark .parametrize (
46
+ ("tracks" ,),
47
+ (["tracks_csv" ], ["tracks_with_extra_coord" ]),
48
+ )
49
+ def test_accel_get_apex_vals (tracks , request ):
50
+ tracks = request .getfixturevalue (tracks )
51
+ result = huracanpy .calc .get_apex_vals (tracks , tracks .wind10 , tracks .track_id )
40
52
41
- expected = tracks_csv .sortby ("wind10" , ascending = False ).groupby ("track_id" ).first ()
53
+ expected = tracks .sortby ("wind10" , ascending = False ).groupby ("track_id" ).first ()
42
54
43
55
xr .testing .assert_identical (result , expected )
44
56
45
57
46
- def test_accel_get_time_from_genesis (tracks_csv ):
47
- result = huracanpy .calc .get_time_from_genesis (tracks_csv .time , tracks_csv .track_id )
58
+ @pytest .mark .parametrize (
59
+ ("tracks" ,),
60
+ (["tracks_csv" ], ["tracks_with_extra_coord" ]),
61
+ )
62
+ def test_accel_get_time_from_genesis (tracks , request ):
63
+ tracks = request .getfixturevalue (tracks )
64
+ result = huracanpy .calc .get_time_from_genesis (tracks .time , tracks .track_id )
48
65
49
- track_groups = tracks_csv .groupby ("track_id" )
66
+ track_groups = tracks .groupby ("track_id" )
50
67
expected = []
51
68
for track_id , track in track_groups :
52
69
expected .append (track .time - track .time [0 ])
@@ -57,12 +74,17 @@ def test_accel_get_time_from_genesis(tracks_csv):
57
74
xr .testing .assert_identical (result , expected )
58
75
59
76
60
- def test_accel_get_time_from_apex (tracks_csv ):
77
+ @pytest .mark .parametrize (
78
+ ("tracks" ,),
79
+ (["tracks_csv" ], ["tracks_with_extra_coord" ]),
80
+ )
81
+ def test_accel_get_time_from_apex (tracks , request ):
82
+ tracks = request .getfixturevalue (tracks )
61
83
result = huracanpy .calc .get_time_from_apex (
62
- tracks_csv .time , tracks_csv .track_id , tracks_csv .wind10
84
+ tracks .time , tracks .track_id , tracks .wind10
63
85
)
64
86
65
- track_groups = tracks_csv .groupby ("track_id" )
87
+ track_groups = tracks .groupby ("track_id" )
66
88
expected = []
67
89
for track_id , track in track_groups :
68
90
idx = track .wind10 .argmax ()
@@ -74,8 +96,12 @@ def test_accel_get_time_from_apex(tracks_csv):
74
96
xr .testing .assert_identical (result , expected )
75
97
76
98
77
- def test_accel_match ():
78
- ref = huracanpy .load (huracanpy .example_csv_file )
99
+ @pytest .mark .parametrize (
100
+ ("tracks" ,),
101
+ (["tracks_csv" ], ["tracks_with_extra_coord" ]),
102
+ )
103
+ def test_accel_match (tracks , request ):
104
+ ref = request .getfixturevalue (tracks )
79
105
tracks = ref .where (ref .track_id < 2 , drop = True )
80
106
tracks = tracks .where (tracks .time .dt .hour == 0 , drop = True )
81
107
tracks ["lon" ] = tracks .lon + 0.5
@@ -117,8 +143,12 @@ def test_accel_match():
117
143
np .testing .assert_allclose (result .dist , np .array (dist ), rtol = 1e-12 )
118
144
119
145
120
- def test_accel_overlap ():
121
- ref = huracanpy .load (huracanpy .example_csv_file )
146
+ @pytest .mark .parametrize (
147
+ ("tracks" ,),
148
+ (["tracks_csv" ], ["tracks_with_extra_coord" ]),
149
+ )
150
+ def test_accel_overlap (tracks , request ):
151
+ ref = request .getfixturevalue (tracks )
122
152
tracks = ref .where (ref .track_id < 2 , drop = True )
123
153
tracks = tracks .where (tracks .time .dt .hour == 0 , drop = True )
124
154
tracks ["lon" ] = tracks .lon + 0.5
0 commit comments