-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Justin Tan (unimelb)
committed
Jan 8, 2018
1 parent
64b6cc7
commit d15d1ff
Showing
11 changed files
with
147 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
#!/usr/bin/python3 | ||
import tensorflow as tf | ||
import numpy as np | ||
import pandas as pd | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
#!/usr/bin/python3 | ||
import tensorflow as tf | ||
import numpy as np | ||
import pandas as pd | ||
import time, os, sys | ||
import argparse | ||
import horovod.tensorflow as hvd | ||
|
||
# User-defined | ||
from network import Network | ||
from diagnostics import Diagnostics | ||
from data import Data | ||
from model import Model | ||
from config import config_train, directories | ||
|
||
tf.logging.set_verbosity(tf.logging.ERROR) | ||
|
||
def train(config, architecture, args): | ||
|
||
print('Architecture: {}'.format(architecture)) | ||
start_time = time.time() | ||
global_step, n_checkpoints, v_acc_best = 0, 0, 0. | ||
ckpt = tf.train.get_checkpoint_state(directories.checkpoints) | ||
|
||
if args.name=='cifar100': | ||
config.n_classes = 100 | ||
config.L = args.langevin_iterations | ||
|
||
# Build graph | ||
cnn = Model(config, directories, name=args.name, optimizer=args.optimizer) | ||
saver = tf.train.Saver() | ||
|
||
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: | ||
sess.run(tf.global_variables_initializer()) | ||
sess.run(tf.local_variables_initializer()) | ||
train_handle = sess.run(cnn.train_iterator.string_handle()) | ||
test_handle = sess.run(cnn.test_iterator.string_handle()) | ||
|
||
if args.restore_last and ckpt.model_checkpoint_path: | ||
# Continue training saved model | ||
saver.restore(sess, ckpt.model_checkpoint_path) | ||
print('{} restored.'.format(ckpt.model_checkpoint_path)) | ||
else: | ||
if args.restore_path: | ||
new_saver = tf.train.import_meta_graph('{}.meta'.format(args.restore_path)) | ||
new_saver.restore(sess, args.restore_path) | ||
print('{} restored.'.format(args.restore_path)) | ||
|
||
sess.run(cnn.test_iterator.initializer) | ||
|
||
for epoch in range(config.num_epochs): | ||
sess.run(cnn.train_iterator.initializer) | ||
# Run diagnostics | ||
v_acc_best = Diagnostics.run_diagnostics(cnn, config_train, directories, sess, saver, train_handle, | ||
test_handle, start_time, v_acc_best, epoch, args.name) | ||
while True: | ||
try: | ||
# Run SGLD iterations | ||
if args.optimizer=='entropy-sgd': | ||
for l in range(config.L): | ||
sess.run([cnn.sgld_op], feed_dict={cnn.training_phase: True, cnn.handle: train_handle}) | ||
|
||
# Update weights | ||
sess.run([cnn.train_op, cnn.update_accuracy], feed_dict={cnn.training_phase: True, | ||
cnn.handle: train_handle}) | ||
|
||
except tf.errors.OutOfRangeError: | ||
print('End of epoch!') | ||
break | ||
|
||
except KeyboardInterrupt: | ||
save_path = saver.save(sess, os.path.join(directories.checkpoints, | ||
'cnn_{}_last.ckpt'.format(args.name)), global_step=epoch) | ||
print('Interrupted, model saved to: ', save_path) | ||
sys.exit() | ||
|
||
|
||
save_path = saver.save(sess, os.path.join(directories.checkpoints, | ||
'cnn_{}_end.ckpt'.format(args.name)), | ||
global_step=epoch) | ||
|
||
print("Training Complete. Model saved to file: {} Time elapsed: {:.3f} s".format(save_path, time.time()-start_time)) | ||
|
||
def main(**kwargs): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-rl", "--restore_last", help="restore last saved model", action="store_true") | ||
parser.add_argument("-r", "--restore_path", help="path to model to be restored", type=str) | ||
parser.add_argument("-opt", "--optimizer", default="entropy-sgd", help="Selected optimizer", type=str) | ||
parser.add_argument("-n", "--name", default="entropy-sgd", help="Checkpoint/Tensorboard label") | ||
parser.add_argument("-d", "--dataset", default="cifar10", help="Dataset to train on (cifar10 || cifar100)", type=str) | ||
parser.add_argument("-L", "--langevin_iterations", default=0, help="Number of Langevin iterations in inner loop.", | ||
type=int) | ||
args = parser.parse_args() | ||
config = config_train | ||
|
||
architecture = 'Layers: {} | Conv dropout: {} | Base LR: {} | SGLD Iterations {} | Epochs: {} | Optimizer: {}'.format( | ||
config.n_layers, | ||
config.conv_keep_prob, | ||
config.learning_rate, | ||
config.L, | ||
config.num_epochs, | ||
args.optimizer | ||
) | ||
|
||
Diagnostics.setup_dataset(args.dataset) | ||
|
||
hvd.init() | ||
|
||
|
||
# Launch training | ||
train(config_train, architecture, args) | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters