diff --git a/tensorflow_probability/python/experimental/autobnn/training_util.py b/tensorflow_probability/python/experimental/autobnn/training_util.py index 8af6ccc4c2..36e6560a3b 100644 --- a/tensorflow_probability/python/experimental/autobnn/training_util.py +++ b/tensorflow_probability/python/experimental/autobnn/training_util.py @@ -211,7 +211,8 @@ def _plot_loss_fn(losses, ax=None, log_scale=True) -> plt.Figure: def make_predictions(params, net: bnn.BNN, x_test: jax.Array) -> jax.Array: """Use a (batch of) parameters to make a prediction on x_test data.""" - return jax.vmap(lambda p: net.apply(p, x_test))(params) + f = jax.jit(lambda p: net.apply(p, x_test)) + return jax.vmap(f)(params) def make_results_dataframe(