Skip to content

Commit

Permalink
Re-enable bnn_test.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 603113976
  • Loading branch information
ThomasColthurst authored and tensorflower-gardener committed Jan 31, 2024
1 parent 25da072 commit dcd820c
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 14 deletions.
5 changes: 3 additions & 2 deletions tensorflow_probability/python/experimental/autobnn/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ py_library(
py_test(
name = "bnn_test",
srcs = ["bnn_test.py"],
# TODO(b/320723712): enable this test in OSS.
tags = ["no-oss-ci"],
deps = [
":bnn",
# absl/testing:absltest dep,
Expand Down Expand Up @@ -232,12 +230,15 @@ py_library(
py_test(
name = "operators_test",
srcs = ["operators_test.py"],
# TODO(b/322864412): enable this test in OSS.
tags = ["no-oss-ci"],
deps = [
":kernels",
":operators",
":util",
# absl/testing:absltest dep,
# absl/testing:parameterized dep,
# bayeux dep,
# google/protobuf:use_fast_cpp_protos dep,
# jax dep,
# numpy dep,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Tests for operators.py."""

from absl.testing import parameterized
import bayeux as bx
import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -186,7 +187,6 @@ def test_add_of_adds_has_penultimate(self):
self.assertEqual((50,), h.shape)

def test_multiply_can_be_trained(self):
# TODO(thomaswc): Restore this test when bayeux is available.
seed = jax.random.PRNGKey(20231018)
x_train, y_train = util.load_fake_dataset()

Expand All @@ -202,16 +202,16 @@ def train_density(params):
return bnn.log_prob(params, x_train, y_train)

transform, inverse_transform, _ = util.make_transforms(bnn)
del transform
del inverse_transform
del init_params
del train_density
# mix_model = bx.Model(train_density, init_params,
# transform_fn=transform,
# inverse_transform_fn=inverse_transform)

# self.assertTrue(
# mix_model.optimize.optax_adam.debug(seed=seed, verbosity=10))
mix_model = bx.Model(
train_density,
init_params,
transform_fn=transform,
inverse_transform_fn=inverse_transform,
)

self.assertTrue(
mix_model.optimize.optax_adam.debug(seed=seed, verbosity=10)
)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import functools
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import bayeux as bx
import jax
import jax.numpy as jnp
from jaxtyping import PyTree # pylint: disable=g-importing-member
Expand All @@ -38,6 +37,11 @@ def _make_bayeux_model(
for_vi: bool = False,
):
"""Use a MAP estimator to fit a BNN."""
# We can't import bayeux at the file level because it would create a
# circular dependency: autobnn imports bayeux imports tfp:jax
# which in turn imports (through __init__.py files) autobnn.
import bayeux as bx # pylint:disable=g-bad-import-order,g-import-not-at-top

test_seed, init_seed = jax.random.split(seed)
test_point = net.init(test_seed, x_train)
transform, inverse_transform, ildj = util.make_transforms(net)
Expand Down

0 comments on commit dcd820c

Please sign in to comment.