diff --git a/tensorflow_probability/python/vi/BUILD b/tensorflow_probability/python/vi/BUILD index 9ffb679efd..3f9b74522b 100644 --- a/tensorflow_probability/python/vi/BUILD +++ b/tensorflow_probability/python/vi/BUILD @@ -55,6 +55,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:nest_util", "//tensorflow_probability/python/internal:reparameterization", + "//tensorflow_probability/python/internal:samplers", "//tensorflow_probability/python/monte_carlo", "//tensorflow_probability/python/stats:leave_one_out", ], diff --git a/tensorflow_probability/python/vi/csiszar_divergence.py b/tensorflow_probability/python/vi/csiszar_divergence.py index 0904bce01c..06bae4dd49 100644 --- a/tensorflow_probability/python/vi/csiszar_divergence.py +++ b/tensorflow_probability/python/vi/csiszar_divergence.py @@ -15,6 +15,7 @@ """Csiszar f-Divergence and helpers.""" import enum +import functools import warnings # Dependency imports @@ -26,6 +27,7 @@ from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import nest_util from tensorflow_probability.python.internal import prefer_static as ps +from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.internal.reparameterization import FULLY_REPARAMETERIZED from tensorflow_probability.python.stats.leave_one_out import log_soomean_exp @@ -55,6 +57,13 @@ ] +def _call_fn_maybe_with_seed(fn, args, *, seed=None): + try: + return nest_util.call_fn(functools.partial(fn, seed=seed), args) + except: # pylint: disable=bare-except + return nest_util.call_fn(fn, args) + + class GradientEstimators(enum.Enum): """Gradient estimators for variational losses. @@ -1045,6 +1054,7 @@ def monte_carlo_variational_loss( raise TypeError('`target_log_prob_fn` must be a Python `callable`' 'function.') + sample_seed, target_seed = samplers.split_seed(seed, 2) reparameterization_types = tf.nest.flatten( surrogate_posterior.reparameterization_type) if gradient_estimator is None: @@ -1067,7 +1077,7 @@ def monte_carlo_variational_loss( 'losses with `importance_sample_size != 1`.') # Score fn objective requires explicit gradients of `log_prob`. q_samples = surrogate_posterior.sample( - [sample_size * importance_sample_size], seed=seed) + [sample_size * importance_sample_size], seed=sample_seed) q_lp = None else: if any(reparameterization_type != FULLY_REPARAMETERIZED @@ -1080,7 +1090,7 @@ def monte_carlo_variational_loss( # Attempt to avoid bijector inverses by computing the surrogate log prob # during the forward sampling pass. q_samples, q_lp = surrogate_posterior.experimental_sample_and_log_prob( - [sample_size * importance_sample_size], seed=seed) + [sample_size * importance_sample_size], seed=sample_seed) return monte_carlo.expectation( f=_make_importance_weighted_divergence_fn( @@ -1090,8 +1100,8 @@ def monte_carlo_variational_loss( precomputed_surrogate_log_prob=q_lp, importance_sample_size=importance_sample_size, gradient_estimator=gradient_estimator, - stopped_surrogate_posterior=( - stopped_surrogate_posterior)), + stopped_surrogate_posterior=stopped_surrogate_posterior, + seed=target_seed), samples=q_samples, # Log-prob is only used if `gradient_estimator == SCORE_FUNCTION`. log_prob=surrogate_posterior.log_prob, @@ -1106,18 +1116,19 @@ def _make_importance_weighted_divergence_fn( precomputed_surrogate_log_prob=None, importance_sample_size=1, gradient_estimator=GradientEstimators.REPARAMETERIZATION, - stopped_surrogate_posterior=None): + stopped_surrogate_posterior=None, + seed=None): """Defines a function to compute an importance-weighted divergence.""" def divergence_fn(q_samples): q_lp = precomputed_surrogate_log_prob - target_log_prob = nest_util.call_fn(target_log_prob_fn, q_samples) + target_log_prob = _call_fn_maybe_with_seed( + target_log_prob_fn, q_samples, seed=seed) if gradient_estimator == GradientEstimators.DOUBLY_REPARAMETERIZED: # Sticking-the-landing is the special case of doubly-reparameterized # gradients with `importance_sample_size=1`. q_lp = stopped_surrogate_posterior.log_prob(q_samples) - log_weights = target_log_prob - q_lp else: if q_lp is None: q_lp = surrogate_posterior.log_prob(q_samples) @@ -1128,7 +1139,8 @@ def importance_weighted_divergence_fn(q_samples): q_lp = precomputed_surrogate_log_prob if q_lp is None: q_lp = surrogate_posterior.log_prob(q_samples) - target_log_prob = nest_util.call_fn(target_log_prob_fn, q_samples) + target_log_prob = _call_fn_maybe_with_seed( + target_log_prob_fn, q_samples, seed=seed) log_weights = target_log_prob - q_lp # Explicitly break out `importance_sample_size` as a separate axis. @@ -1243,10 +1255,12 @@ def csiszar_vimco(f, raise ValueError('Must specify num_draws > 1.') stop = tf.stop_gradient # For readability. - q_sample = q.sample(sample_shape=[num_draws, num_batch_draws], seed=seed) + sample_seed, target_seed = samplers.split_seed(seed, 2) + q_sample = q.sample(sample_shape=[num_draws, num_batch_draws], + seed=sample_seed) x = tf.nest.map_structure(stop, q_sample) logqx = q.log_prob(x) - logu = nest_util.call_fn(p_log_prob, x) - logqx + logu = _call_fn_maybe_with_seed(p_log_prob, x, seed=target_seed) - logqx f_log_sooavg_u, f_log_avg_u = map(f, log_soomean_exp(logu, axis=0)) dotprod = tf.reduce_sum( diff --git a/tensorflow_probability/python/vi/csiszar_divergence_test.py b/tensorflow_probability/python/vi/csiszar_divergence_test.py index 34d3656812..5862e6c4b5 100644 --- a/tensorflow_probability/python/vi/csiszar_divergence_test.py +++ b/tensorflow_probability/python/vi/csiszar_divergence_test.py @@ -907,7 +907,10 @@ def target_log_prob_fn(x): # Manually estimate the expected multi-sample / IWAE loss. zs, q_lp = surrogate_posterior.experimental_sample_and_log_prob( - [sample_size, importance_sample_size], seed=seed) + [sample_size, importance_sample_size], + # Brittle hack to ensure that the q samples match those + # drawn in `monte_carlo_variational_loss`. + seed=samplers.split_seed(seed, 2)[0]) log_weights = target_log_prob_fn(zs) - q_lp iwae_loss = -tf.reduce_mean( tf.math.reduce_logsumexp(log_weights, axis=1) - tf.math.log( @@ -988,7 +991,10 @@ def vimco_loss(s): def logu(s): q = build_q(s) - x = q.sample(sample_shape=[num_draws, num_batch_draws], seed=seed) + x = q.sample(sample_shape=[num_draws, num_batch_draws], + # Brittle hack to ensure that the q samples match those + # drawn in `monte_carlo_variational_loss`. + seed=samplers.split_seed(seed, 2)[0]) x = tf.stop_gradient(x) return p.log_prob(x) - q.log_prob(x) @@ -997,7 +1003,10 @@ def f_log_sum_u(s): def q_log_prob_x(s): q = build_q(s) - x = q.sample(sample_shape=[num_draws, num_batch_draws], seed=seed) + x = q.sample(sample_shape=[num_draws, num_batch_draws], + # Brittle hack to ensure that the q samples match those + # drawn in `monte_carlo_variational_loss`. + seed=samplers.split_seed(seed, 2)[0]) x = tf.stop_gradient(x) return q.log_prob(x) diff --git a/tensorflow_probability/python/vi/optimization_test.py b/tensorflow_probability/python/vi/optimization_test.py index 65e75d1fe2..1b4a872bfe 100644 --- a/tensorflow_probability/python/vi/optimization_test.py +++ b/tensorflow_probability/python/vi/optimization_test.py @@ -351,9 +351,14 @@ def variational_model_fn(): return import optax # pylint: disable=g-import-not-at-top + def seeded_target_log_prob_fn(*xs, seed=None): + # Add a tiny amount of noise to the target log-prob to see if it works. + ret = pinned.unnormalized_log_prob(xs) + return ret + samplers.normal(ret.shape, stddev=0.01, seed=seed) + [optimized_parameters, (losses, _, sample_path)] = optimization.fit_surrogate_posterior_stateless( - target_log_prob_fn=pinned.unnormalized_log_prob, + target_log_prob_fn=seeded_target_log_prob_fn, build_surrogate_posterior_fn=build_surrogate_posterior_fn, initial_parameters=initial_parameters, optimizer=optax.adam(learning_rate=0.1),