From 5f34f9c1e0557b2463de5478392399a48812ca8b Mon Sep 17 00:00:00 2001 From: Jonathan So Date: Tue, 10 Jul 2018 20:43:46 +0100 Subject: [PATCH 1/4] Update to use latest autograd (65c21e2). --- experiments/gmm_svae_synth.py | 2 +- svae/distributions/mniw.py | 6 +++--- svae/distributions/niw.py | 1 - svae/lds/lds_inference.py | 8 ++++---- svae/models/slds_svae.py | 4 ++-- svae/nnet.py | 10 +++++----- svae/svae.py | 2 +- svae/util.py | 6 +++--- 8 files changed, 19 insertions(+), 20 deletions(-) diff --git a/experiments/gmm_svae_synth.py b/experiments/gmm_svae_synth.py index 28432c4..393b95d 100644 --- a/experiments/gmm_svae_synth.py +++ b/experiments/gmm_svae_synth.py @@ -2,7 +2,7 @@ import matplotlib.pyplot as plt import autograd.numpy as np import autograd.numpy.random as npr -from autograd.optimizers import adam, sgd +from autograd.misc.optimizers import adam, sgd from svae.svae import make_gradfun from svae.nnet import init_gresnet, make_loglike, gaussian_mean, gaussian_info from svae.models.gmm import (run_inference, init_pgm_param, make_encoder_decoder, diff --git a/svae/distributions/mniw.py b/svae/distributions/mniw.py index 2f673a2..fc1e589 100644 --- a/svae/distributions/mniw.py +++ b/svae/distributions/mniw.py @@ -4,7 +4,7 @@ from autograd.scipy.special import multigammaln, digamma from autograd.scipy.linalg import solve_triangular from autograd import grad -from autograd.util import make_tuple +from autograd.builtins import tuple as tuple_ from scipy.stats import chi2 from svae.util import symmetrize @@ -51,8 +51,8 @@ def expectedstats_standard(nu, S, M, K, fudge=1e-8): assert is_posdef(E_Sigmainv) assert is_posdef(E_AT_Sigmainv_A) - return make_tuple( - -1./2*E_AT_Sigmainv_A, E_Sigmainv_A.T, -1./2*E_Sigmainv, 1./2*E_logdetSigmainv) + return tuple_(( + -1./2*E_AT_Sigmainv_A, E_Sigmainv_A.T, -1./2*E_Sigmainv, 1./2*E_logdetSigmainv)) def expectedstats_autograd(natparam): return grad(logZ)(natparam) diff --git a/svae/distributions/niw.py b/svae/distributions/niw.py index 6bede2a..f387215 100644 --- a/svae/distributions/niw.py +++ b/svae/distributions/niw.py @@ -2,7 +2,6 @@ import autograd.numpy as np from autograd.scipy.special import multigammaln, digamma from autograd import grad -from autograd.util import make_tuple from svae.util import symmetrize, outer from gaussian import pack_dense, unpack_dense diff --git a/svae/lds/lds_inference.py b/svae/lds/lds_inference.py index 7285d43..404d87e 100644 --- a/svae/lds/lds_inference.py +++ b/svae/lds/lds_inference.py @@ -3,7 +3,7 @@ import autograd.numpy.random as npr from autograd import grad from autograd.convenience_wrappers import grad_and_aux as agrad, value_and_grad as vgrad -from autograd.util import make_tuple +from autograd.builtins import tuple as tuple_ from autograd.core import primitive, primitive_with_aux from operator import itemgetter, attrgetter @@ -135,12 +135,12 @@ def unit(filtered_message): J, h = filtered_message mu, Sigma = natural_to_mean(filtered_message) ExxT = Sigma + np.outer(mu, mu) - return make_tuple(J, h, mu), [(mu, ExxT, 0.)] + return tuple_((J, h, mu)), [(mu, ExxT, 0.)] def bind(result, step): next_smooth, stats = result J, h, (mu, ExxT, ExxnT) = step(next_smooth) - return make_tuple(J, h, mu), [(mu, ExxT, ExxnT)] + stats + return tuple_((J, h, mu)), [(mu, ExxT, ExxnT)] + stats rts = lambda next_pred, filtered, pair_param: lambda next_smooth: \ natural_rts_backward_step(next_smooth, next_pred, filtered, pair_param) @@ -211,7 +211,7 @@ def lds_log_normalizer(all_natparams): init_params, pair_params, node_params) return lognorm, (lognorm, forward_messages) - all_natparams = make_tuple(init_params, pair_params, node_params) + all_natparams = tuple_((init_params, pair_params, node_params)) expected_stats, (lognorm, forward_messages) = agrad(lds_log_normalizer)(all_natparams) samples = natural_sample_backward_general(forward_messages, pair_params, num_samples) diff --git a/svae/models/slds_svae.py b/svae/models/slds_svae.py index 047bf9c..ee0d601 100644 --- a/svae/models/slds_svae.py +++ b/svae/models/slds_svae.py @@ -2,7 +2,7 @@ import autograd.numpy as np import autograd.numpy.random as npr from autograd import grad -from autograd.util import make_tuple +from autograd.builtins import tuple as tuple_ from functools import partial import sys @@ -150,7 +150,7 @@ def get_arhmm_local_nodeparams(lds_global_natparam, lds_expected_stats): def get_hmm_vlb(lds_global_natparam, hmm_local_natparam, lds_expected_stats): init_params, pair_params, _ = hmm_local_natparam node_params = get_arhmm_local_nodeparams(lds_global_natparam, lds_expected_stats) - local_natparam = make_tuple(init_params, pair_params, node_params) + local_natparam = tuple_((init_params, pair_params, node_params)) return hmm_logZ(local_natparam) diff --git a/svae/nnet.py b/svae/nnet.py index 00752a2..e234a1d 100644 --- a/svae/nnet.py +++ b/svae/nnet.py @@ -1,7 +1,7 @@ from __future__ import division import autograd.numpy as np import autograd.numpy.random as npr -from autograd.util import make_tuple +from autograd.builtins import tuple as tuple_ from toolz import curry from collections import defaultdict @@ -38,13 +38,13 @@ def gaussian_mean(inputs, sigmoid_mean=False): mu_input, sigmasq_input = np.split(inputs, 2, axis=-1) mu = sigmoid(mu_input) if sigmoid_mean else mu_input sigmasq = log1pexp(sigmasq_input) - return make_tuple(mu, sigmasq) + return tuple_((mu, sigmasq)) @curry def gaussian_info(inputs): J_input, h = np.split(inputs, 2, axis=-1) J = -1./2 * log1pexp(J_input) - return make_tuple(J, h) + return tuple_((J, h)) ### multi-layer perceptrons (MLPs) @@ -96,12 +96,12 @@ def _gresnet(mlp_type, mlp, params, inputs): mu_mlp, sigmasq_mlp = mlp(mlp_params, inputs) mu_res = unravel(np.dot(ravel(inputs), W) + b1) sigmasq_res = log1pexp(b2) - return make_tuple(mu_mlp + mu_res, sigmasq_mlp + sigmasq_res) + return tuple_((mu_mlp + mu_res, sigmasq_mlp + sigmasq_res)) else: J_mlp, h_mlp = mlp(mlp_params, inputs) J_res = -1./2 * log1pexp(b2) h_res = unravel(np.dot(ravel(inputs), W) + b1) - return make_tuple(J_mlp + J_res, h_mlp + h_res) + return tuple_((J_mlp + J_res, h_mlp + h_res)) def init_gresnet(d_in, layer_specs): d_out = layer_specs[-1][0] // 2 diff --git a/svae/svae.py b/svae/svae.py index c35aed2..326ff4b 100644 --- a/svae/svae.py +++ b/svae/svae.py @@ -1,7 +1,7 @@ from __future__ import division, print_function from toolz import curry from autograd import value_and_grad as vgrad -from autograd.util import flatten +from autograd.misc import flatten from util import split_into_batches, get_num_datapoints callback = lambda i, val, params, grad: print('{}: {}'.format(i, val)) diff --git a/svae/util.py b/svae/util.py index 1e3ccb8..55170f6 100644 --- a/svae/util.py +++ b/svae/util.py @@ -2,14 +2,14 @@ import autograd.numpy as np import autograd.numpy.random as npr import autograd.scipy.linalg as spla -from autograd.util import flatten +from autograd.misc import flatten from itertools import islice, imap, cycle import operator from functools import partial from toolz import curry # autograd internals -from autograd.container_types import TupleNode, ListNode +from autograd.builtins import SequenceBox from autograd.core import getval, primitive @@ -127,7 +127,7 @@ def split_into_batches(data, seq_len, num_seqs=None, permute=True): ### basic math on (nested) tuples -istuple = lambda x: isinstance(x, (tuple, TupleNode, list, ListNode)) +istuple = lambda x: isinstance(x, (tuple, list, SequenceBox)) ensuretuple = lambda x: x if istuple(x) else (x,) concat = lambda *args: reduce(operator.add, map(ensuretuple, args)) inner = lambda a, b: np.dot(np.ravel(a), np.ravel(b)) From 62d84900fdd54f444f3664c1fc82880e7ac962b6 Mon Sep 17 00:00:00 2001 From: Jonathan So Date: Wed, 22 Aug 2018 16:51:37 +0100 Subject: [PATCH 2/4] Fix natgrads and use autograd.misc.fixed_points for gmm example. --- experiments/gmm_svae_synth.py | 4 +- svae/models/gmm.py | 93 ++++++++++++++++++++--------------- svae/svae.py | 17 +++---- 3 files changed, 62 insertions(+), 52 deletions(-) diff --git a/experiments/gmm_svae_synth.py b/experiments/gmm_svae_synth.py index 393b95d..5a773f7 100644 --- a/experiments/gmm_svae_synth.py +++ b/experiments/gmm_svae_synth.py @@ -6,7 +6,7 @@ from svae.svae import make_gradfun from svae.nnet import init_gresnet, make_loglike, gaussian_mean, gaussian_info from svae.models.gmm import (run_inference, init_pgm_param, make_encoder_decoder, - make_plotter_2d) + make_plotter_2d, pgm_expectedstats) def make_pinwheel_data(radial_std, tangential_std, num_classes, num_per_class, rate): rads = np.linspace(0, 2*np.pi, num_classes, endpoint=False) @@ -54,7 +54,7 @@ def make_pinwheel_data(radial_std, tangential_std, num_classes, num_per_class, r plot = make_plotter_2d(recognize, decode, data, num_clusters, params, plot_every=100) # instantiate svae gradient function - gradfun = make_gradfun(run_inference, recognize, loglike, pgm_prior_params, data) + gradfun = make_gradfun(run_inference, recognize, loglike, pgm_prior_params, pgm_expectedstats, data) # optimize params = sgd(gradfun(batch_size=50, num_samples=1, natgrad_scale=1e4, callback=plot), diff --git a/svae/models/gmm.py b/svae/models/gmm.py index 522d33a..385ecf6 100644 --- a/svae/models/gmm.py +++ b/svae/models/gmm.py @@ -1,16 +1,15 @@ from __future__ import division import autograd.numpy as np import autograd.numpy.random as npr +from autograd.misc.fixed_points import fixed_point from itertools import repeat -from functools import partial - from svae.util import unbox, getval, flat, normalize from svae.distributions import dirichlet, categorical, niw, gaussian ### inference functions for the SVAE interface -def run_inference(prior_natparam, global_natparam, nn_potentials, num_samples): - _, stats, local_natparam, local_kl = local_meanfield(global_natparam, nn_potentials) +def run_inference(prior_natparam, global_natparam, global_stats, nn_potentials, num_samples): + _, stats, local_natparam, local_kl = local_meanfield(global_stats, nn_potentials) samples = gaussian.natural_sample(local_natparam[1], num_samples) global_kl = prior_kl(global_natparam, prior_natparam) return samples, unbox(stats), global_kl, local_kl @@ -18,7 +17,7 @@ def run_inference(prior_natparam, global_natparam, nn_potentials, num_samples): def make_encoder_decoder(recognize, decode): def encode_mean(data, natparam, recogn_params): nn_potentials = recognize(recogn_params, data) - (_, gaussian_stats), _, _, _ = local_meanfield(natparam, nn_potentials) + (_, gaussian_stats), _, _, _ = local_meanfield(pgm_expectedstats(natparam), nn_potentials) _, Ex, _, _ = gaussian.unpack_dense(gaussian_stats) return Ex @@ -41,6 +40,10 @@ def init_niw_natparam(N): return dirichlet_natparam, niw_natparam +def pgm_expectedstats(global_natparam): + dirichlet_natparam, niw_natparams = global_natparam + return dirichlet.expectedstats(dirichlet_natparam), niw.expectedstats(niw_natparams) + def prior_logZ(gmm_natparam): dirichlet_natparam, niw_natparams = gmm_natparam return dirichlet.logZ(dirichlet_natparam) + niw.logZ(niw_natparams) @@ -57,24 +60,34 @@ def prior_kl(global_natparam, prior_natparam): logZ_difference = prior_logZ(global_natparam) - prior_logZ(prior_natparam) return np.dot(natparam_difference, expected_stats) - logZ_difference +def local_kl(gaussian_globals, label_global, label_natparam, gaussian_natparam, label_stats, gaussian_stats): + return label_kl(label_global, label_natparam, label_stats) + \ + gaussian_kl(gaussian_globals, label_stats, gaussian_natparam, gaussian_stats) + +def gaussian_kl(gaussian_globals, label_stats, natparam, stats): + global_potentials = np.tensordot(label_stats, gaussian_globals, [1, 0]) + return np.tensordot(natparam - global_potentials, stats, 3) - gaussian.logZ(natparam) + +def label_kl(label_global, natparam, stats): + return np.tensordot(stats, natparam - label_global) - categorical.logZ(natparam) + ### GMM mean field functions -def local_meanfield(global_natparam, node_potentials): - dirichlet_natparam, niw_natparams = global_natparam + +def local_meanfield(global_stats, node_potentials): + label_global, gaussian_globals = global_stats node_potentials = gaussian.pack_dense(*node_potentials) - # compute expected global parameters using current global factors - label_global = dirichlet.expectedstats(dirichlet_natparam) - gaussian_globals = niw.expectedstats(niw_natparams) + def make_fpfun((label_global, gaussian_globals, node_potentials)): + return lambda (local_natparam, local_stats, kl): \ + meanfield_update(label_global, gaussian_globals, node_potentials, local_stats[0]) - # compute mean field fixed point using unboxed node_potentials - label_stats = meanfield_fixed_point(label_global, gaussian_globals, getval(node_potentials)) + x0 = initialize_meanfield(label_global, gaussian_globals, node_potentials) - # compute values that depend directly on boxed node_potentials at optimum - gaussian_natparam, gaussian_stats, gaussian_kl = \ - gaussian_meanfield(gaussian_globals, node_potentials, label_stats) - label_natparam, label_stats, label_kl = \ - label_meanfield(label_global, gaussian_globals, gaussian_stats) + kl_diff = lambda a, b: abs(a[2]-b[2]) + + (label_natparam, gaussian_natparam), (label_stats, gaussian_stats), _ = \ + fixed_point(make_fpfun, (label_global, gaussian_globals, node_potentials), x0, kl_diff, tol=1e-3) # collect sufficient statistics for gmm prior (sum across conditional iid) dirichlet_stats = np.sum(label_stats, 0) @@ -83,31 +96,24 @@ def local_meanfield(global_natparam, node_potentials): local_stats = label_stats, gaussian_stats prior_stats = dirichlet_stats, niw_stats natparam = label_natparam, gaussian_natparam - kl = label_kl + gaussian_kl + kl = local_kl(getval(gaussian_globals), getval(label_global), + label_natparam, gaussian_natparam, label_stats, gaussian_stats) return local_stats, prior_stats, natparam, kl -def meanfield_fixed_point(label_global, gaussian_globals, node_potentials, tol=1e-3, max_iter=100): - kl = np.inf - label_stats = initialize_meanfield(label_global, node_potentials) - for i in xrange(max_iter): - gaussian_natparam, gaussian_stats, gaussian_kl = \ - gaussian_meanfield(gaussian_globals, node_potentials, label_stats) - label_natparam, label_stats, label_kl = \ - label_meanfield(label_global, gaussian_globals, gaussian_stats) - - # recompute gaussian_kl linear term with new label_stats b/c labels were updated - gaussian_global_potentials = np.tensordot(label_stats, gaussian_globals, [1, 0]) - linear_difference = gaussian_natparam - gaussian_global_potentials - node_potentials - gaussian_kl = gaussian_kl + np.tensordot(linear_difference, gaussian_stats, 3) - - kl, prev_kl = label_kl + gaussian_kl, kl - if abs(kl - prev_kl) < tol: - break - else: - print 'iteration limit reached' - - return label_stats +def meanfield_update(label_global, gaussian_globals, node_potentials, label_stats): + gaussian_natparam, gaussian_stats, gaussian_kl = \ + gaussian_meanfield(gaussian_globals, node_potentials, label_stats) + label_natparam, label_stats, label_kl = \ + label_meanfield(label_global, gaussian_globals, gaussian_stats) + + # recompute gaussian_kl linear term with new label_stats b/c labels were updated + gaussian_global_potentials = np.tensordot(label_stats, gaussian_globals, [1, 0]) + linear_difference = gaussian_natparam - gaussian_global_potentials - node_potentials + gaussian_kl = gaussian_kl + np.tensordot(linear_difference, gaussian_stats, 3) + kl = label_kl + gaussian_kl + + return (label_natparam, gaussian_natparam), (label_stats, gaussian_stats), kl def gaussian_meanfield(gaussian_globals, node_potentials, label_stats): global_potentials = np.tensordot(label_stats, gaussian_globals, [1, 0]) @@ -123,9 +129,14 @@ def label_meanfield(label_global, gaussian_globals, gaussian_stats): kl = np.tensordot(stats, node_potentials) - categorical.logZ(natparam) return natparam, stats, kl -def initialize_meanfield(label_global, node_potentials): +def initialize_meanfield(label_global, gaussian_globals, node_potentials): T, K = node_potentials.shape[0], label_global.shape[0] - return normalize(npr.rand(T, K)) + label_stats = normalize(npr.rand(T, K)) + label_natparam = np.zeros(label_stats.shape) + gaussian_stats = np.zeros(gaussian_globals.shape) + gaussian_natparam = np.zeros(gaussian_stats.shape) + kl = np.inf + return (label_natparam, gaussian_natparam), (label_stats, gaussian_stats), kl ### plotting util for 2D diff --git a/svae/svae.py b/svae/svae.py index 326ff4b..d98a779 100644 --- a/svae/svae.py +++ b/svae/svae.py @@ -8,7 +8,7 @@ flat = lambda struct: flatten(struct)[0] @curry -def make_gradfun(run_inference, recognize, loglike, pgm_prior, data, +def make_gradfun(run_inference, recognize, loglike, pgm_prior, pgm_expectedstats, data, batch_size, num_samples, natgrad_scale=1., callback=callback): _, unflat = flatten(pgm_prior) num_datapoints = get_num_datapoints(data) @@ -16,22 +16,21 @@ def make_gradfun(run_inference, recognize, loglike, pgm_prior, data, get_batch = lambda i: data_batches[i % num_batches] saved = lambda: None - def mc_elbo(pgm_params, loglike_params, recogn_params, i): + def mc_elbo(pgm_params, pgm_stats, loglike_params, recogn_params, i): nn_potentials = recognize(recogn_params, get_batch(i)) samples, saved.stats, global_kl, local_kl = \ - run_inference(pgm_prior, pgm_params, nn_potentials, num_samples) + run_inference(pgm_prior, pgm_params, pgm_stats, nn_potentials, num_samples) return (num_batches * loglike(loglike_params, samples, get_batch(i)) - global_kl - num_batches * local_kl) / num_datapoints def gradfun(params, i): pgm_params, loglike_params, recogn_params = params - objective = lambda (loglike_params, recogn_params): \ - -mc_elbo(pgm_params, loglike_params, recogn_params, i) - val, (loglike_grad, recogn_grad) = vgrad(objective)((loglike_params, recogn_params)) - # this expression for pgm_natgrad drops a term that can be computed using - # the function autograd.misc.fixed_points.fixed_point + objective = lambda (pgm_stats, loglike_params, recogn_params): \ + -mc_elbo(pgm_params, pgm_stats, loglike_params, recogn_params, i) + pgm_stats = pgm_expectedstats(pgm_params) + val, (pgm_stats_grad, loglike_grad, recogn_grad) = vgrad(objective)((pgm_stats, loglike_params, recogn_params)) pgm_natgrad = -natgrad_scale / num_datapoints * \ - (flat(pgm_prior) + num_batches*flat(saved.stats) - flat(pgm_params)) + (flat(pgm_prior) + num_batches*(flat(saved.stats) + flat(pgm_stats_grad)) - flat(pgm_params)) grad = unflat(pgm_natgrad), loglike_grad, recogn_grad if callback: callback(i, val, params, grad) return grad From 8540ce6789cb446589cc88bf69ecd276bac659ba Mon Sep 17 00:00:00 2001 From: Jonathan So Date: Wed, 22 Aug 2018 21:46:20 +0100 Subject: [PATCH 3/4] Minor fix in gmm.make_encoder. --- svae/models/gmm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/svae/models/gmm.py b/svae/models/gmm.py index 385ecf6..aee3097 100644 --- a/svae/models/gmm.py +++ b/svae/models/gmm.py @@ -21,8 +21,8 @@ def encode_mean(data, natparam, recogn_params): _, Ex, _, _ = gaussian.unpack_dense(gaussian_stats) return Ex - def decode_mean(z, phi): - mu, _ = decode(z, phi) + def decode_mean(phi, z): + mu, _ = decode(phi, z) return mu.mean(axis=1) return encode_mean, decode_mean From 03160a6013d7aef0df5d621d06d123cd6aa69d03 Mon Sep 17 00:00:00 2001 From: Jonathan So Date: Wed, 22 Aug 2018 21:49:28 +0100 Subject: [PATCH 4/4] Undelete custom adam optimizer. --- experiments/gmm_svae_synth.py | 4 ++-- svae/optimizers.py | 27 +++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 svae/optimizers.py diff --git a/experiments/gmm_svae_synth.py b/experiments/gmm_svae_synth.py index 5a773f7..eb73516 100644 --- a/experiments/gmm_svae_synth.py +++ b/experiments/gmm_svae_synth.py @@ -2,11 +2,11 @@ import matplotlib.pyplot as plt import autograd.numpy as np import autograd.numpy.random as npr -from autograd.misc.optimizers import adam, sgd from svae.svae import make_gradfun from svae.nnet import init_gresnet, make_loglike, gaussian_mean, gaussian_info from svae.models.gmm import (run_inference, init_pgm_param, make_encoder_decoder, make_plotter_2d, pgm_expectedstats) +from svae.optimizers import adam def make_pinwheel_data(radial_std, tangential_std, num_classes, num_per_class, rate): rads = np.linspace(0, 2*np.pi, num_classes, endpoint=False) @@ -57,5 +57,5 @@ def make_pinwheel_data(radial_std, tangential_std, num_classes, num_per_class, r gradfun = make_gradfun(run_inference, recognize, loglike, pgm_prior_params, pgm_expectedstats, data) # optimize - params = sgd(gradfun(batch_size=50, num_samples=1, natgrad_scale=1e4, callback=plot), + params = adam(gradfun(batch_size=50, num_samples=1, natgrad_scale=1e4, callback=plot), params, num_iters=1000, step_size=1e-3) diff --git a/svae/optimizers.py b/svae/optimizers.py new file mode 100644 index 0000000..032c4c0 --- /dev/null +++ b/svae/optimizers.py @@ -0,0 +1,27 @@ +from util import add, sub, scale, zeros_like, square, sqrt, div, add_scalar, concat + +# TODO make optimizers into monads! +# TODO track grad statistics + +def adam(gradfun, allparams, num_iters, step_size, b1=0.9, b2=0.999, eps=1e-8): + natparams, params = allparams[:1], allparams[1:] + m = zeros_like(params) + v = zeros_like(params) + i = 0 + accumulate = lambda rho, a, b: add(scale(1-rho, a), scale(rho, b)) + + for i in xrange(num_iters): + grad = gradfun(allparams, i) + natgrad, grad = grad[:1], grad[1:] + + m = accumulate(b1, grad, m) # first moment estimate + v = accumulate(b2, square(grad), v) # second moment estimate + mhat = scale(1./(1 - b1**(i+1)), m) # bias correction + vhat = scale(1./(1 - b2**(i+1)), v) + update = scale(step_size, div(mhat, add_scalar(eps, sqrt(vhat)))) + + natparams = sub(natparams, scale(step_size, natgrad)) + params = sub(params, update) + allparams = concat(natparams, params) + + return allparams