diff --git a/tensorflow_probability/python/experimental/mcmc/snaper_hmc.py b/tensorflow_probability/python/experimental/mcmc/snaper_hmc.py index 49080e4f4a..063e43cc83 100644 --- a/tensorflow_probability/python/experimental/mcmc/snaper_hmc.py +++ b/tensorflow_probability/python/experimental/mcmc/snaper_hmc.py @@ -1005,26 +1005,26 @@ def outer_trace_fn(all_state): return unconstrained_state, kernel_results, trace -def default_snaper_trace_fn(state, is_burnin, kernel_results, reducer, - reducer_state): +def default_snaper_trace_fn( + state, is_burnin, kernel_results, reducer, reducer_state +): + """Default trace function for SNAPER.""" del reducer, reducer_state kr = kernel_results energy_diff = unnest.get_innermost(kr, 'log_accept_ratio') # The ~ is here to catch NaNs. - has_divergence = ~(tf.math.abs(energy_diff) < 500.) + has_divergence = ~(tf.math.abs(energy_diff) < 500.0) + # SNAPER rescales the inner HMC kernel by max_ema_variance, so to aid + # comparisons with other algorithms which typically don't do this rescaling, + # we undo the rescaling here. This makes the step size consistent with the + # target_log_prob_fn scale implied by `variance_scaling` below. + scale = 1.0 / tf.sqrt(unnest.get_innermost(kr, 'max_ema_variance')) return state, { - # SNAPER rescales the inner HMC kernel by max_ema_variance, so to aid - # comparisons with other algorithms which typically don't do this - # rescaling, we undo the rescaling here. This makes the step size - # consistent with the target_log_prob_fn scale implied by - # `variance_scaling` below. - 'step_size': unnest.get_innermost(kr, 'step_size') / tf.sqrt( - unnest.get_innermost(kr, 'max_ema_variance') - ), + 'step_size': unnest.get_innermost(kr, 'step_size') * scale, 'n_steps': unnest.get_innermost(kr, 'num_leapfrog_steps'), 'tune': is_burnin, - 'max_trajectory_length': unnest.get_innermost( - kr, 'max_trajectory_length' + 'max_trajectory_length': ( + unnest.get_innermost(kr, 'max_trajectory_length') * scale ), 'variance_scaling': tf.nest.map_structure( lambda x: 1.0 / x, unnest.get_innermost(kr, 'ema_variance') diff --git a/tensorflow_probability/python/experimental/mcmc/snaper_hmc_test.py b/tensorflow_probability/python/experimental/mcmc/snaper_hmc_test.py index 1d367ebe37..58d7bc9fd6 100644 --- a/tensorflow_probability/python/experimental/mcmc/snaper_hmc_test.py +++ b/tensorflow_probability/python/experimental/mcmc/snaper_hmc_test.py @@ -106,12 +106,11 @@ def testEndToEndAdaptation(self): ) def trace_fn(_, pkr): + scale = 1.0 / tf.sqrt(unnest.get_innermost(pkr, 'max_ema_variance')) return { - 'step_size': unnest.get_innermost(pkr, 'step_size') / tf.sqrt( - unnest.get_innermost(pkr, 'max_ema_variance') - ), + 'step_size': scale * unnest.get_innermost(pkr, 'step_size'), 'mean_trajectory_length': ( - unnest.get_innermost(pkr, 'max_trajectory_length') / 2.0 + scale * unnest.get_innermost(pkr, 'max_trajectory_length') / 2.0 ), 'principal_component': unnest.get_innermost( pkr, 'ema_principal_component' @@ -143,7 +142,7 @@ def trace_fn(_, pkr): # Adaptation results. # Obtained via a separate run of `windowed_adaptive_nuts`. self.assertAllClose(0.45, trace['step_size'][-1], rtol=0.25) - self.assertAllClose(4., trace['mean_trajectory_length'][-1], atol=1.) + self.assertAllClose(1.25, trace['mean_trajectory_length'][-1], rtol=0.3) self.assertAllClose(np.diag(covariance), trace['variance'][-1], rtol=0.2) self.assertAllClose( principal_component / np.sign(principal_component[0]), @@ -312,7 +311,7 @@ def run(seed): self.assertEqual(self.dtype, chain.dtype) # Obtained via a separate run of `windowed_adaptive_nuts`. self.assertAllClose(0.45, trace['step_size'][-1], rtol=0.25) - self.assertAllClose(8., trace['max_trajectory_length'][-1], atol=2.) + self.assertAllClose(2.5, trace['max_trajectory_length'][-1], rtol=0.3) self.assertAllClose(chain.var((0, 1)), np.diag(covariance), rtol=0.2) self.assertAllClose( np.ones(num_dims, self.dtype), reduction_results, atol=0.1)