Skip to content

Commit

Permalink
Add sample saving for Site Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Sukhil Patel authored and Sukhil Patel committed Dec 17, 2024
1 parent 82b6009 commit 4fbbf15
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 321 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,20 +145,20 @@ This is also where you can update the train, val & test periods to cover the dat

### Running the batch creation script

Run the `save_batches.py` script to create batches with the parameters specified in the datamodule config (`streamed_batches.yaml` in this example):
Run the `save_samples.py` script to create batches with the parameters specified in the datamodule config (`streamed_batches.yaml` in this example):

```bash
python scripts/save_batches.py
python scripts/save_samples.py
```
PVNet uses
[hydra](https://hydra.cc/) which enables us to pass variables via the command
line that will override the configuration defined in the `./configs` directory, like this:

```bash
python scripts/save_batches.py datamodule=streamed_batches datamodule.batch_output_dir="./output" datamodule.num_train_batches=10 datamodule.num_val_batches=5
python scripts/save_samples.py datamodule=streamed_batches datamodule.sample_output_dir="./output" datamodule.num_train_batches=10 datamodule.num_val_batches=5
```

`scripts/save_batches.py` needs a config under `PVNet/configs/datamodule`. You can adapt `streamed_batches.yaml` or create your own in the same folder.
`scripts/save_samples.py` needs a config under `PVNet/configs/datamodule`. You can adapt `streamed_batches.yaml` or create your own in the same folder.

If downloading private data from a GCP bucket make sure to authenticate gcloud (the public satellite data does not need authentication):

Expand Down Expand Up @@ -197,7 +197,7 @@ Make sure to update the following config files before training your model:
2. In `configs/model/local_multimodal.yaml`:
- update the list of encoders to reflect the data sources you are using. If you are using different NWP sources, the encoders for these should follow the same structure with two important updates:
- `in_channels`: number of variables your NWP source supplies
- `image_size_pixels`: spatial crop of your NWP data. It depends on the spatial resolution of your NWP; should match `nwp_image_size_pixels_height` and/or `nwp_image_size_pixels_width` in `datamodule/example_configs.yaml`, unless transformations such as coarsening was applied (e. g. as for ECMWF data)
- `image_size_pixels`: spatial crop of your NWP data. It depends on the spatial resolution of your NWP; should match `image_size_pixels_height` and/or `image_size_pixels_width` in `datamodule/configuration/site_example_configuration.yaml` for the NWP, unless transformations such as coarsening was applied (e. g. as for ECMWF data)
3. In `configs/local_trainer.yaml`:
- set `accelerator: 0` if running on a system without a supported GPU

Expand All @@ -216,7 +216,7 @@ defaults:
- hydra: default.yaml
```

Assuming you ran the `save_batches.py` script to generate some premade train and
Assuming you ran the `save_samples.py` script to generate some premade train and
val data batches, you can now train PVNet by running:

```
Expand Down
172 changes: 0 additions & 172 deletions configs.example/datamodule/configuration/example_configuration.yaml

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
general:
description: Example config for producing PVNet samples for a reneweble generation site
name: site_example_config

input_data:

site:
time_resolution_minutes: 15
interval_start_minutes: -60
interval_end_minutes: 480
file_path: PLACEHOLDER.nc
metadata_file_path: PLACEHOLDER.csv
dropout_timedeltas_minutes: null
dropout_fraction: 0 # Fraction of samples with dropout

nwp:
ecmwf:
provider: ecmwf
# Path to ECMWF NWP data in zarr format
# n.b. It is not necessary to use multiple or any NWP data. These entries can be removed
zarr_path: PLACEHOLDER
interval_start_minutes: -60
interval_end_minutes: 480
time_resolution_minutes: 60
channels:
- t2m # 2-metre temperature
- dswrf # downwards short-wave radiation flux
- dlwrf # downwards long-wave radiation flux
- hcc # high cloud cover
- mcc # medium cloud cover
- lcc # low cloud cover
- tcc # total cloud cover
- sde # snow depth water equivalent
- sr # direct solar radiation
- duvrs # downwards UV radiation at surface
- prate # precipitation rate
- u10 # 10-metre U component of wind speed
- u100 # 100-metre U component of wind speed
- u200 # 200-metre U component of wind speed
- v10 # 10-metre V component of wind speed
- v100 # 100-metre V component of wind speed
- v200 # 200-metre V component of wind speed
image_size_pixels_height: 24
image_size_pixels_width: 24
dropout_timedeltas_minutes: [-360]
dropout_fraction: 1.0
max_staleness_minutes: null

satellite:
zarr_path: PLACEHOLDER.zarr
interval_start_minutes: -30
interval_end_minutes: 0
time_resolution_minutes: 5
channels:
# Uses for each channel taken from https://resources.eumetrain.org/data/3/311/bsc_s4.pdf
- IR_016 # Surface, cloud phase
- IR_039 # Surface, clouds, wind fields
- IR_087 # Surface, clouds, atmospheric instability
- IR_097 # Ozone
- IR_108 # Surface, clouds, wind fields, atmospheric instability
- IR_120 # Surface, clouds, atmospheric instability
- IR_134 # Cirrus cloud height, atmospheric instability
- VIS006 # Surface, clouds, wind fields
- VIS008 # Surface, clouds, wind fields
- WV_062 # Water vapor, high level clouds, upper air analysis
- WV_073 # Water vapor, atmospheric instability, upper-level dynamics
image_size_pixels_height: 24
image_size_pixels_width: 24
dropout_timedeltas_minutes: null
dropout_fraction: 0.
2 changes: 1 addition & 1 deletion configs.example/datamodule/premade_batches.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ _target_: pvnet.data.datamodule.DataModule
configuration: null
# The batch_dir is the location batches were saved to using the save_batches.py script
# The batch_dir should contain train and val subdirectories with batches
batch_dir: "PLACEHOLDER"
sample_dir: "PLACEHOLDER"
num_workers: 10
prefetch_factor: 2
batch_size: 8
10 changes: 3 additions & 7 deletions configs.example/datamodule/streamed_batches.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@ configuration: "PLACEHOLDER.yaml"
num_workers: 20
prefetch_factor: 2
batch_size: 8
batch_output_dir: "PLACEHOLDER"
num_train_batches: 2
num_val_batches: 1

sample_output_dir: "PLACEHOLDER"
num_train_samples: 2
num_val_samples: 1

train_period:
- null
- "2022-05-07"
val_period:
- "2022-05-08"
- "2023-05-08"
test_period:
- "2022-05-08"
- "2023-05-08"
9 changes: 7 additions & 2 deletions pvnet/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from lightning.pytorch import LightningDataModule
from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset
from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset, SitesDataset
from ocf_datapipes.batch import (
NumpyBatch,
TensorBatch,
Expand Down Expand Up @@ -93,7 +93,12 @@ def __init__(
)

def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset:
return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time)
if self.configuration.renewable == "pv":
return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time)
elif self.configuration.renewable in ["wind", "pv_india", "pv_site"]:
return SitesDataset(self.configuration, start_time=start_time, end_time=end_time)
else:
raise ValueError(f"Unknown renewable: {self.configuration.renewable}")

def _get_premade_samples_dataset(self, subdir) -> Dataset:
split_dir = f"{self.sample_dir}/{subdir}"
Expand Down
Loading

0 comments on commit 4fbbf15

Please sign in to comment.