Skip to content

Commit

Permalink
Move autobnn to spinoffs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 619626204
  • Loading branch information
ThomasColthurst authored and tensorflower-gardener committed Mar 27, 2024
1 parent 7e1ce86 commit a77f8dd
Show file tree
Hide file tree
Showing 22 changed files with 55 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,14 @@ py_library(
name = "training_util",
srcs = ["training_util.py"],
deps = [
":bnn",
":util",
# bayeux dep,
# jax dep,
# jaxtyping dep,
# matplotlib dep,
# numpy dep,
# pandas dep,
"//tensorflow_probability/python/experimental/autobnn:bnn",
"//tensorflow_probability/python/experimental/autobnn:util",
"//tensorflow_probability/python/experimental/timeseries:metrics",
],
)
Expand All @@ -268,14 +268,14 @@ py_test(
# TODO(b/322864412): enable this test in OSS.
tags = ["no-oss-ci"],
deps = [
":kernels",
":operators",
":training_util",
":util",
# chex dep,
# google/protobuf:use_fast_cpp_protos dep,
# jax dep,
# numpy dep,
"//tensorflow_probability/python/experimental/autobnn:kernels",
"//tensorflow_probability/python/experimental/autobnn:operators",
"//tensorflow_probability/python/experimental/autobnn:util",
"//tensorflow_probability/python/internal:test_util",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
# ============================================================================
"""Package for training GP-like Bayesian Neural Nets w/ composite structure."""

from tensorflow_probability.python.experimental.autobnn import bnn
from tensorflow_probability.python.experimental.autobnn import bnn_tree
from tensorflow_probability.python.experimental.autobnn import estimators
from tensorflow_probability.python.experimental.autobnn import kernels
from tensorflow_probability.python.experimental.autobnn import likelihoods
from tensorflow_probability.python.experimental.autobnn import models
from tensorflow_probability.python.experimental.autobnn import operators
from tensorflow_probability.python.experimental.autobnn import training_util
from tensorflow_probability.python.experimental.autobnn import util
from tensorflow_probability.python.internal import all_util
from tensorflow_probability.spinoffs.autobnn import bnn
from tensorflow_probability.spinoffs.autobnn import bnn_tree
from tensorflow_probability.spinoffs.autobnn import estimators
from tensorflow_probability.spinoffs.autobnn import kernels
from tensorflow_probability.spinoffs.autobnn import likelihoods
from tensorflow_probability.spinoffs.autobnn import models
from tensorflow_probability.spinoffs.autobnn import operators
from tensorflow_probability.spinoffs.autobnn import training_util
from tensorflow_probability.spinoffs.autobnn import util

_allowed_symbols = [
'bnn',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, PyTree # pylint: disable=g-importing-member,g-multiple-import
from tensorflow_probability.python.experimental.autobnn import likelihoods
from tensorflow_probability.spinoffs.autobnn import likelihoods
from tensorflow_probability.substrates.jax.distributions import distribution as distribution_lib


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from flax import linen as nn
import jax
import jax.numpy as jnp
from tensorflow_probability.python.experimental.autobnn import bnn
from tensorflow_probability.spinoffs.autobnn import bnn
from tensorflow_probability.substrates.jax.distributions import lognormal as lognormal_lib
from tensorflow_probability.substrates.jax.distributions import normal as normal_lib
from absl.testing import absltest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from flax import linen as nn
import jax
import jax.numpy as jnp
from tensorflow_probability.python.experimental.autobnn import bnn
from tensorflow_probability.python.experimental.autobnn import kernels
from tensorflow_probability.python.experimental.autobnn import operators
from tensorflow_probability.python.experimental.autobnn import util
from tensorflow_probability.spinoffs.autobnn import bnn
from tensorflow_probability.spinoffs.autobnn import kernels
from tensorflow_probability.spinoffs.autobnn import operators
from tensorflow_probability.spinoffs.autobnn import util

Array = jnp.ndarray

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from flax import linen as nn
import jax
import jax.numpy as jnp
from tensorflow_probability.python.experimental.autobnn import bnn_tree
from tensorflow_probability.python.experimental.autobnn import kernels
from tensorflow_probability.spinoffs.autobnn import bnn_tree
from tensorflow_probability.spinoffs.autobnn import kernels
from absl.testing import absltest


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import jax
import jax.numpy as jnp
from jaxtyping import ArrayLike, PyTree # pylint: disable=g-importing-member,g-multiple-import
from tensorflow_probability.python.experimental.autobnn import bnn
from tensorflow_probability.python.experimental.autobnn import likelihoods
from tensorflow_probability.python.experimental.autobnn import models
from tensorflow_probability.python.experimental.autobnn import training_util
from tensorflow_probability.spinoffs.autobnn import bnn
from tensorflow_probability.spinoffs.autobnn import likelihoods
from tensorflow_probability.spinoffs.autobnn import models
from tensorflow_probability.spinoffs.autobnn import training_util


class _AutoBnnEstimator:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

import jax
import numpy as np
from tensorflow_probability.python.experimental.autobnn import estimators
from tensorflow_probability.python.experimental.autobnn import kernels
from tensorflow_probability.python.experimental.autobnn import operators
from tensorflow_probability.python.experimental.autobnn import util
from tensorflow_probability.python.internal import test_util
from tensorflow_probability.spinoffs.autobnn import estimators
from tensorflow_probability.spinoffs.autobnn import kernels
from tensorflow_probability.spinoffs.autobnn import operators
from tensorflow_probability.spinoffs.autobnn import util


class AutoBNNTest(test_util.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from flax.linen import initializers
import jax
import jax.numpy as jnp
from tensorflow_probability.python.experimental.autobnn import bnn
from tensorflow_probability.spinoffs.autobnn import bnn
from tensorflow_probability.substrates.jax.distributions import lognormal as lognormal_lib
from tensorflow_probability.substrates.jax.distributions import normal as normal_lib
from tensorflow_probability.substrates.jax.distributions import student_t as student_t_lib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import jax
import jax.numpy as jnp
import numpy as np
from tensorflow_probability.python.experimental.autobnn import kernels
from tensorflow_probability.python.experimental.autobnn import util
from tensorflow_probability.spinoffs.autobnn import kernels
from tensorflow_probability.spinoffs.autobnn import util
from tensorflow_probability.substrates.jax.distributions import lognormal as lognormal_lib

from absl.testing import absltest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from absl.testing import parameterized
import jax.numpy as jnp
from tensorflow_probability.python.experimental.autobnn import likelihoods
from tensorflow_probability.spinoffs.autobnn import likelihoods
from absl.testing import absltest


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
import functools
from typing import Sequence, Union
import jax.numpy as jnp
from tensorflow_probability.python.experimental.autobnn import bnn
from tensorflow_probability.python.experimental.autobnn import bnn_tree
from tensorflow_probability.python.experimental.autobnn import kernels
from tensorflow_probability.python.experimental.autobnn import likelihoods
from tensorflow_probability.python.experimental.autobnn import operators
from tensorflow_probability.spinoffs.autobnn import bnn
from tensorflow_probability.spinoffs.autobnn import bnn_tree
from tensorflow_probability.spinoffs.autobnn import kernels
from tensorflow_probability.spinoffs.autobnn import likelihoods
from tensorflow_probability.spinoffs.autobnn import operators


Array = jnp.ndarray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from tensorflow_probability.python.experimental.autobnn import likelihoods
from tensorflow_probability.python.experimental.autobnn import models
from tensorflow_probability.spinoffs.autobnn import likelihoods
from tensorflow_probability.spinoffs.autobnn import models
from absl.testing import absltest


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from flax import linen as nn
import jax
import jax.numpy as jnp
from tensorflow_probability.python.experimental.autobnn import bnn
from tensorflow_probability.python.experimental.autobnn import likelihoods
from tensorflow_probability.spinoffs.autobnn import bnn
from tensorflow_probability.spinoffs.autobnn import likelihoods
from tensorflow_probability.substrates.jax.bijectors import chain as chain_lib
from tensorflow_probability.substrates.jax.bijectors import scale as scale_lib
from tensorflow_probability.substrates.jax.bijectors import shift as shift_lib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import jax
import jax.numpy as jnp
import numpy as np
from tensorflow_probability.python.experimental.autobnn import kernels
from tensorflow_probability.python.experimental.autobnn import operators
from tensorflow_probability.python.experimental.autobnn import util
from tensorflow_probability.spinoffs.autobnn import kernels
from tensorflow_probability.spinoffs.autobnn import operators
from tensorflow_probability.spinoffs.autobnn import util
from tensorflow_probability.substrates.jax.distributions import distribution as distribution_lib
from absl.testing import absltest

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tensorflow_probability.python.experimental.autobnn import bnn
from tensorflow_probability.python.experimental.autobnn import util
from tensorflow_probability.python.experimental.timeseries import metrics
from tensorflow_probability.spinoffs.autobnn import bnn
from tensorflow_probability.spinoffs.autobnn import util


def _make_bayeux_model(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
import jax
import jax.numpy as jnp
import numpy as np
from tensorflow_probability.python.experimental.autobnn import kernels
from tensorflow_probability.python.experimental.autobnn import operators
from tensorflow_probability.python.experimental.autobnn import training_util
from tensorflow_probability.python.experimental.autobnn import util
from tensorflow_probability.python.internal import test_util
from tensorflow_probability.spinoffs.autobnn import kernels
from tensorflow_probability.spinoffs.autobnn import operators
from tensorflow_probability.spinoffs.autobnn import training_util
from tensorflow_probability.spinoffs.autobnn import util


class TrainingUtilTest(test_util.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import jax
import jax.numpy as jnp
import scipy
from tensorflow_probability.python.experimental.autobnn import bnn
from tensorflow_probability.spinoffs.autobnn import bnn
from tensorflow_probability.substrates.jax.distributions import distribution as distribution_lib


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import jax
import jax.numpy as jnp
import numpy as np
from tensorflow_probability.python.experimental.autobnn import kernels
from tensorflow_probability.python.experimental.autobnn import util
from tensorflow_probability.python.internal import test_util
from tensorflow_probability.spinoffs.autobnn import kernels
from tensorflow_probability.spinoffs.autobnn import util


class UtilTest(test_util.TestCase):
Expand Down

0 comments on commit a77f8dd

Please sign in to comment.