Skip to content

Commit

Permalink
Merge pull request #623 from google:lizhiyu/change_norm_sharding
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 628504084
  • Loading branch information
maxtext authors committed Apr 26, 2024
2 parents 18ba1a7 + 9feab51 commit 6570445
Show file tree
Hide file tree
Showing 9 changed files with 16 additions and 12 deletions.
1 change: 1 addition & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ logical_axis_rules: [
['vocab', ['tensor', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence']],
['embed', ['fsdp', 'sequence']],
['norm', 'tensor'],
['heads', ['tensor', 'autoregressive']],
['kv', []],
['cache_batch', []],
Expand Down
6 changes: 4 additions & 2 deletions MaxText/generate_param_only_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def new_pspec(x):
return jax.sharding.PartitionSpec(*x[0 : config.param_scan_axis] + x[config.param_scan_axis + 1 :])

new_per_layer_state_annotation = jax.tree_util.tree_map(new_pspec, training_state_annotations_layers)
new_per_layer_state_sharding = jax.tree_util.tree_map(lambda x: jax.sharding.NamedSharding(mesh, x), new_per_layer_state_annotation)
new_per_layer_state_sharding = jax.tree_util.tree_map(
lambda x: jax.sharding.NamedSharding(mesh, x), new_per_layer_state_annotation)

for i in range(config.num_decoder_layers):

Expand Down Expand Up @@ -90,7 +91,8 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh):
def _save_decode_checkpoint(config, state, checkpoint_manager):
"""Generate checkpoint for decode from the training_state."""
with jax.spmd_mode("allow_all"):
decode_state = max_utils.init_decode_state(None, jax.tree_util.tree_map(lambda x: x.astype(jax.numpy.bfloat16), state.params))
decode_state = max_utils.init_decode_state(
None, jax.tree_util.tree_map(lambda x: x.astype(jax.numpy.bfloat16), state.params))
if checkpoint_manager is not None:
if save_checkpoint(checkpoint_manager, 0, decode_state):
max_logging.log(f"saved an decode checkpoint at {config.checkpoint_dir}")
Expand Down
4 changes: 2 additions & 2 deletions MaxText/layers/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __call__(
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed"))

# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
lnx = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_self_attention_norm", kernel_axes=("embed",))(
lnx = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_self_attention_norm", kernel_axes=("norm",))(
inputs
)

Expand Down Expand Up @@ -108,7 +108,7 @@ def __call__(
attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed"))
attention_lnx += inputs
residual = attention_lnx
attn_output = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_ffw_norm", kernel_axes=("embed",))(
attn_output = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_ffw_norm", kernel_axes=("norm",))(
attention_lnx
)

Expand Down
2 changes: 1 addition & 1 deletion MaxText/layers/gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def __call__(
lnx_layer_norm = Gpt3LayerNorm(
dtype=cfg.dtype,
name="pre_self_attention_norm",
kernel_axes=("embed",),
kernel_axes=("norm",),
epsilon=cfg.normalization_layer_epsilon,
reductions_in_fp32=False,
use_bias=True,
Expand Down
2 changes: 1 addition & 1 deletion MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False):
name="mlp_layer_norm",
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
kernel_axes=("embed",),
kernel_axes=("norm",),
epsilon=cfg.normalization_layer_epsilon,
)(inputs)

Expand Down
4 changes: 2 additions & 2 deletions MaxText/layers/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __call__(
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="pre_self_attention_layer_norm",
kernel_axes=("embed",),
kernel_axes=("norm",),
epsilon=cfg.normalization_layer_epsilon,
)
lnx = lnx_rms(inputs)
Expand Down Expand Up @@ -115,7 +115,7 @@ def __call__(
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="post_self_attention_layer_norm",
kernel_axes=("embed",),
kernel_axes=("norm",),
epsilon=cfg.normalization_layer_epsilon,
)(intermediate_inputs)
hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed"))
Expand Down
4 changes: 2 additions & 2 deletions MaxText/layers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __call__(
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="pre_self_attention_layer_norm",
kernel_axes=("embed",),
kernel_axes=("norm",),
epsilon=cfg.normalization_layer_epsilon,
)
lnx = lnx_rms(inputs)
Expand Down Expand Up @@ -116,7 +116,7 @@ def __call__(
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="post_self_attention_layer_norm",
kernel_axes=("embed",),
kernel_axes=("norm",),
epsilon=cfg.normalization_layer_epsilon,
)(intermediate_inputs)
hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed"))
Expand Down
2 changes: 1 addition & 1 deletion MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __call__(
weight_dtype=cfg.weight_dtype,
name="pre_self_attention_norm",
epsilon=cfg.normalization_layer_epsilon,
kernel_axes=("embed",),
kernel_axes=("norm",),
)(inputs)
lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed"))

Expand Down
3 changes: 2 additions & 1 deletion MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def load_params(self, *args, **kwargs) -> Params:
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), state.params
)
self.kv_cache_annotations = max_utils.get_kv_cache_annotations(self.model, self.config, self.rng, self._mesh)
self.kv_cache_shardings = jax.tree_util.tree_map(lambda x: jax.sharding.NamedSharding(self._mesh, x), self.kv_cache_annotations)
self.kv_cache_shardings = jax.tree_util.tree_map(
lambda x: jax.sharding.NamedSharding(self._mesh, x), self.kv_cache_annotations)

if not self.model.quant:
self.abstract_params = jax.tree_util.tree_map(
Expand Down

0 comments on commit 6570445

Please sign in to comment.