Skip to content

Commit

Permalink
fixed test
Browse files Browse the repository at this point in the history
  • Loading branch information
aleslamitz committed Jan 21, 2024
1 parent 79ca7d9 commit 183fdf9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 851 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def seeded_one_step(seed_state_results, _):


def smc_squared(
inner_observations,
observations,
initial_parameter_prior,
num_outer_particles,
inner_initial_state_prior,
Expand Down Expand Up @@ -474,14 +474,14 @@ def smc_squared(
seed, n=3, salt='smc_squared'
)

num_observation_steps = ps.size0(tf.nest.flatten(inner_observations)[0])
num_observation_steps = ps.size0(tf.nest.flatten(observations)[0])

# TODO: The following two lines compensates for having the
# first empty step in smc2
num_timesteps = (1 + num_transitions_per_observation *
(num_observation_steps - 1)) + 1
last_obs_expanded = tf.expand_dims(inner_observations[-1], axis=0)
inner_observations = tf.concat([inner_observations,
last_obs_expanded = tf.expand_dims(observations[-1], axis=0)
inner_observations = tf.concat([observations,
last_obs_expanded],
axis=0)

Expand Down Expand Up @@ -1104,12 +1104,13 @@ def _compute_observation_log_weights(step,
observation = tf.nest.map_structure(
lambda x, step=step: tf.gather(x, observation_idx), observations)

if particles_dim == 1:
observation = tf.expand_dims(observation, axis=0)
observation = tf.nest.map_structure(
lambda x: tf.expand_dims(x, axis=particles_dim), observation)
if particles_dim != 1:
observation = tf.nest.map_structure(
lambda x: tf.expand_dims(x, axis=particles_dim), observation
)

log_weights = observation_fn(step, particles).log_prob(observation)

return tf.where(step_has_observation,
log_weights,
tf.zeros_like(log_weights))
Expand Down
Loading

0 comments on commit 183fdf9

Please sign in to comment.