Skip to content

Commit

Permalink
Fix omniglot case to reflect tfds interface.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 297126345
  • Loading branch information
drao2 authored and diegolascasas committed Feb 25, 2020
1 parent d0efbec commit 16a11d4
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions curl/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,14 @@ def preprocess_data(x):
name=dataset, split=tfds.Split.VALIDATION, **dataset_kwargs)
num_valid_examples = ds_info.splits[tfds.Split.VALIDATION].num_examples
assert (num_valid_examples %
test_batch_size == 0), ('test_batch_size must be a multiple of %d' %
test_batch_size == 0), ('test_batch_size must be a divisor of %d' %
num_valid_examples)
valid_dataset = valid_dataset.repeat(1).batch(
test_batch_size, drop_remainder=True)
valid_dataset = valid_dataset.map(preprocess_data)
valid_iter = valid_dataset.make_initializable_iterator()
valid_data = valid_iter.get_next()
except KeyError:
except (KeyError, ValueError):
logging.warning('No validation set!!')
valid_iter = None
valid_data = None
Expand All @@ -210,7 +210,7 @@ def preprocess_data(x):
name=dataset, split=tfds.Split.TEST, **dataset_kwargs)
num_test_examples = ds_info.splits['test'].num_examples
assert (num_test_examples %
test_batch_size == 0), ('test_batch_size must be a multiple of %d' %
test_batch_size == 0), ('test_batch_size must be a divisor of %d' %
num_test_examples)
test_dataset = test_dataset.repeat(1).batch(
test_batch_size, drop_remainder=True)
Expand Down Expand Up @@ -542,8 +542,8 @@ def run_training(
label_key = 'label'
elif dataset == 'omniglot':
batch_size = 15
test_batch_size = 8115
dataset_kwargs = {'split': 'instance', 'label': 'alphabet'}
test_batch_size = 1318
dataset_kwargs = {}
image_key = 'image'
label_key = 'alphabet'
else:
Expand Down

0 comments on commit 16a11d4

Please sign in to comment.