-
Notifications
You must be signed in to change notification settings - Fork 2
/
plot_weatherbench.py
217 lines (181 loc) · 10.5 KB
/
plot_weatherbench.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import matplotlib.pyplot as plt
import torch
# import and set up the typeguard
from typeguard.importhook import install_import_hook
try:
import cartopy.crs as ccrs
except ModuleNotFoundError:
pass
# comment these out when deploying:
install_import_hook('src.nn')
install_import_hook('src.scoring_rules')
install_import_hook('src.utils')
install_import_hook('src.parsers')
install_import_hook('src.calibration')
install_import_hook('src.weatherbench_utils')
from src.nn import ConditionalGenerativeModel, DiscardWindowSizeDim, UNet2D, DiscardNumberGenerationsInOutput
from src.utils import load_net
from src.parsers import parser_plot_weatherbench, setup
from src.weatherbench_utils import load_weatherbench_data, convert_tensor_to_da, plot_map_ax
parser = parser_plot_weatherbench()
args = parser.parse_args()
method = args.method
scoring_rule = args.scoring_rule
kernel = args.kernel
patched = args.patched
base_measure = args.base_measure
root_folder = args.root_folder
model_folder = args.model_folder
datasets_folder = args.datasets_folder
weatherbench_data_folder = args.weatherbench_data_folder
weatherbench_small = args.weatherbench_small
unet_noise_method = args.unet_noise_method
unet_large = args.unet_large
lr = args.lr
lr_c = args.lr_c
batch_size = args.batch_size
no_early_stop = args.no_early_stop
critic_steps_every_generator_step = args.critic_steps_every_generator_step
save_plots = not args.no_save_plots
cuda = args.cuda
load_all_data_GPU = args.load_all_data_GPU
training_ensemble_size = args.training_ensemble_size
prediction_ensemble_size = args.prediction_ensemble_size
nonlinearity = args.nonlinearity
data_size = args.data_size
auxiliary_var_size = args.auxiliary_var_size
seed = args.seed
patch_size = args.patch_size
date = args.date
save_pdf = True
# notice this assumes the WeatherBench dataset is considered in the daily setup.
datasets_folder, nets_folder, data_size, auxiliary_var_size, name_postfix, unet_depths, patch_size, method_is_gan, hidden_size_rnn = \
setup("WeatherBench", root_folder, model_folder, datasets_folder, data_size, method, scoring_rule, kernel, patched,
patch_size, training_ensemble_size, auxiliary_var_size, critic_steps_every_generator_step, base_measure, lr,
lr_c, batch_size, no_early_stop, unet_noise_method, unet_large, "unet", None)
string = f"Plot WeatherBench results with {method}"
if not method_is_gan and not method == "regression":
string += f" using {scoring_rule} scoring rule"
print(string)
dataset_train, dataset_val, dataset_test = load_weatherbench_data(weatherbench_data_folder, cuda, load_all_data_GPU,
return_test=True,
weatherbench_small=weatherbench_small)
print("Validation set size:", len(dataset_val))
print("Test set size:", len(dataset_test))
if method == "regression":
net_class = UNet2D
unet_kwargs = {"in_channels": data_size[0], "out_channels": 1,
"noise_method": "no noise", "conv_depths": unet_depths}
net = DiscardWindowSizeDim(net_class(**unet_kwargs))
net = load_net(nets_folder + f"net{name_postfix}.pth", DiscardNumberGenerationsInOutput, net).net
else: # SR and GAN
# create generative net:
inner_net = UNet2D(in_channels=data_size[0], out_channels=1, noise_method=unet_noise_method,
number_generations_per_forward_call=prediction_ensemble_size, conv_depths=unet_depths)
if unet_noise_method in ["sum", "concat"]:
# here we overwrite the auxiliary_var_size above, as there is a precise constraint
downsampling_factor, n_channels = inner_net.calculate_downsampling_factor()
if weatherbench_small:
auxiliary_var_size = torch.Size(
[n_channels, 16 // downsampling_factor, 16 // downsampling_factor])
else:
auxiliary_var_size = torch.Size(
[n_channels, data_size[1] // downsampling_factor, data_size[2] // downsampling_factor])
net = load_net(nets_folder + f"net{name_postfix}.pth", ConditionalGenerativeModel, inner_net,
size_auxiliary_variable=auxiliary_var_size, base_measure=base_measure,
number_generations_per_forward_call=prediction_ensemble_size, seed=seed + 1)
elif unet_noise_method == "dropout":
net = load_net(nets_folder + f"net{name_postfix}.pth", DiscardWindowSizeDim, inner_net)
if cuda:
net.cuda()
# predict for a given date and create the plot
with torch.no_grad():
# obtain the target and context for the specified timestring
timestring = date + "T12:00:00.000000000"
context, realization = dataset_test.select_time(timestring)
# predict the realization with the context:
prediction = net(context.unsqueeze(1)).cpu() # should specify how many we want
# compute mean and standard deviation of the predictions:
prediction_mean = prediction[0].mean(dim=0)
prediction_std = prediction[0].std(dim=0)
da_prediction_mean = convert_tensor_to_da(prediction_mean, realization)
da_prediction_std = convert_tensor_to_da(prediction_std, realization)
# prediction is shape ("batch_size", "number_generations", "height", "width", "fields"). Batch size should be 1.
# convert to an xarray DataArray:
da_predictions = []
for i in range(prediction.shape[1]):
da_predictions.append(convert_tensor_to_da(prediction[0, i], realization))
if save_plots:
global_projection = False
# we do plots with 5 predictions if not deterministic
if method == "regression":
n_predictions_for_plots = 1
kwargs_subplots = dict(ncols=2, nrows=1, figsize=(16 * 2.0 / 3, 2.5),
subplot_kw=dict(projection=ccrs.PlateCarree(), facecolor="gray"))
else:
n_predictions_for_plots = 5
kwargs_subplots = dict(ncols=3, nrows=2, figsize=(16, 4.5),
subplot_kw=dict(projection=ccrs.PlateCarree(), facecolor="gray"))
# --- plot the absolute values ---
fig, axes = plt.subplots(**kwargs_subplots)
# need to find max and min values over all graphs to have coherent colorbars.
vmax = max([prediction.max().detach().numpy(), realization.max()])
vmin = min([prediction.min().detach().numpy(), realization.min()])
# plot both the realization and the prediction:
p_real = plot_map_ax(realization[:, :, 0], title="Realization", ax=axes.flatten()[0],
global_projection=global_projection, vmax=vmax, vmin=vmin)
for i in range(n_predictions_for_plots):
p_pred = plot_map_ax(da_predictions[i][:, :, 0],
title=f"Prediction" + ("{i + 1}" if method != "regression" else ""),
ax=axes.flatten()[i + 1], global_projection=global_projection, vmax=vmax, vmin=vmin)
# add now the colorbar:
fig.subplots_adjust(right=0.9)
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
fig.colorbar(p_pred, cax=cbar_ax)
fig.suptitle("Z500, " + date, size=20)
plt.savefig(nets_folder + f"map_absolute{name_postfix}." + ("pdf" if save_pdf else "png"))
# --- plot the differences from the realization ---
differences = [da_predictions[i] - realization for i in range(n_predictions_for_plots)]
fig, axes = plt.subplots(**kwargs_subplots)
# need to find max and min values over all graphs to have coherent colorbars.
vmax = max([differences[i].max() for i in range(n_predictions_for_plots)])
vmin = min([differences[i].min() for i in range(n_predictions_for_plots)])
for i in range(n_predictions_for_plots):
p_pred = plot_map_ax(differences[i][:, :, 0],
title=f"Prediction" + ("{i + 1}" if method != "regression" else ""),
ax=axes.flatten()[i + 1], global_projection=global_projection, vmax=vmax, vmin=vmin)
# add now the colorbar:
fig.subplots_adjust(right=0.9)
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
fig.colorbar(p_pred, cax=cbar_ax)
fig.suptitle("Z500, predictions - realization, " + date, size=20)
plt.savefig(nets_folder + f"map_differences{name_postfix}." + ("pdf" if save_pdf else "png"))
if method != "regression":
# --- plot the differences with respect to ensemble mean ---
differences = [da_predictions[i] - da_prediction_mean for i in range(n_predictions_for_plots)]
realization_diff = realization - da_prediction_mean
fig, axes = plt.subplots(**kwargs_subplots)
# need to find max and min values over all graphs to have coherent colorbars.
vmax = max([differences[i].max() for i in range(n_predictions_for_plots)])
vmin = min([differences[i].min() for i in range(n_predictions_for_plots)])
p_real = plot_map_ax(realization_diff[:, :, 0], title="Realization", ax=axes[0, 0],
global_projection=global_projection, vmax=vmax, vmin=vmin)
for i in range(n_predictions_for_plots):
p_pred = plot_map_ax(differences[i][:, :, 0], title=f"Prediction {i + 1}",
ax=axes.flatten()[i + 1], global_projection=global_projection, vmax=vmax,
vmin=vmin)
# add now the colorbar:
fig.subplots_adjust(right=0.9)
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
fig.colorbar(p_pred, cax=cbar_ax)
fig.suptitle("Z500, Centered in mean prediction, " + date, size=20)
plt.savefig(nets_folder + f"map_differences_ens_mean{name_postfix}." + ("pdf" if save_pdf else "png"))
# --- plot the ensemble mean and std ---
fig, axes = plt.subplots(ncols=2, nrows=1, figsize=(16 * 2.0 / 3, 3),
subplot_kw=dict(projection=ccrs.PlateCarree(), facecolor="gray"))
p_real = plot_map_ax(da_prediction_mean[:, :, 0], title="Mean", ax=axes[0],
global_projection=global_projection)
p_pred = plot_map_ax(da_prediction_std[:, :, 0], title=f"Standard deviation",
ax=axes[1], global_projection=global_projection)
fig.suptitle("Z500, Prediction mean and standard deviation, " + date, size=20)
plt.savefig(nets_folder + f"map_differences_mean_std{name_postfix}." + ("pdf" if save_pdf else "png"))