Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 633690677
  • Loading branch information
Jake VanderPlas authored and copybara-github committed May 21, 2024
1 parent 143bd26 commit 7b30c3c
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion vit_jax/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def _shard(data):
def prefetch(dataset, n_prefetch):
"""Prefetches data to device and converts to numpy array."""
ds_iter = iter(dataset)
ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x),
ds_iter = map(lambda x: jax.tree.map(lambda t: np.asarray(memoryview(t)), x),
ds_iter)
if n_prefetch:
ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
Expand Down
2 changes: 1 addition & 1 deletion vit_jax/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_can_instantiate(self, name, size):
self.assertEqual((2, 196, 1000), outputs.shape)
else:
self.assertEqual((2, 1000), outputs.shape)
param_count = sum(p.size for p in jax.tree_flatten(variables)[0])
param_count = sum(p.size for p in jax.tree.flatten(variables)[0])
self.assertEqual(
size, param_count,
f'Expected {name} to have {size} params, found {param_count}.')
Expand Down
2 changes: 1 addition & 1 deletion vit_jax/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _tree_flatten_with_names(tree):
Returns:
A list of values with names: [(name, value), ...]
"""
vals, tree_def = jax.tree_flatten(tree)
vals, tree_def = jax.tree.flatten(tree)

# "Fake" token tree that is use to track jax internal tree traversal and
# adjust our custom tree traversal to be compatible with it.
Expand Down
2 changes: 1 addition & 1 deletion vit_jax/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def loss_fn(params, images, labels):
l, g = utils.accumulate_gradient(
jax.value_and_grad(loss_fn), params, batch['image'], batch['label'],
accum_steps)
g = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), g)
g = jax.tree.map(lambda x: jax.lax.pmean(x, axis_name='batch'), g)
updates, opt_state = tx.update(g, opt_state)
params = optax.apply_updates(params, updates)
l = jax.lax.pmean(l, axis_name='batch')
Expand Down
4 changes: 2 additions & 2 deletions vit_jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def acc_grad_and_loss(i, l_and_g):
(step_size, labels.shape[1]))
li, gi = loss_and_grad_fn(params, imgs, lbls)
l, g = l_and_g
return (l + li, jax.tree_map(lambda x, y: x + y, g, gi))
return (l + li, jax.tree.map(lambda x, y: x + y, g, gi))

l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g))
return jax.tree_map(lambda x: x / accum_steps, (l, g))
return jax.tree.map(lambda x: x / accum_steps, (l, g))
else:
return loss_and_grad_fn(params, images, labels)

0 comments on commit 7b30c3c

Please sign in to comment.