Skip to content

Commit

Permalink
Merge pull request #420 from ACEsuit/distrib-checkpoint-fix
Browse files Browse the repository at this point in the history
log errors and handle checkpoint io on rank 0 only
  • Loading branch information
ilyes319 authored May 14, 2024
2 parents 1cc39eb + 68ad31a commit 6d7b5ed
Showing 1 changed file with 39 additions and 37 deletions.
76 changes: 39 additions & 37 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,58 +209,60 @@ def train(
output_args=output_args,
device=device,
)
if rank == 0:
valid_err_log(
valid_loss,
eval_metrics,
logger,
log_errors,
epoch,
)

if log_wandb:
wandb_log_dict = {
"epoch": epoch,
"valid_loss": valid_loss,
"valid_rmse_e_per_atom": eval_metrics["rmse_e_per_atom"],
"valid_rmse_f": eval_metrics["rmse_f"],
}
wandb.log(wandb_log_dict)

if valid_loss >= lowest_loss:
patience_counter += 1
if patience_counter >= patience and epoch < swa.start:
logging.info(
f"Stopping optimization after {patience_counter} epochs without improvement and starting swa"
)
epoch = swa.start
elif patience_counter >= patience and epoch >= swa.start:
logging.info(
f"Stopping optimization after {patience_counter} epochs without improvement"
)
break
if save_all_checkpoints:
if log_wandb:
wandb_log_dict = {
"epoch": epoch,
"valid_loss": valid_loss,
"valid_rmse_e_per_atom": eval_metrics["rmse_e_per_atom"],
"valid_rmse_f": eval_metrics["rmse_f"],
}
wandb.log(wandb_log_dict)

if valid_loss >= lowest_loss:
patience_counter += 1
if patience_counter >= patience and epoch < swa.start:
logging.info(
f"Stopping optimization after {patience_counter} epochs without improvement and starting swa"
)
epoch = swa.start
elif patience_counter >= patience and epoch >= swa.start:
logging.info(
f"Stopping optimization after {patience_counter} epochs without improvement"
)
break
if save_all_checkpoints:
param_context = (
ema.average_parameters()
if ema is not None
else nullcontext()
)
with param_context:
checkpoint_handler.save(
state=CheckpointState(model, optimizer, lr_scheduler),
epochs=epoch,
keep_last=True,
)
else:
lowest_loss = valid_loss
patience_counter = 0
param_context = (
ema.average_parameters() if ema is not None else nullcontext()
)
with param_context:
checkpoint_handler.save(
state=CheckpointState(model, optimizer, lr_scheduler),
epochs=epoch,
keep_last=True,
keep_last=keep_last,
)
else:
lowest_loss = valid_loss
patience_counter = 0
param_context = (
ema.average_parameters() if ema is not None else nullcontext()
)
with param_context:
checkpoint_handler.save(
state=CheckpointState(model, optimizer, lr_scheduler),
epochs=epoch,
keep_last=keep_last,
)
keep_last = False or save_all_checkpoints
keep_last = False or save_all_checkpoints
if distributed:
torch.distributed.barrier()
epoch += 1
Expand Down

0 comments on commit 6d7b5ed

Please sign in to comment.