Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to use latest autograd (65c21e2). #15

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions experiments/gmm_svae_synth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import matplotlib.pyplot as plt
import autograd.numpy as np
import autograd.numpy.random as npr
from autograd.optimizers import adam, sgd
from svae.svae import make_gradfun
from svae.nnet import init_gresnet, make_loglike, gaussian_mean, gaussian_info
from svae.models.gmm import (run_inference, init_pgm_param, make_encoder_decoder,
make_plotter_2d)
make_plotter_2d, pgm_expectedstats)
from svae.optimizers import adam

def make_pinwheel_data(radial_std, tangential_std, num_classes, num_per_class, rate):
rads = np.linspace(0, 2*np.pi, num_classes, endpoint=False)
Expand Down Expand Up @@ -54,8 +54,8 @@ def make_pinwheel_data(radial_std, tangential_std, num_classes, num_per_class, r
plot = make_plotter_2d(recognize, decode, data, num_clusters, params, plot_every=100)

# instantiate svae gradient function
gradfun = make_gradfun(run_inference, recognize, loglike, pgm_prior_params, data)
gradfun = make_gradfun(run_inference, recognize, loglike, pgm_prior_params, pgm_expectedstats, data)

# optimize
params = sgd(gradfun(batch_size=50, num_samples=1, natgrad_scale=1e4, callback=plot),
params = adam(gradfun(batch_size=50, num_samples=1, natgrad_scale=1e4, callback=plot),
params, num_iters=1000, step_size=1e-3)
6 changes: 3 additions & 3 deletions svae/distributions/mniw.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from autograd.scipy.special import multigammaln, digamma
from autograd.scipy.linalg import solve_triangular
from autograd import grad
from autograd.util import make_tuple
from autograd.builtins import tuple as tuple_
from scipy.stats import chi2

from svae.util import symmetrize
Expand Down Expand Up @@ -51,8 +51,8 @@ def expectedstats_standard(nu, S, M, K, fudge=1e-8):
assert is_posdef(E_Sigmainv)
assert is_posdef(E_AT_Sigmainv_A)

return make_tuple(
-1./2*E_AT_Sigmainv_A, E_Sigmainv_A.T, -1./2*E_Sigmainv, 1./2*E_logdetSigmainv)
return tuple_((
-1./2*E_AT_Sigmainv_A, E_Sigmainv_A.T, -1./2*E_Sigmainv, 1./2*E_logdetSigmainv))

def expectedstats_autograd(natparam):
return grad(logZ)(natparam)
Expand Down
1 change: 0 additions & 1 deletion svae/distributions/niw.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import autograd.numpy as np
from autograd.scipy.special import multigammaln, digamma
from autograd import grad
from autograd.util import make_tuple

from svae.util import symmetrize, outer
from gaussian import pack_dense, unpack_dense
Expand Down
8 changes: 4 additions & 4 deletions svae/lds/lds_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import autograd.numpy.random as npr
from autograd import grad
from autograd.convenience_wrappers import grad_and_aux as agrad, value_and_grad as vgrad
from autograd.util import make_tuple
from autograd.builtins import tuple as tuple_
from autograd.core import primitive, primitive_with_aux
from operator import itemgetter, attrgetter

Expand Down Expand Up @@ -135,12 +135,12 @@ def unit(filtered_message):
J, h = filtered_message
mu, Sigma = natural_to_mean(filtered_message)
ExxT = Sigma + np.outer(mu, mu)
return make_tuple(J, h, mu), [(mu, ExxT, 0.)]
return tuple_((J, h, mu)), [(mu, ExxT, 0.)]

def bind(result, step):
next_smooth, stats = result
J, h, (mu, ExxT, ExxnT) = step(next_smooth)
return make_tuple(J, h, mu), [(mu, ExxT, ExxnT)] + stats
return tuple_((J, h, mu)), [(mu, ExxT, ExxnT)] + stats

rts = lambda next_pred, filtered, pair_param: lambda next_smooth: \
natural_rts_backward_step(next_smooth, next_pred, filtered, pair_param)
Expand Down Expand Up @@ -211,7 +211,7 @@ def lds_log_normalizer(all_natparams):
init_params, pair_params, node_params)
return lognorm, (lognorm, forward_messages)

all_natparams = make_tuple(init_params, pair_params, node_params)
all_natparams = tuple_((init_params, pair_params, node_params))
expected_stats, (lognorm, forward_messages) = agrad(lds_log_normalizer)(all_natparams)
samples = natural_sample_backward_general(forward_messages, pair_params, num_samples)

Expand Down
97 changes: 54 additions & 43 deletions svae/models/gmm.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
from __future__ import division
import autograd.numpy as np
import autograd.numpy.random as npr
from autograd.misc.fixed_points import fixed_point
from itertools import repeat
from functools import partial

from svae.util import unbox, getval, flat, normalize
from svae.distributions import dirichlet, categorical, niw, gaussian

### inference functions for the SVAE interface

def run_inference(prior_natparam, global_natparam, nn_potentials, num_samples):
_, stats, local_natparam, local_kl = local_meanfield(global_natparam, nn_potentials)
def run_inference(prior_natparam, global_natparam, global_stats, nn_potentials, num_samples):
_, stats, local_natparam, local_kl = local_meanfield(global_stats, nn_potentials)
samples = gaussian.natural_sample(local_natparam[1], num_samples)
global_kl = prior_kl(global_natparam, prior_natparam)
return samples, unbox(stats), global_kl, local_kl

def make_encoder_decoder(recognize, decode):
def encode_mean(data, natparam, recogn_params):
nn_potentials = recognize(recogn_params, data)
(_, gaussian_stats), _, _, _ = local_meanfield(natparam, nn_potentials)
(_, gaussian_stats), _, _, _ = local_meanfield(pgm_expectedstats(natparam), nn_potentials)
_, Ex, _, _ = gaussian.unpack_dense(gaussian_stats)
return Ex

def decode_mean(z, phi):
mu, _ = decode(z, phi)
def decode_mean(phi, z):
mu, _ = decode(phi, z)
return mu.mean(axis=1)

return encode_mean, decode_mean
Expand All @@ -41,6 +40,10 @@ def init_niw_natparam(N):

return dirichlet_natparam, niw_natparam

def pgm_expectedstats(global_natparam):
dirichlet_natparam, niw_natparams = global_natparam
return dirichlet.expectedstats(dirichlet_natparam), niw.expectedstats(niw_natparams)

def prior_logZ(gmm_natparam):
dirichlet_natparam, niw_natparams = gmm_natparam
return dirichlet.logZ(dirichlet_natparam) + niw.logZ(niw_natparams)
Expand All @@ -57,24 +60,34 @@ def prior_kl(global_natparam, prior_natparam):
logZ_difference = prior_logZ(global_natparam) - prior_logZ(prior_natparam)
return np.dot(natparam_difference, expected_stats) - logZ_difference

def local_kl(gaussian_globals, label_global, label_natparam, gaussian_natparam, label_stats, gaussian_stats):
return label_kl(label_global, label_natparam, label_stats) + \
gaussian_kl(gaussian_globals, label_stats, gaussian_natparam, gaussian_stats)

def gaussian_kl(gaussian_globals, label_stats, natparam, stats):
global_potentials = np.tensordot(label_stats, gaussian_globals, [1, 0])
return np.tensordot(natparam - global_potentials, stats, 3) - gaussian.logZ(natparam)

def label_kl(label_global, natparam, stats):
return np.tensordot(stats, natparam - label_global) - categorical.logZ(natparam)

### GMM mean field functions

def local_meanfield(global_natparam, node_potentials):
dirichlet_natparam, niw_natparams = global_natparam

def local_meanfield(global_stats, node_potentials):
label_global, gaussian_globals = global_stats
node_potentials = gaussian.pack_dense(*node_potentials)

# compute expected global parameters using current global factors
label_global = dirichlet.expectedstats(dirichlet_natparam)
gaussian_globals = niw.expectedstats(niw_natparams)
def make_fpfun((label_global, gaussian_globals, node_potentials)):
return lambda (local_natparam, local_stats, kl): \
meanfield_update(label_global, gaussian_globals, node_potentials, local_stats[0])

# compute mean field fixed point using unboxed node_potentials
label_stats = meanfield_fixed_point(label_global, gaussian_globals, getval(node_potentials))
x0 = initialize_meanfield(label_global, gaussian_globals, node_potentials)

# compute values that depend directly on boxed node_potentials at optimum
gaussian_natparam, gaussian_stats, gaussian_kl = \
gaussian_meanfield(gaussian_globals, node_potentials, label_stats)
label_natparam, label_stats, label_kl = \
label_meanfield(label_global, gaussian_globals, gaussian_stats)
kl_diff = lambda a, b: abs(a[2]-b[2])

(label_natparam, gaussian_natparam), (label_stats, gaussian_stats), _ = \
fixed_point(make_fpfun, (label_global, gaussian_globals, node_potentials), x0, kl_diff, tol=1e-3)

# collect sufficient statistics for gmm prior (sum across conditional iid)
dirichlet_stats = np.sum(label_stats, 0)
Expand All @@ -83,31 +96,24 @@ def local_meanfield(global_natparam, node_potentials):
local_stats = label_stats, gaussian_stats
prior_stats = dirichlet_stats, niw_stats
natparam = label_natparam, gaussian_natparam
kl = label_kl + gaussian_kl
kl = local_kl(getval(gaussian_globals), getval(label_global),
label_natparam, gaussian_natparam, label_stats, gaussian_stats)

return local_stats, prior_stats, natparam, kl

def meanfield_fixed_point(label_global, gaussian_globals, node_potentials, tol=1e-3, max_iter=100):
kl = np.inf
label_stats = initialize_meanfield(label_global, node_potentials)
for i in xrange(max_iter):
gaussian_natparam, gaussian_stats, gaussian_kl = \
gaussian_meanfield(gaussian_globals, node_potentials, label_stats)
label_natparam, label_stats, label_kl = \
label_meanfield(label_global, gaussian_globals, gaussian_stats)

# recompute gaussian_kl linear term with new label_stats b/c labels were updated
gaussian_global_potentials = np.tensordot(label_stats, gaussian_globals, [1, 0])
linear_difference = gaussian_natparam - gaussian_global_potentials - node_potentials
gaussian_kl = gaussian_kl + np.tensordot(linear_difference, gaussian_stats, 3)

kl, prev_kl = label_kl + gaussian_kl, kl
if abs(kl - prev_kl) < tol:
break
else:
print 'iteration limit reached'

return label_stats
def meanfield_update(label_global, gaussian_globals, node_potentials, label_stats):
gaussian_natparam, gaussian_stats, gaussian_kl = \
gaussian_meanfield(gaussian_globals, node_potentials, label_stats)
label_natparam, label_stats, label_kl = \
label_meanfield(label_global, gaussian_globals, gaussian_stats)

# recompute gaussian_kl linear term with new label_stats b/c labels were updated
gaussian_global_potentials = np.tensordot(label_stats, gaussian_globals, [1, 0])
linear_difference = gaussian_natparam - gaussian_global_potentials - node_potentials
gaussian_kl = gaussian_kl + np.tensordot(linear_difference, gaussian_stats, 3)
kl = label_kl + gaussian_kl

return (label_natparam, gaussian_natparam), (label_stats, gaussian_stats), kl

def gaussian_meanfield(gaussian_globals, node_potentials, label_stats):
global_potentials = np.tensordot(label_stats, gaussian_globals, [1, 0])
Expand All @@ -123,9 +129,14 @@ def label_meanfield(label_global, gaussian_globals, gaussian_stats):
kl = np.tensordot(stats, node_potentials) - categorical.logZ(natparam)
return natparam, stats, kl

def initialize_meanfield(label_global, node_potentials):
def initialize_meanfield(label_global, gaussian_globals, node_potentials):
T, K = node_potentials.shape[0], label_global.shape[0]
return normalize(npr.rand(T, K))
label_stats = normalize(npr.rand(T, K))
label_natparam = np.zeros(label_stats.shape)
gaussian_stats = np.zeros(gaussian_globals.shape)
gaussian_natparam = np.zeros(gaussian_stats.shape)
kl = np.inf
return (label_natparam, gaussian_natparam), (label_stats, gaussian_stats), kl

### plotting util for 2D

Expand Down
4 changes: 2 additions & 2 deletions svae/models/slds_svae.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import autograd.numpy as np
import autograd.numpy.random as npr
from autograd import grad
from autograd.util import make_tuple
from autograd.builtins import tuple as tuple_
from functools import partial
import sys

Expand Down Expand Up @@ -150,7 +150,7 @@ def get_arhmm_local_nodeparams(lds_global_natparam, lds_expected_stats):
def get_hmm_vlb(lds_global_natparam, hmm_local_natparam, lds_expected_stats):
init_params, pair_params, _ = hmm_local_natparam
node_params = get_arhmm_local_nodeparams(lds_global_natparam, lds_expected_stats)
local_natparam = make_tuple(init_params, pair_params, node_params)
local_natparam = tuple_((init_params, pair_params, node_params))
return hmm_logZ(local_natparam)


Expand Down
10 changes: 5 additions & 5 deletions svae/nnet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import division
import autograd.numpy as np
import autograd.numpy.random as npr
from autograd.util import make_tuple
from autograd.builtins import tuple as tuple_
from toolz import curry
from collections import defaultdict

Expand Down Expand Up @@ -38,13 +38,13 @@ def gaussian_mean(inputs, sigmoid_mean=False):
mu_input, sigmasq_input = np.split(inputs, 2, axis=-1)
mu = sigmoid(mu_input) if sigmoid_mean else mu_input
sigmasq = log1pexp(sigmasq_input)
return make_tuple(mu, sigmasq)
return tuple_((mu, sigmasq))

@curry
def gaussian_info(inputs):
J_input, h = np.split(inputs, 2, axis=-1)
J = -1./2 * log1pexp(J_input)
return make_tuple(J, h)
return tuple_((J, h))


### multi-layer perceptrons (MLPs)
Expand Down Expand Up @@ -96,12 +96,12 @@ def _gresnet(mlp_type, mlp, params, inputs):
mu_mlp, sigmasq_mlp = mlp(mlp_params, inputs)
mu_res = unravel(np.dot(ravel(inputs), W) + b1)
sigmasq_res = log1pexp(b2)
return make_tuple(mu_mlp + mu_res, sigmasq_mlp + sigmasq_res)
return tuple_((mu_mlp + mu_res, sigmasq_mlp + sigmasq_res))
else:
J_mlp, h_mlp = mlp(mlp_params, inputs)
J_res = -1./2 * log1pexp(b2)
h_res = unravel(np.dot(ravel(inputs), W) + b1)
return make_tuple(J_mlp + J_res, h_mlp + h_res)
return tuple_((J_mlp + J_res, h_mlp + h_res))

def init_gresnet(d_in, layer_specs):
d_out = layer_specs[-1][0] // 2
Expand Down
27 changes: 27 additions & 0 deletions svae/optimizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from util import add, sub, scale, zeros_like, square, sqrt, div, add_scalar, concat

# TODO make optimizers into monads!
# TODO track grad statistics

def adam(gradfun, allparams, num_iters, step_size, b1=0.9, b2=0.999, eps=1e-8):
natparams, params = allparams[:1], allparams[1:]
m = zeros_like(params)
v = zeros_like(params)
i = 0
accumulate = lambda rho, a, b: add(scale(1-rho, a), scale(rho, b))

for i in xrange(num_iters):
grad = gradfun(allparams, i)
natgrad, grad = grad[:1], grad[1:]

m = accumulate(b1, grad, m) # first moment estimate
v = accumulate(b2, square(grad), v) # second moment estimate
mhat = scale(1./(1 - b1**(i+1)), m) # bias correction
vhat = scale(1./(1 - b2**(i+1)), v)
update = scale(step_size, div(mhat, add_scalar(eps, sqrt(vhat))))

natparams = sub(natparams, scale(step_size, natgrad))
params = sub(params, update)
allparams = concat(natparams, params)

return allparams
19 changes: 9 additions & 10 deletions svae/svae.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,36 @@
from __future__ import division, print_function
from toolz import curry
from autograd import value_and_grad as vgrad
from autograd.util import flatten
from autograd.misc import flatten
from util import split_into_batches, get_num_datapoints

callback = lambda i, val, params, grad: print('{}: {}'.format(i, val))
flat = lambda struct: flatten(struct)[0]

@curry
def make_gradfun(run_inference, recognize, loglike, pgm_prior, data,
def make_gradfun(run_inference, recognize, loglike, pgm_prior, pgm_expectedstats, data,
batch_size, num_samples, natgrad_scale=1., callback=callback):
_, unflat = flatten(pgm_prior)
num_datapoints = get_num_datapoints(data)
data_batches, num_batches = split_into_batches(data, batch_size)
get_batch = lambda i: data_batches[i % num_batches]
saved = lambda: None

def mc_elbo(pgm_params, loglike_params, recogn_params, i):
def mc_elbo(pgm_params, pgm_stats, loglike_params, recogn_params, i):
nn_potentials = recognize(recogn_params, get_batch(i))
samples, saved.stats, global_kl, local_kl = \
run_inference(pgm_prior, pgm_params, nn_potentials, num_samples)
run_inference(pgm_prior, pgm_params, pgm_stats, nn_potentials, num_samples)
return (num_batches * loglike(loglike_params, samples, get_batch(i))
- global_kl - num_batches * local_kl) / num_datapoints

def gradfun(params, i):
pgm_params, loglike_params, recogn_params = params
objective = lambda (loglike_params, recogn_params): \
-mc_elbo(pgm_params, loglike_params, recogn_params, i)
val, (loglike_grad, recogn_grad) = vgrad(objective)((loglike_params, recogn_params))
# this expression for pgm_natgrad drops a term that can be computed using
# the function autograd.misc.fixed_points.fixed_point
objective = lambda (pgm_stats, loglike_params, recogn_params): \
-mc_elbo(pgm_params, pgm_stats, loglike_params, recogn_params, i)
pgm_stats = pgm_expectedstats(pgm_params)
val, (pgm_stats_grad, loglike_grad, recogn_grad) = vgrad(objective)((pgm_stats, loglike_params, recogn_params))
pgm_natgrad = -natgrad_scale / num_datapoints * \
(flat(pgm_prior) + num_batches*flat(saved.stats) - flat(pgm_params))
(flat(pgm_prior) + num_batches*(flat(saved.stats) + flat(pgm_stats_grad)) - flat(pgm_params))
grad = unflat(pgm_natgrad), loglike_grad, recogn_grad
if callback: callback(i, val, params, grad)
return grad
Expand Down
Loading