Skip to content

Commit

Permalink
TFP: fix key reuse issue in monte_carlo_variational_loss
Browse files Browse the repository at this point in the history
Detected with JAX's enable_key_reuse_checks. We can avoid falling afoul of the reuse checker by splitting only once we determine it to be necessary. This should not have any user-visible change.

PiperOrigin-RevId: 613666004
  • Loading branch information
vanderplas authored and tensorflower-gardener committed Mar 7, 2024
1 parent 6098700 commit 898cfe9
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tensorflow_probability/python/vi/csiszar_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,6 @@ def monte_carlo_variational_loss(
raise TypeError('`target_log_prob_fn` must be a Python `callable`'
'function.')

sample_seed, target_seed = samplers.split_seed(seed, 2)
reparameterization_types = tf.nest.flatten(
surrogate_posterior.reparameterization_type)
if gradient_estimator is None:
Expand All @@ -1089,6 +1088,8 @@ def monte_carlo_variational_loss(
num_draws=importance_sample_size,
num_batch_draws=sample_size,
seed=seed)

sample_seed, target_seed = samplers.split_seed(seed, 2)
if gradient_estimator == GradientEstimators.SCORE_FUNCTION:
if tf.get_static_value(importance_sample_size) != 1:
# TODO(b/213378570): Support score function gradients for
Expand Down

0 comments on commit 898cfe9

Please sign in to comment.