Skip to content

Commit

Permalink
In fit_surrogate_posterior, support targets that take random seeds.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 563749304
  • Loading branch information
jburnim authored and tensorflower-gardener committed Sep 8, 2023
1 parent 0fc3b0c commit 558b618
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 14 deletions.
1 change: 1 addition & 0 deletions tensorflow_probability/python/vi/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ multi_substrate_py_library(
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/internal:nest_util",
"//tensorflow_probability/python/internal:reparameterization",
"//tensorflow_probability/python/internal:samplers",
"//tensorflow_probability/python/monte_carlo",
"//tensorflow_probability/python/stats:leave_one_out",
],
Expand Down
34 changes: 24 additions & 10 deletions tensorflow_probability/python/vi/csiszar_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Csiszar f-Divergence and helpers."""

import enum
import functools
import warnings

# Dependency imports
Expand All @@ -26,6 +27,7 @@
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import nest_util
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.internal.reparameterization import FULLY_REPARAMETERIZED
from tensorflow_probability.python.stats.leave_one_out import log_soomean_exp

Expand Down Expand Up @@ -55,6 +57,13 @@
]


def _call_fn_maybe_with_seed(fn, args, *, seed=None):
try:
return nest_util.call_fn(functools.partial(fn, seed=seed), args)
except: # pylint: disable=bare-except
return nest_util.call_fn(fn, args)


class GradientEstimators(enum.Enum):
"""Gradient estimators for variational losses.
Expand Down Expand Up @@ -1045,6 +1054,7 @@ def monte_carlo_variational_loss(
raise TypeError('`target_log_prob_fn` must be a Python `callable`'
'function.')

sample_seed, target_seed = samplers.split_seed(seed, 2)
reparameterization_types = tf.nest.flatten(
surrogate_posterior.reparameterization_type)
if gradient_estimator is None:
Expand All @@ -1067,7 +1077,7 @@ def monte_carlo_variational_loss(
'losses with `importance_sample_size != 1`.')
# Score fn objective requires explicit gradients of `log_prob`.
q_samples = surrogate_posterior.sample(
[sample_size * importance_sample_size], seed=seed)
[sample_size * importance_sample_size], seed=sample_seed)
q_lp = None
else:
if any(reparameterization_type != FULLY_REPARAMETERIZED
Expand All @@ -1080,7 +1090,7 @@ def monte_carlo_variational_loss(
# Attempt to avoid bijector inverses by computing the surrogate log prob
# during the forward sampling pass.
q_samples, q_lp = surrogate_posterior.experimental_sample_and_log_prob(
[sample_size * importance_sample_size], seed=seed)
[sample_size * importance_sample_size], seed=sample_seed)

return monte_carlo.expectation(
f=_make_importance_weighted_divergence_fn(
Expand All @@ -1090,8 +1100,8 @@ def monte_carlo_variational_loss(
precomputed_surrogate_log_prob=q_lp,
importance_sample_size=importance_sample_size,
gradient_estimator=gradient_estimator,
stopped_surrogate_posterior=(
stopped_surrogate_posterior)),
stopped_surrogate_posterior=stopped_surrogate_posterior,
seed=target_seed),
samples=q_samples,
# Log-prob is only used if `gradient_estimator == SCORE_FUNCTION`.
log_prob=surrogate_posterior.log_prob,
Expand All @@ -1106,18 +1116,19 @@ def _make_importance_weighted_divergence_fn(
precomputed_surrogate_log_prob=None,
importance_sample_size=1,
gradient_estimator=GradientEstimators.REPARAMETERIZATION,
stopped_surrogate_posterior=None):
stopped_surrogate_posterior=None,
seed=None):
"""Defines a function to compute an importance-weighted divergence."""

def divergence_fn(q_samples):
q_lp = precomputed_surrogate_log_prob
target_log_prob = nest_util.call_fn(target_log_prob_fn, q_samples)
target_log_prob = _call_fn_maybe_with_seed(
target_log_prob_fn, q_samples, seed=seed)

if gradient_estimator == GradientEstimators.DOUBLY_REPARAMETERIZED:
# Sticking-the-landing is the special case of doubly-reparameterized
# gradients with `importance_sample_size=1`.
q_lp = stopped_surrogate_posterior.log_prob(q_samples)
log_weights = target_log_prob - q_lp
else:
if q_lp is None:
q_lp = surrogate_posterior.log_prob(q_samples)
Expand All @@ -1128,7 +1139,8 @@ def importance_weighted_divergence_fn(q_samples):
q_lp = precomputed_surrogate_log_prob
if q_lp is None:
q_lp = surrogate_posterior.log_prob(q_samples)
target_log_prob = nest_util.call_fn(target_log_prob_fn, q_samples)
target_log_prob = _call_fn_maybe_with_seed(
target_log_prob_fn, q_samples, seed=seed)
log_weights = target_log_prob - q_lp

# Explicitly break out `importance_sample_size` as a separate axis.
Expand Down Expand Up @@ -1243,10 +1255,12 @@ def csiszar_vimco(f,
raise ValueError('Must specify num_draws > 1.')
stop = tf.stop_gradient # For readability.

q_sample = q.sample(sample_shape=[num_draws, num_batch_draws], seed=seed)
sample_seed, target_seed = samplers.split_seed(seed, 2)
q_sample = q.sample(sample_shape=[num_draws, num_batch_draws],
seed=sample_seed)
x = tf.nest.map_structure(stop, q_sample)
logqx = q.log_prob(x)
logu = nest_util.call_fn(p_log_prob, x) - logqx
logu = _call_fn_maybe_with_seed(p_log_prob, x, seed=target_seed) - logqx
f_log_sooavg_u, f_log_avg_u = map(f, log_soomean_exp(logu, axis=0))

dotprod = tf.reduce_sum(
Expand Down
15 changes: 12 additions & 3 deletions tensorflow_probability/python/vi/csiszar_divergence_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,10 @@ def target_log_prob_fn(x):

# Manually estimate the expected multi-sample / IWAE loss.
zs, q_lp = surrogate_posterior.experimental_sample_and_log_prob(
[sample_size, importance_sample_size], seed=seed)
[sample_size, importance_sample_size],
# Brittle hack to ensure that the q samples match those
# drawn in `monte_carlo_variational_loss`.
seed=samplers.split_seed(seed, 2)[0])
log_weights = target_log_prob_fn(zs) - q_lp
iwae_loss = -tf.reduce_mean(
tf.math.reduce_logsumexp(log_weights, axis=1) - tf.math.log(
Expand Down Expand Up @@ -988,7 +991,10 @@ def vimco_loss(s):

def logu(s):
q = build_q(s)
x = q.sample(sample_shape=[num_draws, num_batch_draws], seed=seed)
x = q.sample(sample_shape=[num_draws, num_batch_draws],
# Brittle hack to ensure that the q samples match those
# drawn in `monte_carlo_variational_loss`.
seed=samplers.split_seed(seed, 2)[0])
x = tf.stop_gradient(x)
return p.log_prob(x) - q.log_prob(x)

Expand All @@ -997,7 +1003,10 @@ def f_log_sum_u(s):

def q_log_prob_x(s):
q = build_q(s)
x = q.sample(sample_shape=[num_draws, num_batch_draws], seed=seed)
x = q.sample(sample_shape=[num_draws, num_batch_draws],
# Brittle hack to ensure that the q samples match those
# drawn in `monte_carlo_variational_loss`.
seed=samplers.split_seed(seed, 2)[0])
x = tf.stop_gradient(x)
return q.log_prob(x)

Expand Down
7 changes: 6 additions & 1 deletion tensorflow_probability/python/vi/optimization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,14 @@ def variational_model_fn():
return
import optax # pylint: disable=g-import-not-at-top

def seeded_target_log_prob_fn(*xs, seed=None):
# Add a tiny amount of noise to the target log-prob to see if it works.
ret = pinned.unnormalized_log_prob(xs)
return ret + samplers.normal(ret.shape, stddev=0.01, seed=seed)

[optimized_parameters,
(losses, _, sample_path)] = optimization.fit_surrogate_posterior_stateless(
target_log_prob_fn=pinned.unnormalized_log_prob,
target_log_prob_fn=seeded_target_log_prob_fn,
build_surrogate_posterior_fn=build_surrogate_posterior_fn,
initial_parameters=initial_parameters,
optimizer=optax.adam(learning_rate=0.1),
Expand Down

0 comments on commit 558b618

Please sign in to comment.