Skip to content

Commit

Permalink
momentum in sgld
Browse files Browse the repository at this point in the history
  • Loading branch information
Justin Tan (unimelb) committed Jan 8, 2018
1 parent 64b6cc7 commit d15d1ff
Show file tree
Hide file tree
Showing 11 changed files with 147 additions and 19 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Coming soon...
Both CIFAR-10/CIFAR-100 models are trained with the same hyperparameters and learning rate schedule specified in the original paper. The dataset is subjected to meanstd preprocessing and random rotations+reflections. Convergence when training on both datasets is compared with vanilla SGD and SGD with Nesterov momentum. The accuracy reported is the average of 5 runs with random weight initialization.

Models trained without entropy-SGD are run for 200 epochs, models trained with entropy-SGD are run with L=20 for 10
epochs, with the hyperparameters specified under arXiv 1611.01838, 5.3.
epochs, with the hyperparameters specified as in the CIFAR-10 run in the original paper.

### CIFAR-10
```
Expand Down
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ class config_train(object):
mode = 'beta'
n_layers = 5
num_epochs = 512
batch_size = 256
batch_size = 128
ema_decay = 0.999
learning_rate = 1e-3
n_classes = 10
Expand Down
1 change: 0 additions & 1 deletion data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#!/usr/bin/python3
import tensorflow as tf
import numpy as np
import pandas as pd
Expand Down
8 changes: 4 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def __init__(self, config, directories, single_infer=False, name='', optimizer='
self.path = tf.placeholder(paths.dtype)
self.example = Data.preprocess_inference(self.path)

self.logits = Network.wrn(self.example, config, self.training_phase)
graph = ResNet(config, self.training_phase)
self.logits = Network.wrn(self.example, config, self.training_phase)
# self.logits = graph.wrn(self.example)

self.pred = tf.argmax(self.logits, 1)
Expand All @@ -61,10 +61,10 @@ def __init__(self, config, directories, single_infer=False, name='', optimizer='

if optimizer=='entropy-sgd':
epoch_bounds = [4, 8, 12, 16, 20, 24]
lr_values = [1.0, 2e-1, 4e-2, 8e-3, 1.6e-3, 3.2e-4]
lr_values = [1.0, 2e-1, 4e-2, 8e-3, 1.6e-3, 3.2e-4, 6.4e-5]
else:
epoch_bounds = [60, 120, 160, 200, 220, 240]
lr_values = [1e-1, 2e-2, 4e-3, 8e-4, 1.6e-4, 3.2e-5]
lr_values = [1e-1, 2e-2, 4e-3, 8e-4, 1.6e-4, 3.2e-5, 6.4e-6]

learning_rate = tf.train.piecewise_constant(self.global_step, boundaries=[s*config.steps_per_epoch for s in
epoch_bounds], values=lr_values)
Expand All @@ -77,7 +77,7 @@ def __init__(self, config, directories, single_infer=False, name='', optimizer='
# Ensures that we execute the update_ops before performing the train_step
if optimizer=='entropy-sgd':
opt = EntropySGD(self.iterator, self.training_phase, self.sgld_global_step,
config={'lr':learning_rate, 'gamma':gamma, 'lr_prime':0.1})
config={'lr':learning_rate, 'gamma':gamma, 'lr_prime':0.1, 'momentum':0.9})
self.sgld_op = opt.sgld_opt.minimize(self.cost, global_step=self.sgld_global_step)
self.opt_op = opt.minimize(self.cost, global_step=self.global_step)
elif optimizer=='adam':
Expand Down
11 changes: 8 additions & 3 deletions optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, iterator, training_phase, sgld_global_step, config={},

self._learning_rate = config['lr']
self._gamma = config['gamma']
self._momentum = config['momentum']

# Scalar parameter tensors
self._lr_tensor = None
Expand All @@ -55,7 +56,7 @@ def _prepare(self):
name="learning_rate")
self._gamma_tensor = ops.convert_to_tensor(self._gamma,
name="gamma")
self._momentum_tensor = ops.convert_to_tensor(self.config['momentum'],
self._momentum_tensor = ops.convert_to_tensor(self._momentum,
name="momentum")

def _create_slots(self, var_list):
Expand All @@ -64,26 +65,30 @@ def _create_slots(self, var_list):
for v in var_list:
wc = self._zeros_slot(v, "wc", self._name)
mu = self._zeros_slot(v, "mu", self._name)
mv = self._zeros_slot(v, "mv", self._name)

def _apply_dense(self, grad, var):
# Apply weight updates
lr_t = math_ops.cast(self._lr_tensor, var.dtype.base_dtype)
gamma_t = math_ops.cast(self._gamma_tensor, var.dtype.base_dtype)
momentum_t = math_ops.cast(self._momentum_tensor, var.dtype.base_dtype)

wc = self.get_slot(var, "wc")
mu = self.get_slot(var, "mu")
mv = self.get_slot(var, "mv")

wc_t = wc.assign(self.sgld_opt.get_slot(var, "wc"))
mu_t = mu.assign(self.sgld_opt.get_slot(var, "mu"))
mv_t = mv.assign(momentum_t*mv + (var-mu_t))

# Reset weights to pre-SGLD state, then execute update
var_reset = state_ops.assign(var, wc_t)

with tf.control_dependencies([var_reset]):
# var_update = state_ops.assign_sub(var, lr_t*gamma_t*(var-mu_t))
var_update = state_ops.assign_sub(var, lr_t*(var-mu_t))
var_update = state_ops.assign_sub(var, lr_t*(var-mu_t)+lr_t*((var-mu_t)+momentum_t*mv_t))

return control_flow_ops.group(*[var_update, mu_t, wc_t])
return control_flow_ops.group(*[var_update, mu_t, wc_t, mv_t])

def _apply_sparse(self, grad, var_list):
raise NotImplementedError("Optimizer does not yet support sparse gradient updates.")
20 changes: 15 additions & 5 deletions sgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def _create_slots(self, var_list):
wc = self._zeros_slot(v, "wc", self._name)
xp = self._zeros_slot(v, "xp", self._name)
mu = self._zeros_slot(v, "mu", self._name)
mv = self._zeros_slot(v, "mv", self._name)

def _apply_dense(self, grad, var):
# Updates dummy weights during SGLD
Expand All @@ -76,6 +77,7 @@ def _apply_dense(self, grad, var):
wc = self.get_slot(var, 'wc')
xp = self.get_slot(var, 'xp')
mu = self.get_slot(var, 'mu')
mv = self.get_slot(var, 'mv')

wc_t = tf.cond(tf.logical_not(tf.cast(tf.mod(self.sgld_global_step, self._L_t), tf.bool)),
lambda: wc.assign(var),
Expand All @@ -85,13 +87,21 @@ def _apply_dense(self, grad, var):
eta_t = math_ops.cast(eta, var.dtype.base_dtype)

# update = -lr_prime_t*(grad-gamma_t*(wc-var)) + tf.sqrt(lr_prime)*epsilon_t*eta_t
xp_t = xp.assign(var-lr_prime_t*(grad-gamma_t*(wc-var))+tf.sqrt(lr_prime_t)*epsilon_t*eta_t)
mu_t = mu.assign((1.0-alpha_t)*mu + alpha_t*xp)
mv_t = mv.assign(momentum_t*mv + grad)

# Nesterov's momentum enabled by default
if self._momentum > 0:
xp_t = xp.assign(var-lr_prime_t*(grad-gamma_t*(wc-var))+tf.sqrt(lr_prime_t)*epsilon_t*eta_t-lr_prime_t*(grad+momentum_t*mv_t))
var_update = state_ops.assign_sub(var,
lr_prime_t*(grad-gamma_t*(wc-var))-tf.sqrt(lr_prime_t)*epsilon_t*eta_t+lr_prime_t*(grad+momentum_t*mv_t))
else:
xp_t = xp.assign(var-lr_prime_t*(grad-gamma_t*(wc-var))+tf.sqrt(lr_prime_t)*epsilon_t*eta_t)
var_update = state_ops.assign_sub(var,
lr_prime_t*(grad-gamma_t*(wc-var))-tf.sqrt(lr_prime_t)*epsilon_t*eta_t)

var_update = state_ops.assign_sub(var,
lr_prime_t*(grad-gamma_t*(wc-var))-tf.sqrt(lr_prime_t)*epsilon_t*eta_t)
mu_t = mu.assign((1.0-alpha_t)*mu + alpha_t*xp)

return control_flow_ops.group(*[var_update, wc_t, xp_t, mu_t])
return control_flow_ops.group(*[var_update, mv_t, wc_t, xp_t, mu_t])

def _apply_sparse(self, grad, var_list):
raise NotImplementedError("Optimizer does not yet support sparse gradient updates.")
Binary file added tfrecords/cifar10/res.pdf
Binary file not shown.
Binary file added tfrecords/cifar10/res.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def train(config, architecture, args):
while True:
try:
# Run SGLD iterations
# if optimizer=='entropy-sgd':
for l in range(config.L):
sess.run([cnn.sgld_op], feed_dict={cnn.training_phase: True, cnn.handle: train_handle})
if args.optimizer=='entropy-sgd':
for l in range(config.L):
sess.run([cnn.sgld_op], feed_dict={cnn.training_phase: True, cnn.handle: train_handle})

# Update weights
sess.run([cnn.train_op, cnn.update_accuracy], feed_dict={cnn.training_phase: True,
Expand Down
114 changes: 114 additions & 0 deletions train_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#!/usr/bin/python3
import tensorflow as tf
import numpy as np
import pandas as pd
import time, os, sys
import argparse
import horovod.tensorflow as hvd

# User-defined
from network import Network
from diagnostics import Diagnostics
from data import Data
from model import Model
from config import config_train, directories

tf.logging.set_verbosity(tf.logging.ERROR)

def train(config, architecture, args):

print('Architecture: {}'.format(architecture))
start_time = time.time()
global_step, n_checkpoints, v_acc_best = 0, 0, 0.
ckpt = tf.train.get_checkpoint_state(directories.checkpoints)

if args.name=='cifar100':
config.n_classes = 100
config.L = args.langevin_iterations

# Build graph
cnn = Model(config, directories, name=args.name, optimizer=args.optimizer)
saver = tf.train.Saver()

with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
train_handle = sess.run(cnn.train_iterator.string_handle())
test_handle = sess.run(cnn.test_iterator.string_handle())

if args.restore_last and ckpt.model_checkpoint_path:
# Continue training saved model
saver.restore(sess, ckpt.model_checkpoint_path)
print('{} restored.'.format(ckpt.model_checkpoint_path))
else:
if args.restore_path:
new_saver = tf.train.import_meta_graph('{}.meta'.format(args.restore_path))
new_saver.restore(sess, args.restore_path)
print('{} restored.'.format(args.restore_path))

sess.run(cnn.test_iterator.initializer)

for epoch in range(config.num_epochs):
sess.run(cnn.train_iterator.initializer)
# Run diagnostics
v_acc_best = Diagnostics.run_diagnostics(cnn, config_train, directories, sess, saver, train_handle,
test_handle, start_time, v_acc_best, epoch, args.name)
while True:
try:
# Run SGLD iterations
if args.optimizer=='entropy-sgd':
for l in range(config.L):
sess.run([cnn.sgld_op], feed_dict={cnn.training_phase: True, cnn.handle: train_handle})

# Update weights
sess.run([cnn.train_op, cnn.update_accuracy], feed_dict={cnn.training_phase: True,
cnn.handle: train_handle})

except tf.errors.OutOfRangeError:
print('End of epoch!')
break

except KeyboardInterrupt:
save_path = saver.save(sess, os.path.join(directories.checkpoints,
'cnn_{}_last.ckpt'.format(args.name)), global_step=epoch)
print('Interrupted, model saved to: ', save_path)
sys.exit()


save_path = saver.save(sess, os.path.join(directories.checkpoints,
'cnn_{}_end.ckpt'.format(args.name)),
global_step=epoch)

print("Training Complete. Model saved to file: {} Time elapsed: {:.3f} s".format(save_path, time.time()-start_time))

def main(**kwargs):
parser = argparse.ArgumentParser()
parser.add_argument("-rl", "--restore_last", help="restore last saved model", action="store_true")
parser.add_argument("-r", "--restore_path", help="path to model to be restored", type=str)
parser.add_argument("-opt", "--optimizer", default="entropy-sgd", help="Selected optimizer", type=str)
parser.add_argument("-n", "--name", default="entropy-sgd", help="Checkpoint/Tensorboard label")
parser.add_argument("-d", "--dataset", default="cifar10", help="Dataset to train on (cifar10 || cifar100)", type=str)
parser.add_argument("-L", "--langevin_iterations", default=0, help="Number of Langevin iterations in inner loop.",
type=int)
args = parser.parse_args()
config = config_train

architecture = 'Layers: {} | Conv dropout: {} | Base LR: {} | SGLD Iterations {} | Epochs: {} | Optimizer: {}'.format(
config.n_layers,
config.conv_keep_prob,
config.learning_rate,
config.L,
config.num_epochs,
args.optimizer
)

Diagnostics.setup_dataset(args.dataset)

hvd.init()


# Launch training
train(config_train, architecture, args)

if __name__ == '__main__':
main()
2 changes: 1 addition & 1 deletion wrn_momentum.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
module load CUDA/8.0.44
export LD_LIBRARY_PATH=/data/projects/punim0011/cuda/lib64:$LD_LIBRARY_PATH

python3 train.py -opt momentum -n wrn_momentum_28_10
python3 train.py -opt momentum -n wrn_p2810

0 comments on commit d15d1ff

Please sign in to comment.