Skip to content

Commit

Permalink
Merge trainer revamp
Browse files Browse the repository at this point in the history
  • Loading branch information
Linux-cpp-lisp authored Jul 10, 2024
2 parents d3a7763 + 145e138 commit d18dd2a
Show file tree
Hide file tree
Showing 11 changed files with 270 additions and 77 deletions.
11 changes: 10 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
Most recent change on the bottom.


## Unreleased
## Unreleased - 0.7.0
### Added
- `--override` now supported as a `nequip-train` flag (similar to its use in `nequip-deploy`)
- add SoftAdapt (https://arxiv.org/abs/2403.18122) callback option

### Changed
- [Breaking] training restart behavior altered: file-wise consistency checks performed between original config and config passed to `nequip-train` on restart (instead of checking the config dicts)
- [Breaking] config format for callbacks changed (see `configs/full.yaml` for an example)

### Fixed
- fixed `wandb_watch` bug

## [0.6.1] - 2024-7-9
### Added
Expand Down
15 changes: 11 additions & 4 deletions configs/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,17 @@ loss_coeffs:
# In the "schedule" key each entry is a two-element list of:
# - the 1-based epoch index at which to start the new loss coefficients
# - the new loss coefficients as a dict
#
# start_of_epoch_callbacks:
# - !!python/object:nequip.train.callbacks.loss_schedule.SimpleLossSchedule {"schedule": [[2, {"forces": 0.0, "total_energy": 1.0}]]}
#
# callbacks:
# start_of_epoch:
# - !!python/object:nequip.train.callbacks.SimpleLossSchedule {"schedule": [[2, {"forces": 0.0, "total_energy": 1.0}]]}

# You can also try using the SoftAdapt strategy for adaptively changing loss coefficients
# (see https://arxiv.org/abs/2403.18122)
#callbacks:
# end_of_batch:
# - !!python/object:nequip.train.callbacks.SoftAdapt {"batches_per_update": 5, "beta": 1.1}



# output metrics
metrics_components:
Expand Down
100 changes: 73 additions & 27 deletions nequip/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import logging
import argparse
import warnings
import shutil
import difflib
import yaml

# This is a weird hack to avoid Intel MKL issues on the cluster when this is called as a subprocess of a process that has itself initialized PyTorch.
# Since numpy gets imported later anyway for dataset stuff, this shouldn't affect performance.
Expand All @@ -29,6 +32,8 @@
root="./",
tensorboard=False,
wandb=False,
wandb_watch=False,
wandb_watch_kwargs={},
model_builders=[
"SimpleIrrepsConfig",
"EnergyModel",
Expand All @@ -46,7 +51,7 @@
equivariance_test=False,
grad_anomaly_mode=False,
gpu_oom_offload=False,
append=False,
append=True,
warn_unused=False,
_jit_bailout_depth=2, # avoid 20 iters of pain, see https://github.com/pytorch/pytorch/issues/52286
# Quote from eelison in PyTorch slack:
Expand All @@ -68,32 +73,61 @@


def main(args=None, running_as_script: bool = True):
config = parse_command_line(args)
config, path_to_config, override_options = parse_command_line(args)

if running_as_script:
set_up_script_logger(config.get("log", None), config.verbose)

found_restart_file = exists(f"{config.root}/{config.run_name}/trainer.pth")
train_dir = f"{config.root}/{config.run_name}"
found_restart_file = exists(f"{train_dir}/trainer.pth")
if found_restart_file and not config.append:
raise RuntimeError(
f"Training instance exists at {config.root}/{config.run_name}; "
f"Training instance exists at {train_dir}; "
"either set append to True or use a different root or runname"
)
elif not found_restart_file and isdir(f"{config.root}/{config.run_name}"):
elif not found_restart_file and isdir(train_dir):
# output directory exists but no ``trainer.pth`` file, suggesting previous run crash during
# first training epoch (usually due to memory):
warnings.warn(
f"Previous run folder at {config.root}/{config.run_name} exists, but a saved model "
f"Previous run folder at {train_dir} exists, but a saved model "
f"(trainer.pth file) was not found. This folder will be cleared and a fresh training run will "
f"be started."
)
rmtree(f"{config.root}/{config.run_name}")
rmtree(train_dir)

# for fresh new train
if not found_restart_file:
if not found_restart_file: # fresh start
# update config with override parameters for setting up train-dir
config.update(override_options)
trainer = fresh_start(config)
else:
trainer = restart(config)
# copy original config to training directory
shutil.copyfile(path_to_config, f"{train_dir}/original_config.yaml")
else: # restart
# perform string matching for original config and restart config
# throw error if they are different
with (
open(f"{train_dir}/original_config.yaml") as orig_f,
open(path_to_config) as current_f,
):
diffs = [
x
for x in difflib.Differ().compare(
orig_f.readlines(), current_f.readlines()
)
if x[0] in ("+", "-")
]
if diffs:
raise RuntimeError(
f"Config {path_to_config} used for restart differs from original config for training run in {train_dir}.\n"
+ "The following differences were found:\n\n"
+ "".join(diffs)
+ "\n"
+ "If you intend to override the original config parameters, use the --override flag. For example, use\n"
+ f'`nequip-train {path_to_config} --override "max_epochs: 42"`\n'
+ 'on the command line to override the config parameter "max_epochs"\n'
+ "BE WARNED that use of the --override flag is not protected by consistency checks performed by NequIP."
)
else:
trainer = restart(config, override_options)

# Train
trainer.save()
Expand Down Expand Up @@ -157,6 +191,12 @@ def parse_command_line(args=None):
help="Warn instead of error when the config contains unused keys",
action="store_true",
)
parser.add_argument(
"--override",
help="Override top-level configuration keys from the `--train-dir`/`--model`'s config YAML file. This should be a valid YAML string. Unless you know why you need to, do not use this option.",
type=str,
default=None,
)
args = parser.parse_args(args=args)

config = Config.from_file(args.config, defaults=default_config)
Expand All @@ -169,10 +209,26 @@ def parse_command_line(args=None):
):
config[flag] = getattr(args, flag) or config[flag]

