diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3613e1fa5e..22db22d692 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -703,6 +703,59 @@ def prepare_forward( FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) return + @contextmanager + def prepare_grouped_forward( + self, + inputmats: List[torch.Tensor], + is_first_microbatch: Union[bool, None], # pylint: disable=unused-argument + num_gemms: int = 1, + allow_non_contiguous: bool = False, + ) -> Generator[List[torch.Tensor], None, None]: + """Checks and prep for FWD. + The context manager is needed because there isn't a way for a module to know + if it's the last FP8 module in the forward autocast. It is useful + to setup the forward aggregated amax reduction for every module + just in case. The autocast exit will pick up the most recent one. + """ + # Activation recomputation is used and this is the second forward phase. + if self.fp8 and in_fp8_activation_recompute_phase(): + FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) + else: + for inp in inputmats: + assert inp.is_cuda, "TransformerEngine needs CUDA." + self.set_activation_dtype(inp) + + if self.tp_size > 1: + assert self.tp_group_initialized, "TP group not initialized." + + self.init_fp8_metadata(num_gemms=num_gemms) + + if self.fp8 and self.sequence_parallel: + assert self.fp8_meta["recipe"].reduce_amax, ( + "Amax reduction across tensor parallel group is " + "necessary when using sequence parallelism with FP8." + ) + + if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing(): + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( + self.fp8_meta, fp8_weights=self._get_fp8_params() + ) + + # Activation recomputation is used and this is the first forward phase. + if self.fp8 and self.training and is_fp8_activation_recompute_enabled(): + FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) + + with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): + for i, inp in enumerate(inputmats): + if not allow_non_contiguous: + inputmats[i] = inp.contiguous() + + yield inputmats + + if self.fp8 and in_fp8_activation_recompute_phase(): + FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) + return + def set_nccl_overlap_warning_if_tp(self) -> None: """When using TP, the NCCL communication needs to be scheduled before the GEMM for there to be a guaranteed overlap. From the diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index a91ff5c361..660ab690f0 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -66,7 +66,6 @@ class _GroupedLinear(torch.autograd.Function): @staticmethod def forward( ctx, - inp: torch.Tensor, m_splits: List[int], use_bias: bool, is_first_microbatch: Union[bool, None], @@ -82,17 +81,28 @@ def forward( activation_dtype: torch.dtype, parallel_mode: Union[str, None], is_grad_enabled: bool, - weights_fp8: List[Union[Float8Tensor, None]], - *weights_and_biases: Union[Float8Tensor, torch.Tensor, None], + split_inp: bool, + *inp_weights_biases: Union[torch.Tensor, Float8Tensor, torch.Tensor, None], ) -> torch.Tensor: num_gemms = len(m_splits) - weights = weights_and_biases[:num_gemms] - biases = weights_and_biases[num_gemms:] + if split_inp: + offset = 1 + inp = inp_weights_biases[0] + else: + offset = num_gemms + inputmats = inp_weights_biases[0:offset] + + weights = inp_weights_biases[offset : num_gemms + offset] + weights_fp8 = inp_weights_biases[num_gemms + offset : 2 * num_gemms + offset] + biases = inp_weights_biases[2 * num_gemms + offset :] # Make sure input dimensions are compatible in_features = weights[0].shape[-1] - assert inp.shape[-1] == in_features, "GEMM not possible" - inputmats = torch.split(inp.view(-1, in_features), m_splits) + if split_inp: + inputmats = torch.split(inp.view(-1, in_features), m_splits) + + for inp in inputmats: + assert inp.shape[-1] == in_features, "GEMM not possible" if fp8: for i in range(num_gemms): assert_dim_for_fp8_exec(inputmats[i]) @@ -100,8 +110,6 @@ def forward( # Cast input to expected dtype inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats] - inputmats = [] - inputmats_t = [] global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT if fp8: @@ -251,12 +259,12 @@ def forward( ctx.use_bias = use_bias ctx.sequence_parallel = sequence_parallel ctx.tensor_parallel = tensor_parallel - ctx.inp_shape = inp.shape ctx.parallel_mode = parallel_mode ctx.tp_group = tp_group ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False + ctx.split_inp = split_inp if ctx.fp8 and requires_grad(inp, weights[0], biases[0]): ctx.reduce_and_update_bwd_fp8_tensors = ( ctx.reduce_and_update_bwd_fp8_tensors @@ -340,6 +348,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=grad_output.device, ) + dgrad_list = torch.split(dgrad, ctx.m_splits) fp8_grouped_gemm( [w.transpose_2d() for w in weights_fp8], torch.cat( @@ -351,7 +360,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx.fp8_meta["scaling_bwd"].scale_inv, _GRAD_OUTPUT, fp8_dtype_backward, - torch.split(dgrad, ctx.m_splits), + dgrad_list, ctx.activation_dtype, get_multi_stream_cublas_workspace(), use_split_accumulator=_2X_ACC_DGRAD, @@ -362,10 +371,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=grad_output.device, ) + dgrad_list = torch.split(dgrad, ctx.m_splits) grouped_gemm( weights, grad_output_mats, - torch.split(dgrad, ctx.m_splits), + dgrad_list, ctx.activation_dtype, get_multi_stream_cublas_workspace(), layout="NN", @@ -469,11 +479,18 @@ def handle_custom_ddp_from_mcore(w, wgrad): handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list) ] + if ctx.requires_dgrad: + if ctx.split_inp: + dgrad_holder = [dgrad] + else: + dgrad_holder = dgrad_list + else: + dgrad_holder = [None] * ctx.num_gemms + if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) return ( - dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, None, # m_splits None, # use_bias None, # is_first_microbatch @@ -489,8 +506,10 @@ def handle_custom_ddp_from_mcore(w, wgrad): None, # activation_dtype None, # parallel_mode None, # is_grad_enabled - None, # weights_fp8 + None, # split_inp + *dgrad_holder, *wgrad_list, + *([None] * ctx.num_gemms), # weights_fp8 *grad_biases, ) @@ -696,7 +715,7 @@ def reset_parameters(self, defer_init=False): @no_torch_dynamo() def forward( self, - inp: torch.Tensor, + inp: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], m_splits: List[int], is_first_microbatch: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: @@ -705,8 +724,11 @@ def forward( Parameters ---------- - inp : torch.Tensor - Input tensor. + inp : {torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]} + Input tensor or collection of input tensor. + + * `inp` will be split base `m_splits` if type of `inp` is torch.tensor. + * assume `inp` is split base `m_splits` if `inp` is collection of torch.Tensor. m_splits : List[int] List of integers representing the split of the input tensor. is_first_microbatch : {True, False, None}, default = None @@ -723,17 +745,27 @@ def forward( first microbatch (since it is the first gradient being produced) """ - assert not isinstance( - inp, Float8Tensor - ), "GroupedLinear doesn't support input tensor in FP8." + if isinstance(inp, torch.Tensor): + split_inp = True + inputmats = [inp] + else: + # inp is a list or tuple + split_inp = False + inputmats = inp + + for inp in inputmats: + assert not isinstance( + inp, Float8Tensor + ), "GroupedLinear doesn't support input tensor in FP8." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() if skip_fp8_weight_update is not None: is_first_microbatch = False - with self.prepare_forward(inp, is_first_microbatch, num_gemms=self.num_gemms) as inp: - + with self.prepare_grouped_forward( + list(inputmats), is_first_microbatch, num_gemms=self.num_gemms + ) as inputmats: weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] if not self.fp8: @@ -785,7 +817,6 @@ def forward( linear_fn = _GroupedLinear.forward args = [None] args += ( - inp, m_splits, self.apply_bias and not self.gemm_bias_unfused_add, is_first_microbatch, @@ -801,8 +832,10 @@ def forward( self.activation_dtype, self.parallel_mode, torch.is_grad_enabled(), - weight_tensors_fp8, + split_inp, + *inputmats, *weight_tensors, + *weight_tensors_fp8, *bias_tensors, ) out = linear_fn(*args)