-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 887b9d5
Showing
28 changed files
with
3,547 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
.DS_Store | ||
pull | ||
push | ||
experiment_tools.py | ||
scribe |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# PixelVAE | ||
|
||
Code for the models in [PixelVAE: A Latent Variable Model for Natural Images](https://arxiv.org/abs/1611.05013) | ||
|
||
## MNIST | ||
|
||
To train: | ||
|
||
``` | ||
python models/mnist_pixelvae_train.py -L 12 -fs 5 -algo cond_z_bias -dpx 16 -ldim 16 | ||
``` | ||
|
||
To evaluate, take the weights of the model with best validation score from the above training procedure and then run | ||
|
||
``` | ||
python models/mnist_pixelvae_evaluate.py -L 12 -fs 5 -algo cond_z_bias -dpx 16 -ldim 16 -w path/to/weights.pkl | ||
``` | ||
|
||
## Other datasets | ||
|
||
To train, evaluate, and generate samples: | ||
|
||
``` | ||
python pixelvae.py | ||
``` | ||
|
||
By default, this runs on real-valued MNIST. You can pecify different datasets or model settings within `pixelvae.py`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import numpy | ||
import theano | ||
import theano.tensor as T | ||
|
||
import cPickle as pickle | ||
import math | ||
import time | ||
import locale | ||
|
||
locale.setlocale(locale.LC_ALL, '') | ||
|
||
_params = {} | ||
def param(name, *args, **kwargs): | ||
""" | ||
A wrapper for `theano.shared` which enables parameter sharing in models. | ||
Creates and returns theano shared variables similarly to `theano.shared`, | ||
except if you try to create a param with the same name as a | ||
previously-created one, `param(...)` will just return the old one instead of | ||
making a new one. | ||
This constructor also adds a `param` attribute to the shared variables it | ||
creates, so that you can easily search a graph for all params. | ||
""" | ||
|
||
if name not in _params: | ||
kwargs['name'] = name | ||
param = theano.shared(*args, **kwargs) | ||
param.param = True | ||
_params[name] = param | ||
return _params[name] | ||
|
||
def delete_params_with_name(name): | ||
to_delete = [p_name for p_name in _params if name in p_name] | ||
for p_name in to_delete: | ||
del _params[p_name] | ||
|
||
def delete_all_params(): | ||
to_delete = [p_name for p_name in _params] | ||
for p_name in to_delete: | ||
del _params[p_name] | ||
|
||
def save_params(path): | ||
param_vals = {} | ||
for name, param in _params.iteritems(): | ||
param_vals[name] = param.get_value() | ||
# print name | ||
|
||
with open(path, 'wb') as f: | ||
pickle.dump(param_vals, f) | ||
|
||
def load_params(path): | ||
with open(path, 'rb') as f: | ||
param_vals = pickle.load(f) | ||
|
||
for name, val in param_vals.iteritems(): | ||
_params[name].set_value(val) | ||
# print name | ||
|
||
def search(node, critereon): | ||
""" | ||
Traverse the Theano graph starting at `node` and return a list of all nodes | ||
which match the `critereon` function. When optimizing a cost function, you | ||
can use this to get a list of all of the trainable params in the graph, like | ||
so: | ||
`lib.search(cost, lambda x: hasattr(x, "param"))` | ||
""" | ||
|
||
def _search(node, critereon, visited): | ||
if node in visited: | ||
return [] | ||
visited.add(node) | ||
|
||
results = [] | ||
if isinstance(node, T.Apply): | ||
for inp in node.inputs: | ||
results += _search(inp, critereon, visited) | ||
else: # Variable node | ||
if critereon(node): | ||
results.append(node) | ||
if node.owner is not None: | ||
results += _search(node.owner, critereon, visited) | ||
return results | ||
|
||
return _search(node, critereon, set()) | ||
|
||
def floatX(x): | ||
""" | ||
Convert `x` to the numpy type specified in `theano.config.floatX`. | ||
""" | ||
if theano.config.floatX == 'float16': | ||
return numpy.float16(x) | ||
elif theano.config.floatX == 'float32': | ||
return numpy.float32(x) | ||
else: # Theano's default float type is float64 | ||
print "Warning: lib.floatX using float64" | ||
return numpy.float64(x) | ||
|
||
def print_params_info(params): | ||
"""Print information about the parameters in the given param set.""" | ||
|
||
params = sorted(params, key=lambda p: p.name) | ||
values = [p.get_value(borrow=True) for p in params] | ||
shapes = [p.shape for p in values] | ||
print "Params for cost:" | ||
for param, value, shape in zip(params, values, shapes): | ||
print "\t{0} ({1})".format( | ||
param.name, | ||
",".join([str(x) for x in shape]) | ||
) | ||
|
||
total_param_count = 0 | ||
for shape in shapes: | ||
param_count = 1 | ||
for dim in shape: | ||
param_count *= dim | ||
total_param_count += param_count | ||
print "Total parameter count: {0}".format( | ||
locale.format("%d", total_param_count, grouping=True) | ||
) | ||
|
||
def print_model_settings(locals_): | ||
print "Model settings:" | ||
all_vars = [(k,v) for (k,v) in locals_.items() if (k.isupper() and k!='T')] | ||
all_vars = sorted(all_vars, key=lambda x: x[0]) | ||
for var_name, var_value in all_vars: | ||
print "\t{}: {}".format(var_name, var_value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import numpy as np | ||
import theano | ||
from theano import gof | ||
|
||
class DebugOp(gof.Op): | ||
def __init__(self, name, fn): | ||
super(DebugOp, self).__init__() | ||
self._name = name | ||
self._fn = fn | ||
|
||
def make_node(self, x): | ||
return gof.Apply(self, [x], [x.type()]) | ||
|
||
def perform(self, node, inputs, output_storage): | ||
self._fn(self._name, inputs[0]) | ||
output_storage[0][0] = np.copy(inputs[0]) | ||
|
||
def grad(self, inputs, output_gradients): | ||
return [DebugOp(self._name+'.grad', self._fn)(output_gradients[0])] | ||
|
||
def print_shape(name, x): | ||
def fn(_name, _x): | ||
print "{} shape: {}".format(_name, _x.shape) | ||
return DebugOp(name, fn)(x) | ||
|
||
def print_stats(name, x): | ||
return x | ||
def fn(_name, _x): | ||
mean = np.mean(_x) | ||
std = np.std(_x) | ||
percentiles = np.percentile(_x, [0,25,50,75,100]) | ||
# percentiles = "skipping" | ||
print "{}\tmean:{}\tstd:{}\tpercentiles:{}\t".format(_name, mean, std, percentiles) | ||
return DebugOp(name, fn)(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from fuel.datasets import BinarizedMNIST | ||
import numpy as np | ||
|
||
from fuel.datasets import BinarizedMNIST | ||
from fuel.schemes import ShuffledScheme, SequentialScheme | ||
from fuel.streams import DataStream | ||
# from fuel.transformers.image import RandomFixedSizeCrop | ||
|
||
def _make_stream(stream, bs): | ||
def new_stream(): | ||
result = np.empty((bs, 1, 28, 28), dtype = 'float32') | ||
for (imb,) in stream.get_epoch_iterator(): | ||
for i, img in enumerate(imb): | ||
result[i] = img | ||
yield (result,) | ||
return new_stream | ||
|
||
def load(batch_size, test_batch_size): | ||
tr_data = BinarizedMNIST(which_sets=('train',)) | ||
val_data = BinarizedMNIST(which_sets=('valid',)) | ||
test_data = BinarizedMNIST(which_sets=('test',)) | ||
|
||
ntrain = tr_data.num_examples | ||
nval = val_data.num_examples | ||
ntest = test_data.num_examples | ||
|
||
tr_scheme = ShuffledScheme(examples=ntrain, batch_size=batch_size) | ||
tr_stream = DataStream(tr_data, iteration_scheme=tr_scheme) | ||
|
||
te_scheme = SequentialScheme(examples=ntest, batch_size=test_batch_size) | ||
te_stream = DataStream(test_data, iteration_scheme=te_scheme) | ||
|
||
val_scheme = SequentialScheme(examples=nval, batch_size=batch_size) | ||
val_stream = DataStream(val_data, iteration_scheme=val_scheme) | ||
|
||
return _make_stream(tr_stream, batch_size), \ | ||
_make_stream(val_stream, batch_size), \ | ||
_make_stream(te_stream, test_batch_size) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import lib | ||
import lib.debug | ||
|
||
import numpy as np | ||
import theano | ||
import theano.tensor as T | ||
|
||
_default_weightnorm = False | ||
def enable_default_weightnorm(): | ||
global _default_weightnorm | ||
_default_weightnorm = True | ||
|
||
def Conv2D(name, input_dim, output_dim, filter_size, inputs, he_init=True, mask_type=None, mode = 'half', stride=1, weightnorm=None, biases=True): | ||
""" | ||
inputs: tensor of shape (batch size, num channels, height, width) | ||
mask_type: one of None, 'a', 'b', 'hstack_a', 'hstack', 'vstack' | ||
returns: tensor of shape (batch size, num channels, height, width) | ||
""" | ||
if mask_type is not None: | ||
mask_type, mask_n_channels = mask_type | ||
assert(mode == "half") | ||
|
||
if isinstance(filter_size, int): | ||
filter_size = (filter_size, filter_size) | ||
|
||
#else it is assumed to be a tuple | ||
|
||
def uniform(stdev, size): | ||
return np.random.uniform( | ||
low=-stdev * np.sqrt(3), | ||
high=stdev * np.sqrt(3), | ||
size=size | ||
).astype(theano.config.floatX) | ||
|
||
fan_in = input_dim * filter_size[0]*filter_size[1] | ||
fan_out = output_dim * filter_size[0]*filter_size[1] | ||
# TODO: shouldn't fan_out be divided by stride | ||
|
||
|
||
if mask_type is not None: # only approximately correct | ||
fan_in /= 2. | ||
fan_out /= 2. | ||
|
||
if he_init: | ||
filters_stdev = np.sqrt(4./(fan_in+fan_out)) | ||
else: # Normalized init (Glorot & Bengio) | ||
filters_stdev = np.sqrt(2./(fan_in+fan_out)) | ||
|
||
filter_values = uniform( | ||
filters_stdev, | ||
(output_dim, input_dim, filter_size[0], filter_size[1]) | ||
) | ||
|
||
filters = lib.param(name+'.Filters', filter_values) | ||
|
||
if weightnorm==None: | ||
weightnorm = _default_weightnorm | ||
if weightnorm: | ||
norm_values = np.linalg.norm(filter_values.reshape((filter_values.shape[0], -1)), axis=1) | ||
norms = lib.param( | ||
name + '.g', | ||
norm_values | ||
) | ||
filters = filters * (norms / filters.reshape((filters.shape[0],-1)).norm(2, axis=1)).dimshuffle(0,'x','x','x') | ||
|
||
if mask_type is not None: | ||
mask = np.ones( | ||
(output_dim, input_dim, filter_size[0], filter_size[1]), | ||
dtype=theano.config.floatX | ||
) | ||
center_row = filter_size[0] // 2 | ||
|
||
center_col = filter_size[1]//2 | ||
|
||
# Mask out future locations | ||
# filter shape is (out_channels, in_channels, height, width) | ||
if center_row == 0: | ||
mask[:, :, :, center_col+1:] = 0. | ||
elif center_col == 0: | ||
mask[:, :, center_row+1:, :] = 0. | ||
else: | ||
mask[:, :, center_row+1:, :] = 0. | ||
mask[:, :, center_row, center_col+1:] = 0. | ||
|
||
# Mask out future channels | ||
for i in xrange(mask_n_channels): | ||
for j in xrange(mask_n_channels): | ||
if ((mask_type=='a' or mask_type == 'hstack_a') and i >= j) or (mask_type=='b' and i > j): | ||
mask[ | ||
j::mask_n_channels, | ||
i::mask_n_channels, | ||
center_row, | ||
center_col | ||
] = 0. | ||
|
||
if mask_type == 'vstack': | ||
assert(center_col > 0 and center_row > 0) | ||
mask[:, :, center_row, :] = 1. | ||
|
||
# print mask[0,0,:,:] | ||
|
||
|
||
filters = filters * mask | ||
|
||
if biases: | ||
_biases = lib.param( | ||
name+'.Biases', | ||
np.zeros(output_dim, dtype=theano.config.floatX) | ||
) | ||
|
||
result = T.nnet.conv2d( | ||
inputs, | ||
filters, | ||
border_mode=mode, | ||
filter_flip=False, | ||
subsample=(stride,stride) | ||
) | ||
|
||
if biases: | ||
result = result + _biases[None, :, None, None] | ||
# result = lib.debug.print_stats(name, result) | ||
return result |
Oops, something went wrong.