Skip to content

Commit

Permalink
When printing validation loss during training, pass loss for head, no…
Browse files Browse the repository at this point in the history
…t partial sum over heads calculated so far
  • Loading branch information
bernstei committed Jun 6, 2024
1 parent a92dba2 commit 9616a41
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def train(
)
valid_loss += valid_loss_head
valid_err_log(
valid_loss, eval_metrics, logger, log_errors, None, valid_loader_name
valid_loss_head, eval_metrics, logger, log_errors, None, valid_loader_name
)

while epoch < max_num_epochs:
Expand Down Expand Up @@ -224,7 +224,7 @@ def train(
)
valid_loss += valid_loss_head
valid_err_log(
valid_loss,
valid_loss_head,
eval_metrics,
logger,
log_errors,
Expand Down

0 comments on commit 9616a41

Please sign in to comment.