diff --git a/train.py b/train.py index 00f2b4f..01a45a4 100644 --- a/train.py +++ b/train.py @@ -204,12 +204,12 @@ def test_step(images, targets): test_metric.reset_states() # training loop - for images, targets, true in train_ds: - train_step(images, targets, true) + for images, targets in train_ds: + train_step(images, targets) # test loop - for test_images, test_labels, test_true in test_ds: - test_step(test_images, test_labels, test_true) + for test_images, test_targets in test_ds: + test_step(test_images, test_targets) # log losses and metrics with summary_writer.as_default():