Skip to content

Commit

Permalink
consolidate the code example
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenying-liu committed Dec 19, 2024
1 parent ab52c63 commit 3a57b72
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions docs/gradient-checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -379,20 +379,33 @@ One of JAX's checkpoint policies allows specified checkpoint names to be offload

```{code-cell}
from jax.ad_checkpoint import checkpoint, checkpoint_name
from jax._src import test_util as jtu
def g(self):
def checkpoint_names_saved_offloaded_recomputed(self):
mesh = jtu.create_mesh((2,), ("x",))
shape = (256, 128)
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
s = NamedSharding(mesh, P("x"))
inp = jax.device_put(np_inp, s)
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"],
offload_src='device', offload_dst='pinned_host')
@functools.partial(checkpoint, policy=policy)
def f(x):
y = checkpoint_name(jnp.sin(y), "y")
z = checkpoint_name(jnp.sin(y), "z")
w = checkpoint_name(jnp.sin(z), "w")
return jnp.sum(w)
```
def g(ys, _):
y, _ = ys
y = checkpoint_name(jnp.sin(y), "y")
z = checkpoint_name(jnp.sin(y), "z")
z = z.T
w = checkpoint_name(jnp.sin(z), "w")
return (w.T, jnp.sum(w)), None
_, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
return scan_out
```

The code defines a function `f` that uses checkpointing with a custom policy. Inside `f`, there is a nested function `g` that performs a series of computations using `jnp.sin` and checkpoint names. The `jax.lax.scan` function is used to apply `g` repeatedly over the input data.

#### List of policies

Expand Down

0 comments on commit 3a57b72

Please sign in to comment.