Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two bugfixes, improved err msg, and whitespace removal #63

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 44 additions & 17 deletions gan_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,20 @@
import tflib.inception_score
import tflib.plot

# Download CIFAR-10 (Python version) at
# https://www.cs.toronto.edu/~kriz/cifar.html and fill in the path to the
# extracted files here!
DATA_DIR = ''
DATA_DIR = '/home/catherio/data/cifar10/cifar-10-batches-py/'
if len(DATA_DIR) == 0:
raise Exception('Please specify path to data directory in gan_cifar.py!')
raise Exception('''
Please specify path to data directory in gan_cifar.py!

Download CIFAR-10 (Python version) at
https://www.cs.toronto.edu/~kriz/cifar.html and fill in the path to the
extracted files.

> wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
> tar -xvzf cifar-10-python.tar.gz
''')

DATASET = 'cifar' # (experimental) 'cifar' or 'svhn'
MODE = 'wgan-gp' # Valid options are dcgan, wgan, or wgan-gp
DIM = 128 # This overfits substantially; you're probably better off with 64
LAMBDA = 10 # Gradient penalty lambda hyperparameter
Expand Down Expand Up @@ -110,7 +117,7 @@ def Discriminator(inputs):
clip_bounds = [-.01, .01]
clip_ops.append(
tf.assign(
var,
var,
tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])
)
)
Expand All @@ -123,7 +130,7 @@ def Discriminator(inputs):

# Gradient penalty
alpha = tf.random_uniform(
shape=[BATCH_SIZE,1],
shape=[BATCH_SIZE,1],
minval=0.,
maxval=1.
)
Expand Down Expand Up @@ -167,17 +174,37 @@ def get_inception_score():
all_samples = all_samples.reshape((-1, 3, 32, 32)).transpose(0,2,3,1)
return lib.inception_score.get_inception_score(list(all_samples))

# Dataset iterators
train_gen, dev_gen = lib.cifar10.load(BATCH_SIZE, data_dir=DATA_DIR)
def inf_train_gen():
while True:
for images,_ in train_gen():
yield images

# Train loop
with tf.Session() as session:
# Dataset iterators
if DATASET == 'cifar':
train_gen, dev_gen = lib.cifar10.load(BATCH_SIZE, data_dir=DATA_DIR)
def inf_gen(g):
while True:
for ims,_ in g:
yield ims
train_gen = inf_gen(train_gen())
dev_gen = inf_gen(dev_gen())
elif DATASET == 'svhn':
from ganskill.svhn_data import input_fn
train_gen_tf = (input_fn(is_training=True,
batch_size=BATCH_SIZE)
.make_one_shot_iterator())
dev_gen_tf = (input_fn(is_training=False,
batch_size=BATCH_SIZE)
.make_one_shot_iterator())

def dataset_iter(g_tf):
next_im, _ = g_tf.get_next()
reshape = tf.reshape(next_im, [-1, OUTPUT_DIM])
while True:
yield session.run(reshape)
train_gen = dataset_iter(train_gen_tf)
dev_gen = dataset_iter(dev_gen_tf)

session.run(tf.initialize_all_variables())
gen = inf_train_gen()


for iteration in xrange(ITERS):
start_time = time.time()
Expand All @@ -190,7 +217,7 @@ def inf_train_gen():
else:
disc_iters = CRITIC_ITERS
for i in xrange(disc_iters):
_data = gen.next()
_data = train_gen.next()
_disc_cost, _ = session.run([disc_cost, disc_train_op], feed_dict={real_data_int: _data})
if MODE == 'wgan':
_ = session.run(clip_disc_weights)
Expand All @@ -206,8 +233,8 @@ def inf_train_gen():
# Calculate dev loss and generate samples every 100 iters
if iteration % 100 == 99:
dev_disc_costs = []
for images,_ in dev_gen():
_dev_disc_cost = session.run(disc_cost, feed_dict={real_data_int: images})
for images in dev_gen:
_dev_disc_cost = session.run(disc_cost, feed_dict={real_data_int: images})
dev_disc_costs.append(_dev_disc_cost)
lib.plot.plot('dev disc cost', np.mean(dev_disc_costs))
generate_image(iteration, _data)
Expand Down
8 changes: 4 additions & 4 deletions tflib/inception_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
softmax = None

# Call this function with list of images. Each of elements should be a
# Call this function with list of images. Each of elements should be a
# numpy array with values ranging from 0 to 255.
def get_inception_score(images, splits=10):
assert(type(images) == list)
Expand All @@ -32,7 +32,7 @@ def get_inception_score(images, splits=10):
for img in images:
img = img.astype(np.float32)
inps.append(np.expand_dims(img, 0))
bs = 100
bs = 1
with tf.Session() as sess:
preds = []
n_batches = int(math.ceil(float(len(inps)) / float(bs)))
Expand Down Expand Up @@ -88,9 +88,9 @@ def _progress(count, block_size, total_size):
new_shape.append(None)
else:
new_shape.append(s)
o._shape = tf.TensorShape(new_shape)
o.set_shape(tf.TensorShape(new_shape))
w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
logits = tf.matmul(tf.squeeze(pool3), w)
logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w)
softmax = tf.nn.softmax(logits)

if softmax is None:
Expand Down