Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

particle filter update #1745

Closed
wants to merge 0 commits into from
Closed

particle filter update #1745

wants to merge 0 commits into from

Conversation

aleslamitz
Copy link
Contributor

Updating particle filter, adding rejuvenation and extra

@google-cla
Copy link

google-cla bot commented Aug 3, 2023

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Copy link
Member

@jburnim jburnim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Could you please also update weighted_resampling_test to include some tests for when particles_dim > 0?

@@ -24,6 +26,8 @@
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.mcmc.internal import util as mcmc_util
from tensorflow_probability.python.mcmc.internal.util import choose
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this import. Per https://google.github.io/styleguide/pyguide.html#22-imports , imports should be for packages instead of individual functions or classes.

proposed_particles,
log_weights
) = tf.nest.map_structure(
lambda r, p: choose(do_rejuvenation, r, p),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use mcmc_util.choose instead of directly importing the function choose.

seed=resample_seed)
(resampled_particles,
resample_indices,
log_weights) = tf.nest.map_structure(
lambda r, p: tf.where(do_resample, r, p),
lambda r, p: choose(do_resample, r, p),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use mcmc_util.choose instead of directly importing the function choose.

return extra


def identity(state, new_particles, new_indices, log_weights, extra, step):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this function is unused. Please remove.

@@ -115,19 +118,35 @@ def _dummy_indices_like(indices):
indices_shape)


