Skip to content

Commit

Permalink
Fix SNAPER step size reporting.
Browse files Browse the repository at this point in the history
Internally, SNAPER absorbs part of the diagonal preconditioner inside the step
size rather than placing it entirely into the mass matrix. This makes it
difficult to compare the final step sizes to those obtained via, e.g., NUTS.
This change undoes that scaling when constructing the default trace.

Also, fix the GradientBasedTrajectoryLengthAdaptation not storing the seed in
kernel results.

PiperOrigin-RevId: 607064555
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Feb 14, 2024
1 parent f6211b0 commit c9a22d6
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,8 @@ def one_step(self, current_state, previous_kernel_results, seed=None):
new_kernel_results = new_kernel_results._replace(
inner_results=new_inner_results,
step=previous_kernel_results.step + 1,
criterion=criterion)
criterion=criterion,
seed=seed)

return new_state, new_kernel_results

Expand Down
48 changes: 28 additions & 20 deletions tensorflow_probability/python/experimental/mcmc/snaper_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class SNAPERHamiltonianMonteCarloResults(
'inner_results',
'ema_mean',
'ema_variance',
'max_ema_variance',
'state_ema_points',
'ema_principal_component',
'principal_component_ema_points',
Expand All @@ -80,6 +81,7 @@ class SNAPERHamiltonianMonteCarloResults(
`GradientBasedTrajectoryLengthAdaptationResults`.
ema_mean: Exponential moving average cross-chain state mean.
ema_variance: Exponential moving average cross-chain state variance.
max_ema_variance: Maximum of `ema_variance`.
state_ema_points: Approximate number of points used to compute the
exponential moving averages.
ema_principal_component: Exponential moving average cross-chain state
Expand Down Expand Up @@ -422,7 +424,7 @@ def _max_part(x, named_axis):
validate_args=self.validate_args,
**gbtla_kwargs,
)
return kernel
return kernel, max_variance

def _update_state_ema(
self,
Expand Down Expand Up @@ -539,7 +541,7 @@ def one_step(self, current_state, previous_kernel_results, seed=None):
step = inner_results.step
state_ema_points = previous_kernel_results.state_ema_points

kernel = self._make_kernel(
kernel, max_variance = self._make_kernel(
batch_shape=batch_shape,
step=step,
state_ema_points=state_ema_points,
Expand Down Expand Up @@ -588,6 +590,7 @@ def one_step(self, current_state, previous_kernel_results, seed=None):
inner_results=inner_results,
ema_mean=ema_mean,
ema_variance=ema_variance,
max_ema_variance=max_variance,
state_ema_points=state_ema_points,
ema_principal_component=ema_principal_component,
principal_component_ema_points=principal_component_ema_points,
Expand Down Expand Up @@ -659,7 +662,7 @@ def bootstrap_results(self, init_state):
state_ema_points = tf.ones([], tf.int32)
principal_component_ema_points = tf.ones([], tf.int32)

kernel = self._make_kernel(
kernel, max_variance = self._make_kernel(
batch_shape=batch_shape,
step=tf.zeros([], tf.int32),
state_ema_points=state_ema_points,
Expand All @@ -675,6 +678,7 @@ def bootstrap_results(self, init_state):
inner_results=inner_results,
ema_mean=ema_mean,
ema_variance=ema_variance,
max_ema_variance=max_variance,
state_ema_points=state_ema_points,
ema_principal_component=ema_principal_component,
principal_component_ema_points=principal_component_ema_points,
Expand Down Expand Up @@ -1009,23 +1013,27 @@ def default_snaper_trace_fn(state, is_burnin, kernel_results, reducer,
# The ~ is here to catch NaNs.
has_divergence = ~(tf.math.abs(energy_diff) < 500.)
return state, {
'step_size':
unnest.get_innermost(kr, 'step_size'),
'n_steps':
unnest.get_innermost(kr, 'num_leapfrog_steps'),
'tune':
is_burnin,
'max_trajectory_length':
unnest.get_innermost(kr, 'max_trajectory_length'),
'variance_scaling':
tf.nest.map_structure(lambda x: 1. / x,
unnest.get_innermost(kr, 'ema_variance')),
'diverging':
has_divergence,
'accept_ratio':
tf.minimum(tf.ones_like(energy_diff), tf.exp(energy_diff)),
'is_accepted':
unnest.get_innermost(kr, 'is_accepted'),
# SNAPER rescales the inner HMC kernel by max_ema_variance, so to aid
# comparisons with other algorithms which typically don't do this
# rescaling, we undo the rescaling here. This makes the step size
# consistent with the target_log_prob_fn scale implied by
# `variance_scaling` below.
'step_size': unnest.get_innermost(kr, 'step_size') / tf.sqrt(
unnest.get_innermost(kr, 'max_ema_variance')
),
'n_steps': unnest.get_innermost(kr, 'num_leapfrog_steps'),
'tune': is_burnin,
'max_trajectory_length': unnest.get_innermost(
kr, 'max_trajectory_length'
),
'variance_scaling': tf.nest.map_structure(
lambda x: 1.0 / x, unnest.get_innermost(kr, 'ema_variance')
),
'diverging': has_divergence,
'accept_ratio': tf.minimum(
tf.ones_like(energy_diff), tf.exp(energy_diff)
),
'is_accepted': unnest.get_innermost(kr, 'is_accepted'),
}


Expand Down
41 changes: 24 additions & 17 deletions tensorflow_probability/python/experimental/mcmc/snaper_hmc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def testEndToEndAdaptation(self):
num_mala_steps = 100

eigenvalues = np.exp(np.linspace(0., 3., num_dims))
q, r = np.linalg.qr(np.random.randn(num_dims, num_dims))
q, r = np.linalg.qr(np.random.RandomState(0).randn(num_dims, num_dims))
q *= np.sign(np.diag(r))
covariance = (q * eigenvalues).dot(q.T).astype(self.dtype)

Expand All @@ -100,20 +100,24 @@ def testEndToEndAdaptation(self):
num_mala_steps=num_mala_steps,
)
kernel = dassa.DualAveragingStepSizeAdaptation(
kernel, num_adaptation_steps=num_adaptation_steps)
kernel,
num_adaptation_steps=num_adaptation_steps,
target_accept_prob=0.8,
)

def trace_fn(_, pkr):
return {
'step_size':
unnest.get_innermost(pkr, 'step_size'),
'mean_trajectory_length':
unnest.get_innermost(pkr, 'max_trajectory_length') / 2.,
'principal_component':
unnest.get_innermost(pkr, 'ema_principal_component'),
'variance':
unnest.get_innermost(pkr, 'ema_variance'),
'num_leapfrog_steps':
unnest.get_innermost(pkr, 'num_leapfrog_steps'),
'step_size': unnest.get_innermost(pkr, 'step_size') / tf.sqrt(
unnest.get_innermost(pkr, 'max_ema_variance')
),
'mean_trajectory_length': (
unnest.get_innermost(pkr, 'max_trajectory_length') / 2.0
),
'principal_component': unnest.get_innermost(
pkr, 'ema_principal_component'
),
'variance': unnest.get_innermost(pkr, 'ema_variance'),
'num_leapfrog_steps': unnest.get_innermost(pkr, 'num_leapfrog_steps'),
}

init_x = tf.zeros([num_chains, num_dims], self.dtype)
Expand All @@ -137,7 +141,8 @@ def trace_fn(_, pkr):
self.assertEqual(self.dtype, trace['principal_component'].dtype)

# Adaptation results.
self.assertAllClose(1.75, trace['step_size'][-1], rtol=0.2)
# Obtained via a separate run of `windowed_adaptive_nuts`.
self.assertAllClose(0.45, trace['step_size'][-1], rtol=0.25)
self.assertAllClose(4., trace['mean_trajectory_length'][-1], atol=1.)
self.assertAllClose(np.diag(covariance), trace['variance'][-1], rtol=0.2)
self.assertAllClose(
Expand Down Expand Up @@ -280,7 +285,7 @@ def testEndToEnd(self):
num_dims = 8

eigenvalues = np.exp(np.linspace(0., 3., num_dims))
q, r = np.linalg.qr(np.random.randn(num_dims, num_dims))
q, r = np.linalg.qr(np.random.RandomState(0).randn(num_dims, num_dims))
q *= np.sign(np.diag(r))
covariance = (q * eigenvalues).dot(q.T).astype(self.dtype)

Expand All @@ -305,7 +310,8 @@ def run(seed):
run(test_util.test_seed(sampler_type='stateless')))

self.assertEqual(self.dtype, chain.dtype)
self.assertAllClose(1.4, trace['step_size'][-1], rtol=0.2)
# Obtained via a separate run of `windowed_adaptive_nuts`.
self.assertAllClose(0.45, trace['step_size'][-1], rtol=0.25)
self.assertAllClose(8., trace['max_trajectory_length'][-1], atol=2.)
self.assertAllClose(chain.var((0, 1)), np.diag(covariance), rtol=0.2)
self.assertAllClose(
Expand Down Expand Up @@ -518,7 +524,7 @@ def testShardedChainAxes(self):
num_dims = 8

eigenvalues = np.exp(np.linspace(0., 3., num_dims))
q, r = np.linalg.qr(np.random.randn(num_dims, num_dims))
q, r = np.linalg.qr(np.random.RandomState(0).randn(num_dims, num_dims))
q *= np.sign(np.diag(r))
covariance = (q * eigenvalues).dot(q.T).astype(np.float32)

Expand Down Expand Up @@ -549,7 +555,8 @@ def run(_):
)))

# Adaptation results.
self.assertAllClose(1.4, trace['step_size'][0, -1], rtol=0.2)
# Obtained via a separate run of `windowed_adaptive_nuts`.
self.assertAllClose(0.45, trace['step_size'][0, -1], rtol=0.25)
self.assertAllClose(chain.var((0, 1, 2)), np.diag(covariance), rtol=0.2)

# Shard consistency.
Expand Down

0 comments on commit c9a22d6

Please sign in to comment.