Skip to content

Commit

Permalink
[PyTorch] Add support for padding mask in UnfusedDotProductAttention (
Browse files Browse the repository at this point in the history
#1073)

* add support for padding in UnfusedDPA

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add support for padding_causal/_bottom_right

Signed-off-by: Charlene Yang <[email protected]>

* fix padding_causal/_bottom_right

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* need to test max512 backend

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert last commit

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix mask logic in unfused

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use actual_seqlen for alibi/causal_bottom_right padding

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fixes and convert causal to causal_bottom_right for inference

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use causal in kv cache inference test

Signed-off-by: Charlene Yang <[email protected]>

* simplify get_alibi logic

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* simplify the non-padding path for get_alibi

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* avoid batch_size loop in generating padding_causal/_bottom_right masks

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
cyanguwa and pre-commit-ci[bot] committed Aug 21, 2024
1 parent 26c8fcc commit 467b39a
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 69 deletions.
6 changes: 3 additions & 3 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,8 +1655,8 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
ffn_hidden_size=4 * D,
num_attention_heads=H,
attn_input_format=input_format,
self_attn_mask_type="causal_bottom_right",
enc_dec_attn_mask_type="causal_bottom_right",
self_attn_mask_type="causal",
enc_dec_attn_mask_type="causal",
layer_number=layer_number,
attention_dropout=0.0,
params_dtype=dtype,
Expand All @@ -1670,7 +1670,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
qkv_format=input_format,
layer_number=layer_number,
attention_dropout=0.0,
attn_mask_type="causal_bottom_right",
attn_mask_type="causal",
params_dtype=dtype,
)
.cuda()
Expand Down
5 changes: 4 additions & 1 deletion transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS ||
(bias_type == NVTE_Bias_Type::NVTE_ALIBI &&
attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && sm_arch_ >= 90) ||
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK &&
sm_arch_ >= 90) ||
(bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) ||
((cudnn_runtime_version >= 90000) &&
(bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) &&
Expand Down
172 changes: 122 additions & 50 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,29 +472,32 @@ def get_attention_backend(
use_fused_attention = False

# Filter: Attention mask
# attn_mask_type | supported backends
# -------------------------------------------------------------------
# no_mask | All
# padding | FlashAttention, FusedAttention
# causal |
# self-attention | All
# cross-attention | FusedAttention
# padding_causal |
# self-attention | FlashAttention, FusedAttention
# cross-attention | FusedAttention
# causal_bottom_right | All
# padding_causal_bottom_right | FlashAttention, FusedAttention
# arbitrary | UnfusedDotProductAttention
# attn_mask_type | attention_mask | supported backends
# ----------------------------------------------------------------------------------------
# no_mask | None | All
# padding | | All
# self-attention | One tensor in shape [b, 1, 1, sq] |
# cross-attention | Tuple of two tensors in shapes |
# | [b, 1, 1, sq] and [b, 1, 1, skv] |
# causal | None |
# self-attention | | All
# cross-attention | | FusedAttention, UnfusedDotProductAttention
# padding_causal | Same as "padding" |
# self-attention | | All
# cross-attention | | FusedAttention, UnfusedDotProductAttention
# causal_bottom_right | None | All
# padding_causal_bottom_right | Same as "padding" |
# self-attention | | All
# cross-attention | | FlashAttention, UnfusedDotProductAttention
# arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention
# | [b, h, sq, skv] |
if attn_mask_type == "arbitrary":
if use_flash_attention:
logger.debug("Disabling FlashAttention for arbitrary mask")
use_flash_attention = False
if use_fused_attention:
logger.debug("Disabling FusedAttention for arbitrary mask")
use_fused_attention = False
if use_unfused_attention and "padding" in attn_mask_type:
logger.debug("Disabling UnfusedDotProductAttention for %s mask", attn_mask_type)
use_unfused_attention = False
if (
use_flash_attention
and _flash_attn_2_1_plus
Expand Down Expand Up @@ -780,7 +783,7 @@ def get_attention_backend(
class InferenceParams: # pylint: disable=too-few-public-methods
"""
Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference.
to efficiently calculate and store the context during inference.
Parameters
----------
Expand Down Expand Up @@ -886,6 +889,8 @@ def get_alibi(
num_heads: int,
max_seqlen_q: int,
max_seqlen_kv: int,
actual_seqlens_q: Optional[torch.Tensor] = None,
actual_seqlens_kv: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
bias_dtype: Optional[torch.dtype] = None,
bottom_right_alignment: bool = True,
Expand All @@ -899,6 +904,10 @@ def get_alibi(
Maximum sequence length for queries.
max_seqlen_kv: int
Maximum sequence length for keys and values.
actual_seqlens_q: Optional[torch.Tensor], default = `None`
Actual sequence lengths for queries, in shape [batch_size].
actual_seqlens_kv: Optional[torch.Tensor], default = `None`
Actual sequence lengths for keys and values, in shape [batch_size].
alibi_slopes: Optional[torch.Tensor], default = `None`
Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
bias_dtype: Optional[torch.dtype], default = `None`
Expand All @@ -912,10 +921,12 @@ def get_alibi(
alibi_slopes: torch.Tensor
ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
alibi_bias: torch.Tensor
ALiBi bias in FP32 or `bias_dtype`. If `alibi_slopes` is in [num_heads] shape,
then `alibi_bias` is in [1, num_heads, max_seqlen_q, max_seqlen_kv], and if
`alibi_slopes` is in [batch_size, num_heads], then the bias is in
[batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
ALiBi bias in FP32 or `bias_dtype`. Its shape is
(1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape,
and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or
(2) [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in
[batch_size, num_heads] shape, or, if `alibi_slopes` is in [num_heads] shape and
`actual_seqlens_q` and `actual_seqlens_kv` are not `None`.
"""
global _alibi_cache
if _alibi_cache["_alibi_slopes_require_update"]:
Expand All @@ -941,17 +952,23 @@ def get_alibi(
slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
if _alibi_cache["_alibi_slopes"].dim() == 2:
slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
if bottom_right_alignment:
bias = torch.arange(1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view(
1, 1, 1, max_seqlen_kv
)
else:
bias = torch.arange(
1 - max_seqlen_q, max_seqlen_kv - max_seqlen_q + 1, dtype=torch.int32, device="cuda"
).view(1, 1, 1, max_seqlen_kv)
bias = bias - torch.arange(1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view(
bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
1, 1, 1, max_seqlen_kv
)
if actual_seqlens_q is None and actual_seqlens_kv is None:
if bottom_right_alignment:
bias = bias + max_seqlen_kv - max_seqlen_q
elif actual_seqlens_q is not None and actual_seqlens_kv is not None:
batch_size = actual_seqlens_q.shape[0]
bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
if bottom_right_alignment:
bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
else:
assert (
False
), "actual_seqlens_q and actual_seqlens_kv need to be both None or torch.Tensors!"
bias = bias.abs().mul(-1)
bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
_alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
Expand Down Expand Up @@ -3705,13 +3722,15 @@ class UnfusedDotProductAttention(torch.nn.Module):
def __init__(
self,
softmax_scale: float,
attention_type: str = "self",
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None,
) -> None:
super().__init__()

self.softmax_scale = softmax_scale
self.attention_type = attention_type
self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number

Expand Down Expand Up @@ -3751,6 +3770,58 @@ def forward(
query_layer, key_layer, value_layer = [
x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
]
batch_size, max_seqlen_q, max_seqlen_kv = (
query_layer.shape[1],
query_layer.shape[0],
key_layer.shape[0],
)
if "padding" in attn_mask_type:
if self.attention_type == "self":
assert attention_mask.shape == (
batch_size,
1,
1,
max_seqlen_q,
), "attention_mask should be a single tensor with [b, 1, 1, sq] shape!"
attention_mask = torch.logical_or(
attention_mask.squeeze(1).unsqueeze(3), attention_mask
)
else:
assert (
len(attention_mask) == 2
and attention_mask[0].shape == (batch_size, 1, 1, max_seqlen_q)
and attention_mask[1].shape == (batch_size, 1, 1, max_seqlen_kv)
), (
"attention_mask should be a tuple of two tensors with shapes "
"[b, 1, 1, sq] and [b, 1, 1, skv]!"
)
attention_mask = torch.logical_or(
attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1]
)
mask = attention_mask.squeeze(1).logical_not()
actual_seqlens_q = mask[:, :, 0].sum(dim=1)
actual_seqlens_kv = mask[:, 0, :].sum(dim=1)
mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
1, 1, 1, max_seqlen_kv
)
if attn_mask_type == "padding_causal":
attention_mask = torch.logical_or(
torch.where(mask.view(1, 1, max_seqlen_q, max_seqlen_kv) < 0, 1, 0),
attention_mask,
)
if attn_mask_type == "padding_causal_bottom_right":
attention_mask = torch.logical_or(
torch.where(
mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
+ (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
< 0,
1,
0,
),
attention_mask,
)

batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
Expand Down Expand Up @@ -3805,18 +3876,15 @@ def forward(
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=scale,
)
).view(*output_size)

elif core_attention_bias_type == "pre_scale_bias":
assert core_attention_bias is not None, "core_attention_bias should not be None!"
matmul_result = torch.bmm(
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
)
matmul_result = (
matmul_result.view(output_size[0], output_size[1], output_size[2], output_size[3])
+ core_attention_bias
).view(-1, output_size[2], output_size[3])
matmul_result = matmul_result.view(*output_size) + core_attention_bias
matmul_result *= scale

elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
Expand All @@ -3827,6 +3895,8 @@ def forward(
output_size[1],
output_size[2],
output_size[3],
actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None,
actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None,
alibi_slopes=alibi_slopes,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
)
Expand All @@ -3837,26 +3907,21 @@ def forward(
beta=0.0,
alpha=scale,
)
matmul_result = (
(
matmul_result.view(
output_size[0], output_size[1], output_size[2], output_size[3]
)
+ core_attention_bias
)
.view(-1, output_size[2], output_size[3])
.to(dtype=query_layer.dtype)
matmul_result = (matmul_result.view(*output_size) + core_attention_bias).to(
dtype=query_layer.dtype
)

# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)

# attention scores and attention mask [b, np, sq, sk]
softmax_scale = self.layer_number if apply_qk_layer_scaling else None
attention_probs = self.scale_mask_softmax(
attention_scores, attention_mask, attn_mask_type, softmax_scale
matmul_result, attention_mask, attn_mask_type, softmax_scale
)

# mask out the pad positions in softmax results, mostly for the rows (pad tokens from q)
# the columns (pad tokens from k) are already zeroed out during softmax
if "padding" in attn_mask_type:
attention_probs = attention_probs.masked_fill(attention_mask, 0)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with self.attention_dropout_ctx():
Expand Down Expand Up @@ -6232,7 +6297,10 @@ def __init__(
)

self.unfused_attention = UnfusedDotProductAttention(
softmax_scale, **attn_kwargs, layer_number=layer_number
softmax_scale,
attention_type=attention_type,
**attn_kwargs,
layer_number=layer_number,
)

def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
Expand Down Expand Up @@ -6522,6 +6590,11 @@ def forward(
if inference_params is not None:
assert self.layer_number is not None, "Layer number must be set!"

# convert causal to causal_bottom_right in inference when KV-caching is in use
# so users can run with the same attn_mask_type for training and inference
if attn_mask_type in ["causal", "padding_causal"]:
attn_mask_type = attn_mask_type + "_bottom_right"

if qkv_format == "bshd":
key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)
Expand Down Expand Up @@ -6628,7 +6701,6 @@ def forward(
attention_mask is not None
), "Please provide attention_mask for padding!"
if self.attention_type == "self":
assert max_seqlen_q == max_seqlen_kv
cu_seqlens_q = get_cu_seqlens(attention_mask)
cu_seqlens_kv = cu_seqlens_q
else:
Expand Down
Loading

0 comments on commit 467b39a

Please sign in to comment.