diff --git a/mace/tools/train.py b/mace/tools/train.py index 32231acf..441df428 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -161,7 +161,7 @@ def train( ) valid_loss += valid_loss_head valid_err_log( - valid_loss, eval_metrics, logger, log_errors, None, valid_loader_name + valid_loss_head, eval_metrics, logger, log_errors, None, valid_loader_name ) while epoch < max_num_epochs: @@ -224,7 +224,7 @@ def train( ) valid_loss += valid_loss_head valid_err_log( - valid_loss, + valid_loss_head, eval_metrics, logger, log_errors,