Skip to content

Commit

Permalink
pylint
Browse files Browse the repository at this point in the history
  • Loading branch information
aleslamitz committed Dec 10, 2023
1 parent cc799a5 commit 53ba255
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 62 deletions.
153 changes: 105 additions & 48 deletions tensorflow_probability/python/experimental/mcmc/particle_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.mcmc.internal import util as mcmc_util
from tensorflow_probability.python.distributions import batch_reshape
from tensorflow_probability.python.distributions import batch_broadcast
from tensorflow_probability.python.distributions import normal
from tensorflow_probability.python.distributions import uniform

Expand All @@ -51,7 +49,8 @@ def _default_trace_fn(state, kernel_results):

def _default_kernel(parameters):
mean, variance = tf.nn.moments(parameters, axes=[0])
proposal_distribution = normal.Normal(loc=tf.fill(parameters.shape, mean), scale=tf.sqrt(variance))
proposal_distribution = normal.Normal(loc=tf.fill(parameters.shape, mean),
scale=tf.sqrt(variance))
return proposal_distribution


Expand Down Expand Up @@ -499,15 +498,18 @@ def smc_squared(
unbiased_gradients=True,
seed=None,
):
init_seed, loop_seed, step_seed = samplers.split_seed(seed, n=3, salt='smc_squared')
_1, loop_seed, _2 = samplers.split_seed(seed, n=3, salt='smc_squared')

num_observation_steps = ps.size0(tf.nest.flatten(inner_observations)[0])

# TODO: The following two lines compensates for having the first empty step in smc2
# TODO: The following two lines compensates for having the
# first empty step in smc2
num_timesteps = (1 + num_transitions_per_observation *
(num_observation_steps - 1)) + 1
last_obs_expanded = tf.expand_dims(inner_observations[-1], axis=0)
inner_observations = tf.concat([inner_observations, last_obs_expanded], axis=0)
inner_observations = tf.concat([inner_observations,
last_obs_expanded],
axis=0)

