Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SMC2 #1775

Merged
merged 24 commits into from
Feb 23, 2024
Merged

SMC2 #1775

merged 24 commits into from
Feb 23, 2024

Conversation

aleslamitz
Copy link
Contributor

@brianwa84 brianwa84 requested a review from jburnim December 13, 2023 19:33
Copy link
Member

@jburnim jburnim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great, thanks!

I'm still reviewing, but I left some initial comments.

@@ -670,6 +1122,8 @@ def _compute_observation_log_weights(step,
observation = tf.nest.map_structure(
lambda x, step=step: tf.gather(x, observation_idx), observations)

if particles_dim == 1:
observation = tf.expand_dims(observation, axis=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this added code should be necessary. The lines right below this should already handle any necessary expand_dims for particles_dim.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is necessary due to the dimensions mismatch

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With these two lines included, ParticleFilterTestFloat32.test_batch_of_filters_particles_dim_1 fails when I run it. If I delete these two lines, then that tests passes.

@@ -670,6 +1122,8 @@ def _compute_observation_log_weights(step,
observation = tf.nest.map_structure(
lambda x, step=step: tf.gather(x, observation_idx), observations)

if particles_dim == 1:
observation = tf.expand_dims(observation, axis=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With these two lines included, ParticleFilterTestFloat32.test_batch_of_filters_particles_dim_1 fails when I run it. If I delete these two lines, then that tests passes.

inner_initial_state_prior=lambda _, params:
mvn_diag.MultivariateNormalDiag(
loc=loc, scale_diag=scale_diag
),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be:

inner_initial_state_prior=lambda _, params: mvn_diag.MultivariateNormalDiag(
    loc=tf.broadcast_to([0., 0.], params.shape + [2]),
    scale_diag=tf.broadcast_to([0.01, 0.01], params.shape + [2])

Because:

  • params.shape is just the batch shape (i.e., [num_outer_particles])
  • Which will give inner_inital_state_prior(0, params) the batch_shape of [num_outer_particles] and the correct event_shape of [2]. (And the SMC^2 code is then responsible for drawing num_inner_particles samples from this distribution and re-arranging the dimensions to get initial particles of shape [num_outer_particles, num_inner_particles, 2].)

broadcasted_params = tf.broadcast_to(reshaped_params,
previous_state.shape)
reshaped_dist = independent.Independent(
normal.Normal(previous_state + broadcasted_params + 1, 0.1),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

params is has shape [num_outer_particles] and previous_state has shape [num_outer_particles, num_inner_particles, 2], so here it should be sufficient to just do:

previous_state + params[..., tf.newaxis, tf.newaxis] + 1

tf.equal(tf.math.mod(step, tf.constant(2)), tf.constant(0)),
tf.not_equal(state.extra[0], tf.constant(0))
)
return tf.cond(cond, lambda: tf.constant(True),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should work to just return cond, instead of using tf.cond(cond, ...).

self.evaluate(
tf.reduce_sum(tf.exp(log_weights) *
particles['position'], axis=2)),
observed_positions,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would hope this assertion would pass if we changed this line to observed_positions[..., tf.newaxis, tf.newaxis].

But if that didn't work, we could also do: tf.broadcast_to(observation_positions, [num_timesteps] + batch_shape).

@copybara-service copybara-service bot merged commit aff7da4 into tensorflow:main Feb 23, 2024
6 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants