diff --git a/tensorflow_probability/python/experimental/autobnn/BUILD b/tensorflow_probability/python/experimental/autobnn/BUILD index d74d59673f..051f58b2ed 100644 --- a/tensorflow_probability/python/experimental/autobnn/BUILD +++ b/tensorflow_probability/python/experimental/autobnn/BUILD @@ -24,6 +24,23 @@ package( default_visibility = ["//visibility:public"], ) +py_library( + name = "autobnn.jax", + srcs = ["__init__.py"], + deps = [ + ":bnn", + ":bnn_tree", + ":estimators", + ":kernels", + ":likelihoods", + ":models", + ":operators", + ":training_util", + ":util", + "//tensorflow_probability/python/internal:all_util", + ], +) + py_library( name = "bnn", srcs = ["bnn.py"], @@ -97,6 +114,8 @@ py_library( py_test( name = "estimators_test", srcs = ["estimators_test.py"], + # TODO(b/322864412): enable this test in OSS. + tags = ["no-oss-ci"], deps = [ ":estimators", ":kernels", @@ -141,7 +160,6 @@ py_library( # flax:core dep, # jax dep, # jaxtyping dep, - "//tensorflow_probability:jax", "//tensorflow_probability/python/bijectors:softplus.jax", "//tensorflow_probability/python/distributions:distribution.jax", "//tensorflow_probability/python/distributions:inflated.jax", @@ -200,7 +218,6 @@ py_library( ":likelihoods", # flax:core dep, # jax dep, - "//tensorflow_probability:jax", "//tensorflow_probability/python/bijectors:chain.jax", "//tensorflow_probability/python/bijectors:scale.jax", "//tensorflow_probability/python/bijectors:shift.jax", @@ -238,7 +255,6 @@ py_library( # matplotlib dep, # numpy dep, # pandas dep, - "//tensorflow_probability:jax", "//tensorflow_probability/python/experimental/autobnn:bnn", "//tensorflow_probability/python/experimental/autobnn:util", "//tensorflow_probability/python/experimental/timeseries:metrics", @@ -248,6 +264,8 @@ py_library( py_test( name = "training_util_test", srcs = ["training_util_test.py"], + # TODO(b/322864412): enable this test in OSS. + tags = ["no-oss-ci"], deps = [ ":training_util", # chex dep, diff --git a/tensorflow_probability/python/experimental/autobnn/__init__.py b/tensorflow_probability/python/experimental/autobnn/__init__.py new file mode 100644 index 0000000000..d7e10a3d11 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/__init__.py @@ -0,0 +1,45 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Package for training GP-like Bayesian Neural Nets w/ composite structure.""" + +from tensorflow_probability.python.experimental.autobnn import bnn +from tensorflow_probability.python.experimental.autobnn import bnn_tree +# estimators causes vectorized_stochastic_volatility_test to fail +# because it imports training_util +from tensorflow_probability.python.experimental.autobnn import estimators +from tensorflow_probability.python.experimental.autobnn import kernels +from tensorflow_probability.python.experimental.autobnn import likelihoods +from tensorflow_probability.python.experimental.autobnn import models +from tensorflow_probability.python.experimental.autobnn import operators +# training_util causes vectorized_stochastic_volatility_test to fail +# Suspects: JaxTyping, bayeux, matplotlib, pandas. +# And the culprit is ... bayeux. +from tensorflow_probability.python.experimental.autobnn import training_util +from tensorflow_probability.python.experimental.autobnn import util +from tensorflow_probability.python.internal import all_util + +_allowed_symbols = [ + 'bnn', + 'bnn_tree', + 'estimators', + 'kernels', + 'likelihoods', + 'models', + 'operators', + 'training_util', + 'util', +] + +all_util.remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow_probability/python/experimental/autobnn/training_util.py b/tensorflow_probability/python/experimental/autobnn/training_util.py index 20cc5da31f..a1d9d48663 100644 --- a/tensorflow_probability/python/experimental/autobnn/training_util.py +++ b/tensorflow_probability/python/experimental/autobnn/training_util.py @@ -27,10 +27,6 @@ from tensorflow_probability.python.experimental.autobnn import bnn from tensorflow_probability.python.experimental.autobnn import util from tensorflow_probability.python.experimental.timeseries import metrics -import tensorflow_probability.substrates.jax as tfp - -tfd = tfp.distributions -tfb = tfp.bijectors def _make_bayeux_model(