Skip to content

Commit

Permalink
Remove Praxis related setup (Moving to Praxis TE/Patch)
Browse files Browse the repository at this point in the history
  • Loading branch information
mingxu1067 authored and ashors1 committed May 14, 2024
1 parent b289c45 commit c91d866
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 324 deletions.
9 changes: 0 additions & 9 deletions paxml/contrib/gpu/scripts_gpu/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,6 @@ def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]:
transformer_layer_p = stacked_p.transformer_layer_params_tpl
transformer_layer_p.ln_tpl.reductions_in_fp32 = True
transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True
else:
stacked_p = TransformerEngineHelper.get_stack_transformer(
stacked_p, jnp.dtype(self.FPROP_DTYPE))
if issubclass(fdl.get_callable(model_p.lm_tpl.stacked_transformer_tpl),
transformers.StackedTransformerRepeated):
model_p.lm_tpl.stacked_transformer_tpl.block = stacked_p
else:
model_p.lm_tpl.stacked_transformer_tpl = stacked_p


model_p.params_init = WeightInit.Gaussian(self.INIT_STD)
softmax_init = WeightInit.Gaussian(self.SOFTMAX_INIT_STD)
Expand Down
315 changes: 0 additions & 315 deletions paxml/contrib/gpu/scripts_gpu/te_helper.py
Original file line number Diff line number Diff line change
@@ -1,238 +1,17 @@
import os
from contextlib import contextmanager
from typing import Optional, Sequence

import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
from praxis import base_layer
from praxis import pax_fiddle
from praxis import pytypes
from praxis.layers import transformers
from praxis.layers import stochastics

try:
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
import transformer_engine.jax.praxis as te_praxis
from transformer_engine.common import recipe
_IS_TRANSFORMER_ENGINE_INSTALLED = True
DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST + [te.fp8.FP8Helper.FP8_COLLECTION_NAME]
import praxis.layers.repeats as praxis_repeat
# This is to make Repeat module correctly generate collections we need.
praxis_repeat.SCAN_VARIABLE_AXES.update({base_layer.NON_PAX_VAR_COLLECTION[1]: 0, # 1-idx = params_axes
te.fp8.FP8Helper.FP8_COLLECTION_NAME:0})

except ModuleNotFoundError as e:
_IS_TRANSFORMER_ENGINE_INSTALLED = False
DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST


LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]
JTensor = pytypes.JTensor


class StackedTransformer(transformers.StackedTransformer):
"""A mirror of StackedTransformer layers in Praxis."""

def setup(self) -> None:

assert self.num_layers > 0
assert self.model_dims > 0
assert self.hidden_dims > 0
assert self.num_heads > 0
assert 0.0 <= self.dropout_prob < 1.0
assert 0.0 <= self.input_dropout_prob < 1.0

def _layer_params(i):
"""Construct i-th layer params."""
if isinstance(self.transformer_layer_params_tpl, Sequence):
factor = self.num_layers // len(self.transformer_layer_params_tpl)
ii = i // factor
p_i = self._clone_layer_params(self.transformer_layer_params_tpl[ii])
else:
p_i = self._clone_layer_params(self.transformer_layer_params_tpl)
p_i.name = f'layer_{i}'

p_i.logical_axes_rules = te_flax.extend_logical_axis_rules(tuple())
p_i.layer_type = te_praxis.TransformerLayerType.DECODER if self.use_cross_attention \
else te_praxis.TransformerLayerType.ENCODER
p_i.num_attention_heads = self.num_heads
p_i.hidden_size = self.model_dims
p_i.mlp_hidden_size = self.hidden_dims

p_i.dropout_rng_name = base_layer.RANDOM
p_i.attention_dropout = self.atten_dropout_prob or self.dropout_prob
p_i.hidden_dropout = self.residual_dropout_prob or self.dropout_prob
p_i.intermediate_dropout = self.relu_dropout_prob or self.dropout_prob
if self.residual_droppath_prob > 0.0:
p_i.drop_path = (
self.residual_droppath_prob * i / max(1, self.num_layers)
)

assert self.dim_per_head == self.model_dims // self.num_heads
assert self.packed_input == False
assert len(self.moe_layers) == 0
assert self.ngrammer_tpls is None

if self.ngrammer_tpls is not None:
if self.ngrammer_tpls[i] is not None:
p_i.ngrammer_tpl = self.ngrammer_tpls[i]
return p_i

if isinstance(self.transformer_layer_params_tpl, (list, tuple)):
if self.num_layers % len(self.transformer_layer_params_tpl):
raise ValueError('num_layers should be divisible by '
'transformer_layer_params_tpl')

layer_params = [_layer_params(i) for i in range(self.num_layers)]
self.create_children('x_layers', layer_params)

if self.input_dropout_prob > 0.0:
self.create_child(
'input_dropout',
pax_fiddle.Config(
stochastics.Dropout, keep_prob=1.0 - self.input_dropout_prob
),
)

