diff --git a/tests/pytorch/test_fused_attn.py b/tests/pytorch/test_fused_attn.py index b3efc2d414..9ddf475250 100644 --- a/tests/pytorch/test_fused_attn.py +++ b/tests/pytorch/test_fused_attn.py @@ -12,7 +12,7 @@ ) from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch import TransformerLayer -from transformer_engine.pytorch.attention import DotProductAttention +from transformer_engine.pytorch.attention import DotProductAttention, RotaryPositionEmbedding import os from pkg_resources import packaging @@ -21,6 +21,8 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() _flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2") +_cudnn_version = [int(i) for i in os.environ['CUDNN_VERSION'].split('.')] + class ModelConfig: def __init__( @@ -45,22 +47,26 @@ def __init__( "test5": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"), } -if os.getenv('NVTE_ADDITIONAL_TESTS', '0') == '1': - model_configs["test6"] = ModelConfig(1, 1024, 16, 64, 512, 0.0, "causal") - model_configs["test7"] = ModelConfig(1, 2048, 16, 128, 512, 0.0, "causal") - model_configs["test8"] = ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal") - model_configs["test9"] = ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask") - param_types = [torch.float16] if torch.cuda.is_bf16_supported(): param_types.append(torch.bfloat16) -batch_sizes = [1, 2] # add more if needed, e.g. 32 +batch_sizes = [1, 32] + +model_configs_lean = { + "test6": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"), + "test7": ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal"), +} + +param_types_lean = [torch.bfloat16] + +batch_sizes_lean = [2] + @pytest.mark.skipif( get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.") @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("bs", batch_sizes_lean) @pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("ckpt_attn", [True, False]) @pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"]) @@ -69,7 +75,6 @@ def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type): FlashAttention, FusedAttention and UnfusedDotProductAttention""" config = model_configs[model] - if bias_type == "no_bias": flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( dtype, bs, config, "FlashAttention", ckpt_attn, bias_type) @@ -94,6 +99,7 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type) os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" inp = torch.randn( config.seq_len, bs, 3, config.num_attention_heads, config.head_dim, @@ -150,15 +156,17 @@ def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type) @pytest.mark.skipif( get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.") -@pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.skipif( + _cudnn_version >= [8,9,5], reason="cuDNN 8.9.5+ is required.") +@pytest.mark.parametrize("dtype", param_types_lean) +@pytest.mark.parametrize("bs", batch_sizes_lean) +@pytest.mark.parametrize("model", model_configs_lean.keys()) @pytest.mark.parametrize("workspace_opt", [True, False]) @pytest.mark.parametrize("qkv_layout", qkv_layouts) def test_dpa_qkv_layout(dtype, bs, model, workspace_opt, qkv_layout): """Test DotProductAttention module with different QKV layouts""" - config = model_configs[model] + config = model_configs_lean[model] flash_attn_fwd, flash_attn_bwd = _run_dpa_qkv_layout( dtype, bs, config, "FlashAttention", qkv_layout, workspace_opt) @@ -188,7 +196,6 @@ def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt): os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0" - dim_to_num = {'b': bs, 's': config.seq_len, 'h': config.num_attention_heads, @@ -269,23 +276,23 @@ def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt): get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.") @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) -@pytest.mark.parametrize("ckpt_attn", [False]) +@pytest.mark.parametrize("model", model_configs_lean.keys()) @pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"]) @pytest.mark.parametrize("fused_qkv_params", [True, False]) -def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type, fused_qkv_params): +@pytest.mark.parametrize("RoPE", [True, False]) +def test_transformer_layer(dtype, bs, model, bias_type, fused_qkv_params, RoPE): """Test TransformerLayer module when its DotProductAttention is enabled with FlashAttention, FusedAttention, or UnfusedDotProductAttention backend""" - config = model_configs[model] + config = model_configs_lean[model] if bias_type == "no_bias": flash_attn_fwd, flash_attn_bwd = _run_transformer_layer( - dtype, bs, config, "FlashAttention", ckpt_attn, bias_type, fused_qkv_params) + dtype, bs, config, "FlashAttention", bias_type, fused_qkv_params, RoPE) fused_attn_fwd, fused_attn_bwd = _run_transformer_layer( - dtype, bs, config, "FusedAttention", ckpt_attn, bias_type, fused_qkv_params) + dtype, bs, config, "FusedAttention", bias_type, fused_qkv_params, RoPE) unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer( - dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type, fused_qkv_params) + dtype, bs, config, "UnfusedDotProductAttention", bias_type, fused_qkv_params, RoPE) atol, rtol = (5e-1, 5e-2) if bias_type == "no_bias": @@ -294,7 +301,7 @@ def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type, fused_qkv_par torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol) torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol) -def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fused_qkv_params): +def _run_transformer_layer(dtype, bs, config, backend, bias_type, fused_qkv_params, RoPE): reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" @@ -327,6 +334,11 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fus else: bias = None + rotary_pos_emb = None + if RoPE: + PE = RotaryPositionEmbedding(dim=config.head_dim) + rotary_pos_emb = PE(config.seq_len).cuda().to(dtype=dtype) + block = ( TransformerLayer( config.hidden_size, @@ -365,7 +377,7 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fus num_iters = 5 for i in range(num_iters): op = block(inp, self_attn_mask_type=config.attn_mask_type, - checkpoint_core_attention=ckpt_attn, + rotary_pos_emb=rotary_pos_emb, core_attention_bias_type=bias_type, core_attention_bias=bias) loss = op.sum() @@ -376,14 +388,14 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type, fus @pytest.mark.skipif(not _flash_attn_2_available, reason="FA2.0 is not available") @pytest.mark.skipif( get_device_compute_capability() < 8.0, reason="Compute capability 8.0+ is required.") -@pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("dtype", param_types_lean) +@pytest.mark.parametrize("bs", batch_sizes_lean) +@pytest.mark.parametrize("model", model_configs_lean.keys()) def test_transformer_layer_gqa(dtype, bs, model): """Test TransformerLayer module when its DotProductAttention is enabled with FlashAttention, FusedAttention, or UnfusedDotProductAttention backend""" - config = model_configs[model] + config = model_configs_lean[model] def find_factors(x): f = [] for i in range(1, x + 1): diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index e2da13729b..dd4bf301a3 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1338,12 +1338,13 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, } use_workspace_opt = transformer_engine::getenv( "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT", use_workspace_opt); - // will not be needed in cuDNN 8.9.6 +#if (CUDNN_VERSION < 8906) NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); if ((layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) || (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D)) { use_workspace_opt = false; } +#endif } #endif @@ -1485,12 +1486,13 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t max_seqlen_q, size_t m } use_workspace_opt = transformer_engine::getenv( "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT", use_workspace_opt); - // will not be needed in cuDNN 8.9.6 +#if (CUDNN_VERSION < 8906) NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); if ((layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) || (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D)) { use_workspace_opt = false; } +#endif } #endif diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f9a8f25d5b..37903d8025 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -691,6 +691,64 @@ def flash_attn_forward_func_with_cp(q, k, v, cu_seqlens_q, cu_seqlens_k, return out +class RotaryPositionEmbedding(torch.nn.Module): + """ + Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864. + """ + def __init__( + self, + dim: int, + seq_len_interpolation_factor: Optional[int] = None, + pretrained_max_position_embeddings: Optional[int] = None, + ): + """ + Parameters + ---------- + dim: int + rotary embedding dimension + seq_len_interpolation_factor: int + if not None, discrete positions will be interpolated by this factor via the trick in + https://arxiv.org/abs/2306.15595 + pretrained_max_position_embeddings: int + pre-trained max_position_embeddings before position interpolation + """ + super().__init__() + self.seq_len_interpolation_factor = seq_len_interpolation_factor + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + self.pretrained_max_position_embeddings = pretrained_max_position_embeddings + + def forward(self, max_seq_len: int, offset: int = 0): + """ + Create rotary position embedding frequencies + + Parameters + ---------- + max_seq_len: int + sequence length of a sample + offset: int, default = 0 + fixed offset for freqencies + """ + seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset + seq = seq.type_as(self.inv_freq) + + if (self.pretrained_max_position_embeddings is not None + and self.seq_len_interpolation_factor is not None): + if (max_seq_len > + self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor): + # dynamic linear scaling (length > position we have learned) + seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings) + else: + # fixed linear scaling + seq *= 1 / self.seq_len_interpolation_factor + + freqs = torch.einsum('i , j -> i j', seq, self.inv_freq) + # first part even vector components, second part odd vector components, + # 2 * dim in dimension size + emb = torch.cat((freqs, freqs), dim=-1) + # emb [seq_length, .., dim] + return emb.reshape(emb.size(0), 1, 1, emb.size(1)) + def _rotate_half(x: torch.Tensor) -> torch.Tensor: """ change sign so the last dimension becomes [-odd, +even] @@ -1488,9 +1546,10 @@ class FusedAttention(torch.nn.Module): | qkv_layout | | | | - qkv | qkv_interleaved | qkv_interleaved | | - (q,kv) | kv_interleaved | | - | - (q,k,v) | sb3hd, bs3hd | sb3hd, bs3hd | + | - (q,k,v) | sb3hd, bs3hd | sb3hd, bs3hd, sbh3d, bsh3d | | | sbhd_sb2hd, bshd_bs2hd | sbhd_sb2hd, bshd_bs2hd | - | | bshd_bshd_bshd | sbhd_sbhd_sbhd, bshd_bshd_bshd | + | | bshd_bshd_bshd | sbhd_sbh2d, bshd_bsh2d | + | | | sbhd_sbhd_sbhd, bshd_bshd_bshd | | mask_type | causal/no_mask | causal | | bias_type | no_bias/post_scale_bias | no_bias | | dropout | yes | yes | @@ -2736,6 +2795,7 @@ def forward( q_pos_emb, k_pos_emb = rotary_pos_emb query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb) key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb) + value_layer = value_layer.contiguous() context_layer = self.core_attention( query_layer,