Skip to content

Commit

Permalink
allow different learning rates for submodules + fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Oct 3, 2023
1 parent ab70636 commit 049dcef
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 23 deletions.
10 changes: 6 additions & 4 deletions configs/model/multimodal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ output_quantiles: [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98]
# NWP encoder
#--------------------------------------------

nwp_encoder:
nwp_encoder:
_target_: pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet
_partial_: True
in_channels: 2
Expand All @@ -28,7 +28,7 @@ sat_encoder:
conv3d_channels: 32
image_size_pixels: 24

add_image_embedding_channel: True
add_image_embedding_channel: False

#--------------------------------------------
# PV encoder settings
Expand Down Expand Up @@ -80,8 +80,10 @@ pv_history_minutes: 180
# ----------------------------------------------
optimizer:
_target_: pvnet.optimizers.AdamWReduceLROnPlateau
lr: 0.0001
weight_decay: 0.25
lr:
pv_encoder: 0.002
default: 0.0001
weight_decay: 0.02
amsgrad: True
patience: 5
factor: 0.1
Expand Down
2 changes: 1 addition & 1 deletion pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,4 +448,4 @@ def configure_optimizers(self):
if self.lr is not None:
# Use learning rate found by learning rate finder callback
self._optimizer.lr = self.lr
return self._optimizer(self.parameters())
return self._optimizer(self)
4 changes: 2 additions & 2 deletions pvnet/models/multimodal/site_encoders/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _calculate_attention(self, x):

def _encode_value(self, x):
# Shape: [batch size, sequence length, PV site]
pv_site_seqs = x[BatchKey.pv]
pv_site_seqs = x[BatchKey.pv].float()
batch_size = pv_site_seqs.shape[0]

pv_site_seqs = pv_site_seqs.swapaxes(1,2).flatten(0,1)
Expand All @@ -97,7 +97,7 @@ def _encode_value(self, x):
def forward(self, x):
"""Run model forward"""
# Output has shape: [batch size, num_sites, value_dim]
encodeded_seqs = self.encode_value(x)
encodeded_seqs = self._encode_value(x)

# Calculate learned averaging weights
attn_avg_weights = self._calculate_attention(x)
Expand Down
78 changes: 62 additions & 16 deletions pvnet/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class AbstractOptimizer(ABC):
Optimizer classes will be used by model like:
> OptimizerGenerator = AbstractOptimizer()
> optimizer = OptimizerGenerator(model.parameters())
> optimizer = OptimizerGenerator(model)
The returned object `optimizer` must be something that may be returned by `pytorch_lightning`'s
`configure_optimizers()` method.
See :
Expand All @@ -33,9 +33,9 @@ def __init__(self, lr=0.0005, **kwargs):
self.lr = lr
self.kwargs = kwargs

def __call__(self, model_parameters):
def __call__(self, model):
"""Return optimizer"""
return torch.optim.Adam(model_parameters, lr=self.lr, **self.kwargs)
return torch.optim.Adam(model.parameters(), lr=self.lr, **self.kwargs)


class AdamW(AbstractOptimizer):
Expand All @@ -46,30 +46,76 @@ def __init__(self, lr=0.0005, **kwargs):
self.lr = lr
self.kwargs = kwargs

def __call__(self, model_parameters):
def __call__(self, model):
"""Return optimizer"""
return torch.optim.AdamW(model_parameters, lr=self.lr, **self.kwargs)
return torch.optim.AdamW(model.parameters(), lr=self.lr, **self.kwargs)


class AdamWReduceLROnPlateau(AbstractOptimizer):
"""AdamW optimizer and reduce on plateau scheduler"""

def __init__(self, lr=0.0005, patience=3, factor=0.5, threshold=2e-4, **opt_kwargs):
"""AdamW optimizer and reduce on plateau scheduler"""
self.lr = lr
self._lr = lr
self.patience = patience
self.factor = factor
self.threshold = threshold
self.opt_kwargs = opt_kwargs

def __call__(self, model_parameters):
"""Return optimizer"""
opt = torch.optim.AdamW(model_parameters, lr=self.lr, **self.opt_kwargs)
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
opt,
factor=self.factor,
patience=self.patience,
threshold=self.threshold,

def _call_multi(self, model):

remaining_params = {k:p for k,p in model.named_parameters()}

group_args = []

for key in self._lr.keys():
if key=="default":
continue

submodule_params = []
for param_name in list(remaining_params.keys()):
if param_name.startswith(key):
submodule_params += [remaining_params.pop(param_name)]

group_args += [{"params": submodule_params, "lr": self._lr[key]}]

remaining_params = [p for k, p in remaining_params.items()]
group_args += [{"params": remaining_params}]

opt = torch.optim.AdamW(
group_args,
lr=self._lr["default"] if model.lr is None else model.lr,
**self.opt_kwargs
)
sch = {"scheduler": sch, "monitor": "MAE/train"}
sch = {
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
opt,
factor=self.factor,
patience=self.patience,
threshold=self.threshold,
),
"monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/train"
}

return [opt], [sch]


def __call__(self, model):
"""Return optimizer"""
if not isinstance(self._lr, float):
return self._call_multi(model)
else:
assert False
default_lr = self._lr if model.lr is None else model.lr
opt = torch.optim.AdamW(model.parameters(), lr=default_lr, **self.opt_kwargs)
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
opt,
factor=self.factor,
patience=self.patience,
threshold=self.threshold,
)
sch = {
"scheduler": sch,
"monitor": "quantile_loss/train" if model.use_quantile_regression else "MAE/train",
}
return [opt], [sch]

0 comments on commit 049dcef

Please sign in to comment.