def __call__(self,
inputs: JTensor,
paddings: JTensor,
segment_mask: Optional[JTensor] = None,
cross_inputs: Optional[JTensor] = None,
cross_paddings: Optional[JTensor] = None,
cross_segment_mask: Optional[JTensor] = None,
segment_pos: Optional[JTensor] = None) -> JTensor:

if self.packed_input:
assert segment_mask is not None

if self.use_cross_attention:
assert cross_inputs is not None
assert cross_paddings is not None
if self.packed_input:
assert cross_segment_mask is not None

attention_mask, cross_attention_mask = transformers.compute_attention_masks_for_fprop(
inputs,
paddings,
self.mask_self_attention,
segment_mask,
cross_inputs,
cross_paddings,
cross_segment_mask,
fold_padding_with_segment_mask=self.fold_padding_with_segment_mask,
)

x_out = inputs
if self.input_dropout_prob > 0.0:
x_out = self.input_dropout(x_out)

attention_mask = 1 - (attention_mask == 0)
attention_mask = attention_mask.astype(jnp.uint8)

if cross_attention_mask is not None:
cross_attention_mask = 1 - (cross_attention_mask == 0)
cross_attention_mask = cross_attention_mask.astype(jnp.uint8)

for i in range(self.num_layers):
x_in = x_out
x_out = self.x_layers[i](
inputs=x_in,
attention_mask=attention_mask,
encoded=cross_inputs,
encoder_decoder_mask=cross_attention_mask,
deterministic=self.do_eval)
x_out = checkpoint_name(x_out, 'transformer_layer_out')
return x_out


class PipelinedTransformer(transformers.PipelinedTransformer):
"""A mirror of PipelinedTransformer in Praxis"""

def __call__(
self,
inputs: JTensor,
paddings: JTensor,
segment_mask: JTensor | None = None,
cross_inputs: JTensor | None = None,
cross_paddings: JTensor | None = None,
cross_segment_mask: JTensor | None = None,
segment_pos: JTensor | None = None,
) -> JTensor:

rules = te_flax.extend_logical_axis_rules(tuple())
batch_mapping = rules[0]
hidden_tp_mapping = rules[4]
# [Batch, Seqlen, Hidden]
bld_mapping = [batch_mapping, None, hidden_tp_mapping]

if not self.stream_io:
# Annotate the inputs before the pipeline to prevent unexpected
# propagation from earlier layers.
inputs = base_layer.maybe_shard(inputs, bld_mapping, self.mesh_axis_names)
if bld_mapping is not None:
# Annotate other broadcast inputs.
paddings = base_layer.maybe_shard(
paddings, bld_mapping[:-1], self.mesh_axis_names
)

# For cross inputs, we only specify the batch dim sharding.
def _shard_batch_dim_only(x):
return base_layer.maybe_shard(
x,
[bld_mapping[0]] + [-1] * (x.ndim - 1),
self.mesh_axis_names,
unconstrained_dims=range(1, x.ndim),
)

if segment_mask is not None:
segment_mask = _shard_batch_dim_only(segment_mask)
if cross_inputs is not None:
cross_inputs = _shard_batch_dim_only(cross_inputs)
if cross_paddings is not None:
cross_paddings = _shard_batch_dim_only(cross_paddings)
if cross_segment_mask is not None:
cross_segment_mask = _shard_batch_dim_only(cross_segment_mask)

if segment_pos is not None:
segment_pos = base_layer.maybe_shard(
segment_pos, bld_mapping[:-1], self.mesh_axis_names
)

outputs = self.pipeline(
inputs,
paddings,
segment_mask=segment_mask,
cross_inputs=cross_inputs,
cross_paddings=cross_paddings,
cross_segment_mask=cross_segment_mask,
segment_pos=segment_pos,
)

if not self.stream_io:
outputs = base_layer.maybe_shard(
outputs, bld_mapping, self.mesh_axis_names
)

outputs = base_layer.maybe_shard(
outputs,
self.activation_split_dims_mapping.final_out,
self.mesh_axis_names,
)
return outputs


class TransformerEngineHelperBase:

@staticmethod
def get_stack_transformer(stacked_transformer_p, dtype):
raise NotImplementedError

@staticmethod
def get_pipeline_transformer(pipeline_transformer_p):
raise NotImplementedError

@staticmethod
@contextmanager
def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"):
Expand All @@ -241,14 +20,6 @@ def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="dat

class TENotInstalledHelper(TransformerEngineHelperBase):

@staticmethod
def get_stack_transformer(stacked_transformer_p, dtype):
return stacked_transformer_p

@staticmethod
def get_pipeline_transformer(pipeline_transformer_p):
return pipeline_transformer_p

@staticmethod
@contextmanager
def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"):
Expand All @@ -260,84 +31,6 @@ def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="dat

