From 16a11d4c78be044e612a05d51a0e228a2dbe16ec Mon Sep 17 00:00:00 2001 From: Dushyant Rao Date: Tue, 25 Feb 2020 16:50:29 +0000 Subject: [PATCH] Fix omniglot case to reflect tfds interface. PiperOrigin-RevId: 297126345 --- curl/training.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/curl/training.py b/curl/training.py index 37cddbdb..0cce039b 100644 --- a/curl/training.py +++ b/curl/training.py @@ -193,14 +193,14 @@ def preprocess_data(x): name=dataset, split=tfds.Split.VALIDATION, **dataset_kwargs) num_valid_examples = ds_info.splits[tfds.Split.VALIDATION].num_examples assert (num_valid_examples % - test_batch_size == 0), ('test_batch_size must be a multiple of %d' % + test_batch_size == 0), ('test_batch_size must be a divisor of %d' % num_valid_examples) valid_dataset = valid_dataset.repeat(1).batch( test_batch_size, drop_remainder=True) valid_dataset = valid_dataset.map(preprocess_data) valid_iter = valid_dataset.make_initializable_iterator() valid_data = valid_iter.get_next() - except KeyError: + except (KeyError, ValueError): logging.warning('No validation set!!') valid_iter = None valid_data = None @@ -210,7 +210,7 @@ def preprocess_data(x): name=dataset, split=tfds.Split.TEST, **dataset_kwargs) num_test_examples = ds_info.splits['test'].num_examples assert (num_test_examples % - test_batch_size == 0), ('test_batch_size must be a multiple of %d' % + test_batch_size == 0), ('test_batch_size must be a divisor of %d' % num_test_examples) test_dataset = test_dataset.repeat(1).batch( test_batch_size, drop_remainder=True) @@ -542,8 +542,8 @@ def run_training( label_key = 'label' elif dataset == 'omniglot': batch_size = 15 - test_batch_size = 8115 - dataset_kwargs = {'split': 'instance', 'label': 'alphabet'} + test_batch_size = 1318 + dataset_kwargs = {} image_key = 'image' label_key = 'alphabet' else: