Skip to content

Commit

Permalink
Add training_util to experimental/autobnn.
Browse files Browse the repository at this point in the history
Add metrics to experimental/timeseries.

PiperOrigin-RevId: 595456896
  • Loading branch information
ThomasColthurst authored and tensorflower-gardener committed Jan 3, 2024
1 parent ec23c1f commit e1b7ccb
Show file tree
Hide file tree
Showing 8 changed files with 876 additions and 17 deletions.
65 changes: 49 additions & 16 deletions tensorflow_probability/python/experimental/autobnn/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,35 @@ py_test(
],
)

py_library(
name = "bnn_tree",
srcs = ["bnn_tree.py"],
deps = [
":bnn",
":kernels",
":operators",
":util",
# flax:core dep,
# jax dep,
],
)

py_test(
name = "bnn_tree_test",
timeout = "long",
srcs = ["bnn_tree_test.py"],
shard_count = 3,
deps = [
":bnn_tree",
":kernels",
# absl/testing:absltest dep,
# absl/testing:parameterized dep,
# flax dep,
# google/protobuf:use_fast_cpp_protos dep,
# jax dep,
],
)

py_library(
name = "kernels",
srcs = ["kernels.py"],
Expand Down Expand Up @@ -173,31 +202,35 @@ py_test(
)

py_library(
name = "bnn_tree",
srcs = ["bnn_tree.py"],
name = "training_util",
srcs = ["training_util.py"],
deps = [
":bnn",
":kernels",
":operators",
":util",
# flax:core dep,
# bayeux dep,
# jax dep,
# jaxtyping dep,
# matplotlib dep,
# numpy dep,
# pandas dep,
"//tensorflow_probability:jax",
"//tensorflow_probability/python/experimental/autobnn:bnn",
"//tensorflow_probability/python/experimental/autobnn:util",
"//tensorflow_probability/python/experimental/timeseries:metrics",
],
)

py_test(
name = "bnn_tree_test",
timeout = "long",
srcs = ["bnn_tree_test.py"],
shard_count = 3,
name = "training_util_test",
srcs = ["training_util_test.py"],
deps = [
":bnn_tree",
":kernels",
# absl/testing:absltest dep,
# absl/testing:parameterized dep,
# flax dep,
":training_util",
# chex dep,
# google/protobuf:use_fast_cpp_protos dep,
# jax dep,
# numpy dep,
"//tensorflow_probability/python/experimental/autobnn:kernels",
"//tensorflow_probability/python/experimental/autobnn:operators",
"//tensorflow_probability/python/experimental/autobnn:util",
"//tensorflow_probability/python/internal:test_util",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ else
PIP_FLAGS=""
fi

python -m pip install $PIP_FLAGS flax jaxtyping scipy
python -m pip install $PIP_FLAGS bayeux-ml chex flax jaxtyping matplotlib pandas scipy
Loading

0 comments on commit e1b7ccb

Please sign in to comment.