From 467b39a3aed56987c4b0ec60eef1935b994fd7da Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 20 Aug 2024 20:56:19 -0700 Subject: [PATCH] [PyTorch] Add support for padding mask in `UnfusedDotProductAttention` (#1073) * add support for padding in UnfusedDPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add support for padding_causal/_bottom_right Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix padding_causal/_bottom_right Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * need to test max512 backend Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix mask logic in unfused Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use actual_seqlen for alibi/causal_bottom_right padding Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes and convert causal to causal_bottom_right for inference Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use causal in kv cache inference test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * simplify get_alibi logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * simplify the non-padding path for get_alibi Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * avoid batch_size loop in generating padding_causal/_bottom_right masks Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 6 +- .../common/fused_attn/fused_attn.cpp | 5 +- transformer_engine/pytorch/attention.py | 172 +++++++++++++----- transformer_engine/pytorch/softmax.py | 39 ++-- transformer_engine/pytorch/transformer.py | 2 +- 5 files changed, 155 insertions(+), 69 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a2023f539a..85cd4fc256 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1655,8 +1655,8 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ffn_hidden_size=4 * D, num_attention_heads=H, attn_input_format=input_format, - self_attn_mask_type="causal_bottom_right", - enc_dec_attn_mask_type="causal_bottom_right", + self_attn_mask_type="causal", + enc_dec_attn_mask_type="causal", layer_number=layer_number, attention_dropout=0.0, params_dtype=dtype, @@ -1670,7 +1670,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, qkv_format=input_format, layer_number=layer_number, attention_dropout=0.0, - attn_mask_type="causal_bottom_right", + attn_mask_type="causal", params_dtype=dtype, ) .cuda() diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 0fe62f8cb4..70f1fa409f 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -142,7 +142,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || (bias_type == NVTE_Bias_Type::NVTE_ALIBI && attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && sm_arch_ >= 90) || + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + sm_arch_ >= 90) || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) || ((cudnn_runtime_version >= 90000) && (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) && diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 8fac4778c8..6a46d6c3c1 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -472,19 +472,25 @@ def get_attention_backend( use_fused_attention = False # Filter: Attention mask - # attn_mask_type | supported backends - # ------------------------------------------------------------------- - # no_mask | All - # padding | FlashAttention, FusedAttention - # causal | - # self-attention | All - # cross-attention | FusedAttention - # padding_causal | - # self-attention | FlashAttention, FusedAttention - # cross-attention | FusedAttention - # causal_bottom_right | All - # padding_causal_bottom_right | FlashAttention, FusedAttention - # arbitrary | UnfusedDotProductAttention + # attn_mask_type | attention_mask | supported backends + # ---------------------------------------------------------------------------------------- + # no_mask | None | All + # padding | | All + # self-attention | One tensor in shape [b, 1, 1, sq] | + # cross-attention | Tuple of two tensors in shapes | + # | [b, 1, 1, sq] and [b, 1, 1, skv] | + # causal | None | + # self-attention | | All + # cross-attention | | FusedAttention, UnfusedDotProductAttention + # padding_causal | Same as "padding" | + # self-attention | | All + # cross-attention | | FusedAttention, UnfusedDotProductAttention + # causal_bottom_right | None | All + # padding_causal_bottom_right | Same as "padding" | + # self-attention | | All + # cross-attention | | FlashAttention, UnfusedDotProductAttention + # arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention + # | [b, h, sq, skv] | if attn_mask_type == "arbitrary": if use_flash_attention: logger.debug("Disabling FlashAttention for arbitrary mask") @@ -492,9 +498,6 @@ def get_attention_backend( if use_fused_attention: logger.debug("Disabling FusedAttention for arbitrary mask") use_fused_attention = False - if use_unfused_attention and "padding" in attn_mask_type: - logger.debug("Disabling UnfusedDotProductAttention for %s mask", attn_mask_type) - use_unfused_attention = False if ( use_flash_attention and _flash_attn_2_1_plus @@ -780,7 +783,7 @@ def get_attention_backend( class InferenceParams: # pylint: disable=too-few-public-methods """ Inference parameters that are passed to the main model in order - to efficienly calculate and store the context during inference. + to efficiently calculate and store the context during inference. Parameters ---------- @@ -886,6 +889,8 @@ def get_alibi( num_heads: int, max_seqlen_q: int, max_seqlen_kv: int, + actual_seqlens_q: Optional[torch.Tensor] = None, + actual_seqlens_kv: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, bias_dtype: Optional[torch.dtype] = None, bottom_right_alignment: bool = True, @@ -899,6 +904,10 @@ def get_alibi( Maximum sequence length for queries. max_seqlen_kv: int Maximum sequence length for keys and values. + actual_seqlens_q: Optional[torch.Tensor], default = `None` + Actual sequence lengths for queries, in shape [batch_size]. + actual_seqlens_kv: Optional[torch.Tensor], default = `None` + Actual sequence lengths for keys and values, in shape [batch_size]. alibi_slopes: Optional[torch.Tensor], default = `None` Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads]. bias_dtype: Optional[torch.dtype], default = `None` @@ -912,10 +921,12 @@ def get_alibi( alibi_slopes: torch.Tensor ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads]. alibi_bias: torch.Tensor - ALiBi bias in FP32 or `bias_dtype`. If `alibi_slopes` is in [num_heads] shape, - then `alibi_bias` is in [1, num_heads, max_seqlen_q, max_seqlen_kv], and if - `alibi_slopes` is in [batch_size, num_heads], then the bias is in - [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. + ALiBi bias in FP32 or `bias_dtype`. Its shape is + (1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape, + and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or + (2) [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in + [batch_size, num_heads] shape, or, if `alibi_slopes` is in [num_heads] shape and + `actual_seqlens_q` and `actual_seqlens_kv` are not `None`. """ global _alibi_cache if _alibi_cache["_alibi_slopes_require_update"]: @@ -941,17 +952,23 @@ def get_alibi( slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1]) if _alibi_cache["_alibi_slopes"].dim() == 2: slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1]) - if bottom_right_alignment: - bias = torch.arange(1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view( - 1, 1, 1, max_seqlen_kv - ) - else: - bias = torch.arange( - 1 - max_seqlen_q, max_seqlen_kv - max_seqlen_q + 1, dtype=torch.int32, device="cuda" - ).view(1, 1, 1, max_seqlen_kv) - bias = bias - torch.arange(1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view( + bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( 1, 1, max_seqlen_q, 1 + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( + 1, 1, 1, max_seqlen_kv ) + if actual_seqlens_q is None and actual_seqlens_kv is None: + if bottom_right_alignment: + bias = bias + max_seqlen_kv - max_seqlen_q + elif actual_seqlens_q is not None and actual_seqlens_kv is not None: + batch_size = actual_seqlens_q.shape[0] + bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + if bottom_right_alignment: + bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1) + else: + assert ( + False + ), "actual_seqlens_q and actual_seqlens_kv need to be both None or torch.Tensors!" bias = bias.abs().mul(-1) bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape) _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv @@ -3705,6 +3722,7 @@ class UnfusedDotProductAttention(torch.nn.Module): def __init__( self, softmax_scale: float, + attention_type: str = "self", attention_dropout: float = 0.0, attention_dropout_ctx: Optional[Callable] = nullcontext, layer_number: Optional[int] = None, @@ -3712,6 +3730,7 @@ def __init__( super().__init__() self.softmax_scale = softmax_scale + self.attention_type = attention_type self.attention_dropout_ctx = attention_dropout_ctx self.layer_number = layer_number @@ -3751,6 +3770,58 @@ def forward( query_layer, key_layer, value_layer = [ x.transpose(0, 1) for x in [query_layer, key_layer, value_layer] ] + batch_size, max_seqlen_q, max_seqlen_kv = ( + query_layer.shape[1], + query_layer.shape[0], + key_layer.shape[0], + ) + if "padding" in attn_mask_type: + if self.attention_type == "self": + assert attention_mask.shape == ( + batch_size, + 1, + 1, + max_seqlen_q, + ), "attention_mask should be a single tensor with [b, 1, 1, sq] shape!" + attention_mask = torch.logical_or( + attention_mask.squeeze(1).unsqueeze(3), attention_mask + ) + else: + assert ( + len(attention_mask) == 2 + and attention_mask[0].shape == (batch_size, 1, 1, max_seqlen_q) + and attention_mask[1].shape == (batch_size, 1, 1, max_seqlen_kv) + ), ( + "attention_mask should be a tuple of two tensors with shapes " + "[b, 1, 1, sq] and [b, 1, 1, skv]!" + ) + attention_mask = torch.logical_or( + attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1] + ) + mask = attention_mask.squeeze(1).logical_not() + actual_seqlens_q = mask[:, :, 0].sum(dim=1) + actual_seqlens_kv = mask[:, 0, :].sum(dim=1) + mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( + 1, 1, max_seqlen_q, 1 + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( + 1, 1, 1, max_seqlen_kv + ) + if attn_mask_type == "padding_causal": + attention_mask = torch.logical_or( + torch.where(mask.view(1, 1, max_seqlen_q, max_seqlen_kv) < 0, 1, 0), + attention_mask, + ) + if attn_mask_type == "padding_causal_bottom_right": + attention_mask = torch.logical_or( + torch.where( + mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1) + < 0, + 1, + 0, + ), + attention_mask, + ) batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 @@ -3805,7 +3876,7 @@ def forward( key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=scale, - ) + ).view(*output_size) elif core_attention_bias_type == "pre_scale_bias": assert core_attention_bias is not None, "core_attention_bias should not be None!" @@ -3813,10 +3884,7 @@ def forward( query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] ) - matmul_result = ( - matmul_result.view(output_size[0], output_size[1], output_size[2], output_size[3]) - + core_attention_bias - ).view(-1, output_size[2], output_size[3]) + matmul_result = matmul_result.view(*output_size) + core_attention_bias matmul_result *= scale elif core_attention_bias_type in ["post_scale_bias", "alibi"]: @@ -3827,6 +3895,8 @@ def forward( output_size[1], output_size[2], output_size[3], + actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None, + actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None, alibi_slopes=alibi_slopes, bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], ) @@ -3837,26 +3907,21 @@ def forward( beta=0.0, alpha=scale, ) - matmul_result = ( - ( - matmul_result.view( - output_size[0], output_size[1], output_size[2], output_size[3] - ) - + core_attention_bias - ) - .view(-1, output_size[2], output_size[3]) - .to(dtype=query_layer.dtype) + matmul_result = (matmul_result.view(*output_size) + core_attention_bias).to( + dtype=query_layer.dtype ) - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - # attention scores and attention mask [b, np, sq, sk] softmax_scale = self.layer_number if apply_qk_layer_scaling else None attention_probs = self.scale_mask_softmax( - attention_scores, attention_mask, attn_mask_type, softmax_scale + matmul_result, attention_mask, attn_mask_type, softmax_scale ) + # mask out the pad positions in softmax results, mostly for the rows (pad tokens from q) + # the columns (pad tokens from k) are already zeroed out during softmax + if "padding" in attn_mask_type: + attention_probs = attention_probs.masked_fill(attention_mask, 0) + # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with self.attention_dropout_ctx(): @@ -6232,7 +6297,10 @@ def __init__( ) self.unfused_attention = UnfusedDotProductAttention( - softmax_scale, **attn_kwargs, layer_number=layer_number + softmax_scale, + attention_type=attention_type, + **attn_kwargs, + layer_number=layer_number, ) def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument @@ -6522,6 +6590,11 @@ def forward( if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" + # convert causal to causal_bottom_right in inference when KV-caching is in use + # so users can run with the same attn_mask_type for training and inference + if attn_mask_type in ["causal", "padding_causal"]: + attn_mask_type = attn_mask_type + "_bottom_right" + if qkv_format == "bshd": key_layer = key_layer.transpose(0, 1) value_layer = value_layer.transpose(0, 1) @@ -6628,7 +6701,6 @@ def forward( attention_mask is not None ), "Please provide attention_mask for padding!" if self.attention_type == "self": - assert max_seqlen_q == max_seqlen_kv cu_seqlens_q = get_cu_seqlens(attention_mask) cu_seqlens_kv = cu_seqlens_q else: diff --git a/transformer_engine/pytorch/softmax.py b/transformer_engine/pytorch/softmax.py index 3632d2f367..4fb8a28857 100644 --- a/transformer_engine/pytorch/softmax.py +++ b/transformer_engine/pytorch/softmax.py @@ -329,25 +329,22 @@ def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: return False # sk must be 16 ~ 16384 if sk % 8 != 0: return False # sk must be divisor of 8 - if self.attn_mask_type == "arbitrary": - return False # Custom masks not supported - + if sq == 1: + return False # sq must be > 1 if self.attn_mask_type == "causal" and sq != sk: return False # Fused causal kernel only support causal_bottom_right if ( sq % 4 == 0 # sq must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4 - and self.attn_mask_type != "arbitrary" # Custom masks not supported ): batch_per_block = self.get_batch_per_block(int(sk)) - - if self.attn_mask_type == "padding": + if "padding" in self.attn_mask_type or self.attn_mask_type == "arbitrary": if ( mask is not None and sq % batch_per_block == 0 - and mask.shape[-2] == sq - and mask.shape[-1] == sk + and mask.shape[0] in [1, b] + and mask.shape[1:] == (1, sq, sk) ): return True else: @@ -358,13 +355,21 @@ def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: def forward_fused_softmax( self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None ) -> torch.Tensor: - """Fused masked softmax kernel""" + """ + Fused masked softmax path. + attn_mask_type | module + ----------------------------------------------------------------------------------------- + no_mask | ScaledSoftmax + causal (self-attention), causal_bottom_right | ScaledAlignedCausalMaskedSoftmax + padding, padding_causal, padding_causal_bottom_right | ScaledMaskedSoftmax + arbitrary ([1, 1, sq, sk] or [b, 1, sq, sk]) | ScaledMaskedSoftmax + """ scale = 1.0 if scale is None else scale - if "causal" in self.attn_mask_type: + if self.attn_mask_type in ["causal", "causal_bottom_right"]: return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale) - # input is 4D tensor (b, np, sq, sk) + # input is 4D tensor (1, 1, sq, sk) or (b, 1, sq, sk) if mask is not None and self.attn_mask_type != "no_mask": return ScaledMaskedSoftmax.apply(inp, mask, scale) return ScaledSoftmax.apply(inp, scale) @@ -379,13 +384,19 @@ def forward_torch_softmax( if scale is not None: inp = inp * scale - if "causal" in self.attn_mask_type: + if self.attn_mask_type in ["causal", "causal_bottom_right"]: seq_len_q, seq_len_k = inp.size(2), inp.size(3) if is_in_onnx_export_mode() and self.kvcache_max_seq > 0: assert self.kvcache_max_seq >= seq_len_k - mask = _get_onnx_export_causal_mask(seq_len_q, seq_len_k, self.onnx_causal_mask) + causal_mask = _get_onnx_export_causal_mask( + seq_len_q, seq_len_k, self.onnx_causal_mask + ) + else: + causal_mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k) + if mask is None: + mask = causal_mask else: - mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k) + mask = torch.logical_or(mask, causal_mask) mask_output = inp if mask is not None and self.attn_mask_type != "no_mask": diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 4cbee3d628..bd6e27594d 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -624,7 +624,7 @@ def forward( Whether to set output tensors to 0 or not before use. inference_params: InferenceParams, default = None Inference parameters that are passed to the main model in order - to efficienly calculate and store the context during inference. + to efficiently calculate and store the context during inference. """ if self_attn_mask_type is None: