Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Remove special handling for FP8 params in FP8 recipe infrastructure #1326

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def test_fp8_scale_update(
)

# Check that scaling factors match expected
w_amax_ref = max(w_vals[: step + 2])
w_amax_ref = max(w_vals[: step + 1])
x_amax_ref = max(x_vals[: step + 1])
dy_amax_ref = max(dy_vals[: step + 1])
w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin)
Expand Down
47 changes: 7 additions & 40 deletions transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ def reset(cls) -> None:
cls.fp8_available = None
cls.reason_for_no_fp8 = ""
cls.autocast_arguments = {}
cls.autocast_to_fp8_params = {}
cls.fp8_param_to_autocast = {}
cls.skip_fp8_weight_update_tensor = None

@classmethod
Expand Down Expand Up @@ -156,28 +154,25 @@ def get_buffer_info(cls) -> str:
def get_key_in_buffer(
cls,
forward: bool,
fp8_weights: bool,
fp8_recipe: DelayedScaling,
fp8_group: dist_group_type,
) -> str:
"""Returns a key into the global FP8 buffers."""
autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
fwd_bwd_key = cls.get_fwd_bwd_key(forward)
return f"{fwd_bwd_key}_{fp8_weights}_{autocast_key}"
return f"{fwd_bwd_key}_{autocast_key}"

@classmethod
def split_key_in_buffer(cls, key: str) -> Tuple[bool, bool, str]:
def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]:
"""Splits buffer key into relevant parts."""
forward, fp8_weights, autocast_key = key.split("_", 2)
forward, autocast_key = key.split("_", 1)
forward = forward == "forward"
fp8_weights = fp8_weights == "True"
return forward, fp8_weights, autocast_key
return forward, autocast_key

@classmethod
def add_fp8_tensors_to_global_buffer(
cls,
fp8_meta: Dict[str, Any],
fp8_weights: Optional[List[torch.Tensor]] = None,
) -> None:
"""
The amax reduction process happens completely outside the FP8 modules.
Expand All @@ -202,33 +197,12 @@ def add_fp8_tensors_to_global_buffer(

fp8_meta[index_in_buffer] = []
for forward in (True, False):
# This algorithm creates a two-way map with `autocast_to_fp8_params` and
# `fp8_param_to_autocast`. This is used for keeping track of FP8 weights
# in an autocasted region and cross reference them in `float8_tensor.py`
# to perform the forward amax reduction.
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
if fp8_meta_tensor_key not in fp8_meta:
# Handles non-parameter FP8 modules, e.g. DPA.
continue

if forward and fp8_weights is not None:
autocast_key = cls.get_unique_autocast_key(
fp8_meta["recipe"], fp8_meta["fp8_group"]
)
fp8_weight_set = {id(w._data) for w in fp8_weights}
if autocast_key not in cls.autocast_to_fp8_params:
cls.autocast_to_fp8_params[autocast_key] = fp8_weight_set
else:
cls.autocast_to_fp8_params[autocast_key] = cls.autocast_to_fp8_params[
autocast_key
].union(fp8_weight_set)
# Identify correct autocast key for a given param.
for w in fp8_weight_set:
cls.fp8_param_to_autocast[w] = autocast_key

key = cls.get_key_in_buffer(
forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"]
)
key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"])

if key not in cls.global_amax_buffer:
cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
Expand Down Expand Up @@ -327,20 +301,13 @@ def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_ty
def reduce_and_update_fp8_tensors(
cls,
forward: bool = True,
fp8_weights: bool = False,
) -> None:
"""Concatenate, reduce, and split amaxes in the global buffer."""
for buffer_key, amax_buffer in cls.global_amax_buffer.items():
# Check for forward or backward reduction.
fwd_update, fp8_weights_update, autocast_key = cls.split_key_in_buffer(buffer_key)
fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key)
if fwd_update != forward:
continue
# Only skip a forward update when `fp8_weights` is explicitly set to `True`
# (inside optimizer) and the current key is not an `fp8_weight_update` key.
# For other cases, we need to reduce because of activation tensors.
# TODO(ksivaman) consider separate weight and activation fp8_tensors.
if fwd_update and fp8_weights and not fp8_weights_update:
continue
if len(amax_buffer) == 0:
continue

Expand Down Expand Up @@ -434,7 +401,7 @@ def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None:
# FP8 weight modules are reduced at the end of the optimizer
# step after the weight amax is populated.
if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
cls.reduce_and_update_fp8_tensors(forward=True, fp8_weights=False)
cls.reduce_and_update_fp8_tensors(forward=True)

@classmethod
def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def new_fwd(*user_args, **user_kwargs):
m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
m.fp8_meta, fp8_weights=m._get_fp8_params()
m.fp8_meta,
)
return graphed(*user_args, **user_kwargs)
return orig_fwd(*user_args, **user_kwargs)
Expand Down
4 changes: 1 addition & 3 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,9 +762,7 @@ def prepare_forward(
)

if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing():
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self.fp8_meta, fp8_weights=self._get_fp8_params()
)
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)

# Activation recomputation is used and this is the first forward phase.
if self.fp8 and self.training and is_fp8_activation_recompute_enabled():
Expand Down
4 changes: 1 addition & 3 deletions transformer_engine/pytorch/ops/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
FP8GlobalStateManager,
get_default_fp8_recipe,
)
from ._common import canonicalize_device, is_float8_tensor
from ._common import canonicalize_device


@dataclasses.dataclass
Expand Down Expand Up @@ -379,10 +379,8 @@ def pre_forward(
self.get_fp8_meta("input"),
)
if self.num_fp8_scales("param"):
fp8_params = list(filter(is_float8_tensor, self.parameters()))
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self.get_fp8_meta("param"),
fp8_weights=(fp8_params if fp8_params else None),
)
if self.num_fp8_scales("grad_output"):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
Expand Down
27 changes: 0 additions & 27 deletions transformer_engine/pytorch/tensor/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,30 +74,6 @@ def backward(
return grad, None


def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None:
"""Amax scale and update when there is at least 1 trainable FP8 parameter."""
param_id = id(param._data)

if param_id not in FP8GlobalStateManager.fp8_param_to_autocast:
return

autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id]

if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params:
return

if autocast_key in updated_fp8_params:
updated_fp8_params[autocast_key].add(param_id)
else:
updated_fp8_params[autocast_key] = {param_id}

current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key]
# All FP8 trainable parameters have been updated.
if updated_fp8_params[autocast_key] == current_fp8_params_set:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True, fp8_weights=True)
del updated_fp8_params[autocast_key]


class _ToFloat8Func(torch.autograd.Function):
"""Cast to FP8 from other dtype"""

Expand Down Expand Up @@ -676,9 +652,6 @@ def quantize_(
)
dst._transpose_invalid = False

# Callback hook to perform amax reduction after optimizer step
post_optimizer_step_fwd_amax_reduction(self)

return self

@classmethod
Expand Down
Loading