1
+ import numpy as np
2
+ import xarray as xr
3
+
4
+ from util import *
5
+
6
+ metadata_attrs = {
7
+ 'u' : {
8
+ 'units' : 'm/s'
9
+ },
10
+ 'w' : {
11
+ 'units' : 'm/s'
12
+ },
13
+ 'theta_p' : {
14
+ 'units' : 'K'
15
+ },
16
+ 'pi' : {
17
+ 'units' : 'dimensionless'
18
+ },
19
+ 'x' : {
20
+ 'units' : 'm'
21
+ },
22
+ 'x_stag' : {
23
+ 'units' : 'm'
24
+ },
25
+ 'z' : {
26
+ 'units' : 'm'
27
+ },
28
+ 'z_stag' : {
29
+ 'units' : 'm'
30
+ },
31
+ 't' : {
32
+ 'units' : 's'
33
+ }
34
+ }
35
+
36
+ class ModelDriver :
37
+
38
+ coords = {}
39
+ prognostic_arrays = {}
40
+ base_state_arrays = {}
41
+ diagnostic_arrays = {}
42
+ params = {}
43
+
44
+ def __init__ (self , nx , nz , dx , dz , dt , ** kwargs ):
45
+ # Set parameters
46
+ self .nx = nx
47
+ self .nz = nz
48
+ self .dx = dx
49
+ self .dz = dz
50
+ self .dt = dt
51
+ for k , v in kwargs .items ():
52
+ if k .endswith ('_tendency' ):
53
+ setattr (self , k , v )
54
+ else :
55
+ self .params [k ] = v
56
+ self .dtype = dtype = getattr (self , 'dtype' , np .float32 )
57
+ self .t_count = 0
58
+
59
+ # Define arrays
60
+ self .coords ['x' ] = np .arange (self .nx ) * self .dx - self .nx * self .dx / 2
61
+ self .coords ['x_stag' ] = np .arange (self .nx + 1 ) * self .dx - (self .nx + 1 ) * self .dx / 2
62
+ self .coords ['z' ] = np .arange (self .nz ) * self .dz
63
+ self .coords ['z_stag' ] = np .arange (self .nz + 1 ) * self .dz - self .dz / 2
64
+ self .prognostic_arrays ['u' ] = np .zeros ((3 , nz , nx + 1 ), dtype = dtype )
65
+ self .prognostic_arrays ['w' ] = np .zeros ((3 , nz + 1 , nx ), dtype = dtype )
66
+ self .prognostic_arrays ['theta_p' ] = np .zeros ((3 , nz , nx ), dtype = dtype )
67
+ self .prognostic_arrays ['pi' ] = np .zeros ((3 , nz , nx ), dtype = dtype )
68
+ self .active_prognostic_variables = ['u' , 'w' , 'theta_p' , 'pi' ]
69
+ self .base_state_arrays ['theta_base' ] = np .zeros (nz , dtype = dtype )
70
+ self .base_state_arrays ['PI_base' ] = np .zeros (nz , dtype = dtype )
71
+ self .base_state_arrays ['rho_base' ] = np .zeros (nz , dtype = dtype )
72
+ ## Todo do we need others??
73
+
74
+ def initialize_isentropic_base_state (self , theta , pressure_surface ):
75
+ # Set uniform potential temperature
76
+ self .base_state_arrays ['theta_base' ] = np .full (
77
+ self .base_state_arrays ['theta_base' ].shape , theta , dtype = self .dtype
78
+ )
79
+ # Calculate pi based on hydrostatic balance (from surface)
80
+ self .base_state_arrays ['PI_base' ] = nondimensional_pressure_hydrostatic (
81
+ self .base_state_arrays ['theta_base' ],
82
+ self .coords ['z' ],
83
+ pressure_surface
84
+ )
85
+ # Calculate density from theta and pi
86
+ self .base_state_arrays ['rho_base' ] = density_from_ideal_gas_law (
87
+ self .base_state_arrays ['theta_base' ],
88
+ self .base_state_arrays ['PI_base' ]
89
+ )
90
+
91
+ def initialize_warm_bubble (self , amplitude , x_radius , z_radius , z_center ):
92
+ if np .min (self .base_state_arrays ['theta_base' ]) <= 0. :
93
+ raise ValueError ("Base state theta must be initialized as positive definite" )
94
+
95
+ # Create thermal bubble (2D)
96
+ theta_p , pi = create_thermal_bubble (
97
+ amplitude , self .coords ['x' ], self .coords ['z' ], x_radius , z_radius , 0.0 , z_center ,
98
+ self .base_state_arrays ['theta_base' ]
99
+ )
100
+ # Ensure boundary conditions, and add time stacking (future, current, past)
101
+ self .prognostic_arrays ['theta_p' ] = np .stack ([apply_periodic_lateral_zerograd_vertical (theta_p )] * 3 )
102
+ self .prognostic_arrays ['pi' ] = np .stack ([apply_periodic_lateral_zerograd_vertical (pi )] * 3 )
103
+
104
+ def prep_new_timestep (self ):
105
+ for var in self .active_prognostic_variables :
106
+ # Future-current to current-past
107
+ self .prognostic_arrays [var ][0 :2 ] = self .prognostic_arrays [var ][1 :3 ]
108
+
109
+ def take_first_timestep (self ):
110
+ # check for needed parameters and methods
111
+ if not 'c_s_sqr' in self .params :
112
+ raise ValueError ("Must set squared speed of sound prior to first timestep" )
113
+ if not (
114
+ getattr (self , 'u_tendency' )
115
+ and getattr (self , 'w_tendency' )
116
+ and getattr (self , 'theta_p_tendency' )
117
+ and getattr (self , 'pi_tendency' )
118
+ ):
119
+ raise ValueError ("Must set tendency equations prior to first timestep" )
120
+
121
+ # Increment
122
+ self .t_count = 1
123
+
124
+ # Integrate forward-in-time
125
+ self .prognostic_arrays ['u' ][2 ] = (
126
+ self .prognostic_arrays ['u' ][1 ]
127
+ + self .dt * apply_periodic_lateral_zerograd_vertical (self .u_tendency (
128
+ self .prognostic_arrays ['u' ][1 ], self .prognostic_arrays ['w' ][1 ],
129
+ self .prognostic_arrays ['pi' ][1 ], self .base_state_arrays ['theta_base' ], self .dx , self .dz
130
+ ))
131
+ )
132
+ self .prognostic_arrays ['w' ][2 ] = (
133
+ self .prognostic_arrays ['w' ][1 ]
134
+ + self .dt * apply_periodic_lateral_zerow_vertical (self .w_tendency (
135
+ self .prognostic_arrays ['u' ][1 ], self .prognostic_arrays ['w' ][1 ],
136
+ self .prognostic_arrays ['pi' ][1 ], self .prognostic_arrays ['theta_p' ][1 ],
137
+ self .base_state_arrays ['theta_base' ], self .dx , self .dz
138
+ ))
139
+ )
140
+ self .prognostic_arrays ['theta_p' ][2 ] = (
141
+ self .prognostic_arrays ['theta_p' ][1 ]
142
+ + self .dt * apply_periodic_lateral_zerograd_vertical (self .theta_p_tendency (
143
+ self .prognostic_arrays ['u' ][1 ], self .prognostic_arrays ['w' ][1 ],
144
+ self .prognostic_arrays ['theta_p' ][1 ], self .base_state_arrays ['theta_base' ], self .dx , self .dz
145
+ ))
146
+ )
147
+ self .prognostic_arrays ['pi' ][2 ] = (
148
+ self .prognostic_arrays ['pi' ][1 ]
149
+ + self .dt * apply_periodic_lateral_zerograd_vertical (self .pi_tendency (
150
+ self .prognostic_arrays ['u' ][1 ], self .prognostic_arrays ['w' ][1 ],
151
+ self .prognostic_arrays ['pi' ][1 ], self .base_state_arrays ['theta_base' ],
152
+ self .base_state_arrays ['rho_base' ], self .params ['c_s_sqr' ], self .dx , self .dz
153
+ ))
154
+ )
155
+
156
+ self .prep_new_timestep ()
157
+
158
+ def take_single_timestep (self ):
159
+ # Check if initialized
160
+ if self .t_count == 0 :
161
+ raise RuntimeError ("Must run initial timestep!" )
162
+ self .t_count += 1
163
+
164
+ # Integrate leapfrog
165
+ self .prognostic_arrays ['u' ][2 ] = (
166
+ self .prognostic_arrays ['u' ][0 ]
167
+ + 2 * self .dt * apply_periodic_lateral_zerograd_vertical (self .u_tendency (
168
+ self .prognostic_arrays ['u' ][1 ], self .prognostic_arrays ['w' ][1 ],
169
+ self .prognostic_arrays ['pi' ][1 ], self .base_state_arrays ['theta_base' ], self .dx , self .dz
170
+ ))
171
+ )
172
+ self .prognostic_arrays ['w' ][2 ] = (
173
+ self .prognostic_arrays ['w' ][0 ]
174
+ + 2 * self .dt * apply_periodic_lateral_zerow_vertical (self .w_tendency (
175
+ self .prognostic_arrays ['u' ][1 ], self .prognostic_arrays ['w' ][1 ],
176
+ self .prognostic_arrays ['pi' ][1 ], self .prognostic_arrays ['theta_p' ][1 ],
177
+ self .base_state_arrays ['theta_base' ], self .dx , self .dz
178
+ ))
179
+ )
180
+ self .prognostic_arrays ['theta_p' ][2 ] = (
181
+ self .prognostic_arrays ['theta_p' ][0 ]
182
+ + 2 * self .dt * apply_periodic_lateral_zerograd_vertical (self .theta_p_tendency (
183
+ self .prognostic_arrays ['u' ][1 ], self .prognostic_arrays ['w' ][1 ],
184
+ self .prognostic_arrays ['theta_p' ][1 ], self .base_state_arrays ['theta_base' ], self .dx , self .dz
185
+ ))
186
+ )
187
+ self .prognostic_arrays ['pi' ][2 ] = (
188
+ self .prognostic_arrays ['pi' ][0 ]
189
+ + 2 * self .dt * apply_periodic_lateral_zerograd_vertical (self .pi_tendency (
190
+ self .prognostic_arrays ['u' ][1 ], self .prognostic_arrays ['w' ][1 ],
191
+ self .prognostic_arrays ['pi' ][1 ], self .base_state_arrays ['theta_base' ],
192
+ self .base_state_arrays ['rho_base' ], self .params ['c_s_sqr' ], self .dx , self .dz
193
+ ))
194
+ )
195
+
196
+ self .prep_new_timestep ()
197
+
198
+ def integrate (self , n_steps ):
199
+ for _ in range (n_steps ):
200
+ self .take_single_timestep ()
201
+
202
+ def current_state (self ):
203
+ """Export the prognostic variables, with coordinates, at current time."""
204
+ data_vars = {}
205
+ for var in self .active_prognostic_variables :
206
+ if var == 'u' :
207
+ dims = ('t' , 'z' , 'x_stag' )
208
+ elif var == 'w' :
209
+ dims = ('t' , 'z_stag' , 'x' )
210
+ else :
211
+ dims = ('t' , 'z' , 'x' )
212
+ data_vars [var ] = xr .Variable (dims , self .prognostic_arrays [var ][1 :2 ].copy (), metadata_attrs [var ])
213
+ data_vars ['x' ] = xr .Variable ('x' , self .coords ['x' ], metadata_attrs ['x' ])
214
+ data_vars ['x_stag' ] = xr .Variable ('x_stag' , self .coords ['x_stag' ], metadata_attrs ['x_stag' ])
215
+ data_vars ['z' ] = xr .Variable ('z' , self .coords ['z' ], metadata_attrs ['z' ])
216
+ data_vars ['z_stag' ] = xr .Variable ('z_stag' , self .coords ['z_stag' ], metadata_attrs ['z_stag' ])
217
+ data_vars ['t' ] = xr .Variable ('t' , [self .t_count * self .dt ], metadata_attrs ['t' ])
218
+ return xr .Dataset (data_vars )
0 commit comments