Skip to content

Commit

Permalink
Merge pull request #102 from openclimatefix/multiple_nwp
Browse files Browse the repository at this point in the history
Allow multiple NWPs
#minor
  • Loading branch information
dfulu authored Dec 12, 2023
2 parents 5714fbf + 3737e48 commit fe30469
Show file tree
Hide file tree
Showing 54 changed files with 586 additions and 187 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ jobs:
sudo_apt_install: "libgeos++-dev libproj-dev proj-data proj-bin"
# brew_install: "proj geos librttopo"
os_list: '["ubuntu-latest"]'
python-version: "['3.10', '3.11']"
5 changes: 3 additions & 2 deletions configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ defaults:
- experiment: null
- hparams_search: null
- hydra: default.yaml

renewable: "pv"

renewable:
"pv"

# enable color logging
# - override hydra/hydra_logging: colorlog
Expand Down
46 changes: 29 additions & 17 deletions configs/datamodule/configuration/gcp_configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ input_data:
default_forecast_minutes: 480

gsp:
gsp_zarr_path: /mnt/disks/nwp/pv_gsp.zarr
gsp_zarr_path: /mnt/disks/nwp_rechunk/pv_gsp_temp.zarr
history_minutes: 120
forecast_minutes: 480
time_resolution_minutes: 30
Expand All @@ -16,8 +16,8 @@ input_data:
pv:
pv_files_groups:
- label: solar_sheffield_passiv
pv_filename: /mnt/disks/nwp/passive/v0/passiv.netcdf
pv_metadata_filename: /mnt/disks/nwp/passive/v0/system_metadata_OCF_ONLY.csv
pv_filename: /mnt/disks/nwp_rechunk/passive/v1.1/passiv.netcdf
pv_metadata_filename: /mnt/disks/nwp_rechunk/passive/v0/system_metadata_OCF_ONLY.csv
pv_ml_ids:
[
154,
Expand Down Expand Up @@ -375,23 +375,35 @@ input_data:
time_resolution_minutes: 5

nwp:
nwp_zarr_path: /mnt/disks/nwp_rechunk/UKV_intermediate_version_7.1.zarr
history_minutes: 120
forecast_minutes: 480
time_resolution_minutes: 60
nwp_channels:
- t # live = t2m
- dswrf
nwp_image_size_pixels_height: 24
nwp_image_size_pixels_width: 24
ukv:
nwp_zarr_path:
- /mnt/disks/nwp_rechunk/UKV_intermediate_version_7.1.zarr
- /mnt/disks/nwp_rechunk/UKV_2021_NWP_missing_chunked.zarr
- /mnt/disks/nwp_rechunk/UKV_2022_NWP_chunked.zarr
- /mnt/disks/nwp_rechunk/UKV_2023_chunked.zarr
history_minutes: 120
forecast_minutes: 480
time_resolution_minutes: 60
nwp_channels:
- t # live = t2m
- dswrf
#- lcc
#- mcc
#- hcc
#- dlwrf
nwp_image_size_pixels_height: 24
nwp_image_size_pixels_width: 24
nwp_provider: ukv

satellite:
satellite_zarr_path:
- /mnt/disks/sat/2017_nonhrv.zarr
- /mnt/disks/sat/2018_nonhrv.zarr
- /mnt/disks/sat/2019_nonhrv.zarr
- /mnt/disks/sat/2020_nonhrv.zarr
- /mnt/disks/sat/2021_nonhrv.zarr
- /mnt/disks/nwp_rechunk/filled_sat/2017_nonhrv.zarr
- /mnt/disks/nwp_rechunk/filled_sat/2018_nonhrv.zarr
- /mnt/disks/nwp_rechunk/filled_sat/2019_nonhrv.zarr
- /mnt/disks/nwp_rechunk/filled_sat/2020_nonhrv.zarr
- /mnt/disks/nwp_rechunk/filled_sat/2021_nonhrv.zarr
- /mnt/disks/nwp_rechunk/filled_sat/2022_nonhrv.zarr
- /mnt/disks/nwp_rechunk/filled_sat/2023_nonhrv.zarr
history_minutes: 90
forecast_minutes: 0
live_delay_minutes: 30
Expand Down
25 changes: 15 additions & 10 deletions configs/model/multimodal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ output_quantiles: [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98]
#--------------------------------------------

nwp_encoder:
_target_: pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet
_partial_: True
in_channels: 2
out_features: 256
number_of_conv3d_layers: 6
conv3d_channels: 32
image_size_pixels: 24
ukv:
_target_: pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet
_partial_: True
in_channels: 2
out_features: 256
number_of_conv3d_layers: 6
conv3d_channels: 32
image_size_pixels: 24

#--------------------------------------------
# Sat encoder settings
Expand Down Expand Up @@ -69,12 +70,16 @@ history_minutes: 120

min_sat_delay_minutes: 60

# --- set to null if same as history_minutes ---
# These must also be set even if identical to forecast_minutes and history_minutes
sat_history_minutes: 90
nwp_history_minutes: 120
nwp_forecast_minutes: 480
pv_history_minutes: 180

# These must be set for each NWP encoder
nwp_history_minutes:
ukv: 120
nwp_forecast_minutes:
ukv: 480

# ----------------------------------------------
# Optimizer
# ----------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion pvnet/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import torch
from lightning.pytorch import LightningDataModule
from ocf_datapipes.batch import stack_np_examples_into_batch
from ocf_datapipes.training.pvnet import pvnet_datapipe
from ocf_datapipes.utils.utils import stack_np_examples_into_batch
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter import FileLister

Expand Down
19 changes: 2 additions & 17 deletions pvnet/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utils common between Wind and PV datamodules"""
import numpy as np
import torch
from ocf_datapipes.batch import unstack_np_batch_into_examples
from ocf_datapipes.utils.consts import BatchKey
from torch.utils.data import IterDataPipe, functional_datapipe

Expand All @@ -24,22 +25,6 @@ def batch_to_tensor(batch):
return batch


def split_batches(batch, splitting_key=BatchKey.gsp):
"""Splits a single batch of data."""

n_samples = batch[splitting_key].shape[0]
keys = list(batch.keys())
examples = [{} for _ in range(n_samples)]
for i in range(n_samples):
b = examples[i]
for k in keys:
if ("idx" in k.name) or ("channel_names" in k.name):
b[k] = batch[k]
else:
b[k] = batch[k][i]
return examples


@functional_datapipe("split_batches")
class BatchSplitter(IterDataPipe):
"""Pipeline step to split batches of data and yield single examples"""
Expand All @@ -52,5 +37,5 @@ def __init__(self, source_datapipe: IterDataPipe, splitting_key: BatchKey = Batc
def __iter__(self):
"""Opens the NWP data"""
for batch in self.source_datapipe:
for example in split_batches(batch, splitting_key=self.splitting_key):
for example in unstack_np_batch_into_examples(batch):
yield example
2 changes: 1 addition & 1 deletion pvnet/data/wind_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from datetime import datetime

from lightning.pytorch import LightningDataModule
from ocf_datapipes.batch import stack_np_examples_into_batch
from ocf_datapipes.training.windnet import windnet_netcdf_datapipe
from ocf_datapipes.utils.utils import stack_np_examples_into_batch
from torch.utils.data import DataLoader

from pvnet.data.utils import batch_to_tensor
Expand Down
80 changes: 48 additions & 32 deletions pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional

import torch
from ocf_datapipes.utils.consts import BatchKey
from ocf_datapipes.utils.consts import BatchKey, NWPBatchKey
from torch import nn

import pvnet
Expand Down Expand Up @@ -39,7 +39,7 @@ def __init__(
self,
output_network: AbstractLinearNetwork,
output_quantiles: Optional[list[float]] = None,
nwp_encoder: Optional[AbstractNWPSatelliteEncoder] = None,
nwp_encoders_dict: Optional[dict[AbstractNWPSatelliteEncoder]] = None,
sat_encoder: Optional[AbstractNWPSatelliteEncoder] = None,
pv_encoder: Optional[AbstractPVSitesEncoder] = None,
add_image_embedding_channel: bool = False,
Expand Down Expand Up @@ -69,8 +69,8 @@ def __init__(
features to produce the forecast.
output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to
None the output is a single value.
nwp_encoder: A partially instatiated pytorch Module class used to encode the NWP data
from 4D into an 1D feature vector.
nwp_encoders_dict: A dictionary of partially instatiated pytorch Module class used to
encode the NWP data from 4D into an 1D feature vector from different sources.
sat_encoder: A partially instatiated pytorch Module class used to encode the satellite
data from 4D into an 1D feature vector.
pv_encoder: A partially instatiated pytorch Module class used to encode the site-level
Expand All @@ -95,9 +95,10 @@ def __init__(
`history_minutes` if not provided.
optimizer: Optimizer factory function used for network.
"""

self.include_gsp_yield_history = include_gsp_yield_history
self.include_sat = sat_encoder is not None
self.include_nwp = nwp_encoder is not None
self.include_nwp = nwp_encoders_dict is not None and len(nwp_encoders_dict) != 0
self.include_pv = pv_encoder is not None
self.include_sun = include_sun
self.embedding_dim = embedding_dim
Expand All @@ -110,10 +111,9 @@ def __init__(
fusion_input_features = 0

if self.include_sat:
# We limit the history to have a delay of 15 mins in satellite data

if sat_history_minutes is None:
sat_history_minutes = history_minutes
# Param checks
assert sat_history_minutes is not None
assert nwp_forecast_minutes is not None

self.sat_sequence_len = (sat_history_minutes - min_sat_delay_minutes) // 5 + 1

Expand All @@ -130,27 +130,41 @@ def __init__(
fusion_input_features += self.sat_encoder.out_features

if self.include_nwp:
if nwp_history_minutes is None:
nwp_history_minutes = history_minutes
if nwp_forecast_minutes is None:
nwp_forecast_minutes = forecast_minutes
nwp_sequence_len = nwp_history_minutes // 60 + nwp_forecast_minutes // 60 + 1

self.nwp_encoder = nwp_encoder(
sequence_length=nwp_sequence_len,
in_channels=nwp_encoder.keywords["in_channels"] + add_image_embedding_channel,
)
# Param checks
assert nwp_forecast_minutes is not None
assert nwp_history_minutes is not None
# For each NWP encoder the forecast and history minutes must be set
assert set(nwp_encoders_dict.keys()) == set(nwp_forecast_minutes.keys())
assert set(nwp_encoders_dict.keys()) == set(nwp_history_minutes.keys())

self.nwp_encoders_dict = torch.nn.ModuleDict()
if add_image_embedding_channel:
self.nwp_embed = ImageEmbedding(
318, nwp_sequence_len, self.nwp_encoder.image_size_pixels
self.nwp_embed_dict = torch.nn.ModuleDict()

for nwp_source in nwp_encoders_dict.keys():
nwp_sequence_len = (
nwp_history_minutes[nwp_source] // 60
+ nwp_forecast_minutes[nwp_source] // 60
+ 1
)

# Update num features
fusion_input_features += self.nwp_encoder.out_features
self.nwp_encoders_dict[nwp_source] = nwp_encoders_dict[nwp_source](
sequence_length=nwp_sequence_len,
in_channels=(
nwp_encoders_dict[nwp_source].keywords["in_channels"]
+ add_image_embedding_channel
),
)
if add_image_embedding_channel:
self.nwp_embed_dict[nwp_source] = ImageEmbedding(
318, nwp_sequence_len, self.nwp_encoders_dict[nwp_source].image_size_pixels
)

# Update num features
fusion_input_features += self.nwp_encoders_dict[nwp_source].out_features

if self.include_pv:
if pv_history_minutes is None:
pv_history_minutes = history_minutes
assert pv_history_minutes is not None

self.pv_encoder = pv_encoder(
sequence_length=pv_history_minutes // 5 + 1,
Expand Down Expand Up @@ -201,13 +215,15 @@ def forward(self, x):

# *********************** NWP Data ************************************
if self.include_nwp:
# shape: batch_size, seq_len, n_chans, height, width
nwp_data = x[BatchKey.nwp].float()
nwp_data = torch.swapaxes(nwp_data, 1, 2) # switch time and channels
if self.add_image_embedding_channel:
id = x[BatchKey.gsp_id][:, 0].int()
nwp_data = self.nwp_embed(nwp_data, id)
modes["nwp"] = self.nwp_encoder(nwp_data)
# Loop through potentially many NMPs
for nwp_source in self.nwp_encoders_dict:
# shape: batch_size, seq_len, n_chans, height, width
nwp_data = x[BatchKey.nwp][nwp_source][NWPBatchKey.nwp].float()
nwp_data = torch.swapaxes(nwp_data, 1, 2) # switch time and channels
if self.add_image_embedding_channel:
id = x[BatchKey.gsp_id][:, 0].int()
nwp_data = self.nwp_embed_dict[nwp_source](nwp_data, id)
modes[f"nwp/{nwp_source}"] = self.nwp_encoders_dict[nwp_source](nwp_data)

# *********************** PV Data *************************************
# Add site-level PV yield
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
ocf_datapipes>=2.2.5
ocf_datapipes>=2.3.0
ocf_ml_metrics
numpy
pandas
matplotlib
xarray
ipykernel
h5netcdf
torch>=2.1.1
torch>=2.0.0
lightning>=2.0.1
torchvision
pytest
Expand Down
9 changes: 4 additions & 5 deletions scripts/save_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@

import hydra
import torch
from ocf_datapipes.batch import stack_np_examples_into_batch
from ocf_datapipes.training.pvnet import pvnet_datapipe
from ocf_datapipes.training.windnet import windnet_datapipe
from ocf_datapipes.utils.utils import stack_np_examples_into_batch
from omegaconf import DictConfig, OmegaConf
from sqlalchemy import exc as sa_exc
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter import IterableWrapper
from tqdm import tqdm

from pvnet.data.datamodule import batch_to_tensor
from pvnet.data.utils import batch_to_tensor
from pvnet.utils import print_config

warnings.filterwarnings("ignore", category=sa_exc.SAWarning)
Expand Down Expand Up @@ -107,9 +107,6 @@ def main(config: DictConfig):

shutil.copyfile(config_dm.configuration, f"{config.batch_output_dir}/data_configuration.yaml")

os.mkdir(f"{config.batch_output_dir}/train")
os.mkdir(f"{config.batch_output_dir}/val")

dataloader_kwargs = dict(
shuffle=False,
batch_size=None, # batched in datapipe step
Expand All @@ -126,6 +123,7 @@ def main(config: DictConfig):
)

if config.num_val_batches > 0:
os.mkdir(f"{config.batch_output_dir}/val")
print("----- Saving val batches -----")

val_batch_pipe = _get_datapipe(
Expand All @@ -143,6 +141,7 @@ def main(config: DictConfig):
)

if config.num_train_batches > 0:
os.mkdir(f"{config.batch_output_dir}/train")
print("----- Saving train batches -----")

train_batch_pipe = _get_datapipe(
Expand Down
Loading

0 comments on commit fe30469

Please sign in to comment.