Skip to content

Commit

Permalink
Support FM32 to OWG parameters.
Browse files Browse the repository at this point in the history
Signed-off-by: Ming Huang <[email protected]>
  • Loading branch information
mingxu1067 authored and ashors1 committed Jun 28, 2024
1 parent 9c21623 commit ab247d6
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion paxml/trainer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from etils import epath
import fiddle as fdl
from flax import struct as flax_struct
from flax.linen.fp8_ops import fm32
import jax
from jax import numpy as jnp
from jax.experimental import pjit
Expand All @@ -35,7 +36,7 @@
from paxml import sgf
from paxml import tasks_lib
from paxml import train_states
from paxml.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper, DEFAULT_INIT_MUTABLE_LIST
from paxml.contrib.gpu.scripts_gpu.te_helper import DEFAULT_INIT_MUTABLE_LIST
from praxis import asserts
from praxis import base_hyperparams
from praxis import base_input
Expand Down Expand Up @@ -804,6 +805,21 @@ def _default_apply_fn(
)


def _maybe_to_fm32_vars(mdl_vars, var_weight_hparams):
asserts.assert_same_structure(mdl_vars, var_weight_hparams)

def _maybe_fm32_var_fn(var, var_param):
if base_layer.var_overwrite_with_gradient(var_param):
return jax.lax.convert_element_type(var, fm32)
else:
return var

is_leaf = lambda x: not isinstance(x, (tuple, dict, list))
return jax.tree_util.tree_map(
_maybe_fm32_var_fn, mdl_vars, var_weight_hparams, is_leaf=is_leaf
)


class LossFnProtocol(Protocol):

def __call__(
Expand Down Expand Up @@ -834,6 +850,8 @@ def _loss_fn(
else:
assert NotImplementedError(f'fprop_dtype {fprop_dtype} not supported.')

mdl_vars = _maybe_to_fm32_vars(mdl_vars, var_weight_hparams)

with base_layer.JaxContext.new_context(hparams=context_p):
k1, k2, k3 = jax.random.split(prng_key, 3)
(metrics, per_example_output), updated_vars = apply_fn(
Expand Down

0 comments on commit ab247d6

Please sign in to comment.