Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] make GroupedLinear inp support collection of torch.Tensor #1120

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,59 @@ def prepare_forward(
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
return

@contextmanager
def prepare_grouped_forward(
self,
inputmats: List[torch.Tensor],
is_first_microbatch: Union[bool, None], # pylint: disable=unused-argument
num_gemms: int = 1,
allow_non_contiguous: bool = False,
) -> Generator[List[torch.Tensor], None, None]:
"""Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
"""
# Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
else:
for inp in inputmats:
assert inp.is_cuda, "TransformerEngine needs CUDA."
self.set_activation_dtype(inp)

if self.tp_size > 1:
assert self.tp_group_initialized, "TP group not initialized."

self.init_fp8_metadata(num_gemms=num_gemms)

if self.fp8 and self.sequence_parallel:
assert self.fp8_meta["recipe"].reduce_amax, (
"Amax reduction across tensor parallel group is "
"necessary when using sequence parallelism with FP8."
)

if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing():
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self.fp8_meta, fp8_weights=self._get_fp8_params()
)

# Activation recomputation is used and this is the first forward phase.
if self.fp8 and self.training and is_fp8_activation_recompute_enabled():
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)

with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
for i, inp in enumerate(inputmats):
if not allow_non_contiguous:
inputmats[i] = inp.contiguous()

yield inputmats

if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
return

def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled
before the GEMM for there to be a guaranteed overlap. From the
Expand Down
81 changes: 57 additions & 24 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ class _GroupedLinear(torch.autograd.Function):
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
m_splits: List[int],
use_bias: bool,
is_first_microbatch: Union[bool, None],
Expand All @@ -82,26 +81,35 @@ def forward(
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
is_grad_enabled: bool,
weights_fp8: List[Union[Float8Tensor, None]],
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
split_inp: bool,
*inp_weights_biases: Union[torch.Tensor, Float8Tensor, torch.Tensor, None],
) -> torch.Tensor:
num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms]
biases = weights_and_biases[num_gemms:]
if split_inp:
offset = 1
inp = inp_weights_biases[0]
else:
offset = num_gemms
inputmats = inp_weights_biases[0:offset]

weights = inp_weights_biases[offset : num_gemms + offset]
weights_fp8 = inp_weights_biases[num_gemms + offset : 2 * num_gemms + offset]
biases = inp_weights_biases[2 * num_gemms + offset :]

# Make sure input dimensions are compatible
in_features = weights[0].shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmats = torch.split(inp.view(-1, in_features), m_splits)
if split_inp:
inputmats = torch.split(inp.view(-1, in_features), m_splits)

for inp in inputmats:
assert inp.shape[-1] == in_features, "GEMM not possible"
if fp8:
for i in range(num_gemms):
assert_dim_for_fp8_exec(inputmats[i])
assert_dim_for_fp8_exec(weights[i])

# Cast input to expected dtype
inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats]
inputmats = []
inputmats_t = []

global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
if fp8:
Expand Down Expand Up @@ -251,12 +259,12 @@ def forward(
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
ctx.reduce_and_update_bwd_fp8_tensors = False
ctx.split_inp = split_inp
if ctx.fp8 and requires_grad(inp, weights[0], biases[0]):
ctx.reduce_and_update_bwd_fp8_tensors = (
ctx.reduce_and_update_bwd_fp8_tensors
Expand Down Expand Up @@ -340,6 +348,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
dtype=ctx.activation_dtype,
device=grad_output.device,
)
dgrad_list = torch.split(dgrad, ctx.m_splits)
fp8_grouped_gemm(
[w.transpose_2d() for w in weights_fp8],
torch.cat(
Expand All @@ -351,7 +360,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
ctx.fp8_meta["scaling_bwd"].scale_inv,
_GRAD_OUTPUT,
fp8_dtype_backward,
torch.split(dgrad, ctx.m_splits),
dgrad_list,
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
Expand All @@ -362,10 +371,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
dtype=ctx.activation_dtype,
device=grad_output.device,
)
dgrad_list = torch.split(dgrad, ctx.m_splits)
grouped_gemm(
weights,
grad_output_mats,
torch.split(dgrad, ctx.m_splits),
dgrad_list,
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
layout="NN",
Expand Down Expand Up @@ -469,11 +479,18 @@ def handle_custom_ddp_from_mcore(w, wgrad):
handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list)
]

if ctx.requires_dgrad:
if ctx.split_inp:
dgrad_holder = [dgrad]
else:
dgrad_holder = dgrad_list
else:
dgrad_holder = [None] * ctx.num_gemms

if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
None, # m_splits
None, # use_bias
None, # is_first_microbatch
Expand All @@ -489,8 +506,10 @@ def handle_custom_ddp_from_mcore(w, wgrad):
None, # activation_dtype
None, # parallel_mode
None, # is_grad_enabled
None, # weights_fp8
None, # split_inp
*dgrad_holder,
*wgrad_list,
*([None] * ctx.num_gemms), # weights_fp8
*grad_biases,
)

Expand Down Expand Up @@ -696,7 +715,7 @@ def reset_parameters(self, defer_init=False):
@no_torch_dynamo()
def forward(
self,
inp: torch.Tensor,
inp: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
m_splits: List[int],
is_first_microbatch: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
Expand All @@ -705,8 +724,11 @@ def forward(

Parameters
----------
inp : torch.Tensor
Input tensor.
inp : {torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]}
Input tensor or collection of input tensor.

* `inp` will be split base `m_splits` if type of `inp` is torch.tensor.
* assume `inp` is split base `m_splits` if `inp` is collection of torch.Tensor.
m_splits : List[int]
List of integers representing the split of the input tensor.
is_first_microbatch : {True, False, None}, default = None
Expand All @@ -723,17 +745,27 @@ def forward(
first microbatch (since it is the first gradient being
produced)
"""
assert not isinstance(
inp, Float8Tensor
), "GroupedLinear doesn't support input tensor in FP8."
if isinstance(inp, torch.Tensor):
split_inp = True
inputmats = [inp]
else:
# inp is a list or tuple
split_inp = False
inputmats = inp

for inp in inputmats:
assert not isinstance(
inp, Float8Tensor
), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."

skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None:
is_first_microbatch = False

with self.prepare_forward(inp, is_first_microbatch, num_gemms=self.num_gemms) as inp:

with self.prepare_grouped_forward(
list(inputmats), is_first_microbatch, num_gemms=self.num_gemms
) as inputmats:
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fp8:
Expand Down Expand Up @@ -785,7 +817,6 @@ def forward(
linear_fn = _GroupedLinear.forward
args = [None]
args += (
inp,
m_splits,
self.apply_bias and not self.gemm_bias_unfused_add,
is_first_microbatch,
Expand All @@ -801,8 +832,10 @@ def forward(
self.activation_dtype,
self.parallel_mode,
torch.is_grad_enabled(),
weight_tensors_fp8,
split_inp,
*inputmats,
*weight_tensors,
*weight_tensors_fp8,
*bias_tensors,
)
out = linear_fn(*args)
Expand Down
Loading