Skip to content

Commit

Permalink
Fixed cifar 10 tests
Browse files Browse the repository at this point in the history
  • Loading branch information
k-w-w committed Nov 7, 2017
1 parent e5f88ad commit 807d6bd
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
24 changes: 15 additions & 9 deletions official/resnet/cifar10_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def get_filenames(is_training, data_dir):
return [os.path.join(data_dir, 'test_batch.bin')]


def parse_and_preprocess_record(raw_record, is_training):
"""Parse and preprocess a CIFAR-10 image and label from a raw record."""
def parse_record(raw_record):
"""Parse CIFAR-10 image and label from a raw record."""
# Every record consists of a label followed by the image, with a fixed number
# of bytes for each.
label_bytes = 1
Expand All @@ -120,12 +120,6 @@ def parse_and_preprocess_record(raw_record, is_training):
# float32.
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)

if is_training:
image = train_preprocess_fn(image)

# Subtract off the mean and divide by the variance of the pixels.
image = tf.image.per_image_standardization(image)

return image, tf.one_hot(label, _NUM_CLASSES)


Expand All @@ -143,6 +137,18 @@ def train_preprocess_fn(image):
return image


def parse_and_preprocess(record, is_training):
"""Parse and preprocess records in the CIFAR-10 dataset."""
image, label = parse_record(record)

if is_training:
image = train_preprocess_fn(image)

# Subtract off the mean and divide by the variance of the pixels.
image = tf.image.per_image_standardization(image)
return image, label


def input_fn(is_training, data_dir, batch_size, num_epochs=1):
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
Expand All @@ -163,7 +169,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)

dataset = dataset.map(
lambda record: parse_and_preprocess_record(record, is_training))
lambda record: parse_and_preprocess(record, is_training))
dataset = dataset.prefetch(2 * batch_size)

# We call repeat after shuffling, rather than before, to prevent separate
Expand Down
2 changes: 1 addition & 1 deletion official/resnet/cifar10_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_dataset_input_fn(self):
data_file.close()

fake_dataset = cifar10_main.record_dataset(filename)
fake_dataset = fake_dataset.map(cifar10_main.dataset_parser)
fake_dataset = fake_dataset.map(cifar10_main.parse_record)
image, label = fake_dataset.make_one_shot_iterator().get_next()

self.assertEqual(label.get_shape().as_list(), [10])
Expand Down

0 comments on commit 807d6bd

Please sign in to comment.