From f7f04c8c2fbf7c7956f25faabc2fdae9a1221db0 Mon Sep 17 00:00:00 2001 From: thomaswc Date: Wed, 20 Mar 2024 10:55:25 -0700 Subject: [PATCH] jax.jit training_util.py::make_predictions. PiperOrigin-RevId: 617565334 --- .../python/experimental/autobnn/training_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow_probability/python/experimental/autobnn/training_util.py b/tensorflow_probability/python/experimental/autobnn/training_util.py index 8af6ccc4c2..36e6560a3b 100644 --- a/tensorflow_probability/python/experimental/autobnn/training_util.py +++ b/tensorflow_probability/python/experimental/autobnn/training_util.py @@ -211,7 +211,8 @@ def _plot_loss_fn(losses, ax=None, log_scale=True) -> plt.Figure: def make_predictions(params, net: bnn.BNN, x_test: jax.Array) -> jax.Array: """Use a (batch of) parameters to make a prediction on x_test data.""" - return jax.vmap(lambda p: net.apply(p, x_test))(params) + f = jax.jit(lambda p: net.apply(p, x_test)) + return jax.vmap(f)(params) def make_results_dataframe(