Skip to content

Commit

Permalink
update tests and fix attention model
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Oct 24, 2023
1 parent 69c1884 commit e052bdc
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 38 deletions.
29 changes: 14 additions & 15 deletions pvnet/models/multimodal/site_encoders/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def __init__(
num_sites: int,
out_features: int,
kdim: int = 10,
num_heads: int = 2,
pv_id_embed_dim: int = 10,
num_heads: int = 2,
n_kv_res_blocks: int = 2,
kv_res_block_layers: int = 2,
use_pv_id_in_value: bool = False,
Expand All @@ -131,22 +131,22 @@ def __init__(
Args:
sequence_length: The time sequence length of the data.
num_sites: Number of PV sites in the input data.
out_features: Number of output features. In this network this is also the the value
dimension in the multi-head attention layer.
kdim: The dimensions used in both the keys and queries.
out_features: Number of output features. In this network this is also the embed and
value dimension in the multi-head attention layer.
kdim: The dimensions used the keys.
pv_id_embed_dim: Number of dimensiosn used in the PD ID embedding layer(s).
num_heads: Number of parallel attention heads. Note that `out_features` will be split
across `num_heads` so `out_features` must be a multiple of `num_heads`.
pv_id_embed_dim: The dimension of the PV ID embedding used in calculating the key.
n_kv_res_blocks: Number of residual blocks to use in the key and value encoders.
kv_res_block_layers: Number of fully-connected layers used in each residual block within
the key and value encoders.
use_pv_id_in_value: Whether to use the PV ID in network used to produce the value for
the attention layer.
use_pv_id_in_value: Whether to use a PV ID embedding in network used to produce the
value for the attention layer.
"""
super().__init__(sequence_length, num_sites, out_features)

self.gsp_id_embedding = nn.Embedding(318, kdim)
self.gsp_id_embedding = nn.Embedding(318, out_features)
self.pv_id_embedding = nn.Embedding(num_sites, pv_id_embed_dim)
self._pv_ids = nn.parameter.Parameter(torch.arange(num_sites), requires_grad=False)
self.use_pv_id_in_value = use_pv_id_in_value
Expand All @@ -163,26 +163,25 @@ def __init__(
res_block_layers=kv_res_block_layers,
dropout_frac=0,
),
nn.Linear(out_features, kdim),
)

self._key_encoder = nn.Sequential(
ResFCNet2(
in_features=pv_id_embed_dim + sequence_length,
in_features=sequence_length + pv_id_embed_dim,
out_features=kdim,
fc_hidden_features=pv_id_embed_dim + sequence_length,
n_res_blocks=n_kv_res_blocks,
res_block_layers=kv_res_block_layers,
dropout_frac=0,
),
nn.Linear(kdim, kdim),
)

self.multihead_attn = nn.MultiheadAttention(
embed_dim=kdim,
embed_dim=out_features,
kdim=kdim,
vdim=out_features,
num_heads=num_heads,
batch_first=True,
vdim=out_features,
)

def _encode_query(self, x):
Expand Down Expand Up @@ -233,8 +232,8 @@ def _attention_forward(self, x, average_attn_weights=True):

attn_output, attn_weights = self.multihead_attn(
query, key, value, average_attn_weights=average_attn_weights
)

)
return attn_output, attn_weights

def forward(self, x):
Expand Down
81 changes: 64 additions & 17 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
import numpy as np
import xarray as xr
import torch
import hydra

from ocf_datapipes.utils.consts import BatchKey
from datetime import timedelta

import pvnet
from pvnet.data.datamodule import DataModule

import pvnet.models.multimodal.encoders.encoders3d
import pvnet.models.multimodal.linear_networks.networks


xr.set_options(keep_attrs=True)

Expand Down Expand Up @@ -114,6 +118,12 @@ def sample_satellite_batch(sample_batch):
return torch.swapaxes(sat_image, 1, 2)


@pytest.fixture()
def sample_pv_batch(sample_batch):
pv_data = sample_batch[BatchKey.pv]
return pv_data


@pytest.fixture()
def model_minutes_kwargs():
kwargs = dict(
Expand All @@ -134,37 +144,74 @@ def encoder_model_kwargs():
)
return kwargs

@pytest.fixture()
def site_encoder_model_kwargs():
# Used to test site encoder model on PV data
kwargs = dict(
sequence_length=180 // 5 +1,
num_sites=349,
out_features=128,
)
return kwargs

@pytest.fixture()
def multimodal_model_kwargs(model_minutes_kwargs):

kwargs = dict(
image_encoder=pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet,
encoder_out_features=128,
encoder_kwargs=dict(
number_of_conv3d_layers=6,
conv3d_channels=32,

sat_encoder=dict(
_target_=pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet,
_partial_=True,
in_channels=11,
out_features=128,
number_of_conv3d_layers=6,
conv3d_channels=32,
image_size_pixels=24,
),

nwp_encoder=dict(
_target_=pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet,
_partial_=True,
in_channels=2,
out_features=128,
number_of_conv3d_layers=6,
conv3d_channels=32,
image_size_pixels=24,
),
include_sat=True,
include_nwp=True,

add_image_embedding_channel=True,
sat_image_size_pixels=24,
nwp_image_size_pixels=24,
number_sat_channels=11,
number_nwp_channels=2,
output_network=pvnet.models.multimodal.linear_networks.networks.ResFCNet2,
output_network_kwargs=dict(
fc_hidden_features=128,
n_res_blocks=6,
res_block_layers=2,
dropout_frac=0.0,

pv_encoder=dict(
_target_=pvnet.models.multimodal.site_encoders.encoders.SingleAttentionNetwork,
_partial_=True,
num_sites=349,
out_features=40,
num_heads=4,
kdim=40,
pv_id_embed_dim=20,
),

output_network=dict(
_target_=pvnet.models.multimodal.linear_networks.networks.ResFCNet2,
_partial_=True,
fc_hidden_features=128,
n_res_blocks=6,
res_block_layers=2,
dropout_frac=0.0,
),

embedding_dim=16,
include_sun=True,
include_gsp_yield_history=True,
sat_history_minutes=90,
nwp_history_minutes=120,
nwp_forecast_minutes=480,
pv_history_minutes=180,
min_sat_delay_minutes=30,

)

kwargs = hydra.utils.instantiate(kwargs)

kwargs.update(model_minutes_kwargs)
return kwargs
Binary file modified tests/data/sample_batches/train/000000.pt
Binary file not shown.
Binary file modified tests/data/sample_batches/train/000001.pt
Binary file not shown.
41 changes: 41 additions & 0 deletions tests/models/multimodal/site_encoders/test_encoders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
from ocf_datapipes.utils.consts import BatchKey
from torch import nn

from pvnet.models.multimodal.site_encoders.encoders import (
SimpleLearnedAggregator,
SingleAttentionNetwork,
)

import pytest


def _test_model_forward(batch, model_class, kwargs):
model = model_class(**kwargs)
y = model(batch)
assert tuple(y.shape) == (2, kwargs["out_features"]), y.shape


def _test_model_backward(batch, model_class, kwargs):
model = model_class(**kwargs)
y = model(batch)
# Backwards on sum drives sum to zero
y.sum().backward()


# Test model forward on all models
def test_simplelearnedaggregator_forward(sample_batch, site_encoder_model_kwargs):
_test_model_forward(sample_batch, SimpleLearnedAggregator, site_encoder_model_kwargs)


def test_singleattentionnetwork_forward(sample_batch, site_encoder_model_kwargs):
_test_model_forward(sample_batch, SingleAttentionNetwork, site_encoder_model_kwargs)


# Test model backward on all models
def test_simplelearnedaggregator_backward(sample_batch, site_encoder_model_kwargs):
_test_model_backward(sample_batch, SimpleLearnedAggregator, site_encoder_model_kwargs)


def test_singleattentionnetwork_backward(sample_batch, site_encoder_model_kwargs):
_test_model_backward(sample_batch, SingleAttentionNetwork, site_encoder_model_kwargs)
4 changes: 2 additions & 2 deletions tests/models/multimodal/test_deep_supervision.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ def deepsupervision_model(multimodal_model_kwargs):
model = Model(**multimodal_model_kwargs)
return model


@pytest.mark.skip(reason="This model is no longer maintained")
def test_model_forward(deepsupervision_model, sample_batch):
y = deepsupervision_model(sample_batch)

# check output is the correct shape
# batch size=2, forecast_len=15
assert tuple(y.shape) == (2, 16), y.shape


@pytest.mark.skip(reason="This model is no longer maintained")
def test_model_backwards(deepsupervision_model, sample_batch):
opt = SGD(deepsupervision_model.parameters(), lr=0.001)

Expand Down
4 changes: 2 additions & 2 deletions tests/models/multimodal/test_nwp_weighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ def nwp_weighting_model(model_minutes_kwargs):
)
return model


@pytest.mark.skip(reason="This model is no longer maintained")
def test_model_forward(nwp_weighting_model, sample_batch):
y = nwp_weighting_model(sample_batch)

# check output is the correct shape
# batch size=2, forecast_len=15
assert tuple(y.shape) == (2, 16), y.shape


@pytest.mark.skip(reason="This model is no longer maintained")
def test_model_backwards(nwp_weighting_model, sample_batch):
opt = SGD(nwp_weighting_model.parameters(), lr=0.001)

Expand Down
4 changes: 2 additions & 2 deletions tests/models/multimodal/test_weather_residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ def weather_residual_model(multimodal_model_kwargs):
model = Model(**multimodal_model_kwargs)
return model


@pytest.mark.skip(reason="This model is no longer maintained")
def test_model_forward(weather_residual_model, sample_batch):
y = weather_residual_model(sample_batch)

# check output is the correct shape
# batch size=2, forecast_len=15
assert tuple(y.shape) == (2, 16), y.shape


@pytest.mark.skip(reason="This model is no longer maintained")
def test_model_backwards(weather_residual_model, sample_batch):
opt = SGD(weather_residual_model.parameters(), lr=0.001)

Expand Down

0 comments on commit e052bdc

Please sign in to comment.