forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
helper_utils.py
149 lines (123 loc) · 5.25 KB
/
helper_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions used for training AutoAugment models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
def setup_loss(logits, labels):
"""Returns the cross entropy for the given `logits` and `labels`."""
predictions = tf.nn.softmax(logits)
cost = tf.losses.softmax_cross_entropy(onehot_labels=labels,
logits=logits)
return predictions, cost
def decay_weights(cost, weight_decay_rate):
"""Calculates the loss for l2 weight decay and adds it to `cost`."""
costs = []
for var in tf.trainable_variables():
costs.append(tf.nn.l2_loss(var))
cost += tf.multiply(weight_decay_rate, tf.add_n(costs))
return cost
def eval_child_model(session, model, data_loader, mode):
"""Evaluates `model` on held out data depending on `mode`.
Args:
session: TensorFlow session the model will be run with.
model: TensorFlow model that will be evaluated.
data_loader: DataSet object that contains data that `model` will
evaluate.
mode: Will `model` either evaluate validation or test data.
Returns:
Accuracy of `model` when evaluated on the specified dataset.
Raises:
ValueError: if invalid dataset `mode` is specified.
"""
if mode == 'val':
images = data_loader.val_images
labels = data_loader.val_labels
elif mode == 'test':
images = data_loader.test_images
labels = data_loader.test_labels
else:
raise ValueError('Not valid eval mode')
assert len(images) == len(labels)
tf.logging.info('model.batch_size is {}'.format(model.batch_size))
assert len(images) % model.batch_size == 0
eval_batches = int(len(images) / model.batch_size)
for i in range(eval_batches):
eval_images = images[i * model.batch_size:(i + 1) * model.batch_size]
eval_labels = labels[i * model.batch_size:(i + 1) * model.batch_size]
_ = session.run(
model.eval_op,
feed_dict={
model.images: eval_images,
model.labels: eval_labels,
})
return session.run(model.accuracy)
def cosine_lr(learning_rate, epoch, iteration, batches_per_epoch, total_epochs):
"""Cosine Learning rate.
Args:
learning_rate: Initial learning rate.
epoch: Current epoch we are one. This is one based.
iteration: Current batch in this epoch.
batches_per_epoch: Batches per epoch.
total_epochs: Total epochs you are training for.
Returns:
The learning rate to be used for this current batch.
"""
t_total = total_epochs * batches_per_epoch
t_cur = float(epoch * batches_per_epoch + iteration)
return 0.5 * learning_rate * (1 + np.cos(np.pi * t_cur / t_total))
def get_lr(curr_epoch, hparams, iteration=None):
"""Returns the learning rate during training based on the current epoch."""
assert iteration is not None
batches_per_epoch = int(hparams.train_size / hparams.batch_size)
lr = cosine_lr(hparams.lr, curr_epoch, iteration, batches_per_epoch,
hparams.num_epochs)
return lr
def run_epoch_training(session, model, data_loader, curr_epoch):
"""Runs one epoch of training for the model passed in.
Args:
session: TensorFlow session the model will be run with.
model: TensorFlow model that will be evaluated.
data_loader: DataSet object that contains data that `model` will
evaluate.
curr_epoch: How many of epochs of training have been done so far.
Returns:
The accuracy of 'model' on the training set
"""
steps_per_epoch = int(model.hparams.train_size / model.hparams.batch_size)
tf.logging.info('steps per epoch: {}'.format(steps_per_epoch))
curr_step = session.run(model.global_step)
assert curr_step % steps_per_epoch == 0
# Get the current learning rate for the model based on the current epoch
curr_lr = get_lr(curr_epoch, model.hparams, iteration=0)
tf.logging.info('lr of {} for epoch {}'.format(curr_lr, curr_epoch))
for step in xrange(steps_per_epoch):
curr_lr = get_lr(curr_epoch, model.hparams, iteration=(step + 1))
# Update the lr rate variable to the current LR.
model.lr_rate_ph.load(curr_lr, session=session)
if step % 20 == 0:
tf.logging.info('Training {}/{}'.format(step, steps_per_epoch))
train_images, train_labels = data_loader.next_batch()
_, step, _ = session.run(
[model.train_op, model.global_step, model.eval_op],
feed_dict={
model.images: train_images,
model.labels: train_labels,
})
train_accuracy = session.run(model.accuracy)
tf.logging.info('Train accuracy: {}'.format(train_accuracy))
return train_accuracy