diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter.py b/tensorflow_probability/python/experimental/mcmc/particle_filter.py index b920b0db85..2af7a68999 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter.py @@ -25,8 +25,6 @@ from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.mcmc.internal import util as mcmc_util -from tensorflow_probability.python.distributions import batch_reshape -from tensorflow_probability.python.distributions import batch_broadcast from tensorflow_probability.python.distributions import normal from tensorflow_probability.python.distributions import uniform @@ -51,7 +49,8 @@ def _default_trace_fn(state, kernel_results): def _default_kernel(parameters): mean, variance = tf.nn.moments(parameters, axes=[0]) - proposal_distribution = normal.Normal(loc=tf.fill(parameters.shape, mean), scale=tf.sqrt(variance)) + proposal_distribution = normal.Normal(loc=tf.fill(parameters.shape, mean), + scale=tf.sqrt(variance)) return proposal_distribution @@ -499,15 +498,18 @@ def smc_squared( unbiased_gradients=True, seed=None, ): - init_seed, loop_seed, step_seed = samplers.split_seed(seed, n=3, salt='smc_squared') + _1, loop_seed, _2 = samplers.split_seed(seed, n=3, salt='smc_squared') num_observation_steps = ps.size0(tf.nest.flatten(inner_observations)[0]) - # TODO: The following two lines compensates for having the first empty step in smc2 + # TODO: The following two lines compensates for having the + # first empty step in smc2 num_timesteps = (1 + num_transitions_per_observation * (num_observation_steps - 1)) + 1 last_obs_expanded = tf.expand_dims(inner_observations[-1], axis=0) - inner_observations = tf.concat([inner_observations, last_obs_expanded], axis=0) + inner_observations = tf.concat([inner_observations, + last_obs_expanded], + axis=0) if outer_rejuvenation_criterion_fn is None: outer_rejuvenation_criterion_fn = lambda *_: tf.constant(False) @@ -543,7 +545,8 @@ def smc_squared( observation_fn=inner_observation_fn(initial_state), initial_state_prior=inner_initial_state_prior(0, initial_state), initial_state_proposal=(inner_initial_state_proposal(0, initial_state) - if inner_initial_state_proposal is not None else None), + if inner_initial_state_proposal is not None + else None), num_particles=num_inner_particles, particles_dim=1, seed=seed @@ -594,7 +597,8 @@ def smc_squared( traced_results = sequential_monte_carlo( initial_weighted_particles=initial_state, - propose_and_update_log_weights_fn=outer_propose_and_update_log_weights_fn, + propose_and_update_log_weights_fn= + outer_propose_and_update_log_weights_fn, resample_fn=outer_resample_fn, resample_criterion_fn=outer_resample_criterion_fn, trace_criterion_fn=outer_trace_criterion_fn, @@ -631,7 +635,8 @@ def _outer_particle_filter_propose_and_update_log_weights_fn( """Build a function specifying a particle filter update step.""" def _outer_propose_and_update_log_weights_fn(step, state, seed=None): outside_parameters = state.particles[0] - inner_weighted_particles, log_weights = state.particles[1], state.log_weights + inner_weighted_particles, log_weights = state.particles[1], \ + state.log_weights filter_results = smc_kernel.SequentialMonteCarloResults( steps=step, @@ -654,38 +659,57 @@ def _outer_propose_and_update_log_weights_fn(step, state, seed=None): ) kernel = smc_kernel.SequentialMonteCarlo( - propose_and_update_log_weights_fn=inner_propose_and_update_log_weights_fn, + propose_and_update_log_weights_fn= + inner_propose_and_update_log_weights_fn, resample_fn=inner_resample_fn, resample_criterion_fn=inner_resample_criterion_fn, particles_dim=1, unbiased_gradients=unbiased_gradients ) - inner_weighted_particles, filter_results = kernel.one_step(inner_weighted_particles, - filter_results, - seed=seed) + inner_weighted_particles, filter_results = kernel.one_step( + inner_weighted_particles, + filter_results, + seed=seed + ) - updated_log_weights = log_weights + filter_results.incremental_log_marginal_likelihood + updated_log_weights = log_weights + \ + filter_results.incremental_log_marginal_likelihood do_rejuvenation = outer_rejuvenation_criterion_fn(step, state) - def rejuvenate_particles(outside_parameters, updated_log_weights, inner_weighted_particles, filter_results): - proposed_parameters = parameter_proposal_kernel(outside_parameters).sample(seed=seed) + def rejuvenate_particles(outside_parameters, + updated_log_weights, + inner_weighted_particles, + filter_results): + proposed_parameters = parameter_proposal_kernel( + outside_parameters + ).sample(seed=seed) rej_params_log_weights = ps.zeros_like( initial_parameter_prior.log_prob(proposed_parameters) ) - rej_params_log_weights = tf.nn.log_softmax(rej_params_log_weights, axis=0) - - rej_inner_weighted_particles = _particle_filter_initial_weighted_particles( - observations=inner_observations, - observation_fn=inner_observation_fn(proposed_parameters), - initial_state_prior=inner_initial_state_prior(0, proposed_parameters), - initial_state_proposal=(inner_initial_state_proposal(0, proposed_parameters) - if inner_initial_state_proposal is not None else None), - num_particles=num_inner_particles, - particles_dim=1, - seed=seed) + rej_params_log_weights = tf.nn.log_softmax( + rej_params_log_weights, + axis=0 + ) + + rej_inner_weighted_particles = \ + _particle_filter_initial_weighted_particles( + observations=inner_observations, + observation_fn=inner_observation_fn(proposed_parameters), + initial_state_prior=inner_initial_state_prior( + 0, + proposed_parameters + ), + initial_state_proposal=( + inner_initial_state_proposal(0, proposed_parameters) + if inner_initial_state_proposal is not None + else None), + num_particles=num_inner_particles, + particles_dim=1, + seed=seed + ) batch_zeros = tf.zeros(ps.shape(log_weights)) @@ -709,11 +733,13 @@ def rejuvenate_particles(outside_parameters, updated_log_weights, inner_weighted observation_fn=inner_observation_fn(proposed_parameters), extra_fn=extra_fn, particles_dim=1, - num_transitions_per_observation=num_transitions_per_observation) + num_transitions_per_observation= + num_transitions_per_observation) ) rej_kernel = smc_kernel.SequentialMonteCarlo( - propose_and_update_log_weights_fn=rej_inner_propose_and_update_log_weights_fn, + propose_and_update_log_weights_fn= + rej_inner_propose_and_update_log_weights_fn, resample_fn=inner_resample_fn, resample_criterion_fn=inner_resample_criterion_fn, particles_dim=1, @@ -732,16 +758,27 @@ def body(i, rej_parameters_weights, rej_params_log_weights): - rej_inner_weighted_particles, rej_filter_results = rej_kernel.one_step( - rej_inner_weighted_particles, rej_filter_results, seed=seed - ) + rej_inner_weighted_particles, rej_filter_results = \ + rej_kernel.one_step( + rej_inner_weighted_particles, rej_filter_results, seed=seed + ) rej_parameters_weights += rej_inner_weighted_particles.log_weights - rej_params_log_weights = rej_params_log_weights + rej_filter_results.incremental_log_marginal_likelihood - return i + 1, rej_inner_weighted_particles, rej_filter_results, rej_parameters_weights, rej_params_log_weights - - i, rej_inner_weighted_particles, rej_filter_results, rej_inner_particles_weights, rej_params_log_weights = tf.while_loop( + rej_params_log_weights = \ + rej_params_log_weights + \ + rej_filter_results.incremental_log_marginal_likelihood + return i + 1, \ + rej_inner_weighted_particles, \ + rej_filter_results, \ + rej_parameters_weights, \ + rej_params_log_weights + + _, \ + rej_inner_weighted_particles, \ + rej_filter_results, \ + rej_inner_particles_weights, \ + rej_params_log_weights = tf.while_loop( condition, body, loop_vars=[0, @@ -754,19 +791,27 @@ def body(i, log_a = rej_filter_results.accumulated_log_marginal_likelihood - \ filter_results.accumulated_log_marginal_likelihood + \ - parameter_proposal_kernel(proposed_parameters).log_prob(outside_parameters) - \ - parameter_proposal_kernel(outside_parameters).log_prob(proposed_parameters) + parameter_proposal_kernel( + proposed_parameters).log_prob(outside_parameters) - \ + parameter_proposal_kernel( + outside_parameters).log_prob(proposed_parameters) acceptance_probs = tf.minimum(1., tf.exp(log_a)) - random_numbers = uniform.Uniform(0., 1.).sample(num_outer_particles, seed=seed) + random_numbers = uniform.Uniform(0., 1.).sample(num_outer_particles, + seed=seed) # Determine if the proposed particle should be accepted or reject accept = random_numbers > acceptance_probs - # Update the chosen particles and filter restults based on the acceptance step - outside_parameters = tf.where(accept, outside_parameters, proposed_parameters) - updated_log_weights = tf.where(accept, updated_log_weights, rej_params_log_weights) + # Update the chosen particles and filter restults + # based on the acceptance step + outside_parameters = tf.where(accept, + outside_parameters, + proposed_parameters) + updated_log_weights = tf.where(accept, + updated_log_weights, + rej_params_log_weights) inner_weighted_particles_particles = mcmc_util.choose( accept, @@ -786,17 +831,29 @@ def body(i, ) filter_results = tf.nest.map_structure( - lambda a, b: where_fn(accept, a, b, num_outer_particles, num_inner_particles), + lambda a, b: where_fn(accept, a, b, + num_outer_particles, + num_inner_particles), filter_results, rej_filter_results ) - return outside_parameters, updated_log_weights, inner_weighted_particles, filter_results + return outside_parameters, updated_log_weights, \ + inner_weighted_particles, filter_results - outside_parameters, updated_log_weights, inner_weighted_particles, filter_results = tf.cond( + outside_parameters, \ + updated_log_weights, \ + inner_weighted_particles, \ + filter_results = tf.cond( do_rejuvenation, - lambda: (rejuvenate_particles(outside_parameters, updated_log_weights, inner_weighted_particles, filter_results)), - lambda: (outside_parameters, updated_log_weights, inner_weighted_particles, filter_results) + lambda: (rejuvenate_particles(outside_parameters, + updated_log_weights, + inner_weighted_particles, + filter_results)), + lambda: (outside_parameters, + updated_log_weights, + inner_weighted_particles, + filter_results) ) return smc_kernel.WeightedParticles( @@ -1066,7 +1123,7 @@ def _compute_observation_log_weights(step, lambda x, step=step: tf.gather(x, observation_idx), observations) if particles_dim == 1: - observation = tf.expand_dims(observation, axis=0) + observation = tf.expand_dims(observation, axis=0) observation = tf.nest.map_structure( lambda x: tf.expand_dims(x, axis=particles_dim), observation) diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py index e190c76bda..dc6a152e4b 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py @@ -615,8 +615,11 @@ def marginal_log_likelihood(level_scale, noise_scale): def test_smc_squared_rejuvenation_parameters(self): def particle_dynamics(params, _, previous_state): - reshaped_params = tf.reshape(params, [params.shape[0]] + [1] * (previous_state.shape.rank - 1)) - broadcasted_params = tf.broadcast_to(reshaped_params, previous_state.shape) + reshaped_params = tf.reshape(params, + [params.shape[0]] + + [1] * (previous_state.shape.rank - 1)) + broadcasted_params = tf.broadcast_to(reshaped_params, + previous_state.shape) return normal.Normal(previous_state + broadcasted_params + 1, 0.1) def rejuvenation_criterion(step, state): @@ -625,7 +628,8 @@ def rejuvenation_criterion(step, state): tf.equal(tf.math.mod(step, tf.constant(2)), tf.constant(0)), tf.not_equal(state.extra[0], tf.constant(0)) ) - return tf.cond(cond, lambda: tf.constant(True), lambda: tf.constant(False)) + return tf.cond(cond, lambda: tf.constant(True), + lambda: tf.constant(False)) inner_observations = tf.range(30, dtype=tf.float32) @@ -637,7 +641,8 @@ def rejuvenation_criterion(step, state): params, inner_pt = self.evaluate(particle_filter.smc_squared( inner_observations=inner_observations, - inner_initial_state_prior=lambda _, params: mvn_diag.MultivariateNormalDiag( + inner_initial_state_prior=lambda _, params: + mvn_diag.MultivariateNormalDiag( loc=loc, scale_diag=scale_diag ), initial_parameter_prior=normal.Normal(3., 1.), @@ -645,7 +650,9 @@ def rejuvenation_criterion(step, state): num_inner_particles=num_inner_particles, outer_rejuvenation_criterion_fn=rejuvenation_criterion, inner_transition_fn=lambda params: ( - lambda _, state: independent.Independent(particle_dynamics(params, _, state), 1)), + lambda _, state: independent.Independent( + particle_dynamics(params, _, state), 1) + ), inner_observation_fn=lambda params: ( lambda _, state: independent.Independent(normal.Normal(state, 2.), 1)), outer_trace_fn=lambda s, r: ( @@ -693,7 +700,8 @@ def observe_position(_, state): inner_initial_state_prior=lambda _, params: initial_state_prior, initial_parameter_prior=deterministic.Deterministic(0.), num_outer_particles=1, - inner_transition_fn=lambda params: simple_harmonic_motion_transition_fn, + inner_transition_fn=lambda params: + simple_harmonic_motion_transition_fn, inner_observation_fn=lambda params: observe_position, num_inner_particles=1024, outer_trace_fn=lambda s, r: ( @@ -704,7 +712,9 @@ def observe_position(_, state): seed=test_util.test_seed()) ) - self.assertAllEqual(ps.shape(particles['position']), tf.constant([102, 1, 1024])) + self.assertAllEqual(ps.shape(particles['position']), tf.constant([102, + 1, + 1024])) self.assertAllClose(tf.transpose(np.mean(particles['position'], axis=-1)), tf.reshape(tf.math.cos(dt * np.arange(102)), [1, -1]), @@ -733,8 +743,10 @@ def trace_fn(state, _): inner_observations=tf.convert_to_tensor([1., 3., 5., 7., 9.]), inner_initial_state_prior=lambda _, params: normal.Normal([0.], 1.), initial_parameter_prior=deterministic.Deterministic(0.), - inner_transition_fn=lambda params: (lambda _, state: normal.Normal(state, 1.)), - inner_observation_fn=lambda params: (lambda _, state: normal.Normal(state, 1.)), + inner_transition_fn=lambda params: (lambda _, state: + normal.Normal(state, 1.)), + inner_observation_fn=lambda params: (lambda _, state: + normal.Normal(state, 1.)), num_inner_particles=1024, num_outer_particles=1, outer_trace_fn=trace_fn, @@ -766,15 +778,21 @@ def rejuvenation_criterion(step, state): tf.equal(tf.math.mod(step, tf.constant(3)), tf.constant(0)), tf.not_equal(state.extra[0], tf.constant(0)) ) - return tf.cond(cond, lambda: tf.constant(True), lambda: tf.constant(False)) + return tf.cond(cond, lambda: tf.constant(True), + lambda: tf.constant(False)) - (parameters, weight_parameters, inner_particles, inner_log_weights, lp) = self.evaluate( + (parameters, weight_parameters, + inner_particles, inner_log_weights, lp) = self.evaluate( particle_filter.smc_squared( inner_observations=tf.convert_to_tensor([1., 3., 5., 7., 9.]), initial_parameter_prior=deterministic.Deterministic(0.), - inner_initial_state_prior=lambda _, params: normal.Normal([0.] * num_outer_particles, 1.), - inner_transition_fn=lambda params: (lambda _, state: normal.Normal(state, 10.)), - inner_observation_fn=lambda params: (lambda _, state: normal.Normal(state, 0.1)), + inner_initial_state_prior=lambda _, params: normal.Normal( + [0.] * num_outer_particles, 1. + ), + inner_transition_fn=lambda params: + (lambda _, state: normal.Normal(state, 10.)), + inner_observation_fn=lambda params: + (lambda _, state: normal.Normal(state, 0.1)), num_inner_particles=num_inner_particles, num_outer_particles=num_outer_particles, outer_rejuvenation_criterion_fn=rejuvenation_criterion,