Skip to content

Commit

Permalink
Fix missing DEFAULT_INIT_MUTABLE_LIST
Browse files Browse the repository at this point in the history
  • Loading branch information
mingxu1067 authored and ashors1 committed May 14, 2024
1 parent c91d866 commit b6add3c
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions paxml/contrib/gpu/scripts_gpu/te_helper.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import os
from contextlib import contextmanager

from praxis import base_layer

try:
import transformer_engine.jax as te
from transformer_engine.common import recipe
_IS_TRANSFORMER_ENGINE_INSTALLED = True
DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST + [te.fp8.FP8Helper.FP8_COLLECTION_NAME]

except ModuleNotFoundError as e:
_IS_TRANSFORMER_ENGINE_INSTALLED = False
DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST


class TransformerEngineHelperBase:
Expand Down

0 comments on commit b6add3c

Please sign in to comment.