def log_ess_from_log_weights(log_weights):
def log_ess_from_log_weights(log_weights, particles_dim=0):
"""Computes log-ESS estimate from log-weights along axis=0."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"along axis=particles_dim"

@@ -49,6 +49,9 @@ class WeightedParticles(collections.namedtuple(
`exp(reduce_logsumexp(log_weights, axis=0)) == 1.`. These must be used in
conjunction with `particles` to compute expectations under the target
distribution.
extra: a (structure of) Tensor(s) each of shape
`concat([[num_particles, b1, ..., bN], event_shape])`, where `event_shape`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think num_particles shouldn't be in the shape here. (And could the comment mention that this represents global state of the sampling process that is not associated with individual particles?)

(resampled_particles, resample_indices, weights_after_resampling),
(state.particles, _dummy_indices_like(resample_indices),
normalized_log_weights))

proposed_extra = self.extra_fn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the extra_fn argument and property.

The propose_and_update_log_weights_fn is responsible for updating extra in WeightedParticles.

And proposed_state.extra is unaffected by resampling, so below we can just use proposed_state.extra.

(Alternatively, we could skip adding extra in this PR, as it turns out we don't yet need it for SMC^2.)

Copy link
Member

@jburnim jburnim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great!

I would like to chat about how we want to handle extra and rejuvenation when we next meet. Until then, could we simplify this first PR to just make the changes for?:

  1. Handling particles_dim in weighted_resampling.py, sequential_monte_carlo_kernel.py, and particle_filter.py.

  2. Introducing the sequential_monte_carlo function in particle_filter.py.

(In particular, we would leave out adding extra to WeightedParticles, as well as any new code for doing rejuvenation in particle_filter.py.)

(We could still replace the rejuvenation_kernel_fn=None argument with rejuvenation_fn=None and rejuvenation_criterion_fn=None, but let's leave these arguments unused until the next PR.)

I think we could get the simplified PR in quickly, and then be able to focus on rejuvenation, exta, and SMC^2.

if trace_criterion_fn is never_trace:
# Return results from just the final step.
traced_results = trace_fn(*final_seed_state_result[1:])
def sample_at_dim(initial_state_prior, dim, num_samples, seed=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please remove sample_at_dim from this PR? Let's add put the definition in a later PR that also adds code that calls sample_at_dim.

@@ -33,7 +33,7 @@
]


def resample(particles, log_weights, resample_fn, target_log_weights=None,
def resample(particles, log_weights, resample_fn, target_log_weights=None, particles_dim=0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please reformat this line to be under 80 characters long?

The Google Python style guide imposes a maximum line length of 80 characters -- https://google.github.io/styleguide/pyguide.html#32-line-length .

@@ -293,7 +302,7 @@ def resample_systematic(log_probs, event_size, sample_shape,
"""
with tf.name_scope(name or 'resample_systematic') as name:
log_probs = tf.convert_to_tensor(log_probs, dtype_hint=tf.float32)
log_probs = dist_util.move_dimension(log_probs, source_idx=0, dest_idx=-1)
log_probs = dist_util.move_dimension(log_probs, source_idx=particles_dim, dest_idx=-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please reformat this line to be under 80 characters long?

The Google Python style guide imposes a maximum line length of 80 characters -- https://google.github.io/styleguide/pyguide.html#32-line-length .

@@ -310,7 +319,7 @@ def resample_systematic(log_probs, event_size, sample_shape,
log_points = tf.broadcast_to(tf.math.log(even_spacing), points_shape)

resampled = _resample_using_log_points(log_probs, sample_shape, log_points)
return dist_util.move_dimension(resampled, source_idx=-1, dest_idx=0)
return dist_util.move_dimension(resampled, source_idx=-1, dest_idx=particles_dim)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please reformat this line to be under 80 characters long?

The Google Python style guide imposes a maximum line length of 80 characters -- https://google.github.io/styleguide/pyguide.html#32-line-length .

@@ -241,7 +250,7 @@ def resample_independent(log_probs, event_size, sample_shape,


# TODO(b/153689734): rewrite so as not to use `move_dimension`.
def resample_systematic(log_probs, event_size, sample_shape,
def resample_systematic(log_probs, event_size, sample_shape, particles_dim=0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the docstring to include an entry for particles_dim.

[20., 20.],
axis=1,
atol=1e-2)
self.assertAllClose(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a more clear as:

self.assertAllClose(
        tf.reduce_sum(tf.nn.softmax(new_log_weights) * new_particles, axis=1),
        [20., 20.],
        atol=1e-2)

axis=1,
atol=1e-2)

self.assertAllClose(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be more clear as:

    self.assertAllClose(
        tf.reduce_sum(tf.nn.softmax(new_log_weights) * new_particles, axis=1),
        [30., 30.],
        atol=1.)

@@ -299,6 +299,46 @@ def resample_with_target_distribution(self):
tf.reduce_sum(tf.nn.softmax(new_log_weights) * new_particles),
30., atol=1.)

def test_with_target_distribution_dim_one(self):
particles = np.linspace(0., 500., num=2500, dtype=np.float32)
stacked_particles = np.stack([particles, particles], axis=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be worth making the two different batch dimensions contain different particles -- e.g., something like:

stacked_particles = np.stack([
    np.linspace(0., 500., num=2500, dtype=np.float32),
    np.linspace(0.17, 433., num=2500, dtype=np.float32),
], axis=0)

stacked_log_weights,
resample_fn=resample_systematic,
particles_dim=1,
target_log_weights=poisson.Poisson(30).log_prob(particles),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

particles -> stacked_particles



def _particle_filter_initial_weighted_particles(observations,
observation_fn,
initial_state_prior,
initial_state_proposal,
num_particles,
num_inner_particles,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this argument should still be called num_particles.

Callers of this function might have different particle dimensions for different nested SMCs / particle filters -- e.g., "inner" and "outer". But from the perspective of this function, I don't think the concept of "inner particles" makes sense -- we're just creating the initial state for some particle filter that has num_particles particles, indexed by dimension particles_dim.

# Return results from just the final step.
traced_results = trace_fn(*final_seed_state_result[1:])

return traced_results
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line needs to be un-indented -- so we return trace_results whether or not trace_criterion_fn is never_trace.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@@ -145,6 +148,7 @@ def __init__(self,
resample_fn=weighted_resampling.resample_systematic,
resample_criterion_fn=ess_below_threshold,
unbiased_gradients=True,
particles_dim=0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the bootstrap method also needs to be updated to handle particles_dim.

That is, instead of initializing incremental_log_marginal_likelihood and accumulated_log_marginal_likelihood to:

batch_zeros = tf.zeros(
    ps.shape(init_state.log_weights)[1:],
    dtype=init_state.log_weights.dtype)

we need to initialize them to something like:

        particles_shape = ps.shape(init_state.log_weights)
        weights_shape = ps.concat([
            particles_shape[:self.particles_dim],
            particles_shape[self.particles_dim+1:]
        ], axis=0)
        batch_zeros = tf.zeros(
            weights_shape, dtype=init_state.log_weights.dtype)

To verify that we have done this correctly, could you also please add a version of test_batch_of_filters in particle_filter_test.py with a particles_dim of 1 or 2?

@@ -474,7 +585,8 @@ def propose_and_update_log_weights_fn(step, state, seed=None):
particles=proposed_particles,
log_weights=log_weights + _compute_observation_log_weights(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like _compute_observation_log_weights needs to be updated to add a particles_dim argument.

And that function will have to do something like:

observation = tf.expand_dims(observation, axis=particles_dim)

so that the computed log-weights have the right shape.

true_initial_positions)

# Set particles dimension in position 1
observed_positions_transposed = np.transpose(observed_positions, (1, 0, 2))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be necessary. The observations passed to particle_filter should have shape [40, 3, 2] -- i.e., [num_steps] + [batch_shape].

@@ -344,7 +461,8 @@ def particle_filter(observations,

init_seed, loop_seed = samplers.split_seed(seed, salt='particle_filter')
with tf.name_scope(name or 'particle_filter'):
num_observation_steps = ps.size0(tf.nest.flatten(observations)[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should still be getting the size of the 0-th dimension. I don't believe the shape of the observations is affected by particles_dim.

lambda x, step=step: tf.gather(x, observation_idx), observations)
lambda x, step=step: tf.gather(x,
observation_idx,
axis=particles_dim),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line should be unchanged -- because the shape of observations is not affected by particles_dim.

observations)

# Expand dimensions of observation at the particles_dim
observation = tf.expand_dims(observation, axis=particles_dim)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: The goal of this expand_dims is something like:

  • observations has shape like [num_steps] + batch_shape.

  • so observation has shape like batch_shape.

  • particles has shape like batch_shape_1 + [num_particles] + batch_shape_2 + event_shape

  • observation_fn(step, particles) is a distribution with batch shape like batch_shape_1 + [num_particles] + batch_shape_2 and event shape event_shape.

  • When particles_dim=0, we can just apply observation_fn(step, particles).log_prob(observations), because the log_prob is required to broadcast the observations of shape batch_shape + event_shape up to [num_particles] + batch_shape + event_shape.

  • But when particles_dim>0, this won't work. Instead, we have to expand_dims on observations so it has shape like batch_shape_1 + [1] + batch_shape_2 + [event_shape. This will allow the log_prob to broadcast observations up to shape batch_shape_1 + [num_steps] + batch_shape_2 + [event_shape], so that the log_prob can run and will return something with shape batch_shape_1 + [num_steps] + batch_shape_2.

@jburnim
Copy link
Member

jburnim commented Nov 21, 2023

It looks like there are merge conflicts we need to resolve before we can submit this PR.

To resolve these merge conflicts, I recommend either:

  1. Create a brand new fork and a new branch in that fork, and copy over the six modified files (particle_filter.py, particle_filter_test.py, sequential_monte_carlo_kernel.py, sequential_monte_carlo_kernel_test.py, weighted_resampling.py, and weighted_resampling_test.py) and create a new PR.

  2. Or:

    a. First get the main branch of our fork to match the main branch of https://github.com/tensorflow/probability , by following the steps under Option B at https://scribu.net/blog/resetting-your-github-fork.html .

    b. Then get this branch in a clean state. It may work to merge the updated main branch into this branch. But it might also be easier to delete and recreate the branch with a single commit containing your changes to the six modified files.

@aleslamitz aleslamitz closed this Nov 21, 2023
@aleslamitz
Copy link
Contributor Author

#1771

@jburnim jburnim mentioned this pull request Dec 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants