Skip to content

Commit

Permalink
Add __init__.py file for autobnn.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 602721585
  • Loading branch information
ThomasColthurst authored and tensorflower-gardener committed Jan 30, 2024
1 parent 64a70d0 commit eb67a7f
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 7 deletions.
24 changes: 21 additions & 3 deletions tensorflow_probability/python/experimental/autobnn/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,23 @@ package(
default_visibility = ["//visibility:public"],
)

py_library(
name = "autobnn.jax",
srcs = ["__init__.py"],
deps = [
":bnn",
":bnn_tree",
":estimators",
":kernels",
":likelihoods",
":models",
":operators",
":training_util",
":util",
"//tensorflow_probability/python/internal:all_util",
],
)

py_library(
name = "bnn",
srcs = ["bnn.py"],
Expand Down Expand Up @@ -97,6 +114,8 @@ py_library(
py_test(
name = "estimators_test",
srcs = ["estimators_test.py"],
# TODO(b/322864412): enable this test in OSS.
tags = ["no-oss-ci"],
deps = [
":estimators",
":kernels",
Expand Down Expand Up @@ -141,7 +160,6 @@ py_library(
# flax:core dep,
# jax dep,
# jaxtyping dep,
"//tensorflow_probability:jax",
"//tensorflow_probability/python/bijectors:softplus.jax",
"//tensorflow_probability/python/distributions:distribution.jax",
"//tensorflow_probability/python/distributions:inflated.jax",
Expand Down Expand Up @@ -200,7 +218,6 @@ py_library(
":likelihoods",
# flax:core dep,
# jax dep,
"//tensorflow_probability:jax",
"//tensorflow_probability/python/bijectors:chain.jax",
"//tensorflow_probability/python/bijectors:scale.jax",
"//tensorflow_probability/python/bijectors:shift.jax",
Expand Down Expand Up @@ -238,7 +255,6 @@ py_library(
# matplotlib dep,
# numpy dep,
# pandas dep,
"//tensorflow_probability:jax",
"//tensorflow_probability/python/experimental/autobnn:bnn",
"//tensorflow_probability/python/experimental/autobnn:util",
"//tensorflow_probability/python/experimental/timeseries:metrics",
Expand All @@ -248,6 +264,8 @@ py_library(
py_test(
name = "training_util_test",
srcs = ["training_util_test.py"],
# TODO(b/322864412): enable this test in OSS.
tags = ["no-oss-ci"],
deps = [
":training_util",
# chex dep,
Expand Down
45 changes: 45 additions & 0 deletions tensorflow_probability/python/experimental/autobnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2024 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Package for training GP-like Bayesian Neural Nets w/ composite structure."""

from tensorflow_probability.python.experimental.autobnn import bnn
from tensorflow_probability.python.experimental.autobnn import bnn_tree
# estimators causes vectorized_stochastic_volatility_test to fail
# because it imports training_util
from tensorflow_probability.python.experimental.autobnn import estimators
from tensorflow_probability.python.experimental.autobnn import kernels
from tensorflow_probability.python.experimental.autobnn import likelihoods
from tensorflow_probability.python.experimental.autobnn import models
from tensorflow_probability.python.experimental.autobnn import operators
# training_util causes vectorized_stochastic_volatility_test to fail
# Suspects: JaxTyping, bayeux, matplotlib, pandas.
# And the culprit is ... bayeux.
from tensorflow_probability.python.experimental.autobnn import training_util
from tensorflow_probability.python.experimental.autobnn import util
from tensorflow_probability.python.internal import all_util

_allowed_symbols = [
'bnn',
'bnn_tree',
'estimators',
'kernels',
'likelihoods',
'models',
'operators',
'training_util',
'util',
]

all_util.remove_undocumented(__name__, _allowed_symbols)
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@
from tensorflow_probability.python.experimental.autobnn import bnn
from tensorflow_probability.python.experimental.autobnn import util
from tensorflow_probability.python.experimental.timeseries import metrics
import tensorflow_probability.substrates.jax as tfp

tfd = tfp.distributions
tfb = tfp.bijectors


def _make_bayeux_model(
Expand Down

0 comments on commit eb67a7f

Please sign in to comment.