diff --git a/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py b/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py index 6f65e2ae5a..62bb1390a0 100644 --- a/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py +++ b/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py @@ -108,6 +108,39 @@ def _reduce_with_axes(index_op, name_op, x, axis_idx=None, axis_names=None): distribute_lib.pmean) +def _estimate_empirical_mean(x, accept_prob, safe, reduce_chain_axis_names): + """Estimates the empirical mean of x.""" + batch_ndims = ps.rank(accept_prob) + batch_axes = ps.range(batch_ndims, dtype=tf.int32) + + if safe: + # Note that we don't do a monte carlo average of the accepted chain + # position, but rather try to get an estimate of the underlying dynamics. + # This is done by only looking at proposed states where the integration + # error is low. + # TODO(mhoffman): Needs more experimentation. + expanded_accept_prob = bu.left_justified_expand_dims_like(accept_prob, x) + + # accept_prob is zero when x is NaN, but we still want to sanitize such + # values. + x_safe = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x)) + # If all accept_prob's are zero, the x_center will have a nonsense value, + # but we'll discard the resultant gradients later on, so it's fine. + x_mean = _reduce_sum_with_axes( + expanded_accept_prob * x_safe, batch_axes, reduce_chain_axis_names + ) / ( + _reduce_sum_with_axes( + expanded_accept_prob, batch_axes, reduce_chain_axis_names + ) + + 1e-20 + ) + else: + x_mean = _reduce_mean_with_axes(x, batch_axes, reduce_chain_axis_names) + # The empirical mean here is a stand-in for the true mean, so we drop the + # gradient that flows through this term. + return tf.stop_gradient(x_mean) + + def hmc_like_num_leapfrog_steps_getter_fn(kernel_results): """Getter for `num_leapfrog_steps` so it can be inspected.""" return unnest.get_innermost(kernel_results, 'num_leapfrog_steps') @@ -120,24 +153,32 @@ def hmc_like_num_leapfrog_steps_setter_fn(kernel_results, kernel_results, num_leapfrog_steps=new_num_leapfrog_steps) -def hmc_like_proposed_velocity_getter_fn(kernel_results): - """Getter for `proposed_velocity` so it can be inspected.""" - final_momentum = unnest.get_innermost(kernel_results, 'final_momentum') +def _hmc_like_velocity_getter_fn(kernel_results, momentum_name): + """Getter for a velocity so it can be inspected.""" + momentum = unnest.get_innermost(kernel_results, momentum_name) proposed_state = unnest.get_innermost(kernel_results, 'proposed_state') momentum_distribution = unnest.get_innermost( kernel_results, 'momentum_distribution', default=None) if momentum_distribution is None: - proposed_velocity = final_momentum + velocity = momentum else: momentum_log_prob = getattr(momentum_distribution, '_log_prob_unnormalized', momentum_distribution.log_prob) kinetic_energy_fn = lambda *args: -momentum_log_prob(*args) - _, proposed_velocity = mcmc_util.maybe_call_fn_and_grads( - kinetic_energy_fn, final_momentum) + _, velocity = mcmc_util.maybe_call_fn_and_grads( + kinetic_energy_fn, momentum) # proposed_velocity has the wrong structure when state is a scalar. return tf.nest.pack_sequence_as(proposed_state, - tf.nest.flatten(proposed_velocity)) + tf.nest.flatten(velocity)) + + +hmc_like_proposed_velocity_getter_fn = functools.partial( + _hmc_like_velocity_getter_fn, momentum_name='final_momentum' +) +hmc_like_initial_velocity_getter_fn = functools.partial( + _hmc_like_velocity_getter_fn, momentum_name='initial_momentum' +) def hmc_like_proposed_state_getter_fn(kernel_results): @@ -161,6 +202,7 @@ def chees_criterion(previous_state, proposed_state, accept_prob, trajectory_length, + forward=True, validate_args=False, experimental_shard_axis_names=None, experimental_reduce_chain_axis_names=None): @@ -196,6 +238,8 @@ def chees_criterion(previous_state, accept_prob: Floating `Tensor`. Probability of acceping the proposed state. trajectory_length: Floating `Tensor`. Mean trajectory length (not used in this criterion). + forward: Whether accept_prob refers to the proposed_state (True) or the + previous_state (False). validate_args: Whether to perform non-static argument validation. experimental_shard_axis_names: A structure of string names indicating how members of the state are sharded. @@ -218,7 +262,6 @@ def chees_criterion(previous_state, """ del trajectory_length batch_ndims = ps.rank(accept_prob) - batch_axes = ps.range(batch_ndims, dtype=tf.int32) reduce_chain_axis_names = distribute_lib.canonicalize_named_axis( experimental_reduce_chain_axis_names) @@ -230,33 +273,20 @@ def chees_criterion(previous_state, ) def _center_previous_state(x): - # The empirical mean here is a stand-in for the true mean, so we drop the - # gradient that flows through this term. - x_mean = _reduce_mean_with_axes(x, batch_axes, reduce_chain_axis_names) - return x - tf.stop_gradient(x_mean) + return x - _estimate_empirical_mean( + x, + accept_prob=accept_prob, + safe=not forward, + reduce_chain_axis_names=reduce_chain_axis_names, + ) def _center_proposed_state(x): - # Note that we don't do a monte carlo average of the accepted chain - # position, but rather try to get an estimate of the underlying dynamics. - # This is done by only looking at proposed states where the integration - # error is low. - # TODO(mhoffman): Needs more experimentation. - expanded_accept_prob = bu.left_justified_expand_dims_like(accept_prob, x) - - # accept_prob is zero when x is NaN, but we still want to sanitize such - # values. - x_safe = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x)) - # If all accept_prob's are zero, the x_center will have a nonsense value, - # but we'll discard the resultant gradients later on, so it's fine. - x_center = ( - _reduce_sum_with_axes(expanded_accept_prob * x_safe, batch_axes, - reduce_chain_axis_names) / - (_reduce_sum_with_axes(expanded_accept_prob, batch_axes, - reduce_chain_axis_names) + 1e-20)) - - # The empirical mean here is a stand-in for the true mean, so we drop the - # gradient that flows through this term. - return x - tf.stop_gradient(x_center) + return x - _estimate_empirical_mean( + x, + accept_prob=accept_prob, + safe=forward, + reduce_chain_axis_names=reduce_chain_axis_names, + ) def _sum_event_part(x, shard_axes=None): event_axes = ps.range(batch_ndims, ps.rank(x)) @@ -287,6 +317,7 @@ def chees_rate_criterion(previous_state, proposed_state, accept_prob, trajectory_length, + forward=True, validate_args=False, experimental_shard_axis_names=None, experimental_reduce_chain_axis_names=None): @@ -306,6 +337,8 @@ def chees_rate_criterion(previous_state, state of the HMC chain. accept_prob: Floating `Tensor`. Probability of acceping the proposed state. trajectory_length: Floating `Tensor`. Trajectory length. + forward: Whether accept_prob refers to the proposed_state (True) or the + previous_state (False). validate_args: Whether to perform non-static argument validation. experimental_shard_axis_names: A structure of string names indicating how members of the state are sharded. @@ -321,6 +354,7 @@ def chees_rate_criterion(previous_state, proposed_state=proposed_state, accept_prob=accept_prob, trajectory_length=trajectory_length, + forward=forward, validate_args=validate_args, experimental_shard_axis_names=experimental_shard_axis_names, experimental_reduce_chain_axis_names=experimental_reduce_chain_axis_names, @@ -334,6 +368,7 @@ def snaper_criterion(previous_state, direction, state_mean=None, state_mean_weight=0., + forward=True, validate_args=False, experimental_shard_axis_names=None, experimental_reduce_chain_axis_names=None): @@ -383,6 +418,8 @@ def snaper_criterion(previous_state, state_mean: Optional (Possibly nested) floating point `Tensor`. The estimated state mean. state_mean_weight: Floating point `Tensor`. The weight of the `state_mean`. + forward: Whether accept_prob refers to the proposed_state (True) or the + previous_state (False). validate_args: Whether to perform non-static argument validation. experimental_shard_axis_names: A structure of string names indicating how members of the state are sharded. @@ -400,7 +437,6 @@ def snaper_criterion(previous_state, """ batch_ndims = ps.rank(accept_prob) - batch_axes = ps.range(batch_ndims, dtype=tf.int32) reduce_chain_axis_names = distribute_lib.canonicalize_named_axis( experimental_reduce_chain_axis_names) @@ -411,7 +447,10 @@ def snaper_criterion(previous_state, accept_prob, reduce_chain_axis_names=reduce_chain_axis_names, validate_args=validate_args, - message='snaper_criterion requires at least 2 chains when `state_mean` is `None`' + message=( + 'snaper_criterion requires at least 2 chains when `state_mean` is' + ' `None`' + ), ) def _mix_in_state_mean(empirical_mean, state_mean): @@ -422,33 +461,22 @@ def _mix_in_state_mean(empirical_mean, state_mean): state_mean_weight * state_mean) def _center_previous_state(x, x_mean): - # The empirical mean here is a stand-in for the true mean, so we drop the - # gradient that flows through this term. - emp_x_mean = tf.stop_gradient( - distribute_lib.reduce_mean(x, batch_axes, reduce_chain_axis_names)) + emp_x_mean = _estimate_empirical_mean( + x, + accept_prob=accept_prob, + safe=not forward, + reduce_chain_axis_names=reduce_chain_axis_names, + ) x_mean = _mix_in_state_mean(emp_x_mean, x_mean) return x - x_mean def _center_proposed_state(x, x_mean): - # Note that we don't do a monte carlo average of the accepted chain - # position, but rather try to get an estimate of the underlying dynamics. - # This is done by only looking at proposed states where the integration - # error is low. - expanded_accept_prob = bu.left_justified_expand_dims_like(accept_prob, x) - - # accept_prob is zero when x is NaN, but we still want to sanitize such - # values. - x_safe = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x)) - # The empirical mean here is a stand-in for the true mean, so we drop the - # gradient that flows through this term. - # If all accept_prob's are zero, the x_center will have a nonsense value, - # but we'll discard the resultant gradients later on, so it's fine. - emp_x_mean = tf.stop_gradient( - distribute_lib.reduce_sum(expanded_accept_prob * x_safe, batch_axes, - reduce_chain_axis_names) / - (distribute_lib.reduce_sum(expanded_accept_prob, batch_axes, - reduce_chain_axis_names) + 1e-20)) - + emp_x_mean = _estimate_empirical_mean( + x, + accept_prob=accept_prob, + safe=forward, + reduce_chain_axis_names=reduce_chain_axis_names, + ) x_mean = _mix_in_state_mean(emp_x_mean, x_mean) return x - x_mean @@ -505,6 +533,13 @@ class GradientBasedTrajectoryLengthAdaptation(kernel_base.TransitionKernel): value during development in order to inspect the behavior of the chain during adaptation. + Optionally, it is possible to use the improved gradient estimator from [3] by + setting `use_reverse_estimator` to `True`. This estimator relies on the + reversibility of HMC proposal to reduce variance and thus improve the + adaptation speed and reliability. If this is set to `true`, `criterion_fn` + needs to also take the `forward` argument to distinguish the implied + integration direction. + #### Examples This implements something similar to ChEES HMC from [2]. @@ -535,7 +570,9 @@ class GradientBasedTrajectoryLengthAdaptation(kernel_base.TransitionKernel): kernel, num_adaptation_steps=num_adaptation_steps) kernel = tfp.mcmc.DualAveragingStepSizeAdaptation( - kernel, num_adaptation_steps=num_adaptation_steps) + kernel, + num_adaptation_steps=num_adaptation_steps, + reduce_fn=tfp.math.reduce_log_harmonic_mean_exp) kernel = tfp.mcmc.TransformedTransitionKernel( kernel, [tfb.Identity(), @@ -560,7 +597,8 @@ def trace_fn(_, pkr): kernel=kernel, trace_fn=trace_fn,)) - # ~0.75 + # ~0.95, because Exp bijector is really bad for HalfNormal. Use Softplus in + # practice. accept_prob = tf.math.exp(tfp.math.reduce_logmeanexp( tf.minimum(log_accept_ratio, 0.))) ``` @@ -574,6 +612,10 @@ def trace_fn(_, pkr): for Setting Trajectory Lengths in Hamiltonian Monte Carlo. + [3]: Riou-Durand, L., Sountsov, P., Vogrinc, J., Margossian, C., Power, S. + (2023) Adaptive Tuning for Metropolis Adjusted Langevin Trajectories. + + """ def __init__( @@ -589,9 +631,11 @@ def __init__( num_leapfrog_steps_getter_fn=hmc_like_num_leapfrog_steps_getter_fn, num_leapfrog_steps_setter_fn=hmc_like_num_leapfrog_steps_setter_fn, step_size_getter_fn=hmc_like_step_size_getter_fn, + initial_velocity_getter_fn=hmc_like_initial_velocity_getter_fn, proposed_velocity_getter_fn=hmc_like_proposed_velocity_getter_fn, log_accept_prob_getter_fn=hmc_like_log_accept_prob_getter_fn, proposed_state_getter_fn=hmc_like_proposed_state_getter_fn, + use_reverse_estimator=False, validate_args=False, experimental_shard_axis_names=None, experimental_reduce_chain_axis_names=None, @@ -636,11 +680,16 @@ def __init__( step_size_getter_fn: A callable with the signature `(kernel_results) -> step_size` where `kernel_results` are the results of the `inner_kernel`, and `step_size` is a floating point `Tensor`. + initial_velocity_getter_fn: A callable with the signature + `(kernel_results) -> initial_velocity` where `kernel_results` are the + results of the `inner_kernel`, and `initial_velocity` is a (possibly + nested) floating point `Tensor`. Velocity is the derivative of state + with respect to trajectory length. proposed_velocity_getter_fn: A callable with the signature `(kernel_results) -> proposed_velocity` where `kernel_results` are the results of the `inner_kernel`, and `proposed_velocity` is a (possibly - nested) floating point `Tensor`. Velocity is derivative of state with - respect to trajectory length. + nested) floating point `Tensor`. Velocity is the derivative of state + with respect to trajectory length. log_accept_prob_getter_fn: A callable with the signature `(kernel_results) -> log_accept_prob` where `kernel_results` are the results of the `inner_kernel`, and `log_accept_prob` is a floating point `Tensor`. @@ -649,6 +698,9 @@ def __init__( -> proposed_state` where `kernel_results` are the results of the `inner_kernel`, and `proposed_state` is a (possibly nested) floating point `Tensor`. + use_reverse_estimator: Whether to use an improved estimator to compute + trajectory length gradients. If `True`, `criterion_fn` needs to take a + `forward` kwarg. validate_args: Python `bool`. When `True` kernel parameters are checked for validity. When `False` invalid inputs may silently render incorrect outputs. @@ -690,9 +742,11 @@ class docstring). num_leapfrog_steps_getter_fn=num_leapfrog_steps_getter_fn, num_leapfrog_steps_setter_fn=num_leapfrog_steps_setter_fn, step_size_getter_fn=step_size_getter_fn, + initial_velocity_getter_fn=initial_velocity_getter_fn, proposed_velocity_getter_fn=proposed_velocity_getter_fn, log_accept_prob_getter_fn=log_accept_prob_getter_fn, proposed_state_getter_fn=hmc_like_proposed_state_getter_fn, + use_reverse_estimator=use_reverse_estimator, validate_args=validate_args, experimental_shard_axis_names=experimental_shard_axis_names, experimental_reduce_chain_axis_names=experimental_reduce_chain_axis_names, @@ -712,7 +766,7 @@ def num_adaptation_steps(self): return self._parameters['num_adaptation_steps'] def criterion_fn(self, previous_state, proposed_state, accept_prob, - trajectory_length): + trajectory_length, forward=True): kwargs = {} if self.experimental_reduce_chain_axis_names is not None: kwargs['experimental_reduce_chain_axis_names'] = ( @@ -720,6 +774,8 @@ def criterion_fn(self, previous_state, proposed_state, accept_prob, if self.experimental_shard_axis_names is not None: kwargs['experimental_shard_axis_names'] = ( self.experimental_shard_axis_names) + if self.use_reverse_estimator: + kwargs['forward'] = forward return self._parameters['criterion_fn'](previous_state, proposed_state, accept_prob, trajectory_length, **kwargs) @@ -743,6 +799,9 @@ def num_leapfrog_steps_setter_fn(self, kernel_results, def step_size_getter_fn(self, kernel_results): return self._parameters['step_size_getter_fn'](kernel_results) + def initial_velocity_getter_fn(self, kernel_results): + return self._parameters['initial_velocity_getter_fn'](kernel_results) + def proposed_velocity_getter_fn(self, kernel_results): return self._parameters['proposed_velocity_getter_fn'](kernel_results) @@ -752,6 +811,10 @@ def log_accept_prob_getter_fn(self, kernel_results): def proposed_state_getter_fn(self, kernel_results): return self._parameters['proposed_state_getter_fn'](kernel_results) + @property + def use_reverse_estimator(self): + return self._parameters['use_reverse_estimator'] + @property def validate_args(self): return self._parameters['validate_args'] @@ -806,6 +869,7 @@ def one_step(self, current_state, previous_kernel_results, seed=None): current_state, previous_kernel_results_with_jitter.inner_results, inner_seed) + initial_velocity = self.initial_velocity_getter_fn(new_inner_results) proposed_state = self.proposed_state_getter_fn(new_inner_results) proposed_velocity = self.proposed_velocity_getter_fn(new_inner_results) accept_prob = tf.exp(self.log_accept_prob_getter_fn(new_inner_results)) @@ -815,13 +879,15 @@ def one_step(self, current_state, previous_kernel_results, seed=None): previous_state=current_state, proposed_state=proposed_state, proposed_velocity=proposed_velocity, + initial_velocity=initial_velocity, trajectory_jitter=trajectory_jitter, accept_prob=accept_prob, step_size=step_size, criterion_fn=self.criterion_fn, max_leapfrog_steps=self.max_leapfrog_steps, experimental_shard_axis_names=self.experimental_shard_axis_names, - reduce_chain_axis_names=self.experimental_reduce_chain_axis_names) + reduce_chain_axis_names=self.experimental_reduce_chain_axis_names, + use_reverse_estimator=self.use_reverse_estimator) # Undo the effect of adaptation if we're not in the burnin phase. We keep # the criterion, however, as that's a diagnostic. We also keep the @@ -930,6 +996,7 @@ def _halton_sequence(i, max_bits=MAX_HALTON_SEQUENCE_BITS): def _update_trajectory_grad(previous_kernel_results, previous_state, + initial_velocity, proposed_state, proposed_velocity, trajectory_jitter, @@ -937,35 +1004,54 @@ def _update_trajectory_grad(previous_kernel_results, step_size, criterion_fn, max_leapfrog_steps, + use_reverse_estimator, experimental_shard_axis_names=None, reduce_chain_axis_names=None): """Updates the trajectory length.""" # Compute criterion grads. def leapfrog_action(dt): - # This represents the effect on the criterion value as the state follows the - # proposed velocity. This implicitly assumes an identity mass matrix. + fwd_start_end_vel = [ + (True, previous_state, proposed_state, proposed_velocity) + ] + if use_reverse_estimator: + fwd_start_end_vel.append(( + False, + proposed_state, + previous_state, + tf.nest.map_structure(lambda x: -x, initial_velocity), + )) + + # This represents the effect on the criterion value as the state follows + # the proposed velocity. This implicitly assumes an identity mass matrix. def adjust_state(x, v, shard_axes=None): broadcasted_dt = distribute_lib.pbroadcast( bu.left_justified_expand_dims_like(dt, v), shard_axes) return x + broadcasted_dt * v - adjusted_state = _map_structure_up_to_with_axes( - proposed_state, - adjust_state, - proposed_state, - proposed_velocity, - experimental_shard_axis_names=experimental_shard_axis_names) - return criterion_fn( - previous_state=previous_state, - proposed_state=adjusted_state, - accept_prob=accept_prob, - # We add the step size here because we effectively do `floor(traj + - # step_size) / step_size` when computing the number of leapfrog steps. - trajectory_length=( - trajectory_jitter * previous_kernel_results.max_trajectory_length + - step_size + dt), - ) + criterion_vals = [] + for forward, start, end, vel in fwd_start_end_vel: + adjusted_end = _map_structure_up_to_with_axes( + end, + adjust_state, + end, + vel, + experimental_shard_axis_names=experimental_shard_axis_names) + criterion_val = criterion_fn( + previous_state=start, + proposed_state=adjusted_end, + accept_prob=accept_prob, + # We add the step size here because we effectively do `floor(traj + + # step_size) / step_size` when computing the number of leapfrog steps. + trajectory_length=( + trajectory_jitter * previous_kernel_results.max_trajectory_length + + step_size + + dt + ), + forward=forward, + ) + criterion_vals.append(criterion_val) + return tf.reduce_mean(criterion_vals, axis=0) criterion, trajectory_grad = gradient.value_and_gradient( leapfrog_action, tf.zeros_like(accept_prob)) @@ -999,8 +1085,9 @@ def adjust_state(x, v, shard_axes=None): # Apply the gradient. Clip absolute value to ~log(2)/2. log_update = tf.clip_by_value(trajectory_step_size * trajectory_grad, -0.35, 0.35) - new_max_trajectory_length = previous_kernel_results.max_trajectory_length * tf.exp( - log_update) + new_max_trajectory_length = ( + previous_kernel_results.max_trajectory_length * tf.exp(log_update) + ) # Iterate averaging. average_weight = iteration_f**(-0.5) diff --git a/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation_test.py b/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation_test.py index 49054d6bf0..768c516084 100644 --- a/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation_test.py +++ b/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation_test.py @@ -159,9 +159,13 @@ def target_log_prob_fn(*x): num_leapfrog_steps=1, ) kernel = gbtla.GradientBasedTrajectoryLengthAdaptation( - kernel, num_adaptation_steps=num_adaptation_steps, validate_args=True) + kernel, num_adaptation_steps=num_adaptation_steps, validate_args=True, + use_reverse_estimator=True) kernel = dassa.DualAveragingStepSizeAdaptation( - kernel, num_adaptation_steps=num_adaptation_steps) + kernel, + num_adaptation_steps=num_adaptation_steps, + reduce_fn=generic.reduce_log_harmonic_mean_exp, + ) kernel = transformed_kernel.TransformedTransitionKernel( kernel, [identity.Identity(), exp.Exp()]) @@ -192,9 +196,9 @@ def trace_fn(_, pkr): mean_step_size = tf.reduce_mean(step_size) mean_max_trajectory_length = tf.reduce_mean(max_trajectory_length) - self.assertAllClose(0.75, p_accept, atol=0.1) - self.assertAllClose(0.52, mean_step_size, atol=0.2) - self.assertAllClose(46., mean_max_trajectory_length, atol=15) + self.assertAllClose(0.95, p_accept, rtol=0.2) + self.assertAllClose(0.3, mean_step_size, rtol=0.2) + self.assertAllClose(43., mean_max_trajectory_length, rtol=0.2) self.assertAllClose( target.mean(), [tf.reduce_mean(x, axis=[0, 1]) for x in chain], atol=1.5) @@ -328,10 +332,12 @@ def target_log_prob_fn(x, y): final_kernel_results.max_trajectory_length), 0.0005) @parameterized.named_parameters( - ('ChEES', gbtla.chees_rate_criterion), - ('SNAPER', snaper_criterion_2d_direction), + ('ChEES', gbtla.chees_rate_criterion, False), + ('SNAPER', snaper_criterion_2d_direction, False), + ('ChEES_reverse', gbtla.chees_rate_criterion, True), + ('SNAPER_reverse', snaper_criterion_2d_direction, True), ) - def testAdaptation(self, criterion_fn): + def testAdaptation(self, criterion_fn, use_reverse_estimator): if tf.executing_eagerly() and not JAX_MODE: self.skipTest('Too slow for TF Eager.') @@ -353,6 +359,7 @@ def testAdaptation(self, criterion_fn): kernel, num_adaptation_steps=num_adaptation_steps, criterion_fn=criterion_fn, + use_reverse_estimator=use_reverse_estimator, validate_args=True) kernel = dassa.DualAveragingStepSizeAdaptation( kernel, num_adaptation_steps=num_adaptation_steps) diff --git a/tensorflow_probability/python/experimental/mcmc/snaper_hmc.py b/tensorflow_probability/python/experimental/mcmc/snaper_hmc.py index 063e43cc83..d08d3f9014 100644 --- a/tensorflow_probability/python/experimental/mcmc/snaper_hmc.py +++ b/tensorflow_probability/python/experimental/mcmc/snaper_hmc.py @@ -406,6 +406,7 @@ def _max_part(x, named_axis): gbtla_kwargs = ( self.gradient_based_trajectory_length_adaptation_kwargs.copy()) gbtla_kwargs.setdefault('averaged_sq_grad_adaptation_rate', 0.5) + gbtla_kwargs.setdefault('use_reverse_estimator', True) kernel = gbtla.GradientBasedTrajectoryLengthAdaptation( kernel, num_adaptation_steps=self.num_adaptation_steps,