Skip to content

Commit

Permalink
In SNAPER, also rescale max_trajectory_length, along with the step size.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 607811880
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Feb 16, 2024
1 parent 6467548 commit 045745e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 19 deletions.
26 changes: 13 additions & 13 deletions tensorflow_probability/python/experimental/mcmc/snaper_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,26 +1005,26 @@ def outer_trace_fn(all_state):
return unconstrained_state, kernel_results, trace


def default_snaper_trace_fn(state, is_burnin, kernel_results, reducer,
reducer_state):
def default_snaper_trace_fn(
state, is_burnin, kernel_results, reducer, reducer_state
):
"""Default trace function for SNAPER."""
del reducer, reducer_state
kr = kernel_results
energy_diff = unnest.get_innermost(kr, 'log_accept_ratio')
# The ~ is here to catch NaNs.
has_divergence = ~(tf.math.abs(energy_diff) < 500.)
has_divergence = ~(tf.math.abs(energy_diff) < 500.0)
# SNAPER rescales the inner HMC kernel by max_ema_variance, so to aid
# comparisons with other algorithms which typically don't do this rescaling,
# we undo the rescaling here. This makes the step size consistent with the
# target_log_prob_fn scale implied by `variance_scaling` below.
scale = 1.0 / tf.sqrt(unnest.get_innermost(kr, 'max_ema_variance'))
return state, {
# SNAPER rescales the inner HMC kernel by max_ema_variance, so to aid
# comparisons with other algorithms which typically don't do this
# rescaling, we undo the rescaling here. This makes the step size
# consistent with the target_log_prob_fn scale implied by
# `variance_scaling` below.
'step_size': unnest.get_innermost(kr, 'step_size') / tf.sqrt(
unnest.get_innermost(kr, 'max_ema_variance')
),
'step_size': unnest.get_innermost(kr, 'step_size') * scale,
'n_steps': unnest.get_innermost(kr, 'num_leapfrog_steps'),
'tune': is_burnin,
'max_trajectory_length': unnest.get_innermost(
kr, 'max_trajectory_length'
'max_trajectory_length': (
unnest.get_innermost(kr, 'max_trajectory_length') * scale
),
'variance_scaling': tf.nest.map_structure(
lambda x: 1.0 / x, unnest.get_innermost(kr, 'ema_variance')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,11 @@ def testEndToEndAdaptation(self):
)

def trace_fn(_, pkr):
scale = 1.0 / tf.sqrt(unnest.get_innermost(pkr, 'max_ema_variance'))
return {
'step_size': unnest.get_innermost(pkr, 'step_size') / tf.sqrt(
unnest.get_innermost(pkr, 'max_ema_variance')
),
'step_size': scale * unnest.get_innermost(pkr, 'step_size'),
'mean_trajectory_length': (
unnest.get_innermost(pkr, 'max_trajectory_length') / 2.0
scale * unnest.get_innermost(pkr, 'max_trajectory_length') / 2.0
),
'principal_component': unnest.get_innermost(
pkr, 'ema_principal_component'
Expand Down Expand Up @@ -143,7 +142,7 @@ def trace_fn(_, pkr):
# Adaptation results.
# Obtained via a separate run of `windowed_adaptive_nuts`.
self.assertAllClose(0.45, trace['step_size'][-1], rtol=0.25)
self.assertAllClose(4., trace['mean_trajectory_length'][-1], atol=1.)
self.assertAllClose(1.25, trace['mean_trajectory_length'][-1], rtol=0.3)
self.assertAllClose(np.diag(covariance), trace['variance'][-1], rtol=0.2)
self.assertAllClose(
principal_component / np.sign(principal_component[0]),
Expand Down Expand Up @@ -312,7 +311,7 @@ def run(seed):
self.assertEqual(self.dtype, chain.dtype)
# Obtained via a separate run of `windowed_adaptive_nuts`.
self.assertAllClose(0.45, trace['step_size'][-1], rtol=0.25)
self.assertAllClose(8., trace['max_trajectory_length'][-1], atol=2.)
self.assertAllClose(2.5, trace['max_trajectory_length'][-1], rtol=0.3)
self.assertAllClose(chain.var((0, 1)), np.diag(covariance), rtol=0.2)
self.assertAllClose(
np.ones(num_dims, self.dtype), reduction_results, atol=0.1)
Expand Down

0 comments on commit 045745e

Please sign in to comment.