Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Fused dbias-cast-transpose in bias operation #1168

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,36 +450,45 @@ def test_reshape(
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)

@pytest.mark.parametrize("size", (1, 7, 32))
@pytest.mark.parametrize("size", (7, 32))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (2, 3, 4, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("fp8_grad_input", (False, True))
def test_bias(
self,
*,
size: int,
in_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device,
fp8: bool,
device: torch.device = "cuda",
fp8_input: bool = False,
fp8_grad_output: bool = False,
fp8_grad_input: bool,
) -> None:

# Make input and bias shapes consistent
in_shape = list(in_shape)[:-1] + [size]

# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
if fp8_input or fp8_grad_output or fp8_grad_input:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")

# FP8 recipe
fp8_recipe = None
if fp8_grad_input:
fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)

# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_fp8=fp8_input,
)
b_ref, b_test = make_reference_and_test_tensors(
size,
Expand All @@ -490,26 +499,34 @@ def test_bias(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=(fp8_grad_output or fp8_grad_input),
requires_grad=False,
)
if not fp8_grad_output and is_float8_tensor(dy_test):
dy_test = dy_test.from_float8()

# Plain PyTorch implementation
y_ref = x_ref + b_ref.reshape([1] * (len(in_shape) - 1) + [size])
y_ref.backward(dy_ref)

# Implementation with fusible operation
op = te_ops.Bias(size, device=device, dtype=dtype)
bias_op = te_ops.Bias(size, device=device, dtype=dtype)
with torch.no_grad():
op.bias.copy_(b_test)
bias_op.bias.copy_(b_test)
del b_test
y_test = op(x_test)
forward = te_ops.Sequential(
te_ops.CastFloat8(forward=False, backward=fp8_grad_input),
bias_op,
)
with te.fp8_autocast(enabled=fp8_grad_input, fp8_recipe=fp8_recipe):
y_test = forward(x_test)
y_test.backward(dy_test)

# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
db_test = bias_op.bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(db_test, b_ref.grad, **tols)
Expand Down
12 changes: 1 addition & 11 deletions transformer_engine/pytorch/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,7 @@

"""

from transformer_engine.pytorch.ops.basic import (
AddInPlace,
AllGather,
AllReduce,
BasicLinear,
Bias,
Identity,
MakeExtraOutput,
ReduceScatter,
Reshape,
)
from transformer_engine.pytorch.ops.basic import *
from transformer_engine.pytorch.ops.linear import Linear
from transformer_engine.pytorch.ops.op import FusibleOperation
from transformer_engine.pytorch.ops.sequential import Sequential
1 change: 1 addition & 0 deletions transformer_engine/pytorch/ops/basic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .all_reduce import AllReduce
from .basic_linear import BasicLinear
from .bias import Bias
from .cast_float8 import CastFloat8
from .identity import Identity
from .make_extra_output import MakeExtraOutput
from .reduce_scatter import ReduceScatter
Expand Down
98 changes: 88 additions & 10 deletions transformer_engine/pytorch/ops/basic/bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@

import torch

from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from ...cpp_extensions import fp8_cast_transpose_bgrad_fused
from ...float8_tensor import Float8Tensor
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ..op import BasicOperation, OperationContext
from .._common import (
canonicalize_device,
canonicalize_dtype,
is_float8_tensor,
reshape,
)


Expand Down Expand Up @@ -60,9 +62,14 @@ def __init__(
if device.type == "meta":
defer_param_init = True
device = canonicalize_device(None)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
self.device: torch.device = device

# Bias tensor datatype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
self.dtype: torch.dtype = canonicalize_dtype(dtype)

# Tensor parallel configuration
Expand Down Expand Up @@ -101,7 +108,7 @@ def reset_parameters(self) -> None:

# Make sure parameter is initialized
bias = self.bias
if bias.device.type != "cuda":
if bias.device.type != self.device.type:
bias = torch.empty_like(bias, device=self.device)
bias = bias.to(device=self.device, dtype=self.dtype)

Expand All @@ -125,18 +132,89 @@ def op_forward(
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:

# Apply bias
x = input_
b = self.bias.reshape([1] * (x.dim() - 1) + [self.local_size])
return x + b
y = x + b

# Save state for backward pass
ctx.bias_requires_grad = self.bias.requires_grad
ctx.with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled()
ctx.prev_op = prev_op

return y

def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:

# Check if FP8 is enabled
with_fp8_grad_input = (
ctx.with_fp8_compute
and ctx.prev_op is not None
and ctx.prev_op.num_fp8_scales("grad_output") > 0
and grad_output.size(-1) % 16 == 0
and grad_output.numel() // grad_output.size(-1) % 16 == 0
)

# Compute grad bias
dy = grad_output
if dy.dim() > 1:
db = dy.sum(tuple(range(dy.dim() - 1)))
db = None
dx: torch.Tensor
if not ctx.bias_requires_grad:
# Trivial case: Don't compute bgrad, don't do anything
# with dgrad
dx = dy
if not with_fp8_grad_input or is_float8_tensor(dy):
# Non-FP8 case: Compute bgrad, don't do anything with
# dgrad
if dy.dim() > 1:
db = dy.sum(tuple(range(dy.dim() - 1)))
else:
db = dy
dx = dy
else:
db = dy
return dy, (db,)
# FP8 case: Call fused kernel to compute bgrad and cast
# dgrad to FP8

# Check grad output tensor
output_dims = grad_output.size()
dy = reshape(
dy,
(-1, output_dims[-1]),
device=self.device,
dtype=self.dtype,
)

# Call fused kernel for bgrad and casting dgrad to FP8
fp8_meta = ctx.prev_op.get_fp8_meta("grad_output")
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False)
fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=self.device)
db, dx_data, dx_data_transpose = fp8_cast_transpose_bgrad_fused(
dy,
fp8_meta[fp8_meta_key],
0,
fp8_dtype,
scale_inv=fp8_scale_inv,
)

# Construct grad input tensor
if dx_data.size() != output_dims:
dx_data = dx_data.reshape(output_dims)
dx = Float8Tensor(
data=dx_data,
fp8_meta=fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
dtype=self.dtype,
)
dx._transpose = dx_data_transpose
dx._transpose_invalid = False

return dx, (db,)
100 changes: 100 additions & 0 deletions transformer_engine/pytorch/ops/basic/cast_float8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Fusible operation for identity."""

from __future__ import annotations
from typing import Optional

import torch

from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
get_fp8_te_dtype,
)
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from .._common import is_float8_tensor


