Skip to content

Conversation

aleslamitz
Copy link
Contributor

@brianwa84 brianwa84 requested a review from jburnim December 13, 2023 19:33
Copy link
Member

@jburnim jburnim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great, thanks!

I'm still reviewing, but I left some initial comments.

lambda x, step=step: tf.gather(x, observation_idx), observations)

if particles_dim == 1:
observation = tf.expand_dims(observation, axis=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this added code should be necessary. The lines right below this should already handle any necessary expand_dims for particles_dim.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is necessary due to the dimensions mismatch

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With these two lines included, ParticleFilterTestFloat32.test_batch_of_filters_particles_dim_1 fails when I run it. If I delete these two lines, then that tests passes.

lambda x, step=step: tf.gather(x, observation_idx), observations)

if particles_dim == 1:
observation = tf.expand_dims(observation, axis=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With these two lines included, ParticleFilterTestFloat32.test_batch_of_filters_particles_dim_1 fails when I run it. If I delete these two lines, then that tests passes.

inner_initial_state_prior=lambda _, params:
mvn_diag.MultivariateNormalDiag(
loc=loc, scale_diag=scale_diag
),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be:

inner_initial_state_prior=lambda _, params: mvn_diag.MultivariateNormalDiag(
    loc=tf.broadcast_to([0., 0.], params.shape + [2]),
    scale_diag=tf.broadcast_to([0.01, 0.01], params.shape + [2])

Because:

  • params.shape is just the batch shape (i.e., [num_outer_particles])
  • Which will give inner_inital_state_prior(0, params) the batch_shape of [num_outer_particles] and the correct event_shape of [2]. (And the SMC^2 code is then responsible for drawing num_inner_particles samples from this distribution and re-arranging the dimensions to get initial particles of shape [num_outer_particles, num_inner_particles, 2].)

broadcasted_params = tf.broadcast_to(reshaped_params,
previous_state.shape)
reshaped_dist = independent.Independent(
normal.Normal(previous_state + broadcasted_params + 1, 0.1),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

params is has shape [num_outer_particles] and previous_state has shape [num_outer_particles, num_inner_particles, 2], so here it should be sufficient to just do:

previous_state + params[..., tf.newaxis, tf.newaxis] + 1

tf.equal(tf.math.mod(step, tf.constant(2)), tf.constant(0)),
tf.not_equal(state.extra[0], tf.constant(0))
)
return tf.cond(cond, lambda: tf.constant(True),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should work to just return cond, instead of using tf.cond(cond, ...).

self.evaluate(
tf.reduce_sum(tf.exp(log_weights) *
particles['position'], axis=2)),
observed_positions,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would hope this assertion would pass if we changed this line to observed_positions[..., tf.newaxis, tf.newaxis].

But if that didn't work, we could also do: tf.broadcast_to(observation_positions, [num_timesteps] + batch_shape).

@copybara-service copybara-service bot merged commit aff7da4 into tensorflow:main Feb 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants