Skip to content

Commit

Permalink
Merge pull request #1775 from aleslamitz:smc2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609789859
  • Loading branch information
tensorflower-gardener committed Feb 23, 2024
2 parents e3a03b4 + 222c197 commit aff7da4
Show file tree
Hide file tree
Showing 4 changed files with 410 additions and 55 deletions.
2 changes: 2 additions & 0 deletions tensorflow_probability/python/experimental/mcmc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -559,11 +559,13 @@ multi_substrate_py_test(
shard_count = 3,
deps = [
":particle_filter",
":sequential_monte_carlo_kernel",
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/bijectors:shift",
"//tensorflow_probability/python/distributions:bernoulli",
"//tensorflow_probability/python/distributions:deterministic",
"//tensorflow_probability/python/distributions:independent",
"//tensorflow_probability/python/distributions:joint_distribution_auto_batched",
"//tensorflow_probability/python/distributions:joint_distribution_named",
"//tensorflow_probability/python/distributions:linear_gaussian_ssm",
Expand Down
304 changes: 294 additions & 10 deletions tensorflow_probability/python/experimental/mcmc/particle_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.distributions import uniform
from tensorflow_probability.python.experimental.mcmc import sequential_monte_carlo_kernel as smc_kernel
from tensorflow_probability.python.experimental.mcmc import weighted_resampling
from tensorflow_probability.python.internal import assert_util
Expand Down Expand Up @@ -549,6 +550,7 @@ def _particle_filter_initial_weighted_particles(observations,
initial_state_proposal,
num_particles,
particles_dim=0,
extra=(),
seed=None):
"""Initialize a set of weighted particles including the first observation."""
# Propose an initial state.
Expand All @@ -574,14 +576,16 @@ def _particle_filter_initial_weighted_particles(observations,
axis=particles_dim)

# Return particles weighted by the initial observation.
if observations is not None:
initial_log_weights += _compute_observation_log_weights(
step=0,
particles=initial_state,
observations=observations,
observation_fn=observation_fn,
particles_dim=particles_dim)

return smc_kernel.WeightedParticles(
particles=initial_state,
log_weights=initial_log_weights + _compute_observation_log_weights(
step=0,
particles=initial_state,
observations=observations,
observation_fn=observation_fn,
particles_dim=particles_dim))
particles=initial_state, log_weights=initial_log_weights, extra=extra)


def _particle_filter_propose_and_update_log_weights_fn(
Expand Down Expand Up @@ -625,7 +629,8 @@ def propose_and_update_log_weights_fn(step, state, seed=None):
log_weights=log_weights + _compute_observation_log_weights(
step + 1, proposed_particles, observations, observation_fn,
num_transitions_per_observation=num_transitions_per_observation,
particles_dim=particles_dim))
particles_dim=particles_dim),
extra=state.extra)
return propose_and_update_log_weights_fn


Expand Down Expand Up @@ -670,8 +675,18 @@ def _compute_observation_log_weights(step,
observation = tf.nest.map_structure(
lambda x, step=step: tf.gather(x, observation_idx), observations)

observation = tf.nest.map_structure(
lambda x: tf.expand_dims(x, axis=particles_dim), observation)
# For now, when particles_dim > 0, we do not support the observations
# having batch shape. (This is not needed for SMC^2.)
#
# In JAX, particles_dim > 0 can be handled like:
# vmap(lambda p: observation_fn(step, p).log_prob(observations),
# in_axes=particles_dim, out_axes=particles_dim)
#
# In TF, we could re-arrange dimensions here. Or we could left-pad the
# observations with additional dimensions until they have rank one less
# than the batch-and-event rank of observation_fn(step, particles), and
# then we could expand_dims at dimension particles_dim.
del particles_dim

log_weights = observation_fn(step, particles).log_prob(observation)
return tf.where(step_has_observation,
Expand Down Expand Up @@ -741,3 +756,272 @@ def _assert_batch_shape_matches_weights(distribution, weights_shape, diststr):
assertions = [assert_util.assert_equal(a, b, message=msg)
for a, b in zip(shapes[1:], shapes[:-1])]
return assertions


def _default_rejuvenation_criterion_fn(step, weighted_particles):
del step
return smc_kernel.ess_below_threshold(weighted_particles, particles_dim=0)[0]


def smc_squared(
observations,
initial_parameter_prior,
inner_initial_state_prior,
inner_transition_fn,
observation_fn,
num_outer_particles,
num_inner_particles,
initial_parameter_proposal=None,
parameter_proposal_kernel=None,
inner_initial_state_proposal=None,
inner_proposal_fn=None,
outer_rejuvenation_criterion_fn=_default_rejuvenation_criterion_fn,
inner_resample_criterion_fn=smc_kernel.ess_below_threshold,
inner_resample_fn=weighted_resampling.resample_systematic,
outer_trace_fn=_default_trace_fn,
outer_trace_criterion_fn=_always_trace,
parallel_iterations=1,
num_transitions_per_observation=1,
static_trace_allocation_size=None,
unbiased_gradients=True,
seed=None):
"""SMC^2."""
init_seed, smc_seed = samplers.split_seed(seed, salt='smc_squared')

num_observation_steps = ps.size0(tf.nest.flatten(observations)[0])
num_timesteps = (
1 + num_transitions_per_observation * (num_observation_steps - 1))

initial_state = _smc_squared_intial_weighted_particles(
observations, observation_fn, initial_parameter_prior,
initial_parameter_proposal, num_outer_particles,
inner_initial_state_prior, inner_initial_state_proposal,
num_inner_particles, seed=init_seed)

outer_propose_and_update_log_weights_fn = (
_smc_squared_propose_and_update_log_weights_fn(
outer_rejuvenation_criterion_fn=outer_rejuvenation_criterion_fn,
observations=observations,
inner_transition_fn=inner_transition_fn,
inner_proposal_fn=inner_proposal_fn,
observation_fn=observation_fn,
inner_resample_fn=inner_resample_fn,
inner_resample_criterion_fn=inner_resample_criterion_fn,
parameter_proposal_kernel=parameter_proposal_kernel,
initial_parameter_prior=initial_parameter_prior,
num_transitions_per_observation=num_transitions_per_observation,
unbiased_gradients=unbiased_gradients,
inner_initial_state_prior=inner_initial_state_prior,
inner_initial_state_proposal=inner_initial_state_proposal,
num_inner_particles=num_inner_particles,
num_outer_particles=num_outer_particles))

return sequential_monte_carlo(
initial_weighted_particles=initial_state,
propose_and_update_log_weights_fn=
outer_propose_and_update_log_weights_fn,
resample_fn=None,
resample_criterion_fn=None,
trace_criterion_fn=outer_trace_criterion_fn,
static_trace_allocation_size=static_trace_allocation_size,
parallel_iterations=parallel_iterations,
unbiased_gradients=unbiased_gradients,
num_steps=num_timesteps,
particles_dim=0,
trace_fn=outer_trace_fn,
seed=smc_seed)


def _smc_squared_intial_weighted_particles(
observations,
observation_fn,
initial_parameter_prior,
initial_parameter_proposal,
num_outer_particles,
inner_initial_state_prior,
inner_initial_state_proposal,
num_inner_particles,
seed=None):
"""Initialize particles for SMC^2, including the first observation."""
params_seed, particles_seed = samplers.split_seed(
seed, n=2, salt='smc_squared_init_particles')

initial_params, initial_log_weights, _ = (
_particle_filter_initial_weighted_particles(
observations=None,
observation_fn=None,
initial_state_prior=initial_parameter_prior,
initial_state_proposal=initial_parameter_proposal,
num_particles=num_outer_particles,
seed=params_seed))

inner_weighted_particles = _particle_filter_initial_weighted_particles(
observations=observations,
observation_fn=observation_fn(initial_params),
initial_state_prior=inner_initial_state_prior(0, initial_params),
initial_state_proposal=(inner_initial_state_proposal(0, initial_params)
if inner_initial_state_proposal is not None
else None),
num_particles=num_inner_particles,
particles_dim=1,
seed=particles_seed)

inner_filter_results = smc_kernel.SequentialMonteCarlo(
None, None, particles_dim=1).bootstrap_results(inner_weighted_particles)

return smc_kernel.WeightedParticles(
particles=(initial_params,
inner_weighted_particles,
inner_filter_results.parent_indices,
inner_filter_results.incremental_log_marginal_likelihood,
inner_filter_results.accumulated_log_marginal_likelihood),
log_weights=initial_log_weights,
extra=inner_filter_results.seed)


def _smc_squared_propose_and_update_log_weights_fn(
observations,
inner_transition_fn,
inner_proposal_fn,
observation_fn,
initial_parameter_prior,
inner_initial_state_prior,
inner_initial_state_proposal,
num_transitions_per_observation,
inner_resample_fn,
inner_resample_criterion_fn,
outer_rejuvenation_criterion_fn,
unbiased_gradients,
parameter_proposal_kernel,
num_inner_particles,
num_outer_particles):
"""Build a function specifying an SMC^2 update step."""
def _rejuvenate_particles(
outer_params, log_weights, inner_particles, filter_results, seed=None):
seeds = samplers.split_seed(seed, n=4)
step = filter_results.steps

proposal_kernel = parameter_proposal_kernel(outer_params, log_weights)
rej_outer_params = proposal_kernel(outer_params).sample(seed=seeds[0])

rej_inner_particles = _particle_filter_initial_weighted_particles(
observations=observations,
observation_fn=observation_fn(rej_outer_params),
initial_state_prior=inner_initial_state_prior(0, rej_outer_params),
initial_state_proposal=(
inner_initial_state_proposal(0, rej_outer_params)
if inner_initial_state_proposal is not None else None),
num_particles=num_inner_particles,
particles_dim=1,
seed=seeds[1])

rej_kernel = smc_kernel.SequentialMonteCarlo(
propose_and_update_log_weights_fn=(
_particle_filter_propose_and_update_log_weights_fn(
observations=observations,
transition_fn=inner_transition_fn(rej_outer_params),
proposal_fn=(inner_proposal_fn(rej_outer_params)
if inner_proposal_fn is not None else None),
observation_fn=observation_fn(rej_outer_params),
particles_dim=1,
num_transitions_per_observation=(
num_transitions_per_observation))),
resample_fn=inner_resample_fn,
resample_criterion_fn=inner_resample_criterion_fn,
particles_dim=1,
unbiased_gradients=unbiased_gradients)

rej_inner_filter_results = (
rej_kernel.bootstrap_results(rej_inner_particles))

def body(i, state, results):
state, results = rej_kernel.one_step(
state, results, seed=samplers.fold_in(seeds[2], i))
return (i + 1, state, results)

(_, rej_inner_particles, rej_filter_results) = tf.while_loop(
lambda i, *_: tf.less_equal(i, step),
body,
[0, rej_inner_particles, rej_inner_filter_results])

log_a = (rej_filter_results.accumulated_log_marginal_likelihood
- filter_results.accumulated_log_marginal_likelihood
+ initial_parameter_prior.log_prob(rej_outer_params)
- initial_parameter_prior.log_prob(outer_params)
+ proposal_kernel(rej_outer_params).log_prob(outer_params)
- proposal_kernel(outer_params).log_prob(rej_outer_params))
u = uniform.Uniform(0., 1.).sample(num_outer_particles, seed=seeds[3])
accept = tf.math.log(u) <= log_a

def _choose(a, b):
if len(a.shape) >= 1 and a.shape[0] == accept.shape[0]:
return mcmc_util.choose(accept, a, b)
return b
outer_params, inner_particles, filter_results = (
tf.nest.map_structure(
_choose,
(rej_outer_params, rej_inner_particles, rej_filter_results),
(outer_params, inner_particles, filter_results)))

return (outer_params, tf.zeros_like(log_weights),
inner_particles, filter_results)

def _outer_propose_and_update_log_weights_fn(step, state, seed=None):
step_seed, rejuvenation_seed = samplers.split_seed(seed, 2)

outer_params, inner_particles, *_ = state.particles
filter_results = smc_kernel.SequentialMonteCarloResults(
steps=step,
parent_indices=state.particles[2],
incremental_log_marginal_likelihood=state.particles[3],
accumulated_log_marginal_likelihood=state.particles[4],
seed=state.extra)

kernel = smc_kernel.SequentialMonteCarlo(
propose_and_update_log_weights_fn=(
_particle_filter_propose_and_update_log_weights_fn(
observations=observations,
transition_fn=inner_transition_fn(outer_params),
proposal_fn=(inner_proposal_fn(outer_params)
if inner_proposal_fn is not None else None),
observation_fn=observation_fn(outer_params),
particles_dim=1,
num_transitions_per_observation=(
num_transitions_per_observation))),
resample_fn=inner_resample_fn,
resample_criterion_fn=inner_resample_criterion_fn,
particles_dim=1,
unbiased_gradients=unbiased_gradients)

inner_particles, filter_results = kernel.one_step(
inner_particles, filter_results, seed=step_seed)
log_weights = (
state.log_weights + filter_results.incremental_log_marginal_likelihood)

do_rejuvenation = outer_rejuvenation_criterion_fn(
step, smc_kernel.WeightedParticles(
particles=(outer_params,
inner_particles,
filter_results.parent_indices,
filter_results.incremental_log_marginal_likelihood,
filter_results.accumulated_log_marginal_likelihood),
log_weights=log_weights,
extra=filter_results.seed))

(outer_params, log_weights, inner_particles, filter_results) = tf.cond(
do_rejuvenation,
lambda: _rejuvenate_particles(
outer_params, log_weights, inner_particles, filter_results,
seed=rejuvenation_seed),
lambda: (outer_params, log_weights, inner_particles, filter_results))

return smc_kernel.WeightedParticles(
particles=(outer_params,
inner_particles,
filter_results.parent_indices,
filter_results.incremental_log_marginal_likelihood,
filter_results.accumulated_log_marginal_likelihood),
log_weights=log_weights,
extra=filter_results.seed)

return _outer_propose_and_update_log_weights_fn
Loading

0 comments on commit aff7da4

Please sign in to comment.