-
Hi jax team, I'm working on switching our code to Jax (using Flax as NN library) and I'm amazed with jit and vmap. I'm wondering if there are best practices for when to apply these. For example:
Maybe these things don't matter at all and the jit and vmap magic is super robust against amateurs like me, but otherwise a simple best practices page could help. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 6 replies
-
@GJBoth I'd also love to know the best practices - good questions. Maybe this issue should be moved GitHub Discussions by the JAX admins (there's a guide for this: https://docs.github.com/en/free-pro-team@latest/discussions/managing-discussions-for-your-community/moderating-discussions#converting-an-issue-to-a-discussion) cc Flax @avital |
Beta Was this translation helpful? Give feedback.
-
Hi @GJBoth 👋 I did some research on Flax and Haiku examples to try to find some answers.
Working on cutting edge stuff 🔥 Nice!
It's like magical 🦄 dust wrapped around XLA (...I don't know what this means and how those transforms work).
It looks like it depends on how we define the I included some examples - supervised classification, RL, generative - of Flax Linen and Haiku code below. Some Flax Linen examples also use
@jax.jit
def train_step(optimizer, batch, z_rng):
def loss_fn(params):
...
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
...
optimizer = optimizer.apply_gradient(grad)
return optimizer
@jax.jit
def train_step(optimizer, batch, masks, key):
...
def loss_fn(params):
...
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
...
return optimizer, metrics
@functools.partial(jax.jit, static_argnums=1)
def loss_fn(
...):
...
return PPO_loss + vf_coeff*value_loss - entropy_coeff*entropy
@functools.partial(jax.jit, static_argnums=(0,7))
def train_step(
...):
...
for batch in zip(*trajectories):
grad_fn = jax.value_and_grad(loss_fn)
...
optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
return optimizer, loss
@jax.jit
def loss_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch) -> jnp.ndarray:
...
outputs: VAEOutput = model.apply(params, rng_key, batch["image"])
log_likelihood = -binary_cross_entropy(batch["image"], outputs.logits)
kl = kl_gaussian(outputs.mean, outputs.stddev**2)
elbo = log_likelihood - kl
return -jnp.mean(elbo)
@jax.jit
def update(
...
) -> Tuple[hk.Params, OptState]:
...
grads = jax.grad(loss_fn)(params, rng_key, batch)
updates, new_opt_state = optimizer.update(grads, opt_state)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_state
@functools.partial(jax.jit, static_argnums=0)
def update(self, params, opt_state, batch: util.Transition):
"""The actual update function."""
(_, logs), grads = jax.value_and_grad(
self._loss, has_aux=True)(params, batch)
grad_norm_unclipped = optimizers.l2_norm(grads)
...
return params, updated_opt_state, logs
...
def policy_gradient_loss(logits, *args):
...
mean_per_batch = jax.vmap(rlax.policy_gradient_loss, in_axes=1)(logits, *args)
total_loss_per_batch = mean_per_batch * logits.shape[0]
return jnp.sum(total_loss_per_batch)
def entropy_loss(logits, *args):
...
mean_per_batch = jax.vmap(rlax.entropy_loss, in_axes=1)(logits, *args)
total_loss_per_batch = mean_per_batch * logits.shape[0]
return jnp.sum(total_loss_per_batch)
def initialized(key, image_size, model):
input_shape = (1, image_size, image_size, 3)
@jax.jit
def init(*args):
return model.init(*args)
variables = init({'params': key}, jnp.ones(input_shape, model.dtype))
model_state, params = variables.pop('params')
return params, model_state
@jax.jit
@functools.partial(jax.vmap, in_axes=(1, 1, 1, None, None), out_axes=1)
def gae_advantages(
...):
....
for t in reversed(range(len(rewards))):
...
gae = delta + discount * gae_param * terminal_masks[t] * gae
advantages.append(gae)
advantages = advantages[::-1]
return jnp.array(advantages)
@functools.partial(jax.jit, static_argnums=(0, 1))
def initial_state(self, batch_size: Optional[int]):
...
return self._initial_state_apply_fn(None, batch_size)
@functools.partial(jax.jit, static_argnums=(0,))
def step(
....
) -> Tuple[AgentOutput, Nest]:
...
action = hk.multinomial(rng_key, net_out.policy_logits, num_samples=1)
...
return AgentOutput(net_out.policy_logits, net_out.value, action), next_state |
Beta Was this translation helpful? Give feedback.
-
Generally, you can always jit the top-most level (but it's fine to jit inner functions -- that doesn't impact the end result.)
But I agree we should have a simple page with best practices for |
Beta Was this translation helpful? Give feedback.
Generally, you can always jit the top-most level (but it's fine to jit inner functions -- that doesn't impact the end result.)
vmap
changes the function signature so put it where it makes sense -- if you want to have a forward pass function that works on batches but defined on single elements, then you shouldvmap
that function.But I agree we should have a simple page with best practices for
jit
,vmap
, etc. I'll file an issue for that.