diff --git a/MaxText/convert_gemma_chkpt.py b/MaxText/convert_gemma_chkpt.py index c690130c2..75fdd7730 100644 --- a/MaxText/convert_gemma_chkpt.py +++ b/MaxText/convert_gemma_chkpt.py @@ -143,7 +143,7 @@ def main(raw_args=None) -> None: layer_weight["self_attention"] = copy.deepcopy(self_attention) jax_weights["decoder"]["layers"] = copy.deepcopy(layer_weight) - jax_weights = jax.tree_map(jnp.array, jax_weights) + jax_weights = jax.tree_util.tree_map(jnp.array, jax_weights) def astype_fn(x): if isinstance(x, jnp.ndarray): @@ -151,7 +151,7 @@ def astype_fn(x): else: return x - jax_weights = jax.tree_map(astype_fn, jax_weights) + jax_weights = jax.tree_util.tree_map(astype_fn, jax_weights) enable_checkpointing = True async_checkpointing = False diff --git a/MaxText/generate_param_only_checkpoint.py b/MaxText/generate_param_only_checkpoint.py index 673b3ce51..68270e070 100644 --- a/MaxText/generate_param_only_checkpoint.py +++ b/MaxText/generate_param_only_checkpoint.py @@ -54,13 +54,13 @@ def _possibly_unroll_params(config, training_state, training_state_annotations, def new_pspec(x): return jax.sharding.PartitionSpec(*x[0 : config.param_scan_axis] + x[config.param_scan_axis + 1 :]) - new_per_layer_state_annotation = jax.tree_map(new_pspec, training_state_annotations_layers) - new_per_layer_state_sharding = jax.tree_map(lambda x: jax.sharding.NamedSharding(mesh, x), new_per_layer_state_annotation) + new_per_layer_state_annotation = jax.tree_util.tree_map(new_pspec, training_state_annotations_layers) + new_per_layer_state_sharding = jax.tree_util.tree_map(lambda x: jax.sharding.NamedSharding(mesh, x), new_per_layer_state_annotation) for i in range(config.num_decoder_layers): def slice_ith(input_layers): - return jax.tree_map(lambda x: jax.numpy.take(x, i, axis=config.param_scan_axis), input_layers) + return jax.tree_util.tree_map(lambda x: jax.numpy.take(x, i, axis=config.param_scan_axis), input_layers) new_layer = jax.jit(slice_ith, out_shardings=new_per_layer_state_sharding)(training_state_layers) @@ -70,7 +70,7 @@ def slice_ith(input_layers): del training_state.params["params"]["decoder"]["layers"] del training_state_annotations.params["params"]["decoder"]["layers"] - jax.tree_map(lambda x: x.delete(), training_state_layers) + jax.tree_util.tree_map(lambda x: x.delete(), training_state_layers) def _read_train_checkpoint(config, checkpoint_manager, mesh): @@ -90,7 +90,7 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh): def _save_decode_checkpoint(config, state, checkpoint_manager): """Generate checkpoint for decode from the training_state.""" with jax.spmd_mode("allow_all"): - decode_state = max_utils.init_decode_state(None, jax.tree_map(lambda x: x.astype(jax.numpy.bfloat16), state.params)) + decode_state = max_utils.init_decode_state(None, jax.tree_util.tree_map(lambda x: x.astype(jax.numpy.bfloat16), state.params)) if checkpoint_manager is not None: if save_checkpoint(checkpoint_manager, 0, decode_state): max_logging.log(f"saved an decode checkpoint at {config.checkpoint_dir}") diff --git a/MaxText/input_pipeline/input_pipeline_interface.py b/MaxText/input_pipeline/input_pipeline_interface.py index 81d3eaa75..492484431 100644 --- a/MaxText/input_pipeline/input_pipeline_interface.py +++ b/MaxText/input_pipeline/input_pipeline_interface.py @@ -98,7 +98,7 @@ def __init__(self, config, mesh): self.mesh = mesh self.config = config data_pspec = P(*config.data_sharding) - data_pspec_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) + data_pspec_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) self.data_generator = jax.jit( SyntheticDataIterator.raw_generate_synthetic_data, out_shardings=data_pspec_shardings, static_argnums=0 ) diff --git a/MaxText/llama_or_mistral_ckpt.py b/MaxText/llama_or_mistral_ckpt.py index 97a456764..9f3cc94ab 100644 --- a/MaxText/llama_or_mistral_ckpt.py +++ b/MaxText/llama_or_mistral_ckpt.py @@ -323,7 +323,7 @@ def checkpoint_device_put(arr): return jax.device_put(arr, device=s3) # convert all weights to jax.numpy with sharding if applicable - jax_weights = jax.tree_map(checkpoint_device_put, jax_weights) + jax_weights = jax.tree_util.tree_map(checkpoint_device_put, jax_weights) # dummy configs for the checkpoint_manager step_number_to_save_new_ckpt = 0 diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index cce01bfee..3142f59f4 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -49,7 +49,7 @@ def find_nans_and_infs(pytree): def finder(x): return jnp.any(jnp.isinf(x) | jnp.isnan(x)) - bad_pytree = jax.tree_map(finder, pytree) + bad_pytree = jax.tree_util.tree_map(finder, pytree) return jax.tree_util.tree_flatten(bad_pytree) @@ -660,7 +660,7 @@ def delete_leaf(leaf): leaf.delete() del leaf - jax.tree_map(delete_leaf, p) + jax.tree_util.tree_map(delete_leaf, p) def summarize_pytree_data(params, name="Params", raw=False): diff --git a/MaxText/maxengine.py b/MaxText/maxengine.py index ed38929e2..7634eb8e5 100644 --- a/MaxText/maxengine.py +++ b/MaxText/maxengine.py @@ -79,14 +79,14 @@ def load_params(self, *args, **kwargs) -> Params: """Load Parameters, typically from GCS""" # pylint: disable=unused-argument state, self.state_mesh_annotations = max_utils.setup_decode_state(self.model, self.config, self.rng, self._mesh, None) - self.abstract_params = jax.tree_map( + self.abstract_params = jax.tree_util.tree_map( lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), state.params ) self.kv_cache_annotations = max_utils.get_kv_cache_annotations(self.model, self.config, self.rng, self._mesh) - self.kv_cache_shardings = jax.tree_map(lambda x: jax.sharding.NamedSharding(self._mesh, x), self.kv_cache_annotations) + self.kv_cache_shardings = jax.tree_util.tree_map(lambda x: jax.sharding.NamedSharding(self._mesh, x), self.kv_cache_annotations) if not self.model.quant: - self.abstract_params = jax.tree_map( + self.abstract_params = jax.tree_util.tree_map( lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), state.params ) return state.params @@ -113,7 +113,7 @@ def model_apply(_p, _rng): # Remove param values which have corresponding qtensors in aqt to save memory. params["params"] = quantizations.remove_quantized_params(state.params["params"], new_vars["aqt"]) - self.abstract_params = jax.tree_map( + self.abstract_params = jax.tree_util.tree_map( lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), params ) @@ -342,13 +342,13 @@ def init(abstract_params): with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): mesh_annotations = nn.logical_to_mesh(logical_annotations) - shardings = jax.tree_map( + shardings = jax.tree_util.tree_map( lambda mesh_annotation: jax.sharding.NamedSharding(self._mesh, mesh_annotation), mesh_annotations ) @functools.partial(jax.jit, out_shardings=shardings) def initialize(): - return jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), abstract_outputs) + return jax.tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), abstract_outputs) cache = initialize()["cache"] diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index 3a2b84d13..cd7fb4ad7 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -33,8 +33,8 @@ def get_functional_train_with_signature(train_step, mesh, state_mesh_annotations functional_train = get_functional_train_step(train_step, model, config) functional_train.__name__ = "train_step" data_pspec = P(*config.data_sharding) - state_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) - data_sharding = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) + state_mesh_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) + data_sharding = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng out_shardings = (state_mesh_shardings, None) # State, metrics static_argnums = () # We partial out the static argnums of model and config @@ -51,8 +51,8 @@ def get_functional_eval_with_signature(eval_step, mesh, state_mesh_annotations, functional_eval = get_functional_eval_step(eval_step, model, config) functional_eval.__name__ = "eval_step" data_pspec = P(*config.data_sharding) - state_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) - data_sharding = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) + state_mesh_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) + data_sharding = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng out_shardings = None # metrics static_argnums = () # We partial out the static argnums of model, config diff --git a/MaxText/optimizers.py b/MaxText/optimizers.py index 1d15b358f..04c432958 100644 --- a/MaxText/optimizers.py +++ b/MaxText/optimizers.py @@ -129,19 +129,19 @@ def _update_momentum(update, mu, nu): nu = (1.0 - beta2_decay) * (update**2) + beta2_decay * nu return _slot_opt_state(mu=mu, nu=nu) - updated_moments = jax.tree_map(_update_momentum, updates, state.mu, state.nu) + updated_moments = jax.tree_util.tree_map(_update_momentum, updates, state.mu, state.nu) - mu = jax.tree_map(lambda x: x.mu, updated_moments) - nu = jax.tree_map(lambda x: x.nu, updated_moments) + mu = jax.tree_util.tree_map(lambda x: x.mu, updated_moments) + nu = jax.tree_util.tree_map(lambda x: x.nu, updated_moments) - updates = jax.tree_map(lambda mu, nu: mu / (jnp.sqrt(nu + epsilon_root) + epsilon), mu, nu) + updates = jax.tree_util.tree_map(lambda mu, nu: mu / (jnp.sqrt(nu + epsilon_root) + epsilon), mu, nu) if weight_decay > 0: - updates = jax.tree_map(lambda x, v: x + weight_decay * v, updates, params) + updates = jax.tree_util.tree_map(lambda x, v: x + weight_decay * v, updates, params) step_size = -1.0 * learning_rate_fn(count) # Finally, fold in step size. - updates = jax.tree_map(lambda x: step_size * x, updates) + updates = jax.tree_util.tree_map(lambda x: step_size * x, updates) updated_states = optax.ScaleByAdamState(count=count + 1, mu=mu, nu=nu) return updates, updated_states diff --git a/pedagogical_examples/shardings.py b/pedagogical_examples/shardings.py index 912266667..73374b02e 100644 --- a/pedagogical_examples/shardings.py +++ b/pedagogical_examples/shardings.py @@ -207,16 +207,16 @@ def multiply_layers_with_loss(in_act, in_layers): def training_step(in_act, in_layers): _, grad_layers = multiply_layers_and_grad(in_act, in_layers) - out_layers = jax.tree_map(lambda param, grad: param - 1e-4 * grad, in_layers, grad_layers[0]) + out_layers = jax.tree_util.tree_map(lambda param, grad: param - 1e-4 * grad, in_layers, grad_layers[0]) return out_layers print("finished includes ", flush=True) replicated_sharding = jax.sharding.NamedSharding(mesh, data_sharding) - parameter_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) + parameter_mesh_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) - data_pspec_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) + data_pspec_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) jit_func = jax.jit( training_step, @@ -224,11 +224,11 @@ def training_step(in_act, in_layers): out_shardings=data_pspec_shardings, ) - data_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_sharding) + data_mesh_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_sharding) jit_gen_data = jax.jit(gen_data, in_shardings=None, out_shardings=data_mesh_shardings) - parameter_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) + parameter_mesh_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) jit_gen_layers = jax.jit(gen_layers, in_shardings=None, out_shardings=parameter_mesh_shardings)