class TEInstalledHelper(TransformerEngineHelperBase):

@staticmethod
def get_stack_transformer(stacked_transformer_p, dtype):

assert stacked_transformer_p.cls == transformers.StackedTransformer

te_stacked_transformer_p = pax_fiddle.Config(StackedTransformer,
use_cross_attention=stacked_transformer_p.use_cross_attention,
mask_self_attention=stacked_transformer_p.mask_self_attention,
num_layers=stacked_transformer_p.num_layers,
model_dims=stacked_transformer_p.model_dims,
hidden_dims=stacked_transformer_p.hidden_dims,
num_heads=stacked_transformer_p.num_heads,
dim_per_head=stacked_transformer_p.dim_per_head,
dropout_prob=stacked_transformer_p.dropout_prob,
atten_dropout_prob=stacked_transformer_p.atten_dropout_prob,
residual_dropout_prob=stacked_transformer_p.residual_dropout_prob,
relu_dropout_prob=stacked_transformer_p.relu_dropout_prob,
residual_droppath_prob=stacked_transformer_p.residual_droppath_prob,
input_dropout_prob=stacked_transformer_p.input_dropout_prob,
gating_func=stacked_transformer_p.gating_func,
unadjusted_expert_capacity_factor=stacked_transformer_p.unadjusted_expert_capacity_factor,
packed_input=stacked_transformer_p.packed_input,
fold_padding_with_segment_mask=stacked_transformer_p.fold_padding_with_segment_mask,
moe_layer_tpl=stacked_transformer_p.moe_layer_tpl,
num_experts=stacked_transformer_p.num_experts,
num_groups=stacked_transformer_p.num_groups,
min_group_size=stacked_transformer_p.min_group_size,
moe_layers=stacked_transformer_p.moe_layers,
ngrammer_tpls=stacked_transformer_p.ngrammer_tpls
)

ori_transformer_engine_p = stacked_transformer_p.transformer_layer_params_tpl

te_stacked_transformer_p.transformer_layer_params_tpl = pax_fiddle.Config(te_praxis.TransformerLayer,
name='transformer_layer',
params_init=stacked_transformer_p.params_init,
dtype=dtype,
hidden_size=stacked_transformer_p.model_dims,
mlp_hidden_size=stacked_transformer_p.hidden_dims,
num_attention_heads=stacked_transformer_p.num_heads,
layernorm_type='layernorm',
layernorm_epsilon=ori_transformer_engine_p.ln_tpl.epsilon,
zero_centered_gamma = True,
hidden_dropout=ori_transformer_engine_p.residual_dropout_prob,
attention_dropout=ori_transformer_engine_p.atten_dropout_prob,
mlp_activations=('gelu',),
use_bias=True,
layer_type=te_praxis.TransformerLayerType.ENCODER,
self_attn_mask_type='causal',
enable_relative_embedding=False,
drop_path=ori_transformer_engine_p.residual_droppath_prob,
scaled_query_init=False,
scale_attn_logits=True,
transpose_batch_sequence=False
)

return te_stacked_transformer_p

@staticmethod
def get_pipeline_transformer(pipeline_transformer_p):

assert pipeline_transformer_p.cls == transformers.PipelinedTransformer

te_pipeline_transformer_p = pax_fiddle.Config(PipelinedTransformer,
pipeline_stage=pipeline_transformer_p.pipeline_stage,
circular_repeat=pipeline_transformer_p.circular_repeat,
num_pipeline_stages=pipeline_transformer_p.num_pipeline_stages,
num_pipeline_microbatches=pipeline_transformer_p.num_pipeline_microbatches,
pipeline_microbatch_size=pipeline_transformer_p.pipeline_microbatch_size,
stream_io=pipeline_transformer_p.stream_io,
pipeline_broadcast_inputs=pipeline_transformer_p.pipeline_broadcast_inputs,
checkpoint_policy=pipeline_transformer_p.checkpoint_policy,
enable_async_circular_transfer=pipeline_transformer_p.enable_async_circular_transfer,
bf16_accum_in_fp32=pipeline_transformer_p.bf16_accum_in_fp32
)

return te_pipeline_transformer_p

@staticmethod
@contextmanager
def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"):
Expand Down Expand Up @@ -369,14 +62,6 @@ def get_helper():
return TEInstalledHelper
return TENotInstalledHelper

@staticmethod
def get_stack_transformer(stacked_transformer_p, dtype):
return TransformerEngineHelper.get_helper().get_stack_transformer(stacked_transformer_p, dtype)

@staticmethod
def get_pipeline_transformer(pipeline_transformer_p):
return TransformerEngineHelper.get_helper().get_pipeline_transformer(pipeline_transformer_p)

@staticmethod
@contextmanager
def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"):
Expand Down

0 comments on commit c91d866

Please sign in to comment.