Skip to content

Commit

Permalink
NumPy 2.0 related fixes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 666851860
  • Loading branch information
brianwa84 authored and tensorflower-gardener committed Aug 23, 2024
1 parent 8a4555e commit 4f8251a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion tensorflow_probability/python/experimental/mcmc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ multi_substrate_py_test(
size = "large",
srcs = ["particle_filter_test.py"],
numpy_tags = ["notap"],
shard_count = 3,
shard_count = 5,
deps = [
":particle_filter",
":sequential_monte_carlo_kernel",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -476,19 +476,21 @@ def test_proposal_weights_dont_affect_marginal_likelihood(self):
_, lps = self.evaluate(
particle_filter.infer_trajectories(
observation,
initial_state_prior=normal.Normal(loc=0., scale=1.),
initial_state_prior=normal.Normal(loc=self.dtype(0.), scale=1.),
transition_fn=lambda _, x: normal.Normal(loc=x, scale=1.),
observation_fn=lambda _, x: normal.Normal(loc=x, scale=1.),
initial_state_proposal=normal.Normal(loc=0., scale=5.),
initial_state_proposal=normal.Normal(loc=self.dtype(0.), scale=5.),
proposal_fn=lambda _, x: normal.Normal(loc=x, scale=5.),
num_particles=2048,
seed=test_util.test_seed()))

# Compare marginal likelihood against that
# from the true (jointly normal) marginal distribution.
y1_marginal_dist = normal.Normal(loc=0., scale=np.sqrt(1. + 1.))
y1_marginal_dist = normal.Normal(loc=0.,
scale=np.sqrt(1. + 1.).astype(self.dtype))
y2_conditional_dist = (
lambda y1: normal.Normal(loc=y1 / 2., scale=np.sqrt(5. / 2.)))
lambda y1: normal.Normal(
loc=y1 / self.dtype(2.), scale=np.sqrt(5. / 2.).astype(self.dtype)))
true_lps = tf.stack(
[y1_marginal_dist.log_prob(observation[0]),
y2_conditional_dist(observation[0]).log_prob(observation[1])],
Expand Down

0 comments on commit 4f8251a

Please sign in to comment.