Skip to content

Commit

Permalink
[C/PyTorch] RoPE fixes and minor improvements for fused attention (#453)
Browse files Browse the repository at this point in the history
* add support for h2d/2hd in 8.9.6

Signed-off-by: Charlene Yang <[email protected]>

* cull unit tests in fused_attn.py and add skipif for layout tests

Signed-off-by: Charlene Yang <[email protected]>

* add workopt=1 flag for dpa tests

Signed-off-by: Charlene Yang <[email protected]>

* update support table for arbi_seqlen backend

Signed-off-by: Charlene Yang <[email protected]>

* fix rotary position embedding and add unit tests accordingly

Signed-off-by: Charlene Yang <[email protected]>

* further cut down unit tests for CI efficiency

Signed-off-by: Charlene Yang <[email protected]>

* fix lint

Signed-off-by: Charlene Yang <[email protected]>

* remove einops dependency

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa authored Oct 9, 2023
1 parent 79f5fac commit 92d1ba0
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 32 deletions.
68 changes: 40 additions & 28 deletions tests/pytorch/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch import TransformerLayer
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import DotProductAttention, RotaryPositionEmbedding
import os

from pkg_resources import packaging
Expand All @@ -21,6 +21,8 @@
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
_cudnn_version = [int(i) for i in os.environ['CUDNN_VERSION'].split('.')]


class ModelConfig:
def __init__(
Expand All @@ -45,22 +47,26 @@ def __init__(
"test5": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"),
}

if os.getenv('NVTE_ADDITIONAL_TESTS', '0') == '1':
model_configs["test6"] = ModelConfig(1, 1024, 16, 64, 512, 0.0, "causal")
model_configs["test7"] = ModelConfig(1, 2048, 16, 128, 512, 0.0, "causal")
model_configs["test8"] = ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal")
model_configs["test9"] = ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask")

param_types = [torch.float16]
if torch.cuda.is_bf16_supported():
param_types.append(torch.bfloat16)

batch_sizes = [1, 2] # add more if needed, e.g. 32
batch_sizes = [1, 32]

model_configs_lean = {
"test6": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"),
"test7": ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal"),
}

param_types_lean = [torch.bfloat16]

batch_sizes_lean = [2]


@pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("bs", batch_sizes_lean)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("ckpt_attn", [True, False])
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
Expand All @@ -69,7 +75,6 @@ def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type):
FlashAttention, FusedAttention and UnfusedDotProductAttention"""

config = model_configs[model]

if bias_type == "no_bias":
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, bs, config, "FlashAttention", ckpt_attn, bias_type)
Expand All @@ -94,6 +99,7 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type)
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1"

inp = torch.randn(
config.seq_len, bs, 3, config.num_attention_heads, config.head_dim,
Expand Down Expand Up @@ -150,15 +156,17 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type)

@pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.skipif(
_cudnn_version >= [8,9,5], reason="cuDNN 8.9.5+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("bs", batch_sizes_lean)
@pytest.mark.parametrize("model", model_configs_lean.keys())
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
def test_dpa_qkv_layout(dtype, bs, model, workspace_opt, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""

config = model_configs[model]
config = model_configs_lean[model]

flash_attn_fwd, flash_attn_bwd = _run_dpa_qkv_layout(
dtype, bs, config, "FlashAttention", qkv_layout, workspace_opt)
Expand Down Expand Up @@ -188,7 +196,6 @@ def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt):
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"


dim_to_num = {'b': bs,
's': config.seq_len,
'h': config.num_attention_heads,
Expand Down Expand Up @@ -269,23 +276,23 @@ def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt):
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("model", model_configs_lean.keys())
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
@pytest.mark.parametrize("fused_qkv_params", [True, False])
def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type, fused_qkv_params):
@pytest.mark.parametrize("RoPE", [True, False])
def test_transformer_layer(dtype, bs, model, bias_type, fused_qkv_params, RoPE):
"""Test TransformerLayer module when its DotProductAttention is enabled with
FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""

config = model_configs[model]
config = model_configs_lean[model]

if bias_type == "no_bias":
flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FlashAttention", ckpt_attn, bias_type, fused_qkv_params)
dtype, bs, config, "FlashAttention", bias_type, fused_qkv_params, RoPE)
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "FusedAttention", ckpt_attn, bias_type, fused_qkv_params)
dtype, bs, config, "FusedAttention", bias_type, fused_qkv_params, RoPE)
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type, fused_qkv_params)
dtype, bs, config, "UnfusedDotProductAttention", bias_type, fused_qkv_params, RoPE)

atol, rtol = (5e-1, 5e-2)
if bias_type == "no_bias":
Expand All @@ -294,7 +301,7 @@ def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type, fused_qkv_par
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)

def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fused_qkv_params):
def _run_transformer_layer(dtype, bs, config, backend, bias_type, fused_qkv_params, RoPE):

reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
Expand Down Expand Up @@ -327,6 +334,11 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fus
else:
bias = None

rotary_pos_emb = None
if RoPE:
PE = RotaryPositionEmbedding(dim=config.head_dim)
rotary_pos_emb = PE(config.seq_len).cuda().to(dtype=dtype)

block = (
TransformerLayer(
config.hidden_size,
Expand Down Expand Up @@ -365,7 +377,7 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fus
num_iters = 5
for i in range(num_iters):
op = block(inp, self_attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=bias_type,
core_attention_bias=bias)
loss = op.sum()
Expand All @@ -376,14 +388,14 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fus
@pytest.mark.skipif(not _flash_attn_2_available, reason="FA2.0 is not available")
@pytest.mark.skipif(
get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("bs", batch_sizes_lean)
@pytest.mark.parametrize("model", model_configs_lean.keys())
def test_transformer_layer_gqa(dtype, bs, model):
"""Test TransformerLayer module when its DotProductAttention is enabled with
FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""

config = model_configs[model]
config = model_configs_lean[model]
def find_factors(x):
f = []
for i in range(1, x + 1):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1338,12 +1338,13 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen,
}
use_workspace_opt = transformer_engine::getenv<bool>(
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT", use_workspace_opt);
// will not be needed in cuDNN 8.9.6
#if (CUDNN_VERSION < 8906)
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if ((layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD)
|| (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D)) {
use_workspace_opt = false;
}
#endif
}
#endif

Expand Down Expand Up @@ -1485,12 +1486,13 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t m
}
use_workspace_opt = transformer_engine::getenv<bool>(
"NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT", use_workspace_opt);
// will not be needed in cuDNN 8.9.6
#if (CUDNN_VERSION < 8906)
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if ((layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD)
|| (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D)) {
use_workspace_opt = false;
}
#endif
}
#endif

Expand Down
64 changes: 62 additions & 2 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,64 @@ def flash_attn_forward_func_with_cp(q, k, v, cu_seqlens_q, cu_seqlens_k,
return out


class RotaryPositionEmbedding(torch.nn.Module):
"""
Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
"""
def __init__(
self,
dim: int,
seq_len_interpolation_factor: Optional[int] = None,
pretrained_max_position_embeddings: Optional[int] = None,
):
"""
Parameters
----------
dim: int
rotary embedding dimension
seq_len_interpolation_factor: int
if not None, discrete positions will be interpolated by this factor via the trick in
https://arxiv.org/abs/2306.15595
pretrained_max_position_embeddings: int
pre-trained max_position_embeddings before position interpolation
"""
super().__init__()
self.seq_len_interpolation_factor = seq_len_interpolation_factor
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings

def forward(self, max_seq_len: int, offset: int = 0):
"""
Create rotary position embedding frequencies
Parameters
----------
max_seq_len: int
sequence length of a sample
offset: int, default = 0
fixed offset for freqencies
"""
seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
seq = seq.type_as(self.inv_freq)

if (self.pretrained_max_position_embeddings is not None
and self.seq_len_interpolation_factor is not None):
if (max_seq_len >
self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor):
# dynamic linear scaling (length > position we have learned)
seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
else:
# fixed linear scaling
seq *= 1 / self.seq_len_interpolation_factor

freqs = torch.einsum('i , j -> i j', seq, self.inv_freq)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
emb = torch.cat((freqs, freqs), dim=-1)
# emb [seq_length, .., dim]
return emb.reshape(emb.size(0), 1, 1, emb.size(1))

def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""
change sign so the last dimension becomes [-odd, +even]
Expand Down Expand Up @@ -1488,9 +1546,10 @@ class FusedAttention(torch.nn.Module):
| qkv_layout | | |
| - qkv | qkv_interleaved | qkv_interleaved |
| - (q,kv) | kv_interleaved | |
| - (q,k,v) | sb3hd, bs3hd | sb3hd, bs3hd |
| - (q,k,v) | sb3hd, bs3hd | sb3hd, bs3hd, sbh3d, bsh3d |
| | sbhd_sb2hd, bshd_bs2hd | sbhd_sb2hd, bshd_bs2hd |
| | bshd_bshd_bshd | sbhd_sbhd_sbhd, bshd_bshd_bshd |
| | bshd_bshd_bshd | sbhd_sbh2d, bshd_bsh2d |
| | | sbhd_sbhd_sbhd, bshd_bshd_bshd |
| mask_type | causal/no_mask | causal |
| bias_type | no_bias/post_scale_bias | no_bias |
| dropout | yes | yes |
Expand Down Expand Up @@ -2736,6 +2795,7 @@ def forward(
q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb)
value_layer = value_layer.contiguous()

context_layer = self.core_attention(
query_layer,
Expand Down

0 comments on commit 92d1ba0

Please sign in to comment.