Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 634153264
  • Loading branch information
Jake VanderPlas authored and The precondition Authors committed May 16, 2024
1 parent 78e8a4f commit dec64a5
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 72 deletions.
72 changes: 36 additions & 36 deletions precondition/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def init_training_metrics(
"""Initialize TrainingMetrics, masked if disabled."""
if not generate_training_metrics:
return optax.MaskedNode()
return jax.tree_map(
return jax.tree.map(
functools.partial(jnp.repeat, repeats=num_statistics),
default_training_metrics(
generate_fd_metrics=generate_fd_metrics,
Expand All @@ -462,7 +462,7 @@ def init_training_metrics_shapes(
generate_training_metrics,
generate_fd_metrics=generate_fd_metrics,
)
return jax.tree_map(lambda arr: [list(arr.shape), arr.dtype], seed)
return jax.tree.map(lambda arr: [list(arr.shape), arr.dtype], seed)


def init_training_metrics_pspec(
Expand All @@ -472,7 +472,7 @@ def init_training_metrics_pspec(
"""Initialize training metrics partition specification."""
if not generate_training_metrics:
return optax.MaskedNode()
return jax.tree_map(
return jax.tree.map(
lambda _: jax.sharding.PartitionSpec(),
default_training_metrics(
generate_fd_metrics=generate_fd_metrics,
Expand Down Expand Up @@ -1810,7 +1810,7 @@ def _add_metrics_into_local_stats(local_stats, metrics, keep_old):
index_start = int(local_stat.index_start)
index_end = int(len(local_stat.sizes)) + index_start
# pylint:disable=cell-var-from-loop Used immediately.
per_stat_metrics = jax.tree_map(lambda x: x[index_start:index_end], metrics)
per_stat_metrics = jax.tree.map(lambda x: x[index_start:index_end], metrics)
# We don't want to update the metrics if we didn't do a new inverse p-th
# root calculation to find a new preconditioner, so that TensorBoard curves
# look consistent (otherwise they'd oscillate between NaN and measured
Expand Down Expand Up @@ -2165,7 +2165,7 @@ def sharded_init_fn(params):
Args:
params: the parameters that should be updated.
"""
params_flat, treedef = jax.tree_flatten(params)
params_flat, treedef = jax.tree.flatten(params)
# Find max size to pad to.
max_size = 0
for param in params_flat:
Expand Down Expand Up @@ -2226,7 +2226,7 @@ def sharded_init_fn(params):
index_start,
sizes))

local_stats = jax.tree_unflatten(treedef, local_stats_flat)
local_stats = jax.tree.unflatten(treedef, local_stats_flat)
to_pad = -len(padded_statistics) % num_devices_for_pjit
if max_size == 0:
to_pad = num_devices_for_pjit
Expand Down Expand Up @@ -2285,9 +2285,9 @@ def sharded_init_partition_spec_fn(params, params_partition_spec,
partition_spec_for_statistics: PartitionSpec for the statistics.
"""
# Parallel lists of spec, and params.
param_pspec_flat, _ = jax.tree_flatten(
param_pspec_flat, _ = jax.tree.flatten(
params_partition_spec, is_leaf=lambda x: x is None)
params_flat, treedef = jax.tree_flatten(params)
params_flat, treedef = jax.tree.flatten(params)
assert param_pspec_flat
assert params_flat
# Step is replicated across cores.
Expand Down Expand Up @@ -2332,7 +2332,7 @@ def sharded_init_partition_spec_fn(params, params_partition_spec,
index_start,
sizes))

local_stats = jax.tree_unflatten(treedef, local_stats_flat)
local_stats = jax.tree.unflatten(treedef, local_stats_flat)
global_stats = GlobalShardedParameterStats(partition_spec_for_statistics, # pytype: disable=wrong-arg-types # numpy-scalars
partition_spec_for_statistics,
jax.sharding.PartitionSpec())
Expand All @@ -2348,7 +2348,7 @@ def sharded_init_shape_and_dtype_fn(params):
params: A pytree with params.
"""
# Parallel lists of spec, and params.
params_flat, treedef = jax.tree_flatten(params)
params_flat, treedef = jax.tree.flatten(params)
assert params_flat
# Step is replicated across cores.
# None means cores.
Expand Down Expand Up @@ -2394,7 +2394,7 @@ def sharded_init_shape_and_dtype_fn(params):
sizes,
))

local_stats = jax.tree_unflatten(treedef, local_stats_flat)
local_stats = jax.tree.unflatten(treedef, local_stats_flat)
max_statistics_size = _max_statistics_size_from_params(params_flat)
to_pad = -num_statistics % num_devices_for_pjit
num_statistics += to_pad
Expand Down Expand Up @@ -2426,7 +2426,7 @@ def sharded_update_fn(grads, state, params):
Returns:
A tuple containing the new parameters and the new optimizer state.
"""
params_flat, treedef = jax.tree_flatten(params)
params_flat, treedef = jax.tree.flatten(params)
grads_flat = treedef.flatten_up_to(grads)

global_stats = state.stats.global_stats
Expand All @@ -2440,16 +2440,16 @@ def sharded_update_fn(grads, state, params):
compression_rank,
))

new_stats_flat = jax.tree_map(
new_stats_flat = jax.tree.map(
lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat,
stats_flat, params_flat)

outputs = jax.tree_map(
outputs = jax.tree.map(
lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat,
new_stats_flat, params_flat)
updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())

updates = jax.tree_unflatten(treedef, updates_flat)
updates = jax.tree.unflatten(treedef, updates_flat)
new_local_stats_flat = []
for new_stat, local_stat in zip(new_stats_flat, local_stats_flat):
new_local_stats_flat.append(
Expand Down Expand Up @@ -2564,7 +2564,7 @@ def _update_preconditioners():
if generate_training_metrics:
new_local_stats_flat = _add_metrics_into_local_stats(
new_local_stats_flat, metrics, ~perform_step)
new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat)
new_local_stats = jax.tree.unflatten(treedef, new_local_stats_flat)
errors = metrics.inverse_pth_root_errors
errors = errors.reshape((-1, 1, 1))
predicate = jnp.logical_or(
Expand Down Expand Up @@ -2622,7 +2622,7 @@ def _init(param):
))

return ShampooState(
count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params))
count=jnp.zeros([], jnp.int32), stats=jax.tree.map(_init, params))

def _skip_preconditioning(param):
return len(param.shape) < skip_preconditioning_rank_lt or any(
Expand Down Expand Up @@ -2876,7 +2876,7 @@ def _internal_inverse_pth_root_all():
preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name)
metrics = jax.lax.all_gather(metrics, batch_axis_name)
preconditioners_flat = unbatch(preconditioners)
metrics_flat = jax.tree_map(unbatch, metrics)
metrics_flat = jax.tree.map(unbatch, metrics)
else:
preconditioners, metrics = _matrix_inverse_pth_root_vmap(
all_statistics[0],
Expand All @@ -2885,9 +2885,9 @@ def _internal_inverse_pth_root_all():
_maybe_ix(all_preconditioners, 0),
)
preconditioners_flat = unbatch(jnp.stack([preconditioners]))
metrics = jax.tree_map(
metrics = jax.tree.map(
functools.partial(jnp.expand_dims, axis=0), metrics)
metrics_flat = jax.tree_map(unbatch, metrics)
metrics_flat = jax.tree.map(unbatch, metrics)

return preconditioners_flat, metrics_flat

Expand Down Expand Up @@ -2916,7 +2916,7 @@ def _update_preconditioners():
s[:, :precond_dim(s.shape[0])] for s in packed_statistics
]
n = len(packed_statistics)
metrics_init = jax.tree_map(
metrics_init = jax.tree.map(
lambda x: [x] * n,
default_training_metrics(
generate_fd_metrics
Expand Down Expand Up @@ -2973,12 +2973,12 @@ def _select_preconditioner(error, new_p, old_p):

if generate_training_metrics:
# pylint:disable=cell-var-from-loop Used immediately.
metrics_for_state = jax.tree_map(
metrics_for_state = jax.tree.map(
lambda x: jnp.stack(x[idx:idx + num_statistics]),
metrics_flat,
is_leaf=lambda x: isinstance(x, list))
assert jax.tree_util.tree_all(
jax.tree_map(lambda x: len(state.statistics) == len(x),
jax.tree.map(lambda x: len(state.statistics) == len(x),
metrics_for_state))
# If we skipped preconditioner computation, record old metrics.
metrics_for_state = efficient_cond(perform_step,
Expand Down Expand Up @@ -3123,7 +3123,7 @@ def _internal_inverse_pth_root_all():
quantized_preconditioners_flat = unbatch(quantized_preconditioners)
quantized_diagonals_flat = unbatch(quantized_diagonals)
quantized_bucket_sizes_flat = unbatch(quantized_bucket_sizes)
metrics_flat = jax.tree_map(unbatch, metrics)
metrics_flat = jax.tree.map(unbatch, metrics)
return (quantized_preconditioners_flat, quantized_diagonals_flat,
quantized_bucket_sizes_flat, metrics_flat)

Expand Down Expand Up @@ -3155,7 +3155,7 @@ def _update_quantized_preconditioners():
quantized_diagonals_init = packed_quantized_diagonals
quantized_bucket_sizes_init = packed_quantized_bucket_sizes
n = len(quantized_preconditioners_init)
metrics_init = jax.tree_map(
metrics_init = jax.tree.map(
lambda x: [x] * n,
default_training_metrics(
generate_fd_metrics
Expand Down Expand Up @@ -3231,7 +3231,7 @@ def _select_preconditioner(error, new_p, old_p):

if generate_training_metrics:
# pylint:disable=cell-var-from-loop Used immediately.
metrics_for_state = jax.tree_map(
metrics_for_state = jax.tree.map(
lambda x: jnp.stack(x[idx:idx + num_statistics]),
metrics_flat,
is_leaf=lambda x: isinstance(x, list))
Expand All @@ -3241,7 +3241,7 @@ def _select_preconditioner(error, new_p, old_p):
assert len(state.statistics) == len(quantized_diagonals_for_state)
assert len(state.statistics) == len(quantized_bucket_sizes_for_state)
assert jax.tree_util.tree_all(
jax.tree_map(lambda x: len(state.statistics) == len(x),
jax.tree.map(lambda x: len(state.statistics) == len(x),
metrics_for_state))

# If we skipped preconditioner computation, record old metrics.
Expand Down Expand Up @@ -3333,7 +3333,7 @@ def split(batched_values):
for v in jnp.split(batched_values, indices_or_sections=b1, axis=0)
]

return split(preconditioners), jax.tree_map(split, metrics)
return split(preconditioners), jax.tree.map(split, metrics)

scheduled_preconditioning_compute_steps = (
decay_preconditioning_compute_steps
Expand All @@ -3357,7 +3357,7 @@ def _update_preconditioners():
pd = precond_dim(max_size)
preconditioners_init = [s[:, :pd] for s in padded_statistics]
n = len(padded_statistics)
metrics_init = jax.tree_map(
metrics_init = jax.tree.map(
lambda x: [x] * n,
TrainingMetrics(inverse_pth_root_errors=inverse_failure_threshold))
init_state = [preconditioners_init, metrics_init]
Expand Down Expand Up @@ -3411,12 +3411,12 @@ def _select_preconditioner(error, new_p, old_p):

if generate_training_metrics:
# pylint:disable=cell-var-from-loop Used immediately.
metrics_for_state = jax.tree_map(
metrics_for_state = jax.tree.map(
lambda x: jnp.stack(x[idx:idx + num_statistics]),
metrics_flat,
is_leaf=functools.partial(isinstance, list))
assert jax.tree_util.tree_all(
jax.tree_map(lambda x: len(state.statistics) == len(x),
jax.tree.map(lambda x: len(state.statistics) == len(x),
metrics_for_state))
# pylint:enable=cell-var-from-loop
else:
Expand Down Expand Up @@ -3636,24 +3636,24 @@ def update_fn(grads, state, params):
Returns:
A tuple containing the new parameters and the new optimizer state.
"""
params_flat, treedef = jax.tree_flatten(params)
params_flat, treedef = jax.tree.flatten(params)
stats_flat = treedef.flatten_up_to(state.stats)
grads_flat = treedef.flatten_up_to(grads)
stats_grads = grads_flat

new_stats_flat = jax.tree_map(
new_stats_flat = jax.tree.map(
lambda g, s, p: _compute_stats(g, s, p, state.count), stats_grads,
stats_flat, params_flat)

new_stats_flat = _compute_preconditioners(new_stats_flat, params_flat,
state.count)
outputs = jax.tree_map(
outputs = jax.tree.map(
lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat,
new_stats_flat, params_flat)
updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())

updates = jax.tree_unflatten(treedef, updates_flat)
new_stats = jax.tree_unflatten(treedef, new_stats_flat)
updates = jax.tree.unflatten(treedef, updates_flat)
new_stats = jax.tree.unflatten(treedef, new_stats_flat)

new_state = ShampooState(count=state.count + 1, stats=new_stats)
return updates, new_state
Expand Down
22 changes: 11 additions & 11 deletions precondition/sm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _init(param):
return ParameterStats(accumulators, momentum) # pytype: disable=wrong-arg-types # numpy-scalars

return SM3State(
count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params))
count=jnp.zeros([], jnp.int32), stats=jax.tree.map(_init, params))

def _get_expanded_shape(shape, i):
rank = len(shape)
Expand Down Expand Up @@ -110,11 +110,11 @@ def _sketch_diagonal_statistics(grad, updated_diagonal_statistics):
def update_fn(updates, state, params):
stats = state.stats
if normalize_grads:
updates = jax.tree_map(
updates = jax.tree.map(
lambda g: g / (jnp.linalg.norm(g) + 1e-16), updates)
# Reshape all vectors into N-d tensors to compute min over them.
# [n], [m] -> [n, 1], [1, m]
expanded_diagonal_statistics = jax.tree_map(
expanded_diagonal_statistics = jax.tree.map(
lambda grad, state: # pylint:disable=g-long-lambda
[
jnp.reshape(state.diagonal_statistics[i],
Expand All @@ -125,28 +125,28 @@ def update_fn(updates, state, params):
stats)

# Compute new diagonal statistics
new_diagonal_statistics = jax.tree_map(_moving_averages, updates,
new_diagonal_statistics = jax.tree.map(_moving_averages, updates,
expanded_diagonal_statistics)

# Compute preconditioners (1/sqrt(s)) where s is the statistics.
new_preconditioners = jax.tree_map(
new_preconditioners = jax.tree.map(
lambda t: 1.0 / jnp.sqrt(t + diagonal_epsilon), new_diagonal_statistics)
preconditioned_grads = jax.tree_map(lambda g, p: g * p, updates,
preconditioned_grads = jax.tree.map(lambda g, p: g * p, updates,
new_preconditioners)

# Compute updated momentum (also handle quantization)
updated_momentum = jax.tree_map(
updated_momentum = jax.tree.map(
lambda preconditioned_grad, state: # pylint:disable=g-long-lambda
_moving_averages_momentum(preconditioned_grad, state.diagonal_momentum),
preconditioned_grads,
stats)

# Update diagonal statistics.
updated_diagonal_statistics = jax.tree_map(_sketch_diagonal_statistics,
updated_diagonal_statistics = jax.tree.map(_sketch_diagonal_statistics,
updates, new_diagonal_statistics)

# Update momentum.
new_sm3_stats = jax.tree_map(
new_sm3_stats = jax.tree.map(
lambda momentum, diagonal_stats: # pylint:disable=g-long-lambda
ParameterStats(diagonal_stats, _quantize_momentum(momentum)),
updated_momentum,
Expand All @@ -155,14 +155,14 @@ def update_fn(updates, state, params):
# Apply weight decay
updated_momentum_with_wd = updated_momentum
if weight_decay > 0.0:
updated_momentum_with_wd = jax.tree_map(lambda g, p: g + weight_decay * p,
updated_momentum_with_wd = jax.tree.map(lambda g, p: g + weight_decay * p,
updated_momentum, params)

lr = learning_rate
if callable(learning_rate):
lr = learning_rate(state.count)

new_updates = jax.tree_map(lambda pg: -lr * pg, updated_momentum_with_wd)
new_updates = jax.tree.map(lambda pg: -lr * pg, updated_momentum_with_wd)
return new_updates, SM3State(count=state.count+1, stats=new_sm3_stats)

return optax.GradientTransformation(init_fn, update_fn)
Loading

0 comments on commit dec64a5

Please sign in to comment.