diff --git a/tensorflow_probability/python/experimental/autobnn/BUILD b/tensorflow_probability/python/experimental/autobnn/BUILD index 051f58b2ed..6cff94bc4c 100644 --- a/tensorflow_probability/python/experimental/autobnn/BUILD +++ b/tensorflow_probability/python/experimental/autobnn/BUILD @@ -56,8 +56,6 @@ py_library( py_test( name = "bnn_test", srcs = ["bnn_test.py"], - # TODO(b/320723712): enable this test in OSS. - tags = ["no-oss-ci"], deps = [ ":bnn", # absl/testing:absltest dep, @@ -232,12 +230,15 @@ py_library( py_test( name = "operators_test", srcs = ["operators_test.py"], + # TODO(b/322864412): enable this test in OSS. + tags = ["no-oss-ci"], deps = [ ":kernels", ":operators", ":util", # absl/testing:absltest dep, # absl/testing:parameterized dep, + # bayeux dep, # google/protobuf:use_fast_cpp_protos dep, # jax dep, # numpy dep, diff --git a/tensorflow_probability/python/experimental/autobnn/operators_test.py b/tensorflow_probability/python/experimental/autobnn/operators_test.py index 63f978f003..aab0c34e5b 100644 --- a/tensorflow_probability/python/experimental/autobnn/operators_test.py +++ b/tensorflow_probability/python/experimental/autobnn/operators_test.py @@ -15,6 +15,7 @@ """Tests for operators.py.""" from absl.testing import parameterized +import bayeux as bx import jax import jax.numpy as jnp import numpy as np @@ -186,7 +187,6 @@ def test_add_of_adds_has_penultimate(self): self.assertEqual((50,), h.shape) def test_multiply_can_be_trained(self): - # TODO(thomaswc): Restore this test when bayeux is available. seed = jax.random.PRNGKey(20231018) x_train, y_train = util.load_fake_dataset() @@ -202,16 +202,16 @@ def train_density(params): return bnn.log_prob(params, x_train, y_train) transform, inverse_transform, _ = util.make_transforms(bnn) - del transform - del inverse_transform - del init_params - del train_density - # mix_model = bx.Model(train_density, init_params, - # transform_fn=transform, - # inverse_transform_fn=inverse_transform) - - # self.assertTrue( - # mix_model.optimize.optax_adam.debug(seed=seed, verbosity=10)) + mix_model = bx.Model( + train_density, + init_params, + transform_fn=transform, + inverse_transform_fn=inverse_transform, + ) + + self.assertTrue( + mix_model.optimize.optax_adam.debug(seed=seed, verbosity=10) + ) if __name__ == "__main__": diff --git a/tensorflow_probability/python/experimental/autobnn/training_util.py b/tensorflow_probability/python/experimental/autobnn/training_util.py index a1d9d48663..1ff680ce4a 100644 --- a/tensorflow_probability/python/experimental/autobnn/training_util.py +++ b/tensorflow_probability/python/experimental/autobnn/training_util.py @@ -17,7 +17,6 @@ import functools from typing import Any, Dict, List, Optional, Sequence, Tuple, Union -import bayeux as bx import jax import jax.numpy as jnp from jaxtyping import PyTree # pylint: disable=g-importing-member @@ -38,6 +37,11 @@ def _make_bayeux_model( for_vi: bool = False, ): """Use a MAP estimator to fit a BNN.""" + # We can't import bayeux at the file level because it would create a + # circular dependency: autobnn imports bayeux imports tfp:jax + # which in turn imports (through __init__.py files) autobnn. + import bayeux as bx # pylint:disable=g-bad-import-order,g-import-not-at-top + test_seed, init_seed = jax.random.split(seed) test_point = net.init(test_seed, x_train) transform, inverse_transform, ildj = util.make_transforms(net)