7
7
import numpy as np
8
8
import xarray as xr
9
9
from metpy .units import units
10
+ from metpy .xarray import preprocess_and_wrap
10
11
12
+ from .._metpy import dequantify_results
11
13
12
- def get_delta (var , track_ids = None , var_units = None , centering = "forward" ):
14
+
15
+ @dequantify_results
16
+ @preprocess_and_wrap (wrap_like = "var" )
17
+ def get_delta (var , track_ids = None , centering = "forward" ):
13
18
"""Take the differences across var, without including differences between the end
14
19
and start of different tracks
15
20
@@ -30,46 +35,41 @@ def get_delta(var, track_ids=None, var_units=None, centering="forward"):
30
35
# Curate input
31
36
# If track_id is not provided, all points are considered to belong to the same track
32
37
if track_ids is None :
33
- track_ids = xr . DataArray ([ 0 ] * len ( var ), dims = var . dims )
38
+ track_ids = np . zeros ( var . shape )
34
39
warnings .warn (
35
40
"track_id is not provided, all points are considered to come from the same"
36
41
"track"
37
42
)
38
- ## If time is provided, convert to numeric ns
39
- if var .dtype == "<M8[ns]" :
40
- var = var .astype (float )
41
- var_units = "ns"
42
- ## Check that centering is supported
43
+
44
+ # Check that centering is supported
43
45
if centering not in ["forward" , "backward" ]:
44
46
raise ValueError ("centering must be one of ['forward', 'backward']" )
45
47
46
48
# Compute delta
47
49
delta = var [1 :] - var [:- 1 ]
48
50
49
51
# Mask points where track_id changes
50
- tid_switch = track_ids [1 :] == track_ids [:- 1 ]
51
- delta = delta .where (tid_switch )
52
+ # Multiplying np.nan by an array element gives us the correct type of nan for both
53
+ # np.timedelta and pint.Quantity
54
+ delta [track_ids [1 :] != track_ids [:- 1 ]] = np .nan * delta [0 ]
52
55
53
56
# Apply centering
54
57
if centering == "forward" :
55
- delta = xr . concat ([delta , xr . DataArray ( [np .nan ], dims = "record" )], dim = "record" )
58
+ delta = np . concatenate ([delta , [np .nan * delta [ 0 ]]] )
56
59
elif centering == "backward" :
57
- delta = xr .concat (
58
- [
59
- xr .DataArray ([np .nan ], dims = "record" ),
60
- delta ,
61
- ],
62
- dim = "record" ,
63
- )
60
+ delta = np .concatenate ([[np .nan * delta [0 ]], delta ])
64
61
65
- # return with units # TODO: If var has units, retrieve those
66
- if var_units is None :
67
- return xr .DataArray (delta , dims = var .dims )
68
- else :
69
- return xr .DataArray (delta , dims = var .dims ) * units (var_units )
62
+ # Fix for timedeltas
63
+ if np .issubdtype (delta .magnitude .dtype , np .timedelta64 ):
64
+ delta = delta / np .timedelta64 (1 , "s" )
65
+ delta = delta .magnitude * units ("s" )
70
66
67
+ return delta
71
68
72
- def get_rate (var , time , track_ids = None , var_units = None , centering = "forward" ):
69
+
70
+ @dequantify_results
71
+ @preprocess_and_wrap (wrap_like = "var" )
72
+ def get_rate (var , time , track_ids = None , centering = "forward" ):
73
73
"""Compute rate of change of var, without including differences between the end
74
74
and start of different tracks
75
75
@@ -78,7 +78,6 @@ def get_rate(var, time, track_ids=None, var_units=None, centering="forward"):
78
78
var : xarray.DataArray
79
79
time : xarray.DataArray
80
80
track_ids : array_like, optional
81
- var_units : str, optional
82
81
centering : str, optional
83
82
84
83
Returns
@@ -101,8 +100,7 @@ def get_rate(var, time, track_ids=None, var_units=None, centering="forward"):
101
100
# TODO: If var has units, retrieve those
102
101
103
102
# Compute deltas
104
- dx = get_delta (var , track_ids , var_units = var_units , centering = centering )
103
+ dx = get_delta (var , track_ids , centering = centering )
105
104
dt = get_delta (time , track_ids , centering = centering )
106
- dt = dt .metpy .convert_units ("s" ) # Convert to seconds
107
105
108
106
return dx / dt
0 commit comments