Skip to content

Commit

Permalink
update PV site subnetwork to minimal attention mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Oct 23, 2023
1 parent 049dcef commit d8883fa
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 17 deletions.
20 changes: 9 additions & 11 deletions configs/model/multimodal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ add_image_embedding_channel: False
#--------------------------------------------

pv_encoder:
_target_: pvnet.models.multimodal.site_encoders.encoders.SimpleLearnedAggregator
_target_: pvnet.models.multimodal.site_encoders.encoders.SingleAttentionNetwork
_partial_: True
num_sites: 349
out_features: 64
value_dim: 64
value_enc_resblocks: 2
final_resblocks: 2
out_features: 40
num_heads: 4
kdim: 40
pv_id_embed_dim: 20

#--------------------------------------------
# Tabular network settings
Expand Down Expand Up @@ -79,12 +79,10 @@ pv_history_minutes: 180
# Optimizer
# ----------------------------------------------
optimizer:
_target_: pvnet.optimizers.AdamWReduceLROnPlateau
lr:
pv_encoder: 0.002
default: 0.0001
weight_decay: 0.02
_target_: pvnet.optimizers.EmbAdamWReduceLROnPlateau
lr: 0.0001
weight_decay: 0.01
amsgrad: True
patience: 5
factor: 0.1
threshold: 0.002
threshold: 0.002
3 changes: 3 additions & 0 deletions pvnet/models/multimodal/site_encoders/basic_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ def __init__(
out_features: Number of output features.
"""
super().__init__()
self.sequence_length = sequence_length
self.num_sites = num_sites
self.out_features = out_features


@abstractmethod
def forward(self):
"""Run model forward"""
Expand Down
148 changes: 145 additions & 3 deletions pvnet/models/multimodal/site_encoders/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ def __init__(

super().__init__(sequence_length, num_sites, out_features)

self.sequence_length = sequence_length
self.num_sites = num_sites

# Network used to encode each PV site sequence
self._value_encoder = nn.Sequential(
ResFCNet2(
Expand Down Expand Up @@ -109,3 +106,148 @@ def forward(self, x):
x_out = self.output_network(value_weighted_avg)

return x_out


class SingleAttentionNetwork(AbstractPVSitesEncoder):
"""A simple attention-based model with a single multihead attention layer
For the attention layer the query is based on the target GSP alone, the key is based on the PV
ID and the recent PV data, the value is based on the recent PV data.
"""
def __init__(
self,
sequence_length: int,
num_sites: int,
out_features: int,
kdim: int = 10,
num_heads: int = 2,
pv_id_embed_dim: int = 10,
n_kv_res_blocks: int = 2,
kv_res_block_layers: int = 2,
use_pv_id_in_value: bool = False,
):
"""A simple attention-based model with a single multihead attention layer
Args:
sequence_length: The time sequence length of the data.
num_sites: Number of PV sites in the input data.
out_features: Number of output features. In this network this is also the the value
dimension in the multi-head attention layer.
kdim: The dimensions used in both the keys and queries.
num_heads: Number of parallel attention heads. Note that `out_features` will be split
across `num_heads` so `out_features` must be a multiple of `num_heads`.
pv_id_embed_dim: The dimension of the PV ID embedding used in calculating the key.
n_kv_res_blocks: Number of residual blocks to use in the key and value encoders.
kv_res_block_layers: Number of fully-connected layers used in each residual block within
the key and value encoders.
use_pv_id_in_value: Whether to use the PV ID in network used to produce the value for
the attention layer.
"""
super().__init__(sequence_length, num_sites, out_features)

self.gsp_id_embedding = nn.Embedding(318, kdim)
self.pv_id_embedding = nn.Embedding(num_sites, pv_id_embed_dim)
self._pv_ids = nn.parameter.Parameter(torch.arange(num_sites), requires_grad=False)
self.use_pv_id_in_value = use_pv_id_in_value

if use_pv_id_in_value:
self.value_pv_id_embedding = nn.Embedding(num_sites, pv_id_embed_dim)

self._value_encoder = nn.Sequential(
ResFCNet2(
in_features=sequence_length+int(use_pv_id_in_value)*pv_id_embed_dim,
out_features=out_features,
fc_hidden_features=sequence_length,
n_res_blocks=n_kv_res_blocks,
res_block_layers=kv_res_block_layers,
dropout_frac=0,
),
nn.Linear(out_features, kdim),
)

self._key_encoder = nn.Sequential(
ResFCNet2(
in_features=pv_id_embed_dim+sequence_length,
out_features=kdim,
fc_hidden_features=pv_id_embed_dim+sequence_length,
n_res_blocks=n_kv_res_blocks,
res_block_layers=kv_res_block_layers,
dropout_frac=0,
),
nn.Linear(kdim, kdim)
)

self.multihead_attn = nn.MultiheadAttention(
embed_dim=kdim,
num_heads=num_heads,
batch_first=True,
vdim=out_features,
)


def _encode_query(self, x):
gsp_ids = x[BatchKey.gsp_id].squeeze().int()
query = self.gsp_id_embedding(gsp_ids).unsqueeze(1)
return query

def _encode_key(self, x):
# Shape: [batch size, sequence length, PV site]
pv_site_seqs = x[BatchKey.pv].float()
batch_size = pv_site_seqs.shape[0]

# PV ID embeddings are the same for each sample
pv_id_embed = torch.tile(self.pv_id_embedding(self._pv_ids), (batch_size, 1, 1))

# Each concated (PV sequence, PV ID embedding) is processed with encoder
x_seq_in = torch.cat((pv_site_seqs.swapaxes(1,2), pv_id_embed), dim=2).flatten(0,1)
key = self._key_encoder(x_seq_in)

# Reshape to [batch size, PV site, kdim]
key = key.unflatten(0, (batch_size, self.num_sites))
return key

def _encode_value(self, x):
# Shape: [batch size, sequence length, PV site]
pv_site_seqs = x[BatchKey.pv].float()
batch_size = pv_site_seqs.shape[0]

if self.use_pv_id_in_value:
# PV ID embeddings are the same for each sample
pv_id_embed = torch.tile(self.value_pv_id_embedding(self._pv_ids), (batch_size, 1, 1))
# Each concated (PV sequence, PV ID embedding) is processed with encoder
x_seq_in = torch.cat((pv_site_seqs.swapaxes(1,2), pv_id_embed), dim=2).flatten(0,1)
else:
# Encode each PV sequence independently
x_seq_in = pv_site_seqs.swapaxes(1,2).flatten(0,1)

value = self._value_encoder(x_seq_in)

# Reshape to [batch size, PV site, vdim]
value = value.unflatten(0, (batch_size, self.num_sites))
return value

def _attention_forward(self, x, average_attn_weights=True):

query = self._encode_query(x)
key = self._encode_key(x)
value = self._encode_value(x)

attn_output, attn_weights = self.multihead_attn(
query,
key,
value,
average_attn_weights=average_attn_weights
)

return attn_output, attn_weights

def forward(self, x):
"""Run model forward"""
attn_output, attn_output_weights = self._attention_forward(x)

# Reshape from [batch_size, 1, vdim] to [batch_size, vdim]
x_out = attn_output.squeeze()

return x_out
79 changes: 76 additions & 3 deletions pvnet/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,89 @@ def __call__(self, model):
"""Return optimizer"""
return torch.optim.AdamW(model.parameters(), lr=self.lr, **self.kwargs)


def find_submodule_parameters(model, search_modules):
if isinstance(model, search_modules):
return model.parameters()

children = list(model.children())
if len(children)==0:
return []
else:
params = []
for c in children:
params += find_submodule_parameters(c, search_modules)
return params


def find_other_than_submodule_parameters(model, ignore_modules):
if isinstance(model, ignore_modules):
return []

children = list(model.children())
if len(children)==0:
return model.parameters()
else:
params = []
for c in children:
params += find_other_than_submodule_parameters(c, ignore_modules)
return params


class EmbAdamWReduceLROnPlateau(AbstractOptimizer):
"""AdamW optimizer and reduce on plateau scheduler"""

def __init__(self, lr=0.0005, weight_decay=0.01, patience=3, factor=0.5, threshold=2e-4,
**opt_kwargs
):
"""AdamW optimizer and reduce on plateau scheduler"""
self.lr = lr
self.weight_decay = weight_decay
self.patience = patience
self.factor = factor
self.threshold = threshold
self.opt_kwargs = opt_kwargs


def __call__(self, model):
"""Return optimizer"""

search_modules = (torch.nn.Embedding, )

no_decay = find_submodule_parameters(model, search_modules)
decay = find_other_than_submodule_parameters(model, search_modules)

optim_groups = [
{"params": decay, "weight_decay": self.weight_decay},
{"params": no_decay, "weight_decay": 0.0},
]
opt = torch.optim.AdamW(optim_groups, lr=self.lr, **self.opt_kwargs)

sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
opt,
factor=self.factor,
patience=self.patience,
threshold=self.threshold,
)
sch = {
"scheduler": sch,
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
}
return [opt], [sch]


class AdamWReduceLROnPlateau(AbstractOptimizer):
"""AdamW optimizer and reduce on plateau scheduler"""

def __init__(self, lr=0.0005, patience=3, factor=0.5, threshold=2e-4, **opt_kwargs):
def __init__(self, lr=0.0005, patience=3, factor=0.5, threshold=2e-4, step_freq=None,
**opt_kwargs
):
"""AdamW optimizer and reduce on plateau scheduler"""
self._lr = lr
self.patience = patience
self.factor = factor
self.threshold = threshold
self.step_freq = step_freq
self.opt_kwargs = opt_kwargs

def _call_multi(self, model):
Expand Down Expand Up @@ -94,7 +167,7 @@ def _call_multi(self, model):
patience=self.patience,
threshold=self.threshold,
),
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/train"
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
}

return [opt], [sch]
Expand All @@ -116,6 +189,6 @@ def __call__(self, model):
)
sch = {
"scheduler": sch,
"monitor": "quantile_loss/train" if model.use_quantile_regression else "MAE/train",
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val",
}
return [opt], [sch]

0 comments on commit d8883fa

Please sign in to comment.