return config
# Set override options before _set_global_options so that things like allow_tf32 are correctly handled
if args.override is not None:
override_options = yaml.load(args.override, Loader=yaml.Loader)
assert isinstance(
override_options, dict
), "--override's YAML string must define a dictionary of top-level options"
overridden_keys = set(config.keys()).intersection(override_options.keys())
set_keys = set(override_options.keys()) - set(overridden_keys)
logging.info(
f"--override: overrode keys {list(overridden_keys)} and set new keys {list(set_keys)}"
)
del overridden_keys, set_keys
else:
override_options = {}

return config, args.config, override_options


def fresh_start(config):

# we use add_to_config cause it's a fresh start and need to record it
check_code_version(config, add_to_config=True)
_set_global_options(config)
Expand Down Expand Up @@ -267,7 +323,7 @@ def _unused_check():
return trainer


def restart(config):
def restart(config, override_options):
# load the dictionary
restart_file = f"{config.root}/{config.run_name}/trainer.pth"
dictionary = load_file(
Expand All @@ -276,20 +332,6 @@ def restart(config):
enforced_format="torch",
)

# compare dictionary to config and update stop condition related arguments
for k in config.keys():
if config[k] != dictionary.get(k, ""):
if k == "max_epochs":
dictionary[k] = config[k]
logging.info(f'Update "{k}" to {dictionary[k]}')
elif k.startswith("early_stop"):
dictionary[k] = config[k]
logging.info(f'Update "{k}" to {dictionary[k]}')
elif isinstance(config[k], type(dictionary.get(k, ""))):
raise ValueError(
f'Key "{k}" is different in config and the result trainer.pth file. Please double check'
)

# note, "trainer.pth"/dictionary also store code versions,
# which will not be stored in config and thus not checked here
check_code_version(config)
Expand All @@ -299,6 +341,10 @@ def restart(config):

config = Config(dictionary, exclude_keys=["state_dict", "progress"])

# override configs loaded from save
dictionary.update(override_options)
config.update(override_options)

# dtype, etc.
_set_global_options(config)

Expand Down
49 changes: 49 additions & 0 deletions nequip/train/callback_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from nequip.utils import load_callable
import dataclasses


class CallbackManager:
"""Parent callback class
Centralized object to manage various callbacks that can be added-on.
"""

def __init__(
self,
callbacks={},
):
CALLBACK_TYPES = [
"init",
"start_of_epoch",
"end_of_epoch",
"end_of_batch",
"end_of_train",
"final",
]
# load all callbacks
self.callbacks = {callback_type: [] for callback_type in CALLBACK_TYPES}

