Skip to content

Commit

Permalink
jax.jit training_util.py::make_predictions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 617565334
  • Loading branch information
ThomasColthurst authored and tensorflower-gardener committed Mar 20, 2024
1 parent b0abbc7 commit f7f04c8
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ def _plot_loss_fn(losses, ax=None, log_scale=True) -> plt.Figure:

def make_predictions(params, net: bnn.BNN, x_test: jax.Array) -> jax.Array:
"""Use a (batch of) parameters to make a prediction on x_test data."""
return jax.vmap(lambda p: net.apply(p, x_test))(params)
f = jax.jit(lambda p: net.apply(p, x_test))
return jax.vmap(f)(params)


def make_results_dataframe(
Expand Down

0 comments on commit f7f04c8

Please sign in to comment.