-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
particle filter update #1745
particle filter update #1745
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Could you please also update weighted_resampling_test
to include some tests for when particles_dim > 0
?
@@ -24,6 +26,8 @@ | |||
from tensorflow_probability.python.internal import prefer_static as ps | |||
from tensorflow_probability.python.internal import samplers | |||
from tensorflow_probability.python.mcmc.internal import util as mcmc_util | |||
from tensorflow_probability.python.mcmc.internal.util import choose |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove this import. Per https://google.github.io/styleguide/pyguide.html#22-imports , imports should be for packages instead of individual functions or classes.
proposed_particles, | ||
log_weights | ||
) = tf.nest.map_structure( | ||
lambda r, p: choose(do_rejuvenation, r, p), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use mcmc_util.choose
instead of directly importing the function choose
.
seed=resample_seed) | ||
(resampled_particles, | ||
resample_indices, | ||
log_weights) = tf.nest.map_structure( | ||
lambda r, p: tf.where(do_resample, r, p), | ||
lambda r, p: choose(do_resample, r, p), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use mcmc_util.choose
instead of directly importing the function choose
.
return extra | ||
|
||
|
||
def identity(state, new_particles, new_indices, log_weights, extra, step): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this function is unused. Please remove.
@@ -115,19 +118,35 @@ def _dummy_indices_like(indices): | |||
indices_shape) | |||
|
|||
|
|||
def log_ess_from_log_weights(log_weights): | |||
def log_ess_from_log_weights(log_weights, particles_dim=0): | |||
"""Computes log-ESS estimate from log-weights along axis=0.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"along axis=particles_dim"
@@ -49,6 +49,9 @@ class WeightedParticles(collections.namedtuple( | |||
`exp(reduce_logsumexp(log_weights, axis=0)) == 1.`. These must be used in | |||
conjunction with `particles` to compute expectations under the target | |||
distribution. | |||
extra: a (structure of) Tensor(s) each of shape | |||
`concat([[num_particles, b1, ..., bN], event_shape])`, where `event_shape` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think num_particles
shouldn't be in the shape here. (And could the comment mention that this represents global state of the sampling process that is not associated with individual particles?)
(resampled_particles, resample_indices, weights_after_resampling), | ||
(state.particles, _dummy_indices_like(resample_indices), | ||
normalized_log_weights)) | ||
|
||
proposed_extra = self.extra_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the extra_fn
argument and property.
The propose_and_update_log_weights_fn
is responsible for updating extra
in WeightedParticles.
And proposed_state.extra
is unaffected by resampling, so below we can just use proposed_state.extra
.
(Alternatively, we could skip adding extra
in this PR, as it turns out we don't yet need it for SMC^2.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking great!
I would like to chat about how we want to handle extra
and rejuvenation when we next meet. Until then, could we simplify this first PR to just make the changes for?:
-
Handling
particles_dim
in weighted_resampling.py, sequential_monte_carlo_kernel.py, and particle_filter.py. -
Introducing the
sequential_monte_carlo
function inparticle_filter.py
.
(In particular, we would leave out adding extra
to WeightedParticles
, as well as any new code for doing rejuvenation in particle_filter.py
.)
(We could still replace the rejuvenation_kernel_fn=None
argument with rejuvenation_fn=None
and rejuvenation_criterion_fn=None
, but let's leave these arguments unused until the next PR.)
I think we could get the simplified PR in quickly, and then be able to focus on rejuvenation, exta, and SMC^2.
if trace_criterion_fn is never_trace: | ||
# Return results from just the final step. | ||
traced_results = trace_fn(*final_seed_state_result[1:]) | ||
def sample_at_dim(initial_state_prior, dim, num_samples, seed=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please remove sample_at_dim
from this PR? Let's add put the definition in a later PR that also adds code that calls sample_at_dim
.
@@ -33,7 +33,7 @@ | |||
] | |||
|
|||
|
|||
def resample(particles, log_weights, resample_fn, target_log_weights=None, | |||
def resample(particles, log_weights, resample_fn, target_log_weights=None, particles_dim=0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please reformat this line to be under 80 characters long?
The Google Python style guide imposes a maximum line length of 80 characters -- https://google.github.io/styleguide/pyguide.html#32-line-length .
@@ -293,7 +302,7 @@ def resample_systematic(log_probs, event_size, sample_shape, | |||
""" | |||
with tf.name_scope(name or 'resample_systematic') as name: | |||
log_probs = tf.convert_to_tensor(log_probs, dtype_hint=tf.float32) | |||
log_probs = dist_util.move_dimension(log_probs, source_idx=0, dest_idx=-1) | |||
log_probs = dist_util.move_dimension(log_probs, source_idx=particles_dim, dest_idx=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please reformat this line to be under 80 characters long?
The Google Python style guide imposes a maximum line length of 80 characters -- https://google.github.io/styleguide/pyguide.html#32-line-length .
@@ -310,7 +319,7 @@ def resample_systematic(log_probs, event_size, sample_shape, | |||
log_points = tf.broadcast_to(tf.math.log(even_spacing), points_shape) | |||
|
|||
resampled = _resample_using_log_points(log_probs, sample_shape, log_points) | |||
return dist_util.move_dimension(resampled, source_idx=-1, dest_idx=0) | |||
return dist_util.move_dimension(resampled, source_idx=-1, dest_idx=particles_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please reformat this line to be under 80 characters long?
The Google Python style guide imposes a maximum line length of 80 characters -- https://google.github.io/styleguide/pyguide.html#32-line-length .
@@ -241,7 +250,7 @@ def resample_independent(log_probs, event_size, sample_shape, | |||
|
|||
|
|||
# TODO(b/153689734): rewrite so as not to use `move_dimension`. | |||
def resample_systematic(log_probs, event_size, sample_shape, | |||
def resample_systematic(log_probs, event_size, sample_shape, particles_dim=0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update the docstring to include an entry for particles_dim
.
[20., 20.], | ||
axis=1, | ||
atol=1e-2) | ||
self.assertAllClose( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be a more clear as:
self.assertAllClose(
tf.reduce_sum(tf.nn.softmax(new_log_weights) * new_particles, axis=1),
[20., 20.],
atol=1e-2)
axis=1, | ||
atol=1e-2) | ||
|
||
self.assertAllClose( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be more clear as:
self.assertAllClose(
tf.reduce_sum(tf.nn.softmax(new_log_weights) * new_particles, axis=1),
[30., 30.],
atol=1.)
@@ -299,6 +299,46 @@ def resample_with_target_distribution(self): | |||
tf.reduce_sum(tf.nn.softmax(new_log_weights) * new_particles), | |||
30., atol=1.) | |||
|
|||
def test_with_target_distribution_dim_one(self): | |||
particles = np.linspace(0., 500., num=2500, dtype=np.float32) | |||
stacked_particles = np.stack([particles, particles], axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may be worth making the two different batch dimensions contain different particles -- e.g., something like:
stacked_particles = np.stack([
np.linspace(0., 500., num=2500, dtype=np.float32),
np.linspace(0.17, 433., num=2500, dtype=np.float32),
], axis=0)
stacked_log_weights, | ||
resample_fn=resample_systematic, | ||
particles_dim=1, | ||
target_log_weights=poisson.Poisson(30).log_prob(particles), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
particles -> stacked_particles
|
||
|
||
def _particle_filter_initial_weighted_particles(observations, | ||
observation_fn, | ||
initial_state_prior, | ||
initial_state_proposal, | ||
num_particles, | ||
num_inner_particles, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this argument should still be called num_particles
.
Callers of this function might have different particle dimensions for different nested SMCs / particle filters -- e.g., "inner" and "outer". But from the perspective of this function, I don't think the concept of "inner particles" makes sense -- we're just creating the initial state for some particle filter that has num_particles
particles, indexed by dimension particles_dim
.
# Return results from just the final step. | ||
traced_results = trace_fn(*final_seed_state_result[1:]) | ||
|
||
return traced_results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line needs to be un-indented -- so we return trace_results
whether or not trace_criterion_fn is never_trace
.
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
@@ -145,6 +148,7 @@ def __init__(self, | |||
resample_fn=weighted_resampling.resample_systematic, | |||
resample_criterion_fn=ess_below_threshold, | |||
unbiased_gradients=True, | |||
particles_dim=0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the bootstrap
method also needs to be updated to handle particles_dim
.
That is, instead of initializing incremental_log_marginal_likelihood
and accumulated_log_marginal_likelihood
to:
batch_zeros = tf.zeros(
ps.shape(init_state.log_weights)[1:],
dtype=init_state.log_weights.dtype)
we need to initialize them to something like:
particles_shape = ps.shape(init_state.log_weights)
weights_shape = ps.concat([
particles_shape[:self.particles_dim],
particles_shape[self.particles_dim+1:]
], axis=0)
batch_zeros = tf.zeros(
weights_shape, dtype=init_state.log_weights.dtype)
To verify that we have done this correctly, could you also please add a version of test_batch_of_filters
in particle_filter_test.py
with a particles_dim
of 1 or 2?
@@ -474,7 +585,8 @@ def propose_and_update_log_weights_fn(step, state, seed=None): | |||
particles=proposed_particles, | |||
log_weights=log_weights + _compute_observation_log_weights( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like _compute_observation_log_weights
needs to be updated to add a particles_dim
argument.
And that function will have to do something like:
observation = tf.expand_dims(observation, axis=particles_dim)
so that the computed log-weights have the right shape.
true_initial_positions) | ||
|
||
# Set particles dimension in position 1 | ||
observed_positions_transposed = np.transpose(observed_positions, (1, 0, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this should be necessary. The observations passed to particle_filter
should have shape [40, 3, 2]
-- i.e., [num_steps] + [batch_shape]
.
@@ -344,7 +461,8 @@ def particle_filter(observations, | |||
|
|||
init_seed, loop_seed = samplers.split_seed(seed, salt='particle_filter') | |||
with tf.name_scope(name or 'particle_filter'): | |||
num_observation_steps = ps.size0(tf.nest.flatten(observations)[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should still be getting the size of the 0-th dimension. I don't believe the shape of the observations is affected by particles_dim.
lambda x, step=step: tf.gather(x, observation_idx), observations) | ||
lambda x, step=step: tf.gather(x, | ||
observation_idx, | ||
axis=particles_dim), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this line should be unchanged -- because the shape of observations
is not affected by particles_dim.
observations) | ||
|
||
# Expand dimensions of observation at the particles_dim | ||
observation = tf.expand_dims(observation, axis=particles_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NOTE: The goal of this expand_dims is something like:
-
observations
has shape like[num_steps] + batch_shape
. -
so
observation
has shape likebatch_shape
. -
particles has shape like
batch_shape_1 + [num_particles] + batch_shape_2 + event_shape
-
observation_fn(step, particles)
is a distribution with batch shape likebatch_shape_1 + [num_particles] + batch_shape_2
and event shapeevent_shape
. -
When particles_dim=0, we can just apply
observation_fn(step, particles).log_prob(observations)
, because thelog_prob
is required to broadcast the observations of shapebatch_shape + event_shape
up to[num_particles] + batch_shape + event_shape
. -
But when particles_dim>0, this won't work. Instead, we have to expand_dims on
observations
so it has shape likebatch_shape_1 + [1] + batch_shape_2 + [event_shape
. This will allow thelog_prob
to broadcastobservations
up to shapebatch_shape_1 + [num_steps] + batch_shape_2 + [event_shape]
, so that thelog_prob
can run and will return something with shapebatch_shape_1 + [num_steps] + batch_shape_2
.
It looks like there are merge conflicts we need to resolve before we can submit this PR. To resolve these merge conflicts, I recommend either:
|
Updating particle filter, adding rejuvenation and extra