From dec64a574a196337bccff05f4850d7d0d548ea4d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 15 May 2024 18:40:42 -0700 Subject: [PATCH] Replace deprecated `jax.tree_*` functions with `jax.tree.*` The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25. PiperOrigin-RevId: 634153264 --- precondition/distributed_shampoo.py | 72 +++++++++++++------------- precondition/sm3.py | 22 ++++---- precondition/tearfree/grafting.py | 12 ++--- precondition/tearfree/grafting_test.py | 2 +- precondition/tearfree/momentum.py | 2 +- precondition/tearfree/praxis_shim.py | 2 +- precondition/tearfree/reshaper.py | 8 +-- precondition/tearfree/reshaper_test.py | 4 +- precondition/tearfree/shampoo.py | 10 ++-- precondition/tearfree/sketchy_test.py | 10 ++-- 10 files changed, 72 insertions(+), 72 deletions(-) diff --git a/precondition/distributed_shampoo.py b/precondition/distributed_shampoo.py index 92884ab..8ad9655 100644 --- a/precondition/distributed_shampoo.py +++ b/precondition/distributed_shampoo.py @@ -444,7 +444,7 @@ def init_training_metrics( """Initialize TrainingMetrics, masked if disabled.""" if not generate_training_metrics: return optax.MaskedNode() - return jax.tree_map( + return jax.tree.map( functools.partial(jnp.repeat, repeats=num_statistics), default_training_metrics( generate_fd_metrics=generate_fd_metrics, @@ -462,7 +462,7 @@ def init_training_metrics_shapes( generate_training_metrics, generate_fd_metrics=generate_fd_metrics, ) - return jax.tree_map(lambda arr: [list(arr.shape), arr.dtype], seed) + return jax.tree.map(lambda arr: [list(arr.shape), arr.dtype], seed) def init_training_metrics_pspec( @@ -472,7 +472,7 @@ def init_training_metrics_pspec( """Initialize training metrics partition specification.""" if not generate_training_metrics: return optax.MaskedNode() - return jax.tree_map( + return jax.tree.map( lambda _: jax.sharding.PartitionSpec(), default_training_metrics( generate_fd_metrics=generate_fd_metrics, @@ -1810,7 +1810,7 @@ def _add_metrics_into_local_stats(local_stats, metrics, keep_old): index_start = int(local_stat.index_start) index_end = int(len(local_stat.sizes)) + index_start # pylint:disable=cell-var-from-loop Used immediately. - per_stat_metrics = jax.tree_map(lambda x: x[index_start:index_end], metrics) + per_stat_metrics = jax.tree.map(lambda x: x[index_start:index_end], metrics) # We don't want to update the metrics if we didn't do a new inverse p-th # root calculation to find a new preconditioner, so that TensorBoard curves # look consistent (otherwise they'd oscillate between NaN and measured @@ -2165,7 +2165,7 @@ def sharded_init_fn(params): Args: params: the parameters that should be updated. """ - params_flat, treedef = jax.tree_flatten(params) + params_flat, treedef = jax.tree.flatten(params) # Find max size to pad to. max_size = 0 for param in params_flat: @@ -2226,7 +2226,7 @@ def sharded_init_fn(params): index_start, sizes)) - local_stats = jax.tree_unflatten(treedef, local_stats_flat) + local_stats = jax.tree.unflatten(treedef, local_stats_flat) to_pad = -len(padded_statistics) % num_devices_for_pjit if max_size == 0: to_pad = num_devices_for_pjit @@ -2285,9 +2285,9 @@ def sharded_init_partition_spec_fn(params, params_partition_spec, partition_spec_for_statistics: PartitionSpec for the statistics. """ # Parallel lists of spec, and params. - param_pspec_flat, _ = jax.tree_flatten( + param_pspec_flat, _ = jax.tree.flatten( params_partition_spec, is_leaf=lambda x: x is None) - params_flat, treedef = jax.tree_flatten(params) + params_flat, treedef = jax.tree.flatten(params) assert param_pspec_flat assert params_flat # Step is replicated across cores. @@ -2332,7 +2332,7 @@ def sharded_init_partition_spec_fn(params, params_partition_spec, index_start, sizes)) - local_stats = jax.tree_unflatten(treedef, local_stats_flat) + local_stats = jax.tree.unflatten(treedef, local_stats_flat) global_stats = GlobalShardedParameterStats(partition_spec_for_statistics, # pytype: disable=wrong-arg-types # numpy-scalars partition_spec_for_statistics, jax.sharding.PartitionSpec()) @@ -2348,7 +2348,7 @@ def sharded_init_shape_and_dtype_fn(params): params: A pytree with params. """ # Parallel lists of spec, and params. - params_flat, treedef = jax.tree_flatten(params) + params_flat, treedef = jax.tree.flatten(params) assert params_flat # Step is replicated across cores. # None means cores. @@ -2394,7 +2394,7 @@ def sharded_init_shape_and_dtype_fn(params): sizes, )) - local_stats = jax.tree_unflatten(treedef, local_stats_flat) + local_stats = jax.tree.unflatten(treedef, local_stats_flat) max_statistics_size = _max_statistics_size_from_params(params_flat) to_pad = -num_statistics % num_devices_for_pjit num_statistics += to_pad @@ -2426,7 +2426,7 @@ def sharded_update_fn(grads, state, params): Returns: A tuple containing the new parameters and the new optimizer state. """ - params_flat, treedef = jax.tree_flatten(params) + params_flat, treedef = jax.tree.flatten(params) grads_flat = treedef.flatten_up_to(grads) global_stats = state.stats.global_stats @@ -2440,16 +2440,16 @@ def sharded_update_fn(grads, state, params): compression_rank, )) - new_stats_flat = jax.tree_map( + new_stats_flat = jax.tree.map( lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat, stats_flat, params_flat) - outputs = jax.tree_map( + outputs = jax.tree.map( lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) - updates = jax.tree_unflatten(treedef, updates_flat) + updates = jax.tree.unflatten(treedef, updates_flat) new_local_stats_flat = [] for new_stat, local_stat in zip(new_stats_flat, local_stats_flat): new_local_stats_flat.append( @@ -2564,7 +2564,7 @@ def _update_preconditioners(): if generate_training_metrics: new_local_stats_flat = _add_metrics_into_local_stats( new_local_stats_flat, metrics, ~perform_step) - new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat) + new_local_stats = jax.tree.unflatten(treedef, new_local_stats_flat) errors = metrics.inverse_pth_root_errors errors = errors.reshape((-1, 1, 1)) predicate = jnp.logical_or( @@ -2622,7 +2622,7 @@ def _init(param): )) return ShampooState( - count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)) + count=jnp.zeros([], jnp.int32), stats=jax.tree.map(_init, params)) def _skip_preconditioning(param): return len(param.shape) < skip_preconditioning_rank_lt or any( @@ -2876,7 +2876,7 @@ def _internal_inverse_pth_root_all(): preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name) metrics = jax.lax.all_gather(metrics, batch_axis_name) preconditioners_flat = unbatch(preconditioners) - metrics_flat = jax.tree_map(unbatch, metrics) + metrics_flat = jax.tree.map(unbatch, metrics) else: preconditioners, metrics = _matrix_inverse_pth_root_vmap( all_statistics[0], @@ -2885,9 +2885,9 @@ def _internal_inverse_pth_root_all(): _maybe_ix(all_preconditioners, 0), ) preconditioners_flat = unbatch(jnp.stack([preconditioners])) - metrics = jax.tree_map( + metrics = jax.tree.map( functools.partial(jnp.expand_dims, axis=0), metrics) - metrics_flat = jax.tree_map(unbatch, metrics) + metrics_flat = jax.tree.map(unbatch, metrics) return preconditioners_flat, metrics_flat @@ -2916,7 +2916,7 @@ def _update_preconditioners(): s[:, :precond_dim(s.shape[0])] for s in packed_statistics ] n = len(packed_statistics) - metrics_init = jax.tree_map( + metrics_init = jax.tree.map( lambda x: [x] * n, default_training_metrics( generate_fd_metrics @@ -2973,12 +2973,12 @@ def _select_preconditioner(error, new_p, old_p): if generate_training_metrics: # pylint:disable=cell-var-from-loop Used immediately. - metrics_for_state = jax.tree_map( + metrics_for_state = jax.tree.map( lambda x: jnp.stack(x[idx:idx + num_statistics]), metrics_flat, is_leaf=lambda x: isinstance(x, list)) assert jax.tree_util.tree_all( - jax.tree_map(lambda x: len(state.statistics) == len(x), + jax.tree.map(lambda x: len(state.statistics) == len(x), metrics_for_state)) # If we skipped preconditioner computation, record old metrics. metrics_for_state = efficient_cond(perform_step, @@ -3123,7 +3123,7 @@ def _internal_inverse_pth_root_all(): quantized_preconditioners_flat = unbatch(quantized_preconditioners) quantized_diagonals_flat = unbatch(quantized_diagonals) quantized_bucket_sizes_flat = unbatch(quantized_bucket_sizes) - metrics_flat = jax.tree_map(unbatch, metrics) + metrics_flat = jax.tree.map(unbatch, metrics) return (quantized_preconditioners_flat, quantized_diagonals_flat, quantized_bucket_sizes_flat, metrics_flat) @@ -3155,7 +3155,7 @@ def _update_quantized_preconditioners(): quantized_diagonals_init = packed_quantized_diagonals quantized_bucket_sizes_init = packed_quantized_bucket_sizes n = len(quantized_preconditioners_init) - metrics_init = jax.tree_map( + metrics_init = jax.tree.map( lambda x: [x] * n, default_training_metrics( generate_fd_metrics @@ -3231,7 +3231,7 @@ def _select_preconditioner(error, new_p, old_p): if generate_training_metrics: # pylint:disable=cell-var-from-loop Used immediately. - metrics_for_state = jax.tree_map( + metrics_for_state = jax.tree.map( lambda x: jnp.stack(x[idx:idx + num_statistics]), metrics_flat, is_leaf=lambda x: isinstance(x, list)) @@ -3241,7 +3241,7 @@ def _select_preconditioner(error, new_p, old_p): assert len(state.statistics) == len(quantized_diagonals_for_state) assert len(state.statistics) == len(quantized_bucket_sizes_for_state) assert jax.tree_util.tree_all( - jax.tree_map(lambda x: len(state.statistics) == len(x), + jax.tree.map(lambda x: len(state.statistics) == len(x), metrics_for_state)) # If we skipped preconditioner computation, record old metrics. @@ -3333,7 +3333,7 @@ def split(batched_values): for v in jnp.split(batched_values, indices_or_sections=b1, axis=0) ] - return split(preconditioners), jax.tree_map(split, metrics) + return split(preconditioners), jax.tree.map(split, metrics) scheduled_preconditioning_compute_steps = ( decay_preconditioning_compute_steps @@ -3357,7 +3357,7 @@ def _update_preconditioners(): pd = precond_dim(max_size) preconditioners_init = [s[:, :pd] for s in padded_statistics] n = len(padded_statistics) - metrics_init = jax.tree_map( + metrics_init = jax.tree.map( lambda x: [x] * n, TrainingMetrics(inverse_pth_root_errors=inverse_failure_threshold)) init_state = [preconditioners_init, metrics_init] @@ -3411,12 +3411,12 @@ def _select_preconditioner(error, new_p, old_p): if generate_training_metrics: # pylint:disable=cell-var-from-loop Used immediately. - metrics_for_state = jax.tree_map( + metrics_for_state = jax.tree.map( lambda x: jnp.stack(x[idx:idx + num_statistics]), metrics_flat, is_leaf=functools.partial(isinstance, list)) assert jax.tree_util.tree_all( - jax.tree_map(lambda x: len(state.statistics) == len(x), + jax.tree.map(lambda x: len(state.statistics) == len(x), metrics_for_state)) # pylint:enable=cell-var-from-loop else: @@ -3636,24 +3636,24 @@ def update_fn(grads, state, params): Returns: A tuple containing the new parameters and the new optimizer state. """ - params_flat, treedef = jax.tree_flatten(params) + params_flat, treedef = jax.tree.flatten(params) stats_flat = treedef.flatten_up_to(state.stats) grads_flat = treedef.flatten_up_to(grads) stats_grads = grads_flat - new_stats_flat = jax.tree_map( + new_stats_flat = jax.tree.map( lambda g, s, p: _compute_stats(g, s, p, state.count), stats_grads, stats_flat, params_flat) new_stats_flat = _compute_preconditioners(new_stats_flat, params_flat, state.count) - outputs = jax.tree_map( + outputs = jax.tree.map( lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) - updates = jax.tree_unflatten(treedef, updates_flat) - new_stats = jax.tree_unflatten(treedef, new_stats_flat) + updates = jax.tree.unflatten(treedef, updates_flat) + new_stats = jax.tree.unflatten(treedef, new_stats_flat) new_state = ShampooState(count=state.count + 1, stats=new_stats) return updates, new_state diff --git a/precondition/sm3.py b/precondition/sm3.py index 9387af8..d44ff74 100644 --- a/precondition/sm3.py +++ b/precondition/sm3.py @@ -77,7 +77,7 @@ def _init(param): return ParameterStats(accumulators, momentum) # pytype: disable=wrong-arg-types # numpy-scalars return SM3State( - count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)) + count=jnp.zeros([], jnp.int32), stats=jax.tree.map(_init, params)) def _get_expanded_shape(shape, i): rank = len(shape) @@ -110,11 +110,11 @@ def _sketch_diagonal_statistics(grad, updated_diagonal_statistics): def update_fn(updates, state, params): stats = state.stats if normalize_grads: - updates = jax.tree_map( + updates = jax.tree.map( lambda g: g / (jnp.linalg.norm(g) + 1e-16), updates) # Reshape all vectors into N-d tensors to compute min over them. # [n], [m] -> [n, 1], [1, m] - expanded_diagonal_statistics = jax.tree_map( + expanded_diagonal_statistics = jax.tree.map( lambda grad, state: # pylint:disable=g-long-lambda [ jnp.reshape(state.diagonal_statistics[i], @@ -125,28 +125,28 @@ def update_fn(updates, state, params): stats) # Compute new diagonal statistics - new_diagonal_statistics = jax.tree_map(_moving_averages, updates, + new_diagonal_statistics = jax.tree.map(_moving_averages, updates, expanded_diagonal_statistics) # Compute preconditioners (1/sqrt(s)) where s is the statistics. - new_preconditioners = jax.tree_map( + new_preconditioners = jax.tree.map( lambda t: 1.0 / jnp.sqrt(t + diagonal_epsilon), new_diagonal_statistics) - preconditioned_grads = jax.tree_map(lambda g, p: g * p, updates, + preconditioned_grads = jax.tree.map(lambda g, p: g * p, updates, new_preconditioners) # Compute updated momentum (also handle quantization) - updated_momentum = jax.tree_map( + updated_momentum = jax.tree.map( lambda preconditioned_grad, state: # pylint:disable=g-long-lambda _moving_averages_momentum(preconditioned_grad, state.diagonal_momentum), preconditioned_grads, stats) # Update diagonal statistics. - updated_diagonal_statistics = jax.tree_map(_sketch_diagonal_statistics, + updated_diagonal_statistics = jax.tree.map(_sketch_diagonal_statistics, updates, new_diagonal_statistics) # Update momentum. - new_sm3_stats = jax.tree_map( + new_sm3_stats = jax.tree.map( lambda momentum, diagonal_stats: # pylint:disable=g-long-lambda ParameterStats(diagonal_stats, _quantize_momentum(momentum)), updated_momentum, @@ -155,14 +155,14 @@ def update_fn(updates, state, params): # Apply weight decay updated_momentum_with_wd = updated_momentum if weight_decay > 0.0: - updated_momentum_with_wd = jax.tree_map(lambda g, p: g + weight_decay * p, + updated_momentum_with_wd = jax.tree.map(lambda g, p: g + weight_decay * p, updated_momentum, params) lr = learning_rate if callable(learning_rate): lr = learning_rate(state.count) - new_updates = jax.tree_map(lambda pg: -lr * pg, updated_momentum_with_wd) + new_updates = jax.tree.map(lambda pg: -lr * pg, updated_momentum_with_wd) return new_updates, SM3State(count=state.count+1, stats=new_sm3_stats) return optax.GradientTransformation(init_fn, update_fn) diff --git a/precondition/tearfree/grafting.py b/precondition/tearfree/grafting.py index a607c89..18acb54 100644 --- a/precondition/tearfree/grafting.py +++ b/precondition/tearfree/grafting.py @@ -205,7 +205,7 @@ def _rmsprop(options: Options) -> praxis_shim.ShardedGradientTransformation: """Create RMSProp sharded gradient transform.""" def init_fn(params): - acc = jax.tree_map(jnp.zeros_like, params) + acc = jax.tree.map(jnp.zeros_like, params) return RMSPropAccumulator(acc=acc) def update_fn(updates, state, params=None): @@ -219,9 +219,9 @@ def ema(prev, new): else: return snew * (1 - second_moment_decay) + second_moment_decay * prev - new_state = RMSPropAccumulator(jax.tree_map(ema, state.acc, updates)) + new_state = RMSPropAccumulator(jax.tree.map(ema, state.acc, updates)) epsilon = options.epsilon - new_updates = jax.tree_map( + new_updates = jax.tree.map( lambda g, acc: g * jax.lax.rsqrt(acc + epsilon), updates, new_state.acc ) return new_updates, new_state @@ -232,7 +232,7 @@ def _opt_state_sharding_spec(var_hparams): s_var_hparams.init = None return s_var_hparams - mdl_sharding = jax.tree_map(_opt_state_sharding_spec, mdl_params) + mdl_sharding = jax.tree.map(_opt_state_sharding_spec, mdl_params) return RMSPropAccumulator(acc=mdl_sharding) return praxis_shim.ShardedGradientTransformation( @@ -291,7 +291,7 @@ def maybe_graft(graft_upd, base): graft_upd, ) - new_updates = jax.tree_map( + new_updates = jax.tree.map( maybe_graft, graft_updates, base_updates, is_leaf=_masked ) return new_updates, new_state @@ -334,7 +334,7 @@ def _maybe_mask(x: jax.Array): return _GraftMask() return x - return jax.tree_map(_maybe_mask, tree) + return jax.tree.map(_maybe_mask, tree) def _masked(tree_node: Any) -> bool: diff --git a/precondition/tearfree/grafting_test.py b/precondition/tearfree/grafting_test.py index 69cbeb3..0d866f0 100644 --- a/precondition/tearfree/grafting_test.py +++ b/precondition/tearfree/grafting_test.py @@ -30,7 +30,7 @@ def _minustwo() -> praxis_shim.ShardedGradientTransformation: """Generate a direction-reversing gradient transformation.""" - update = functools.partial(jax.tree_map, lambda x: -2 * x) + update = functools.partial(jax.tree.map, lambda x: -2 * x) return praxis_shim.ShardedGradientTransformation( lambda _: optax.EmptyState, lambda u, s, _: (update(u), s), diff --git a/precondition/tearfree/momentum.py b/precondition/tearfree/momentum.py index 93a241d..67b58c6 100644 --- a/precondition/tearfree/momentum.py +++ b/precondition/tearfree/momentum.py @@ -131,7 +131,7 @@ def _opt_state_sharding_spec(var_hparams): s_var_hparams.init = None return s_var_hparams - mdl_sharding = jax.tree_map(_opt_state_sharding_spec, mdl_params) + mdl_sharding = jax.tree.map(_opt_state_sharding_spec, mdl_params) return optax.TraceState(trace=mdl_sharding) return praxis_shim.ShardedGradientTransformation( diff --git a/precondition/tearfree/praxis_shim.py b/precondition/tearfree/praxis_shim.py index 3e7394d..96dd36b 100644 --- a/precondition/tearfree/praxis_shim.py +++ b/precondition/tearfree/praxis_shim.py @@ -61,7 +61,7 @@ def update_fn(updates, state, params=None): for s, fn in zip(state, args): updates, new_s = fn.update(updates, s, params) # Some of the new states may have None instead of optax.MaskedNode. - new_s = jax.tree_map( + new_s = jax.tree.map( lambda x: optax.MaskedNode() if x is None else x, new_s, is_leaf=lambda x: x is None, diff --git a/precondition/tearfree/reshaper.py b/precondition/tearfree/reshaper.py index db8bc8e..3e9efa8 100644 --- a/precondition/tearfree/reshaper.py +++ b/precondition/tearfree/reshaper.py @@ -103,8 +103,8 @@ def update( state: optax.MaskedNode, params: optax.Params, ) -> tuple[optax.Updates, optax.MaskedNode]: - shapes = jax.tree_map(functools.partial(_derive_shapes, options), params) - new_updates = jax.tree_map(_merge, updates, shapes) + shapes = jax.tree.map(functools.partial(_derive_shapes, options), params) + new_updates = jax.tree.map(_merge, updates, shapes) return new_updates, state return optax.GradientTransformation(lambda _: optax.MaskedNode(), update) @@ -126,8 +126,8 @@ def update( state: optax.MaskedNode, params: optax.Params, ) -> tuple[optax.Updates, optax.MaskedNode]: - shapes = jax.tree_map(functools.partial(_derive_shapes, options), params) - new_updates = jax.tree_map(_unmerge, updates, shapes) + shapes = jax.tree.map(functools.partial(_derive_shapes, options), params) + new_updates = jax.tree.map(_unmerge, updates, shapes) return new_updates, state return optax.GradientTransformation(lambda _: optax.MaskedNode(), update) diff --git a/precondition/tearfree/reshaper_test.py b/precondition/tearfree/reshaper_test.py index 94d27ad..f3c41a4 100644 --- a/precondition/tearfree/reshaper_test.py +++ b/precondition/tearfree/reshaper_test.py @@ -106,13 +106,13 @@ def test_tree(self): 1, ), } - init = jax.tree_map( + init = jax.tree.map( jnp.zeros, shapes, is_leaf=lambda x: isinstance(x, tuple) ) options = reshaper.Options(merge_dims=2, block_size=2) init_fn, update_fn = reshaper.merge(options) out, _ = update_fn(init, init_fn(None), init) - out_shapes = jax.tree_map(lambda x: tuple(x.shape), out) + out_shapes = jax.tree.map(lambda x: tuple(x.shape), out) expected_shapes = {'w': [[{'b': (4, 2)}]], 'z': (2,)} self.assertEqual(out_shapes, expected_shapes) diff --git a/precondition/tearfree/shampoo.py b/precondition/tearfree/shampoo.py index f5fff14..86526b1 100644 --- a/precondition/tearfree/shampoo.py +++ b/precondition/tearfree/shampoo.py @@ -283,11 +283,11 @@ def _update( lambda path, x: _blocks_metadata(options, x.shape, str(path)), updates ) blocks = state.blocks - blockified_updates = jax.tree_map(_blockify, updates, meta) + blockified_updates = jax.tree.map(_blockify, updates, meta) is_block = lambda x: isinstance(x, _AxesBlocks) stats_updated_blocks = functools.partial( - jax.tree_map, + jax.tree.map, functools.partial(_update_block_stats, options.second_moment_decay), blockified_updates, blocks, @@ -300,7 +300,7 @@ def _update( ) precond_updated_blocks = functools.partial( - jax.tree_map, + jax.tree.map, _update_block_precond, blocks, meta, @@ -313,10 +313,10 @@ def _update( should_update_precond, precond_updated_blocks, lambda: blocks ) new_state = _ShampooState(count=state.count + 1, blocks=blocks) - new_updates = jax.tree_map( + new_updates = jax.tree.map( _precondition_blocks, blockified_updates, blocks, meta, is_leaf=is_block ) - new_updates = jax.tree_map(_deblockify, new_updates, meta) + new_updates = jax.tree.map(_deblockify, new_updates, meta) return new_updates, new_state diff --git a/precondition/tearfree/sketchy_test.py b/precondition/tearfree/sketchy_test.py index 3a71835..a39d16b 100644 --- a/precondition/tearfree/sketchy_test.py +++ b/precondition/tearfree/sketchy_test.py @@ -154,14 +154,14 @@ def test_realloc(self): 'c': {'d': [4], 'e': [8]}, } tx = sketchy.apply(sketchy.Options(memory_alloc=memory_dict)) - shape = jax.tree_map( + shape = jax.tree.map( lambda x: (dim,), memory_dict, is_leaf=lambda x: isinstance(x, list) and all(not isinstance(y, list) for y in x), ) grads_tree, updates = self._unroll(tx, nsteps, shape, None, True) - emw_run = jax.tree_map( + emw_run = jax.tree.map( lambda k, sp, grad: self._unroll( tx=sketchy.apply(sketchy.Options(rank=k[0])), n=nsteps, @@ -174,7 +174,7 @@ def test_realloc(self): is_leaf=lambda x: isinstance(x, list) and all(not isinstance(y, list) for y in x), ) - jax.tree_map(np.testing.assert_allclose, updates, emw_run) + jax.tree.map(np.testing.assert_allclose, updates, emw_run) # test ekfac resulting preconditioned gradient on random values are finite def test_ekfac(self): @@ -287,14 +287,14 @@ def _make_cov(sketch: sketchy._AxisState, add_tail=True): def _unroll(self, tx, n, shape, grads=None, return_grads=False): """Generate states and grad updates n times.""" rng = jax.random.PRNGKey(0) - params = jax.tree_map( + params = jax.tree.map( jnp.zeros, shape, is_leaf=lambda x: isinstance(x, tuple) and all(isinstance(y, int) for y in x), ) if grads is None: - grads = jax.tree_map( + grads = jax.tree.map( lambda sp: jax.random.normal(rng, (n, *sp)), shape, is_leaf=lambda x: isinstance(x, tuple)