diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index f16a4a30..6e7a4cdb 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -1,64 +1,70 @@ dataset: - name: meps_example + name: cosmo + + # Value definitions var_names: - - pres_0g - - pres_0s - - nlwrs_0 - - nswrs_0 - - r_2 - - r_65 - - t_2 - - t_65 - - t_500 - - t_850 - - u_65 - - u_850 - - v_65 - - v_850 - - wvint_0 - - z_1000 - - z_500 + - "T" + - "U" + - "V" + - "RELHUM" + - "PMSL" + - "PP" var_units: - - Pa - - Pa - - r"$\mathrm{W}/\mathrm{m}^2$" - - r"$\mathrm{W}/\mathrm{m}^2$" - - "" - - "" - - K - - K - K - - K - - m/s - m/s - m/s - - m/s - - r"$\mathrm{kg}/\mathrm{m}^2$" - - r"$\mathrm{m}^2/\mathrm{s}^2$" - - r"$\mathrm{m}^2/\mathrm{s}^2$" + - Perc. + - Pa + - hPa var_longnames: - - pres_heightAboveGround_0_instant - - pres_heightAboveSea_0_instant - - nlwrs_heightAboveGround_0_accum - - nswrs_heightAboveGround_0_accum - - r_heightAboveGround_2_instant - - r_hybrid_65_instant - - t_heightAboveGround_2_instant - - t_hybrid_65_instant - - t_isobaricInhPa_500_instant - - t_isobaricInhPa_850_instant - - u_hybrid_65_instant - - u_isobaricInhPa_850_instant - - v_hybrid_65_instant - - v_isobaricInhPa_850_instant - - wvint_entireAtmosphere_0_instant - - z_isobaricInhPa_1000_instant - - z_isobaricInhPa_500_instant + - "Temperature" + - "Zonal wind component" + - "Meridional wind component" + - "Relative humidity" + - "Pressure at Mean Sea Level" + - "Pressure Perturbation" + var_is_3d: + - 1 + - 1 + - 1 + - 1 + - 0 + - 1 + grib_names: + PP: "pres" + QV: "q" + RELHUM: "r" + T: "t" + U: "u" + V: "v" + W: "wz" + CLCT: "ccl" + PMSL: "prmsl" + PS: "sp" + T_2M: "2t" + TOT_PREC: "tp" + U_10M: "10u" + V_10M: "10v" + vertical_levels: [1, 5, 13, 22, 38, 41, 60] num_forcing_features: 16 -grid_shape_state: [268, 238] -projection: - class: LambertConformal - kwargs: - central_longitude: 15.0 - central_latitude: 63.3 - standard_parallels: [63.3, 63.3] + + # Plotting + eval_plot_vars: ["TQV"] + grid_shape_state: [390, 582] + projection: + class: LambertConformal + kwargs: + central_longitude: 15.0 + central_latitude: 63.3 + standard_parallels: [63.3, 63.3] + sample_grib: + "templates/lfff02180000" + sample_z_grib: + "templates/lfff02180000z" + eval_datetime: + ["2020100215"] + + # Time step prediction during training / prediction (eval) + train_horizon: 3 + eval_horizon: 25 + diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 29b169d4..1b4642f0 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -1,10 +1,15 @@ # Standard library +from datetime import datetime, timedelta +import glob import os # Third-party +import earthkit.data +import imageio import matplotlib.pyplot as plt import numpy as np import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_only import torch import wandb @@ -93,6 +98,20 @@ def __init__(self, args): # For storing spatial loss maps during evaluation self.spatial_loss_maps = [] + self.inference_output = [] + "Storage for the output of individual inference steps" + + self.variable_indices = self.pre_compute_variable_indices() + "Index mapping of variable names to their levels in the array." + self.selected_vars_units = [ + (var_name, var_unit) + for var_name, var_unit in zip( + self.config_loader.dataset.var_names, + self.config_loader.dataset.var_units, + ) + if var_name in self.config_loader.dataset.eval_plot_vars + ] + def configure_optimizers(self): opt = torch.optim.AdamW( self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) @@ -106,6 +125,34 @@ def interior_mask_bool(self): """ return self.interior_mask[:, 0].to(torch.bool) + def pre_compute_variable_indices(self): + """ + Pre-compute indices for each variable in the input tensor + """ + variable_indices = {} + all_vars = [] + index = 0 + # Create a list of tuples for all variables, using level 0 for 2D + # variables + for var_name in self.config_loader.dataset.var_names: + if self.config_loader.dataset.var_is_3d: + for level in self.config_loader.dataset.vertical_levels: + all_vars.append((var_name, level)) + else: + all_vars.append((var_name, 0)) # Use level 0 for 2D variables + + # Sort the variables based on the tuples + sorted_vars = sorted(all_vars) + + for var in sorted_vars: + var_name, level = var + if var_name not in variable_indices: + variable_indices[var_name] = [] + variable_indices[var_name].append(index) + index += 1 + + return variable_indices + @staticmethod def expand_to_batch(x, batch_size): """ @@ -113,7 +160,7 @@ def expand_to_batch(x, batch_size): """ return x.unsqueeze(0).expand(batch_size, -1, -1) - def predict_step(self, prev_state, prev_prev_state, forcing): + def single_prediction(self, prev_state, prev_prev_state, forcing): """ Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 prev_state: (B, num_grid_nodes, feature_dim), X_t @@ -122,6 +169,48 @@ def predict_step(self, prev_state, prev_prev_state, forcing): """ raise NotImplementedError("No prediction step implemented") + def predict_step(self, batch, batch_idx): + """ + Run the inference on batch. + """ + prediction, target, pred_std = self.common_step(batch) + + # Compute all evaluation metrics for error maps + # Note: explicitly list metrics here, as test_metrics can contain + # additional ones, computed differently, but that should be aggregated + # on_predict_epoch_end + for metric_name in ("mse", "mae"): + metric_func = metrics.get_metric(metric_name) + batch_metric_vals = metric_func( + prediction, + target, + pred_std, + mask=self.interior_mask_bool, + sum_vars=False, + ) # (B, pred_steps, d_f) + self.test_metrics[metric_name].append(batch_metric_vals) + + if self.output_std: + # Store output std. per variable, spatially averaged + mean_pred_std = torch.mean( + pred_std[..., self.interior_mask_bool, :], dim=-2 + ) # (B, pred_steps, d_f) + self.test_metrics["output_std"].append(mean_pred_std) + + # Save per-sample spatial loss for specific times + spatial_loss = self.loss( + prediction, target, pred_std, average_grid=False + ) # (B, pred_steps, num_grid_nodes) + log_spatial_losses = spatial_loss[ + :, [step - 1 for step in self.args.val_steps_to_log] + ] + self.spatial_loss_maps.append(log_spatial_losses) + # (B, N_log, num_grid_nodes) + + if self.trainer.global_rank == 0: + self.plot_examples(batch, batch_idx, prediction=prediction) + self.inference_output.append(prediction) + def unroll_prediction(self, init_states, forcing_features, true_states): """ Roll out prediction taking multiple autoregressive steps with model @@ -139,7 +228,7 @@ def unroll_prediction(self, init_states, forcing_features, true_states): forcing = forcing_features[:, i] border_state = true_states[:, i] - pred_state, pred_std = self.predict_step( + pred_state, pred_std = self.single_prediction( prev_state, prev_prev_state, forcing ) # state: (B, num_grid_nodes, d_f) @@ -345,20 +434,50 @@ def test_step(self, batch, batch_idx): batch, n_additional_examples, prediction=prediction ) - def plot_examples(self, batch, n_examples, prediction=None): + @rank_zero_only + def plot_examples(self, batch, n_examples, batch_idx: int, prediction=None): """ - Plot the first n_examples forecasts from batch - - batch: batch with data to plot corresponding forecasts for - n_examples: number of forecasts to plot - prediction: (B, pred_steps, num_grid_nodes, d_f), existing prediction. - Generate if None. + Plot the first n_examples forecasts from batch. + + The function checks for the presence of test_dataset or + predict_dataset within the trainer's data module, + handles indexing within the batch for targeted analysis, + performs prediction rescaling, and plots results. + + Parameters: + - batch: batch with data to plot corresponding forecasts for + - n_examples: number of forecasts to plot + - batch_idx (int): index of the batch being processed + - prediction: (B, pred_steps, num_grid_nodes, d_f), existing prediction. + Generate if None. """ if prediction is None: prediction, target = self.common_step(batch) target = batch[1] + # Determine the dataset to work with (test_dataset or predict_dataset) + dataset = None + if ( + hasattr(self.trainer.datamodule, "test_dataset") + and self.trainer.datamodule.test_dataset + ): + dataset = self.trainer.datamodule.test_dataset + plot_name = "test" + elif ( + hasattr(self.trainer.datamodule, "predict_dataset") + and self.trainer.datamodule.predict_dataset + ): + dataset = self.trainer.datamodule.predict_dataset + plot_name = "prediction" + + if ( + dataset + and self.trainer.global_rank == 0 + and dataset.batch_index == batch_idx + ): + index_within_batch = dataset.index_within_batch + # Rescale to original data scale prediction_rescaled = prediction * self.data_std + self.data_mean target_rescaled = target * self.data_std + self.data_mean @@ -415,7 +534,7 @@ def plot_examples(self, batch, n_examples, prediction=None): example_i = self.plotted_examples wandb.log( { - f"{var_name}_example_{example_i}": wandb.Image(fig) + f"{var_name}_{plot_name}_{example_i}": wandb.Image(fig) for var_name, fig in zip( self.config_loader.dataset.var_names, var_figs ) @@ -573,6 +692,144 @@ def on_test_epoch_end(self): self.spatial_loss_maps.clear() + @rank_zero_only + def on_predict_epoch_end(self): + """ + Compute test metrics and make plots at the end of test epoch. + Will gather stored tensors and perform plotting and logging on rank 0. + """ + + plot_dir_path = f"{wandb.run.dir}/media/images" + value_dir_path = f"{wandb.run.dir}/results/inference" + # Ensure the directory for saving numpy arrays exists + os.makedirs(plot_dir_path, exist_ok=True) + os.makedirs(value_dir_path, exist_ok=True) + + # For values + for i, prediction in enumerate(self.inference_output): + + # Rescale to original data scale + prediction_rescaled = prediction * self.data_std + self.data_mean + + # Process and save the prediction + prediction_array = prediction_rescaled.cpu().numpy() + file_path = os.path.join(value_dir_path, f"prediction_{i}.npy") + np.save(file_path, prediction_array) + self.save_pred_as_grib(file_path, value_dir_path) + + dir_path = f"{wandb.run.dir}/media/images" + for var_name, _ in self.selected_vars_units: + var_indices = self.variable_indices[var_name] + for lvl_i, _ in enumerate(var_indices): + # Calculate var_vrange for each index + lvl = self.config_loader.dataset.vertical_levels[lvl_i] + + # Get all the images for the current variable and index + images = sorted( + glob.glob( + f"{dir_path}/{var_name}_test_lvl_{lvl:02}_t_*.png" + ) + ) + # Generate the GIF + with imageio.get_writer( + f"{dir_path}/{var_name}_lvl_{lvl:02}.gif", + mode="I", + fps=1, + ) as writer: + for filename in images: + image = imageio.imread(filename) + writer.append_data(image) + + self.spatial_loss_maps.clear() + + def _generate_time_steps(self): + """Generate a list with all time steps in inference.""" + # Parse the times + base_time = self.config_loader.dataset.eval_datetime[0] + + if isinstance(base_time, str): + base_time = datetime.strptime(base_time, "%Y%m%d%H") + time_steps = {} + # Generate dates for each step + for i in range(self.config_loader.dataset.eval_horizon - 2): + # Compute the new date by adding the step interval in hours - 3 + new_date = base_time + timedelta(hours=i * self.config_loader.dataset.train_horizon) + # Format the date back + time_steps[i] = new_date.strftime("%Y%m%d%H") + + def save_pred_as_grib(self, file_path: str, value_dir_path: str): + """Save the prediction values into GRIB format.""" + # Initialize the lists to loop over + indices = self.precompute_variable_indices() + time_steps = self._generate_time_steps() + # Loop through all the time steps and all the variables + for time_idx, date_str in time_steps.items(): + # Initialize final data object + final_data = earthkit.data.FieldList() + for variable, grib_code in self.config_loader.dataset.grib_names.items(): + # here find the key of the cariable in constants.is_3D + # and if == 7, assign a cut of 7 on the reshape. Else 1 + if self.config_loader.dataset.var_is_3d[variable]: + shape_val = len(self.config_loader.dataset.vertical_levels) + vertical = self.config_loader.dataset.vertical_levels + else: + # Special handling for T_2M and *_10M variables + if variable == "T_2M": + shape_val = 1 + vertical = 2 + elif variable.endswith("_10M"): + shape_val = 1 + vertical = 10 + else: + shape_val = 1 + vertical = 0 + # Find the value range to sample + value_range = indices[variable] + + sample_file = self.config_loader.dataset.sample_grib + if variable == "RELHUM": + variable = "r" + sample_file = self.config_loader.dataset.sample_z_grib + + # Load the sample grib file + original_data = earthkit.data.from_source("file", sample_file) + + subset = original_data.sel(shortName=grib_code, level=vertical) + md = subset.metadata() + + # Cut the datestring into date and time and then override all + # values in md + date = date_str[:8] + time = date_str[8:] + + for index, item in enumerate(md): + md[index] = item.override({"date": date}).override( + {"time": time} + ) + if len(md) > 0: + # Load the array to replace the values with + replacement_data = np.load(file_path) + original_cut = replacement_data[ + 0, time_idx, :, min(value_range) : max(value_range) + 1 + ].reshape( + self.config_loader.dataset.grib_shape_state[1], + self.config_loader.dataset.grib_shape_state[0], + shape_val, + ) + cut_values = np.moveaxis( + original_cut, [-3, -2, -1], [-1, -2, -3] + ) + # Can we stack Fieldlists? + data_new = earthkit.data.FieldList.from_array( + cut_values, md + ) + final_data += data_new + # Create the modified GRIB file with the predicted data + grib_path = os.path.join( + value_dir_path, f"prediction_{date_str}_grib" + ) + final_data.save(grib_path) + def on_load_checkpoint(self, checkpoint): """ Perform any changes to state dict before loading checkpoint diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 256d4adc..dbe15a02 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -98,7 +98,7 @@ def process_step(self, mesh_rep): """ raise NotImplementedError("process_step not implemented") - def predict_step(self, prev_state, prev_prev_state, forcing): + def single_prediction(self, prev_state, prev_prev_state, forcing): """ Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 prev_state: (B, num_grid_nodes, feature_dim), X_t diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index a782806b..686bffcd 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -35,7 +35,7 @@ def __init__( ): super().__init__() - assert split in ("train", "val", "test"), "Unknown dataset split" + assert split in ("train", "val", "test", "pred"), "Unknown dataset split" self.sample_dir_path = os.path.join( "data", dataset_name, "samples", split ) diff --git a/requirements.txt b/requirements.txt index f381d54f..0c1b09fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,8 @@ Cartopy>=0.22.0 pyproj>=3.4.1 tueplots>=0.0.8 plotly>=5.15.0 +earthkit-data>=0.7.0 +eccodes>=1.7.0 # for dev pre-commit>=2.15.0 diff --git a/train_model.py b/train_model.py index fe064384..2cfed246 100644 --- a/train_model.py +++ b/train_model.py @@ -217,6 +217,7 @@ def main(): None, "val", "test", + "predict", ), f"Unknown eval setting: {args.eval}" # Get an (actual) random run id as a unique identifier @@ -294,6 +295,7 @@ def main(): callbacks=[checkpoint_callback], check_val_every_n_epoch=args.val_interval, precision=args.precision, + limit_predict_batches=1 ) # Only init once, on rank 0 only @@ -305,7 +307,7 @@ def main(): if args.eval: if args.eval == "val": eval_loader = val_loader - else: # Test + elif args.eval == "test": eval_loader = torch.utils.data.DataLoader( WeatherDataset( config_loader.dataset.name, @@ -318,9 +320,34 @@ def main(): shuffle=False, num_workers=args.n_workers, ) + elif args.eval == "predict": + pred_loader = torch.utils.data.DataLoader( + WeatherDataset( + config_loader.dataset.name, + pred_length=max_pred_length, + split="predict", + subsample_step=args.step_length, + subset=bool(args.subset_ds), + ), + args.batch_size, + shuffle=False, + num_workers=args.n_workers, + ) + print(f"Running prediction on {args.eval}") + trainer.predict( + model=model, + dataloaders=pred_loader, + return_predictions=True, + ckpt_path=args.load, + ) + else: + print(f"Unknown evaluation mode: {args.eval}") + raise ValueError(f"Unknown evaluation mode: {args.eval}") - print(f"Running evaluation on {args.eval}") - trainer.test(model=model, dataloaders=eval_loader, ckpt_path=args.load) + if args.eval in ["val", "test"]: + print(f"Running evaluation on {args.eval}") + trainer.test(model=model, dataloaders=eval_loader, ckpt_path=args.load) + else: # Train model trainer.fit(