Skip to content

Commit

Permalink
Add SimpleLossSchedule
Browse files Browse the repository at this point in the history
  • Loading branch information
Linux-cpp-lisp committed May 1, 2024
1 parent 7fcd45d commit 9ba1d5f
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ Most recent change on the bottom.
- `include_file_as_baseline_config` for simple modifications of existing configs
- `nequip-deploy --using-dataset` to support data-dependent deployment steps
- Support for Gaussian Mixture Model uncertainty quantification (https://doi.org/10.1063/5.0136574)
- `start_of_epoch_callbacks`
- `nequip.train.callbacks.loss_schedule.SimpleLossSchedule` for changing the loss coefficients at specified epochs
- `nequip-deploy build --checkpoint` and `--override` to avoid many largely duplicated YAML files
- matscipy neighborlist support enabled with `NEQUIP_MATSCIPY_NL` environment variable

Expand Down
16 changes: 11 additions & 5 deletions configs/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,9 @@ early_stopping_upper_bounds:

# loss function
loss_coeffs: # different weights to use in a weighted loss functions
forces: 1 # if using PerAtomMSELoss, a default weight of 1:1 on each should work well
forces: 1.0 # if using PerAtomMSELoss, a default weight of 1:1 on each should work well
total_energy:
- 1
- 1.0
- PerAtomMSELoss
# note that the ratio between force and energy loss matters for the training process. One may consider using 1:1 with the PerAtomMSELoss. If the energy loss still significantly dominate the loss function at the initial epochs, tune the energy loss weight lower helps the training a lot.

Expand Down Expand Up @@ -249,6 +249,15 @@ loss_coeffs:
# - L1Loss
# forces: 1.0

# You can schedule changes in the loss coefficients using a callback:
# In the "schedule" key each entry is a two-element list of:
# - the 1-based epoch index at which to start the new loss coefficients
# - the new loss coefficients as a dict
#
# start_of_epoch_callbacks:
# - !!python/object:nequip.train.callbacks.loss_schedule.SimpleLossSchedule {"schedule": [[2, {"forces": 0.0, "total_energy": 1.0}]]}
#

# output metrics
metrics_components:
- - forces # key
Expand Down Expand Up @@ -371,6 +380,3 @@ global_rescale_scale_trainable: false
# per_species_rescale_shifts: null
# per_species_rescale_scales: null

# Options for e3nn's set_optimization_defaults. A dict:
# e3nn_optimization_defaults:
# explicit_backward: True
54 changes: 54 additions & 0 deletions nequip/train/callbacks/loss_schedule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Dict, List, Tuple
from dataclasses import dataclass
import numpy as np

from nequip.train import Trainer, Loss

# Making this a dataclass takes care of equality operators, handing restart consistency checks


@dataclass
class SimpleLossSchedule:
"""Schedule `loss_coeffs` through a training run.
To use this in a training, set in your YAML file:
start_of_epoch_callbacks:
- !!python/object:nequip.train.callbacks.loss_schedule.SimpleLossSchedule {"schedule": [[30, {"forces": 1.0, "total_energy": 0.0}], [30, {"forces": 0.0, "total_energy": 1.0}]]}
This funny syntax tells PyYAML to construct an object of this class.
Each entry in the schedule is a tuple of the 1-based epoch index to start that loss coefficient set at, and a dict of loss coefficients.
"""

schedule: List[Tuple[int, Dict[str, float]]] = None

def __call__(self, trainer: Trainer):
assert (
self in trainer._start_of_epoch_callbacks
), "must be start not end of epoch"
# user-facing 1 based indexing of epochs rather than internal zero based
iepoch: int = trainer.iepoch + 1
if iepoch < 1: # initial validation epoch is 0 in user-facing indexing
return
loss_function: Loss = trainer.loss

assert self.schedule is not None
schedule_start_epochs = np.asarray([e[0] for e in self.schedule])
# make sure they are ascending
assert len(schedule_start_epochs) >= 1
assert schedule_start_epochs[0] >= 2, "schedule must start at epoch 2 or later"
assert np.all(
(schedule_start_epochs[1:] - schedule_start_epochs[:-1]) > 0
), "schedule start epochs must be strictly ascending"
# we are running at _start_ of epoch, so we need to apply the right change for the current epoch
current_change_idex = np.searchsorted(schedule_start_epochs, iepoch + 1) - 1
# ^ searchsorted 3 in [2, 10, 19] would return 1, for example
# but searching 2 in [2, 10, 19] gives 0, so we actually search iepoch + 1 to always be ahead of the start
# apply the current change to handle restarts
if current_change_idex >= 0:
new_coeffs = self.schedule[current_change_idex][1]
assert (
loss_function.coeffs.keys() == new_coeffs.keys()
), "all coeff schedules must contain all loss terms"
loss_function.coeffs.update(new_coeffs)
6 changes: 6 additions & 0 deletions nequip/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def __init__(
val_idcs: Optional[list] = None,
train_val_split: str = "random",
init_callbacks: list = [],
start_of_epoch_callbacks: list = [],
end_of_epoch_callbacks: list = [],
end_of_batch_callbacks: list = [],
end_of_train_callbacks: list = [],
Expand Down Expand Up @@ -348,6 +349,9 @@ def __init__(

# load all callbacks
self._init_callbacks = [load_callable(callback) for callback in init_callbacks]
self._start_of_epoch_callbacks = [
load_callable(callback) for callback in start_of_epoch_callbacks
]
self._end_of_epoch_callbacks = [
load_callable(callback) for callback in end_of_epoch_callbacks
]
Expand Down Expand Up @@ -887,6 +891,8 @@ def reset_metrics(self):
self.metrics.to(self.torch_device)

def epoch_step(self):
for callback in self._start_of_epoch_callbacks:
callback(self)

dataloaders = {TRAIN: self.dl_train, VALIDATION: self.dl_val}
categories = [TRAIN, VALIDATION] if self.iepoch >= 0 else [VALIDATION]
Expand Down

0 comments on commit 9ba1d5f

Please sign in to comment.