Skip to content

Commit

Permalink
Cherrypick PR#2878 into r1.8 (#2880)
Browse files Browse the repository at this point in the history
  • Loading branch information
jysohn23 authored Apr 14, 2021
1 parent 66e6cd1 commit f2f8f44
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions test/test_train_mp_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
'--lr_scheduler_divisor': {
'type': int,
},
'--test_only_at_end': {
'action': 'store_true',
},
}

FLAGS = args_parse.parse_common_options(
Expand Down Expand Up @@ -236,15 +239,16 @@ def test_loop_fn(loader, epoch):
xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
train_loop_fn(train_device_loader, epoch)
xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))
accuracy = test_loop_fn(test_device_loader, epoch)
xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
epoch, test_utils.now(), accuracy))
max_accuracy = max(accuracy, max_accuracy)
test_utils.write_to_summary(
writer,
epoch,
dict_to_write={'Accuracy/test': accuracy},
write_xla_metrics=True)
if not FLAGS.test_only_at_end or epoch == FLAGS.num_epochs:
accuracy = test_loop_fn(test_device_loader, epoch)
xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
epoch, test_utils.now(), accuracy))
max_accuracy = max(accuracy, max_accuracy)
test_utils.write_to_summary(
writer,
epoch,
dict_to_write={'Accuracy/test': accuracy},
write_xla_metrics=True)
if FLAGS.metrics_debug:
xm.master_print(met.metrics_report())

Expand Down

0 comments on commit f2f8f44

Please sign in to comment.