Skip to content

Commit

Permalink
Apply OWG to TE's FP8 meta
Browse files Browse the repository at this point in the history
  • Loading branch information
mingxu1067 authored and ashors1 committed May 14, 2024
1 parent 3d6ff0f commit b289c45
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 67 deletions.
59 changes: 0 additions & 59 deletions paxml/contrib/gpu/scripts_gpu/te_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from contextlib import contextmanager
from typing import Optional, Sequence

import jax
import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
from praxis import base_layer
Expand Down Expand Up @@ -234,18 +233,6 @@ def get_stack_transformer(stacked_transformer_p, dtype):
def get_pipeline_transformer(pipeline_transformer_p):
raise NotImplementedError

@staticmethod
def update_fp8_metas_if_needed(mdl_vars, grads):
raise NotImplementedError

@staticmethod
def include_fp8_for_grads_if_needed(variables):
raise NotImplementedError

@staticmethod
def mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt):
raise NotImplementedError

@staticmethod
@contextmanager
def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"):
Expand All @@ -262,18 +249,6 @@ def get_stack_transformer(stacked_transformer_p, dtype):
def get_pipeline_transformer(pipeline_transformer_p):
return pipeline_transformer_p

@staticmethod
def update_fp8_metas_if_needed(mdl_vars, grads):
return mdl_vars

@staticmethod
def include_fp8_for_grads_if_needed(variables):
return variables

@staticmethod
def mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt):
return grads

@staticmethod
@contextmanager
def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"):
Expand Down Expand Up @@ -363,28 +338,6 @@ def get_pipeline_transformer(pipeline_transformer_p):

return te_pipeline_transformer_p

@staticmethod
def update_fp8_metas_if_needed(mdl_vars, grads):
FP8_COLLECTION_NAME = te.fp8.FP8Helper.FP8_COLLECTION_NAME
if FP8_COLLECTION_NAME in grads:
mdl_vars[FP8_COLLECTION_NAME] = grads[FP8_COLLECTION_NAME]
return mdl_vars

@staticmethod
def include_fp8_for_grads_if_needed(variables):
FP8_COLLECTION_NAME = te.fp8.FP8Helper.FP8_COLLECTION_NAME
if FP8_COLLECTION_NAME in variables:
variables[FP8_COLLECTION_NAME] = \
jax.tree_util.tree_map(lambda x: False, variables[FP8_COLLECTION_NAME])
return variables

@staticmethod
def mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt):
FP8_COLLECTION_NAME = te.fp8.FP8Helper.FP8_COLLECTION_NAME
if FP8_COLLECTION_NAME in grads:
grads[FP8_COLLECTION_NAME] = vars_with_opt[FP8_COLLECTION_NAME].copy()
return grads

@staticmethod
@contextmanager
def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"):
Expand Down Expand Up @@ -424,18 +377,6 @@ def get_stack_transformer(stacked_transformer_p, dtype):
def get_pipeline_transformer(pipeline_transformer_p):
return TransformerEngineHelper.get_helper().get_pipeline_transformer(pipeline_transformer_p)

@staticmethod
def update_fp8_metas_if_needed(mdl_vars, grads):
return TransformerEngineHelper.get_helper().update_fp8_metas_if_needed(mdl_vars, grads)

@staticmethod
def include_fp8_for_grads_if_needed(variables):
return TransformerEngineHelper.get_helper().include_fp8_for_grads_if_needed(variables)

@staticmethod
def mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt):
return TransformerEngineHelper.get_helper().mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt)

@staticmethod
@contextmanager
def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"):
Expand Down
12 changes: 4 additions & 8 deletions paxml/trainer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,16 +995,15 @@ def get_excluded_var_masks(
excluded_for_grad = tasks_lib.get_excluded_var_mask_for_grad(
var_weight_hparams, learner
)
excluded_for_grad_but_fp8_meta = TransformerEngineHelper.include_fp8_for_grads_if_needed(excluded_for_grad.copy())

_log_bprop_include_exclude_list(var_weight_hparams, excluded_for_grad_but_fp8_meta)
_log_bprop_include_exclude_list(var_weight_hparams, excluded_for_grad)

# Excluded for optimizer states.
excluded_for_opt = tasks_lib.get_excluded_var_mask_for_opt(
var_weight_hparams,
learner,
)
return excluded_for_grad, excluded_for_grad_but_fp8_meta, excluded_for_opt
return excluded_for_grad, excluded_for_opt


def _prepare_tree_data_for_summary(tree):
Expand Down Expand Up @@ -1103,13 +1102,13 @@ def train_step_single_learner(

_, subkey = jax.random.split(prng_key)

excluded_for_grad, excluded_for_grad_but_fp8_meta, excluded_for_opt = get_excluded_var_masks(
excluded_for_grad, excluded_for_opt = get_excluded_var_masks(
var_weight_hparams, learner
)

# Construct and call the grad function.
if not grad_fn:
grad_fn = _get_default_grad_fn(excluded_for_grad_but_fp8_meta, excluded_for_opt)
grad_fn = _get_default_grad_fn(excluded_for_grad, excluded_for_opt)
(weighted_loss, aux_info), grads = grad_fn(
loss_fn=_get_default_loss_fn(
jax_task=jax_task,
Expand Down Expand Up @@ -1166,9 +1165,6 @@ def train_step_single_learner(
var_weight_hparams, excluded_for_learner
)

mdl_vars = TransformerEngineHelper.update_fp8_metas_if_needed(mdl_vars, grads)
grads = TransformerEngineHelper.mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt)

transformed_grads, new_opt_states = learner.update_states(
grads, states.opt_states[0], vars_with_opt, wps_with_opt
)
Expand Down

0 comments on commit b289c45

Please sign in to comment.