Skip to content

Commit

Permalink
propagate scale_inv modification to GroupedLinear
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Yao <[email protected]>
  • Loading branch information
yaox12 committed Aug 22, 2024
1 parent 47caafb commit f7ed83f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/cpp_extensions/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,15 @@ def fp8_multi_cast_transpose_fused(
amax_indices: List[int],
scale_inv_indices: List[int],
otype: tex.DType,
scale_inv: Optional[torch.Tensor] = None,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Cast + Transpose with FP8 output"""

return tex.fused_multi_cast_transpose_alloc(
input_list,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
scale_inv if scale_inv is not None else fp8_meta_tensor.scale_inv,
scale_indices,
amax_indices,
scale_inv_indices,
Expand Down
35 changes: 27 additions & 8 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor
from ..export import is_in_onnx_export_mode

__all__ = ["GroupedLinear"]

Expand Down Expand Up @@ -102,10 +103,12 @@ def forward(
inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats]
inputmats = []
inputmats_t = []
inputmat_scale_inv = None

global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
inputmat_scale_inv = torch.empty([num_gemms], dtype=torch.float32, device=inp.device)
if (
not fp8_meta["recipe"].override_linear_precision.wgrad
and is_grad_enabled
Expand All @@ -121,6 +124,7 @@ def forward(
indices, # amax_indices
indices, # scale_inv_indices
fp8_dtype_forward,
scale_inv=inputmat_scale_inv,
)
else:
# FP8 input for forward
Expand All @@ -130,9 +134,22 @@ def forward(
fp8_meta["scaling_fwd"],
_GEMM_INPUT + i,
fp8_dtype_forward,
scale_inv=inputmat_scale_inv,
)
for i in range(num_gemms)
]

# Hack for ONNX export
# Note: ONNX models are represented as a graph of tensor
# operations, so the in-place scale-inv update doesn't fit
# very well. We work around this by making it look like
# the scale-inv tensor is initialized with a copy.
# Note: ONNX export expects FP8 scales can be represented
# with constant ops. However, copying into a buffer
# involves an expand op for array broadcasting. We work
# around this by filling the buffer instead.
if is_in_onnx_export_mode():
inputmat_scale_inv.fill_(inputmat_scale_inv.item())
else:
inputmats = inputmats_no_fp8

Expand All @@ -153,12 +170,14 @@ def forward(

_ = fp8_grouped_gemm(
[w._data for w in weights_fp8],
fp8_meta["scaling_fwd"].scale_inv,
_GEMM_WEIGHT,
torch.cat(
[w._scale_inv for w in weights_fp8]
), # avoiding torch.cat requires another interface,
0, # weight offset is 0 for the newly created _scale_inv
fp8_dtype_forward,
inputmats,
fp8_meta["scaling_fwd"].scale_inv,
_GEMM_INPUT,
inputmat_scale_inv,
0,
fp8_dtype_forward,
torch.split(out, m_splits),
activation_dtype,
Expand Down Expand Up @@ -230,7 +249,7 @@ def forward(
t.activation_offloading = True

ctx.save_for_backward(
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
inputmat_scale_inv,
*saved_inputmats,
*saved_inputmats_t,
*weights,
Expand Down Expand Up @@ -270,7 +289,7 @@ def forward(
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
with torch.cuda.nvtx.range("_GroupedLinear_backward"):
(
fwd_scale_inverses,
inputmat_scale_inv,
*saved_tensors,
) = ctx.saved_tensors
inputmats = saved_tensors[: ctx.num_gemms]
Expand Down Expand Up @@ -396,8 +415,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
inp._data if isinstance(inp, Float8Tensor) else inp
for inp in inputmats_t
],
fwd_scale_inverses,
_GEMM_INPUT,
inputmat_scale_inv,
0,
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv,
Expand Down

0 comments on commit f7ed83f

Please sign in to comment.