for callback_type in callbacks:
if callback_type not in CALLBACK_TYPES:
raise ValueError(
f"{callback_type} is not a supported callback type.\nSupported callback types include "
+ str(CALLBACK_TYPES)
)
# make sure callbacks are either dataclasses or functions
for callback in callbacks[callback_type]:
if not (dataclasses.is_dataclass(callback) or callable(callback)):
raise ValueError(
f"Callbacks must be of type dataclass or callable. Error found on the callback {callback} of type {callback_type}"
)
self.callbacks[callback_type].append(load_callable(callback))

def apply(self, trainer, callback_type: str):

for callback in self.callbacks.get(callback_type):
callback(trainer)

def state_dict(self):
return {"callback_manager_obj_callbacks": self.callbacks}

def load_state_dict(self, state_dict):
self.callbacks = state_dict.get("callback_manager_obj_callbacks")
4 changes: 4 additions & 0 deletions nequip/train/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .adaptive_loss_weights import SoftAdapt
from .loss_schedule import SimpleLossSchedule

__all__ = [SoftAdapt, SimpleLossSchedule]
78 changes: 78 additions & 0 deletions nequip/train/callbacks/adaptive_loss_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from dataclasses import dataclass

from nequip.train import Trainer

from nequip.train._key import ABBREV
import torch

# Making this a dataclass takes care of equality operators, handing restart consistency checks


@dataclass
class SoftAdapt:
"""Adaptively modify `loss_coeffs` through a training run using the SoftAdapt scheme (https://arxiv.org/abs/2403.18122)
To use this in a training, set in your YAML file:
end_of_batch_callbacks:
- !!python/object:nequip.train.callbacks.adaptive_loss_weights.SoftAdapt {"batches_per_update": 20, "beta": 1.0}
This funny syntax tells PyYAML to construct an object of this class.
Main hyperparameters are:
- how often the loss weights are updated, `batches_per_update`
- how sensitive the new loss weights are to the change in loss components, `beta`
"""

# user-facing parameters
batches_per_update: int = None
beta: float = None
eps: float = 1e-8 # small epsilon to avoid division by zero
# attributes for internal tracking
batch_counter: int = -1
prev_losses: torch.Tensor = None
cached_weights = None

def __call__(self, trainer: Trainer):

# --- CORRECTNESS CHECKS ---
assert self in trainer.callback_manager.callbacks["end_of_batch"]
assert self.batches_per_update >= 1

# track batch number
self.batch_counter += 1

# empty list of cached weights to store for next cycle
if self.batch_counter % self.batches_per_update == 0:
self.cached_weights = []

# --- MAIN LOGIC THAT RUNS EVERY EPOCH ---

# collect loss for each training target
losses = []
for key in trainer.loss.coeffs.keys():
losses.append(trainer.batch_losses[f"loss_{ABBREV.get(key)}"])
new_losses = torch.tensor(losses)

# compute and cache new loss weights over the update cycle
if self.prev_losses is None:
self.prev_losses = new_losses
return
else:
# compute normalized loss change
loss_change = new_losses - self.prev_losses
loss_change = torch.nn.functional.normalize(
loss_change, dim=0, eps=self.eps
)
self.prev_losses = new_losses
# compute new weights with softmax
exps = torch.exp(self.beta * loss_change)
self.cached_weights.append(exps.div(exps.sum() + self.eps))

# --- average weights over previous cycle and update ---
if self.batch_counter % self.batches_per_update == 1:
softadapt_weights = torch.stack(self.cached_weights, dim=-1).mean(-1)
counter = 0
for key in trainer.loss.coeffs.keys():
trainer.loss.coeffs.update({key: softadapt_weights[counter]})
counter += 1
3 changes: 2 additions & 1 deletion nequip/train/callbacks/loss_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ class SimpleLossSchedule:

def __call__(self, trainer: Trainer):
assert (
self in trainer._start_of_epoch_callbacks
self in trainer.callback_manager.callbacks["start_of_epoch"]
), "must be start not end of epoch"
# user-facing 1 based indexing of epochs rather than internal zero based

iepoch: int = trainer.iepoch + 1
if iepoch < 1: # initial validation epoch is 0 in user-facing indexing
return
Expand Down
9 changes: 9 additions & 0 deletions nequip/train/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,15 @@ def __call__(self, pred: dict, ref: dict):

return loss, contrib

def state_dict(self):
# verbose key names to avoid repetition/clashes
dictionary = {"loss_obj_coeffs": self.coeffs}
return dictionary

def load_state_dict(self, state_dict):
# only need to save/load loss weights (or coefficients)
self.coeffs = state_dict.get("loss_obj_coeffs")


class LossStat:
"""
Expand Down
Loading

0 comments on commit d18dd2a

Please sign in to comment.