class CastFloat8(BasicOperation):
"""Cast tensor to FP8

Uses FP8 recipe from `fp8_autocast` context. When called outside
of an `fp8_autocast` context, this is an identity operation.

Parameters
----------
forward: bool, default = `True`
Perform FP8 cast in forward pass
backward: bool, default = `False`
Perform FP8 cast in backward pass

"""

def __init__(
self,
forward: bool = True,
backward: bool = False,
) -> None:
super().__init__()
self._cast_forward = forward
self._cast_backward = backward

def num_fp8_scales(self, mode: str) -> int:
if mode == "input" and self._cast_forward:
return 1
if mode == "grad_output" and self._cast_backward:
return 1
return 0

def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:

# Check if FP8 is enabled
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
cast_forward = fp8_enabled and self._cast_forward
cast_backward = fp8_enabled and self._cast_backward

# Cast to FP8 if needed
out = input_
if cast_forward and not is_float8_tensor(out):
fp8_meta = self.get_fp8_meta("input")
fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
out = Float8Tensor.to_float8(
out,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
)

ctx.cast_backward = cast_backward
return out

def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
grad_input = grad_output
if ctx.cast_backward and not is_float8_tensor(grad_input):
fp8_meta = self.get_fp8_meta("grad_output")
fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
grad_input = Float8Tensor.to_float8(
grad_input,
fp8_meta=fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
)
return grad_input, ()
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,13 @@ def fuser_forward(
idx = self._op_idxs["linear"]
linear_op = self.basic_ops[idx]
linear_op_ctx = basic_op_ctxs[idx]
if self._op_idxs["bias"] is None:
bias_op = None
bias = None
else:
bias_op = None
bias_op_ctx = None
bias = None
if self._op_idxs["bias"] is not None:
idx = self._op_idxs["bias"]
bias_op = self.basic_ops[idx]
bias_op_ctx = basic_op_ctxs[idx]
bias = bias_op.bias
if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments")
Expand Down Expand Up @@ -129,7 +130,11 @@ def fuser_forward(
linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_.requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
linear_op_ctx.has_prev_op = basic_op_prev_ops[self._op_idxs["linear"]] is not None
if bias_op_ctx is not None:
bias_op_ctx.bias_requires_grad = bias.requires_grad
bias_op_ctx.with_fp8_compute = with_fp8_compute
bias_op_ctx.prev_op = basic_op_prev_ops[self._op_idxs["bias"]]

return output, [() for _ in range(len(self.basic_ops))]

Expand Down
Loading
Loading