-
Notifications
You must be signed in to change notification settings - Fork 142
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
270 additions
and
77 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from nequip.utils import load_callable | ||
import dataclasses | ||
|
||
|
||
class CallbackManager: | ||
"""Parent callback class | ||
Centralized object to manage various callbacks that can be added-on. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
callbacks={}, | ||
): | ||
CALLBACK_TYPES = [ | ||
"init", | ||
"start_of_epoch", | ||
"end_of_epoch", | ||
"end_of_batch", | ||
"end_of_train", | ||
"final", | ||
] | ||
# load all callbacks | ||
self.callbacks = {callback_type: [] for callback_type in CALLBACK_TYPES} | ||
|
||
for callback_type in callbacks: | ||
if callback_type not in CALLBACK_TYPES: | ||
raise ValueError( | ||
f"{callback_type} is not a supported callback type.\nSupported callback types include " | ||
+ str(CALLBACK_TYPES) | ||
) | ||
# make sure callbacks are either dataclasses or functions | ||
for callback in callbacks[callback_type]: | ||
if not (dataclasses.is_dataclass(callback) or callable(callback)): | ||
raise ValueError( | ||
f"Callbacks must be of type dataclass or callable. Error found on the callback {callback} of type {callback_type}" | ||
) | ||
self.callbacks[callback_type].append(load_callable(callback)) | ||
|
||
def apply(self, trainer, callback_type: str): | ||
|
||
for callback in self.callbacks.get(callback_type): | ||
callback(trainer) | ||
|
||
def state_dict(self): | ||
return {"callback_manager_obj_callbacks": self.callbacks} | ||
|
||
def load_state_dict(self, state_dict): | ||
self.callbacks = state_dict.get("callback_manager_obj_callbacks") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .adaptive_loss_weights import SoftAdapt | ||
from .loss_schedule import SimpleLossSchedule | ||
|
||
__all__ = [SoftAdapt, SimpleLossSchedule] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from dataclasses import dataclass | ||
|
||
from nequip.train import Trainer | ||
|
||
from nequip.train._key import ABBREV | ||
import torch | ||
|
||
# Making this a dataclass takes care of equality operators, handing restart consistency checks | ||
|
||
|
||
@dataclass | ||
class SoftAdapt: | ||
"""Adaptively modify `loss_coeffs` through a training run using the SoftAdapt scheme (https://arxiv.org/abs/2403.18122) | ||
To use this in a training, set in your YAML file: | ||
end_of_batch_callbacks: | ||
- !!python/object:nequip.train.callbacks.adaptive_loss_weights.SoftAdapt {"batches_per_update": 20, "beta": 1.0} | ||
This funny syntax tells PyYAML to construct an object of this class. | ||
Main hyperparameters are: | ||
- how often the loss weights are updated, `batches_per_update` | ||
- how sensitive the new loss weights are to the change in loss components, `beta` | ||
""" | ||
|
||
# user-facing parameters | ||
batches_per_update: int = None | ||
beta: float = None | ||
eps: float = 1e-8 # small epsilon to avoid division by zero | ||
# attributes for internal tracking | ||
batch_counter: int = -1 | ||
prev_losses: torch.Tensor = None | ||
cached_weights = None | ||
|
||
def __call__(self, trainer: Trainer): | ||
|
||
# --- CORRECTNESS CHECKS --- | ||
assert self in trainer.callback_manager.callbacks["end_of_batch"] | ||
assert self.batches_per_update >= 1 | ||
|
||
# track batch number | ||
self.batch_counter += 1 | ||
|
||
# empty list of cached weights to store for next cycle | ||
if self.batch_counter % self.batches_per_update == 0: | ||
self.cached_weights = [] | ||
|
||
# --- MAIN LOGIC THAT RUNS EVERY EPOCH --- | ||
|
||
# collect loss for each training target | ||
losses = [] | ||
for key in trainer.loss.coeffs.keys(): | ||
losses.append(trainer.batch_losses[f"loss_{ABBREV.get(key)}"]) | ||
new_losses = torch.tensor(losses) | ||
|
||
# compute and cache new loss weights over the update cycle | ||
if self.prev_losses is None: | ||
self.prev_losses = new_losses | ||
return | ||
else: | ||
# compute normalized loss change | ||
loss_change = new_losses - self.prev_losses | ||
loss_change = torch.nn.functional.normalize( | ||
loss_change, dim=0, eps=self.eps | ||
) | ||
self.prev_losses = new_losses | ||
# compute new weights with softmax | ||
exps = torch.exp(self.beta * loss_change) | ||
self.cached_weights.append(exps.div(exps.sum() + self.eps)) | ||
|
||
# --- average weights over previous cycle and update --- | ||
if self.batch_counter % self.batches_per_update == 1: | ||
softadapt_weights = torch.stack(self.cached_weights, dim=-1).mean(-1) | ||
counter = 0 | ||
for key in trainer.loss.coeffs.keys(): | ||
trainer.loss.coeffs.update({key: softadapt_weights[counter]}) | ||
counter += 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.