if outer_rejuvenation_criterion_fn is None:
outer_rejuvenation_criterion_fn = lambda *_: tf.constant(False)
Expand Down Expand Up @@ -543,7 +545,8 @@ def smc_squared(
observation_fn=inner_observation_fn(initial_state),
initial_state_prior=inner_initial_state_prior(0, initial_state),
initial_state_proposal=(inner_initial_state_proposal(0, initial_state)
if inner_initial_state_proposal is not None else None),
if inner_initial_state_proposal is not None
else None),
num_particles=num_inner_particles,
particles_dim=1,
seed=seed
Expand Down Expand Up @@ -594,7 +597,8 @@ def smc_squared(

traced_results = sequential_monte_carlo(
initial_weighted_particles=initial_state,
propose_and_update_log_weights_fn=outer_propose_and_update_log_weights_fn,
propose_and_update_log_weights_fn=
outer_propose_and_update_log_weights_fn,
resample_fn=outer_resample_fn,
resample_criterion_fn=outer_resample_criterion_fn,
trace_criterion_fn=outer_trace_criterion_fn,
Expand Down Expand Up @@ -631,7 +635,8 @@ def _outer_particle_filter_propose_and_update_log_weights_fn(
"""Build a function specifying a particle filter update step."""
def _outer_propose_and_update_log_weights_fn(step, state, seed=None):
outside_parameters = state.particles[0]
inner_weighted_particles, log_weights = state.particles[1], state.log_weights
inner_weighted_particles, log_weights = state.particles[1], \
state.log_weights

filter_results = smc_kernel.SequentialMonteCarloResults(
steps=step,
Expand All @@ -654,38 +659,57 @@ def _outer_propose_and_update_log_weights_fn(step, state, seed=None):
)

kernel = smc_kernel.SequentialMonteCarlo(
propose_and_update_log_weights_fn=inner_propose_and_update_log_weights_fn,
propose_and_update_log_weights_fn=
inner_propose_and_update_log_weights_fn,
resample_fn=inner_resample_fn,
resample_criterion_fn=inner_resample_criterion_fn,
particles_dim=1,
unbiased_gradients=unbiased_gradients
)

inner_weighted_particles, filter_results = kernel.one_step(inner_weighted_particles,
filter_results,
seed=seed)
inner_weighted_particles, filter_results = kernel.one_step(
inner_weighted_particles,
filter_results,
seed=seed
)

updated_log_weights = log_weights + filter_results.incremental_log_marginal_likelihood
updated_log_weights = log_weights + \
filter_results.incremental_log_marginal_likelihood

do_rejuvenation = outer_rejuvenation_criterion_fn(step, state)

def rejuvenate_particles(outside_parameters, updated_log_weights, inner_weighted_particles, filter_results):
proposed_parameters = parameter_proposal_kernel(outside_parameters).sample(seed=seed)
def rejuvenate_particles(outside_parameters,
updated_log_weights,
inner_weighted_particles,
filter_results):
proposed_parameters = parameter_proposal_kernel(
outside_parameters
).sample(seed=seed)

rej_params_log_weights = ps.zeros_like(
initial_parameter_prior.log_prob(proposed_parameters)
)
rej_params_log_weights = tf.nn.log_softmax(rej_params_log_weights, axis=0)

rej_inner_weighted_particles = _particle_filter_initial_weighted_particles(
observations=inner_observations,
observation_fn=inner_observation_fn(proposed_parameters),
initial_state_prior=inner_initial_state_prior(0, proposed_parameters),
initial_state_proposal=(inner_initial_state_proposal(0, proposed_parameters)
if inner_initial_state_proposal is not None else None),
num_particles=num_inner_particles,
particles_dim=1,
seed=seed)
rej_params_log_weights = tf.nn.log_softmax(
rej_params_log_weights,
axis=0
)

rej_inner_weighted_particles = \
_particle_filter_initial_weighted_particles(
observations=inner_observations,
observation_fn=inner_observation_fn(proposed_parameters),
initial_state_prior=inner_initial_state_prior(
0,
proposed_parameters
),
initial_state_proposal=(
inner_initial_state_proposal(0, proposed_parameters)
if inner_initial_state_proposal is not None
else None),
num_particles=num_inner_particles,
particles_dim=1,
seed=seed
)

batch_zeros = tf.zeros(ps.shape(log_weights))

Expand All @@ -709,11 +733,13 @@ def rejuvenate_particles(outside_parameters, updated_log_weights, inner_weighted
observation_fn=inner_observation_fn(proposed_parameters),
extra_fn=extra_fn,
particles_dim=1,
num_transitions_per_observation=num_transitions_per_observation)
num_transitions_per_observation=
num_transitions_per_observation)
)

rej_kernel = smc_kernel.SequentialMonteCarlo(
propose_and_update_log_weights_fn=rej_inner_propose_and_update_log_weights_fn,
propose_and_update_log_weights_fn=
rej_inner_propose_and_update_log_weights_fn,
resample_fn=inner_resample_fn,
resample_criterion_fn=inner_resample_criterion_fn,
particles_dim=1,
Expand All @@ -732,16 +758,27 @@ def body(i,
rej_parameters_weights,
rej_params_log_weights):

rej_inner_weighted_particles, rej_filter_results = rej_kernel.one_step(
rej_inner_weighted_particles, rej_filter_results, seed=seed
)
rej_inner_weighted_particles, rej_filter_results = \
rej_kernel.one_step(
rej_inner_weighted_particles, rej_filter_results, seed=seed
)

rej_parameters_weights += rej_inner_weighted_particles.log_weights

rej_params_log_weights = rej_params_log_weights + rej_filter_results.incremental_log_marginal_likelihood
return i + 1, rej_inner_weighted_particles, rej_filter_results, rej_parameters_weights, rej_params_log_weights

i, rej_inner_weighted_particles, rej_filter_results, rej_inner_particles_weights, rej_params_log_weights = tf.while_loop(
rej_params_log_weights = \
rej_params_log_weights + \
rej_filter_results.incremental_log_marginal_likelihood
return i + 1, \
rej_inner_weighted_particles, \
rej_filter_results, \
rej_parameters_weights, \
rej_params_log_weights

_, \
rej_inner_weighted_particles, \
rej_filter_results, \
rej_inner_particles_weights, \
rej_params_log_weights = tf.while_loop(
condition,
body,
loop_vars=[0,
Expand All @@ -754,19 +791,27 @@ def body(i,

log_a = rej_filter_results.accumulated_log_marginal_likelihood - \
filter_results.accumulated_log_marginal_likelihood + \
parameter_proposal_kernel(proposed_parameters).log_prob(outside_parameters) - \
parameter_proposal_kernel(outside_parameters).log_prob(proposed_parameters)
parameter_proposal_kernel(
proposed_parameters).log_prob(outside_parameters) - \
parameter_proposal_kernel(
outside_parameters).log_prob(proposed_parameters)

acceptance_probs = tf.minimum(1., tf.exp(log_a))

random_numbers = uniform.Uniform(0., 1.).sample(num_outer_particles, seed=seed)
random_numbers = uniform.Uniform(0., 1.).sample(num_outer_particles,
seed=seed)

# Determine if the proposed particle should be accepted or reject
accept = random_numbers > acceptance_probs

# Update the chosen particles and filter restults based on the acceptance step
outside_parameters = tf.where(accept, outside_parameters, proposed_parameters)
updated_log_weights = tf.where(accept, updated_log_weights, rej_params_log_weights)
# Update the chosen particles and filter restults
# based on the acceptance step
outside_parameters = tf.where(accept,
outside_parameters,
proposed_parameters)
updated_log_weights = tf.where(accept,
updated_log_weights,
rej_params_log_weights)

inner_weighted_particles_particles = mcmc_util.choose(
accept,
Expand All @@ -786,17 +831,29 @@ def body(i,
)

filter_results = tf.nest.map_structure(
lambda a, b: where_fn(accept, a, b, num_outer_particles, num_inner_particles),
lambda a, b: where_fn(accept, a, b,
num_outer_particles,
num_inner_particles),
filter_results,
rej_filter_results
)

return outside_parameters, updated_log_weights, inner_weighted_particles, filter_results
return outside_parameters, updated_log_weights, \
inner_weighted_particles, filter_results

outside_parameters, updated_log_weights, inner_weighted_particles, filter_results = tf.cond(
outside_parameters, \
updated_log_weights, \
inner_weighted_particles, \
filter_results = tf.cond(
do_rejuvenation,
lambda: (rejuvenate_particles(outside_parameters, updated_log_weights, inner_weighted_particles, filter_results)),
lambda: (outside_parameters, updated_log_weights, inner_weighted_particles, filter_results)
lambda: (rejuvenate_particles(outside_parameters,
updated_log_weights,
inner_weighted_particles,
filter_results)),
lambda: (outside_parameters,
updated_log_weights,
inner_weighted_particles,
filter_results)
)

return smc_kernel.WeightedParticles(
Expand Down Expand Up @@ -1066,7 +1123,7 @@ def _compute_observation_log_weights(step,
lambda x, step=step: tf.gather(x, observation_idx), observations)

if particles_dim == 1:
observation = tf.expand_dims(observation, axis=0)
observation = tf.expand_dims(observation, axis=0)
observation = tf.nest.map_structure(
lambda x: tf.expand_dims(x, axis=particles_dim), observation)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -615,8 +615,11 @@ def marginal_log_likelihood(level_scale, noise_scale):

def test_smc_squared_rejuvenation_parameters(self):
def particle_dynamics(params, _, previous_state):
reshaped_params = tf.reshape(params, [params.shape[0]] + [1] * (previous_state.shape.rank - 1))
broadcasted_params = tf.broadcast_to(reshaped_params, previous_state.shape)
reshaped_params = tf.reshape(params,
[params.shape[0]] +
[1] * (previous_state.shape.rank - 1))
broadcasted_params = tf.broadcast_to(reshaped_params,
previous_state.shape)
return normal.Normal(previous_state + broadcasted_params + 1, 0.1)

def rejuvenation_criterion(step, state):
Expand All @@ -625,7 +628,8 @@ def rejuvenation_criterion(step, state):
tf.equal(tf.math.mod(step, tf.constant(2)), tf.constant(0)),
tf.not_equal(state.extra[0], tf.constant(0))
)
return tf.cond(cond, lambda: tf.constant(True), lambda: tf.constant(False))
return tf.cond(cond, lambda: tf.constant(True),
lambda: tf.constant(False))

inner_observations = tf.range(30, dtype=tf.float32)

Expand All @@ -637,15 +641,18 @@ def rejuvenation_criterion(step, state):

params, inner_pt = self.evaluate(particle_filter.smc_squared(
inner_observations=inner_observations,
inner_initial_state_prior=lambda _, params: mvn_diag.MultivariateNormalDiag(
inner_initial_state_prior=lambda _, params:
mvn_diag.MultivariateNormalDiag(
loc=loc, scale_diag=scale_diag
),
initial_parameter_prior=normal.Normal(3., 1.),
num_outer_particles=num_outer_particles,
num_inner_particles=num_inner_particles,
outer_rejuvenation_criterion_fn=rejuvenation_criterion,
inner_transition_fn=lambda params: (
lambda _, state: independent.Independent(particle_dynamics(params, _, state), 1)),
lambda _, state: independent.Independent(
particle_dynamics(params, _, state), 1)
),
inner_observation_fn=lambda params: (
lambda _, state: independent.Independent(normal.Normal(state, 2.), 1)),
outer_trace_fn=lambda s, r: (
Expand Down Expand Up @@ -693,7 +700,8 @@ def observe_position(_, state):
inner_initial_state_prior=lambda _, params: initial_state_prior,
initial_parameter_prior=deterministic.Deterministic(0.),
num_outer_particles=1,
inner_transition_fn=lambda params: simple_harmonic_motion_transition_fn,
inner_transition_fn=lambda params:
simple_harmonic_motion_transition_fn,
inner_observation_fn=lambda params: observe_position,
num_inner_particles=1024,
outer_trace_fn=lambda s, r: (
Expand All @@ -704,7 +712,9 @@ def observe_position(_, state):
seed=test_util.test_seed())
)

self.assertAllEqual(ps.shape(particles['position']), tf.constant([102, 1, 1024]))
self.assertAllEqual(ps.shape(particles['position']), tf.constant([102,
1,
1024]))

self.assertAllClose(tf.transpose(np.mean(particles['position'], axis=-1)),
tf.reshape(tf.math.cos(dt * np.arange(102)), [1, -1]),
Expand Down Expand Up @@ -733,8 +743,10 @@ def trace_fn(state, _):
inner_observations=tf.convert_to_tensor([1., 3., 5., 7., 9.]),
inner_initial_state_prior=lambda _, params: normal.Normal([0.], 1.),
initial_parameter_prior=deterministic.Deterministic(0.),
inner_transition_fn=lambda params: (lambda _, state: normal.Normal(state, 1.)),
inner_observation_fn=lambda params: (lambda _, state: normal.Normal(state, 1.)),
inner_transition_fn=lambda params: (lambda _, state:
normal.Normal(state, 1.)),
inner_observation_fn=lambda params: (lambda _, state:
normal.Normal(state, 1.)),
num_inner_particles=1024,
num_outer_particles=1,
outer_trace_fn=trace_fn,
Expand Down Expand Up @@ -766,15 +778,21 @@ def rejuvenation_criterion(step, state):
tf.equal(tf.math.mod(step, tf.constant(3)), tf.constant(0)),
tf.not_equal(state.extra[0], tf.constant(0))
)
return tf.cond(cond, lambda: tf.constant(True), lambda: tf.constant(False))
return tf.cond(cond, lambda: tf.constant(True),
lambda: tf.constant(False))

(parameters, weight_parameters, inner_particles, inner_log_weights, lp) = self.evaluate(
(parameters, weight_parameters,
inner_particles, inner_log_weights, lp) = self.evaluate(
particle_filter.smc_squared(
inner_observations=tf.convert_to_tensor([1., 3., 5., 7., 9.]),
initial_parameter_prior=deterministic.Deterministic(0.),
inner_initial_state_prior=lambda _, params: normal.Normal([0.] * num_outer_particles, 1.),
inner_transition_fn=lambda params: (lambda _, state: normal.Normal(state, 10.)),
inner_observation_fn=lambda params: (lambda _, state: normal.Normal(state, 0.1)),
inner_initial_state_prior=lambda _, params: normal.Normal(
[0.] * num_outer_particles, 1.
),
inner_transition_fn=lambda params:
(lambda _, state: normal.Normal(state, 10.)),
inner_observation_fn=lambda params:
(lambda _, state: normal.Normal(state, 0.1)),
num_inner_particles=num_inner_particles,
num_outer_particles=num_outer_particles,
outer_rejuvenation_criterion_fn=rejuvenation_criterion,
Expand Down

0 comments on commit 53ba255

Please sign in to comment.