diff --git a/gan_cifar.py b/gan_cifar.py index 1dca0e39..4d7073b0 100644 --- a/gan_cifar.py +++ b/gan_cifar.py @@ -16,13 +16,20 @@ import tflib.inception_score import tflib.plot -# Download CIFAR-10 (Python version) at -# https://www.cs.toronto.edu/~kriz/cifar.html and fill in the path to the -# extracted files here! -DATA_DIR = '' +DATA_DIR = '/home/catherio/data/cifar10/cifar-10-batches-py/' if len(DATA_DIR) == 0: - raise Exception('Please specify path to data directory in gan_cifar.py!') + raise Exception(''' +Please specify path to data directory in gan_cifar.py! +Download CIFAR-10 (Python version) at +https://www.cs.toronto.edu/~kriz/cifar.html and fill in the path to the +extracted files. + +> wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz +> tar -xvzf cifar-10-python.tar.gz +''') + +DATASET = 'cifar' # (experimental) 'cifar' or 'svhn' MODE = 'wgan-gp' # Valid options are dcgan, wgan, or wgan-gp DIM = 128 # This overfits substantially; you're probably better off with 64 LAMBDA = 10 # Gradient penalty lambda hyperparameter @@ -110,7 +117,7 @@ def Discriminator(inputs): clip_bounds = [-.01, .01] clip_ops.append( tf.assign( - var, + var, tf.clip_by_value(var, clip_bounds[0], clip_bounds[1]) ) ) @@ -123,7 +130,7 @@ def Discriminator(inputs): # Gradient penalty alpha = tf.random_uniform( - shape=[BATCH_SIZE,1], + shape=[BATCH_SIZE,1], minval=0., maxval=1. ) @@ -167,17 +174,37 @@ def get_inception_score(): all_samples = all_samples.reshape((-1, 3, 32, 32)).transpose(0,2,3,1) return lib.inception_score.get_inception_score(list(all_samples)) -# Dataset iterators -train_gen, dev_gen = lib.cifar10.load(BATCH_SIZE, data_dir=DATA_DIR) -def inf_train_gen(): - while True: - for images,_ in train_gen(): - yield images # Train loop with tf.Session() as session: + # Dataset iterators + if DATASET == 'cifar': + train_gen, dev_gen = lib.cifar10.load(BATCH_SIZE, data_dir=DATA_DIR) + def inf_gen(g): + while True: + for ims,_ in g: + yield ims + train_gen = inf_gen(train_gen()) + dev_gen = inf_gen(dev_gen()) + elif DATASET == 'svhn': + from ganskill.svhn_data import input_fn + train_gen_tf = (input_fn(is_training=True, + batch_size=BATCH_SIZE) + .make_one_shot_iterator()) + dev_gen_tf = (input_fn(is_training=False, + batch_size=BATCH_SIZE) + .make_one_shot_iterator()) + + def dataset_iter(g_tf): + next_im, _ = g_tf.get_next() + reshape = tf.reshape(next_im, [-1, OUTPUT_DIM]) + while True: + yield session.run(reshape) + train_gen = dataset_iter(train_gen_tf) + dev_gen = dataset_iter(dev_gen_tf) + session.run(tf.initialize_all_variables()) - gen = inf_train_gen() + for iteration in xrange(ITERS): start_time = time.time() @@ -190,7 +217,7 @@ def inf_train_gen(): else: disc_iters = CRITIC_ITERS for i in xrange(disc_iters): - _data = gen.next() + _data = train_gen.next() _disc_cost, _ = session.run([disc_cost, disc_train_op], feed_dict={real_data_int: _data}) if MODE == 'wgan': _ = session.run(clip_disc_weights) @@ -206,8 +233,8 @@ def inf_train_gen(): # Calculate dev loss and generate samples every 100 iters if iteration % 100 == 99: dev_disc_costs = [] - for images,_ in dev_gen(): - _dev_disc_cost = session.run(disc_cost, feed_dict={real_data_int: images}) + for images in dev_gen: + _dev_disc_cost = session.run(disc_cost, feed_dict={real_data_int: images}) dev_disc_costs.append(_dev_disc_cost) lib.plot.plot('dev disc cost', np.mean(dev_disc_costs)) generate_image(iteration, _data) diff --git a/tflib/inception_score.py b/tflib/inception_score.py index 38e805d5..66293ec3 100644 --- a/tflib/inception_score.py +++ b/tflib/inception_score.py @@ -20,7 +20,7 @@ DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' softmax = None -# Call this function with list of images. Each of elements should be a +# Call this function with list of images. Each of elements should be a # numpy array with values ranging from 0 to 255. def get_inception_score(images, splits=10): assert(type(images) == list) @@ -32,7 +32,7 @@ def get_inception_score(images, splits=10): for img in images: img = img.astype(np.float32) inps.append(np.expand_dims(img, 0)) - bs = 100 + bs = 1 with tf.Session() as sess: preds = [] n_batches = int(math.ceil(float(len(inps)) / float(bs))) @@ -88,9 +88,9 @@ def _progress(count, block_size, total_size): new_shape.append(None) else: new_shape.append(s) - o._shape = tf.TensorShape(new_shape) + o.set_shape(tf.TensorShape(new_shape)) w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] - logits = tf.matmul(tf.squeeze(pool3), w) + logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w) softmax = tf.nn.softmax(logits) if softmax is None: