diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html
index a68b4531e3..f94e526f57 100644
--- a/docs/_templates/layout.html
+++ b/docs/_templates/layout.html
@@ -1,4 +1,11 @@
{% extends "!layout.html" %}
+
+ {% block extrahead %}
+
+
+
+ {% endblock %}
+
{% block sidebartitle %} {{ super() }}
- {%- if nvidia_analytics_id %}
-
- {%- endif %}
+ {% endblock %}
+
+ {% block footer %}
+
+
{% endblock %}
diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
index 82875e2791..d6358d1062 100644
--- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
+++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
@@ -22,10 +22,16 @@
"cp_1_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA
- "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
- "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
+ "cp_1_3": ModelConfig(
+ 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
+ ), # MHA
+ "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
+ "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(
- 2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
+ 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
+ ), # GQA
+ "cp_2_3": ModelConfig(
+ 2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # GQA
}
@@ -45,31 +51,32 @@ def get_bash_arguments(**kwargs):
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
-@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
+@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
config = model_configs_flash_attn[model]
+ if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
+ pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
- pytest.skip(
- f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
- )
- if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
- pytest.skip(
- f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
- " type yet!"
- )
+ pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
+ pytest.skip("CP implementation with KV all-gather does not support bias yet!")
+ if cp_comm_type == "a2a" and qkv_format == "thd":
+ pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
+ if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias":
+ pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
+ if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
pytest.skip(
- f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
- " type yet!"
- )
- if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
- pytest.skip(
- f"CP implementation with KV P2P does not support window size {config.window_size} yet!"
+ f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
+ f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)
subprocess.run(
get_bash_arguments(
- dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention"
+ dtype=dtype,
+ model=model,
+ qkv_format=qkv_format,
+ kernel_backend="FlashAttention",
+ cp_comm_type=cp_comm_type,
),
check=True,
)
@@ -81,10 +88,16 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
- "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
- "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
- "cp_2_2": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
- "cp_2_3": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
+ "cp_1_4": ModelConfig(
+ 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
+ ), # MHA
+ "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
+ "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
+ "cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
+ "cp_2_3": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
+ "cp_2_4": ModelConfig(
+ 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
+ ), # GQA
}
@@ -93,37 +106,27 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
-@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
+@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
- pytest.skip("THD format is only supported on sm90+.")
+ pytest.skip("THD format is only supported on sm90+!")
if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0):
- pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0")
+ pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")
config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
- pytest.skip(f"{qkv_format} format does not support QGA/MQA yet!")
+ pytest.skip("THD format does not support QGA/MQA yet!")
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
- pytest.skip(f"{qkv_format} format does not support {config.attn_bias_type} bias type yet!")
- if cp_comm_type == "all_gather" and qkv_format == "thd":
- pytest.skip(
- f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
- )
- if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
- pytest.skip(
- f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
- " type yet!"
- )
- if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
- pytest.skip(
- f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
- " type yet!"
- )
- if config.window_size != (-1, 0) and config.window_size != (-1, -1):
+ pytest.skip("THD format does not support post_scale_bias yet!")
+ if qkv_format == "thd" and cp_comm_type == "all_gather":
+ pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
+ if qkv_format == "thd" and cp_comm_type == "a2a":
+ pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
+ if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a":
pytest.skip(
- "Fused attention does not support sliding window attention + context parallelism yet!"
+ "Sliding window attention only can be supported with the implementation of QKVO A2A!"
)
- if cp_comm_type == "all_gather" and dtype == "fp8":
+ if dtype == "fp8" and cp_comm_type == "all_gather":
pytest.skip(
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!"
)
@@ -131,10 +134,25 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
pytest.skip("FP8 attention cannot work with THD format yet!")
if dtype == "fp8" and config.attn_bias_type != "no_bias":
pytest.skip("FP8 attention cannot work with bias yet!")
+ if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
+ pytest.skip("FP8 attention cannot work with sliding window yet!")
+ if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
+ pytest.skip("CP implementation with KV all-gather does not support bias yet!")
+ if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias":
+ pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
+ if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
+ pytest.skip(
+ f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
+ f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
+ )
subprocess.run(
get_bash_arguments(
- dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention"
+ dtype=dtype,
+ model=model,
+ qkv_format=qkv_format,
+ kernel_backend="FusedAttention",
+ cp_comm_type=cp_comm_type,
),
check=True,
)
diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py
index 723f68369b..ad34b4996f 100644
--- a/tests/pytorch/test_numerics.py
+++ b/tests/pytorch/test_numerics.py
@@ -1266,12 +1266,15 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False
)
inp_hidden_states.retain_grad()
- m = config.seq_len // 16
- dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
- dist.append(dist[-1]) # Manually add a zero
- m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
- m_splits = m_splits * 16
- assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
+ if num_gemms > 1:
+ m = config.seq_len // 16
+ dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
+ dist.append(dist[-1]) # Manually add a zero
+ m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
+ m_splits = m_splits * 16
+ assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
+ else:
+ m_splits = torch.tensor([config.seq_len])
with fp8_autocast(enabled=fp8):
if isinstance(block, GroupedLinear):
@@ -1353,7 +1356,7 @@ def test_grouped_linear_accuracy(
@pytest.mark.parametrize("parallel_mode", ["column", "row"])
def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
- """Split the tests to reduce CI time"""
+ """Split the tests to save CI time"""
test_grouped_linear_accuracy(
dtype=torch.float32,
num_gemms=6,
@@ -1365,6 +1368,18 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
)
+def test_grouped_linear_accuracy_single_gemm():
+ """Split the tests to save CI time"""
+ test_grouped_linear_accuracy(
+ dtype=torch.float32,
+ num_gemms=1,
+ bs=2,
+ model=list(model_configs.keys())[0],
+ fp8=True,
+ fp8_model_params=True,
+ )
+
+
def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False):
def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
@@ -2034,7 +2049,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
fp8_grouped_gemm(
A_fp8,
- scale_inv,
+ [scale_inv],
0, # A_offset
tex.DType.kFloat8E4M3,
B_fp8,
diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py
index 91c14899ec..f8ba46b2ea 100644
--- a/transformer_engine/pytorch/attention.py
+++ b/transformer_engine/pytorch/attention.py
@@ -614,12 +614,6 @@ def get_attention_backend(
"with causal mask, no dropout, and qkv_format = bshd/sbhd"
)
use_fused_attention = False
- elif context_parallel:
- logger.debug(
- "Disabling FusedAttention as it does not support sliding window attention "
- "with context parallelism"
- )
- use_fused_attention = False
elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [
"no_mask",
"padding",
@@ -1429,9 +1423,6 @@ def forward(
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
dropout_p,
- cp_group,
- cp_global_ranks,
- cp_stream,
softmax_scale,
qkv_format,
attn_mask_type,
@@ -1441,6 +1432,9 @@ def forward(
use_fused_attention,
fp8,
fp8_meta,
+ cp_group,
+ cp_global_ranks,
+ cp_stream,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
@@ -2946,10 +2940,10 @@ def backward(ctx, dout):
None,
None,
None,
+ attn_dbias,
None,
None,
None,
- attn_dbias,
None,
None,
None,
@@ -2958,30 +2952,56 @@ def backward(ctx, dout):
@torch.compile
-def get_seq_chunk_ids_to_all_gathered_kv(
- local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size_left, device
+def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous):
+ """
+ Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
+ To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks
+ before or after CP communications (e.g., all-gather, all-to-all). This function is to compute
+ sequence chunk ids for reordering.
+ """
+ chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
+ if to_contiguous:
+ for rank in range(cp_size):
+ chunk_ids[rank] = 2 * rank
+ chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1
+ else:
+ for rank in range(cp_size):
+ chunk_ids[2 * rank] = rank
+ chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1
+ return chunk_ids
+
+
+def get_kv_seq_info_after_all_gather(
+ local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size, causal
):
- """Compute sequence chunk ids to the all-gathered KV."""
- seq_end_idx = (local_chunk_id + 1) * max_seqlen_kv
- seq_start_idx = max(0, seq_end_idx - max_seqlen_q - window_size_left)
- seqlen = seq_end_idx - seq_start_idx
- num_chunks = (seqlen + max_seqlen_kv - 1) // max_seqlen_kv
- chunk_ids = torch.arange(
- local_chunk_id - num_chunks + 1,
- local_chunk_id + 1,
- dtype=torch.int32,
- device=device,
- )
- chunk_ids_to_all_gathered_kv = torch.where(
- chunk_ids < cp_size, 2 * chunk_ids, 2 * (2 * cp_size - chunk_ids) - 1
- )
- return chunk_ids_to_all_gathered_kv
+ """Compute KV sequence index range and update window size after all-gather."""
+ local_chunk_end_idx = (local_chunk_id + 1) * max_seqlen_kv
+ full_seq_end_idx = max_seqlen_kv * cp_size * 2
+
+ if window_size is None:
+ window_size = (-1, 0) if causal else (-1, -1)
+
+ if window_size[1] == -1:
+ seq_end_idx = full_seq_end_idx
+ window_size_right = -1
+ else:
+ seq_end_idx = min(full_seq_end_idx, local_chunk_end_idx + window_size[1])
+ window_size_right = local_chunk_end_idx + window_size[1] - seq_end_idx
+
+ if window_size[0] == -1:
+ seq_start_idx = 0
+ window_size_left = -1
+ else:
+ seq_start_idx = max(0, local_chunk_end_idx - max_seqlen_q - window_size[0])
+ window_size_left = window_size[0] + seq_end_idx - local_chunk_end_idx
+
+ return (seq_start_idx, seq_end_idx), (window_size_left, window_size_right)
class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
"""
- Attention implementation with context parallelism.
- KV all-gather between CP ranks is exposed.
+ Attention implementation with context parallelism. KV all-gather between CP ranks is exposed.
+ Refer section 3.3.2 of `The Llama 3 Herd of Models `_.
"""
@staticmethod
@@ -2992,14 +3012,10 @@ def forward(
k,
v,
cu_seqlens_q,
- cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q_padded,
- cu_seqlens_kv_padded,
dropout_p,
- cp_group,
- cp_stream,
softmax_scale,
qkv_format,
attn_mask_type,
@@ -3008,6 +3024,8 @@ def forward(
deterministic,
use_fused_attention,
window_size,
+ cp_group,
+ cp_stream,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
@@ -3017,10 +3035,9 @@ def forward(
causal = "causal" in attn_mask_type
padding = "padding" in attn_mask_type
- assert causal and not padding, f"{attn_mask_type} mask type is not supported!"
+ assert not padding, f"{attn_mask_type} mask type is not supported!"
if use_fused_attention and causal and "bottom_right" not in attn_mask_type:
attn_mask_type = attn_mask_type + "_bottom_right"
-
assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
assert (
@@ -3029,6 +3046,8 @@ def forward(
fa_optional_forward_kwargs = {}
if _flash_attn_2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = None
+ if _flash_attn_2_5_7_plus:
+ fa_optional_forward_kwargs["block_table"] = None
assert qkv_format != "thd", f"{qkv_format} format is not supported!"
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
@@ -3041,31 +3060,35 @@ def forward(
max_seqlen_q = max_seqlen_q // (2 * cp_size)
max_seqlen_kv = max_seqlen_kv // (2 * cp_size)
cu_seqlens_q = cu_seqlens_q // (2 * cp_size)
- cu_seqlens_kv = cu_seqlens_kv // (2 * cp_size)
cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size)
- cu_seqlens_kv_padded = cu_seqlens_kv_padded // (2 * cp_size)
-
- if causal:
- if qkv_format == "bshd":
- # [b, s, np, hn] -> [b, 2, s//2, np, hn]
- q = q.view(q.shape[0], 2, q.shape[1] // 2, *q.shape[2:])
- # [b, s, np, hn] -> [s, b, np, hn]
- k, v = [x.transpose(0, 1).contiguous() for x in [k, v]]
- elif qkv_format == "sbhd":
- # [s, b, np, hn] -> [2, s//2, b, np, hn]
- q = q.view(2, q.shape[0] // 2, *q.shape[1:])
- # create two streams to resolve wave quantization issue of Flash Attn in each step
- flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
+ # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn]
+ q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :])
+ # [b, s, np, hn] or [s, b, np, hn] -> [s, b, np, hn]
+ k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]]
+ # [s, b, np, hn] -> [cp, s, b, np, hn]
k_ag, _ = gather_along_first_dim(k, cp_group)
v_ag, _ = gather_along_first_dim(v, cp_group)
- cp_stream.wait_stream(torch.cuda.current_stream())
+
+ # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
+ chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True)
+ k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
+ v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
+ # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
+ k_ag = k_ag.view(-1, *k.shape[1:])
+ v_ag = v_ag.view(-1, *v.shape[1:])
+ cp_stream.wait_stream(torch.cuda.current_stream())
+
+ # create two streams to resolve wave quantization issue of Flash Attn in each step
+ flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]
- chunk_ids_to_kv_ag_per_step = [None, None]
+ kv_seq_range_per_step = [None, None]
+ window_size_per_step = [None, None]
+ cu_seqlens_kv_per_step = [None, None]
out_per_step = [None, None]
softmax_lse_per_step = [None, None]
rng_states = [None, None]
@@ -3074,53 +3097,36 @@ def forward(
for i in range(len(local_seq_chunk_ids) + 1):
if i < len(local_seq_chunk_ids):
with torch.cuda.stream(flash_attn_streams[i]):
- chunk_ids_to_kv_ag = get_seq_chunk_ids_to_all_gathered_kv(
- local_seq_chunk_ids[i],
- cp_size,
- max_seqlen_q,
- max_seqlen_kv,
- (
- max_seqlen_kv * cp_size * 2
- if (window_size is None or window_size[0] == -1)
- else window_size[0]
- ),
- k.device,
- )
- chunk_ids_to_kv_ag_per_step[i] = chunk_ids_to_kv_ag
- num_kv_chunks = chunk_ids_to_kv_ag.numel()
- if qkv_format == "bshd":
- # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
- q_ = q[:, i].contiguous()
- # [num_kv_chunks, sq//2, b, np, hn] -> [b, num_kv_chunks*sq//2, np, hn]
- k_ = (
- torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag)
- .movedim(2, 0)
- .contiguous()
- .view(k.shape[1], -1, *k.shape[-2:])
- )
- v_ = (
- torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag)
- .movedim(2, 0)
- .contiguous()
- .view(v.shape[1], -1, *v.shape[-2:])
- )
- elif qkv_format == "sbhd":
- # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
- q_ = q[i].contiguous()
- # [num_kv_chunks, sq//2, b, np, hn] -> [num_kv_chunks*sq//2, b, np, hn]
- k_ = torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag).view(
- -1, *k.shape[-3:]
- )
- v_ = torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag).view(
- -1, *v.shape[-3:]
+ # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
+ q_ = q.select(seq_dim, i).contiguous()
+ kv_seq_range_per_step[i], window_size_per_step[i] = (
+ get_kv_seq_info_after_all_gather(
+ local_seq_chunk_ids[i],
+ cp_size,
+ max_seqlen_q,
+ max_seqlen_kv,
+ window_size,
+ causal,
)
+ )
+ seq_start_idx, seq_end_idx = (
+ kv_seq_range_per_step[i][0],
+ kv_seq_range_per_step[i][1],
+ )
+ max_seqlen_kv_ = seq_end_idx - seq_start_idx
+ cu_seqlens_kv_per_step[i] = _get_full_cu_seqlens(
+ k.shape[1], max_seqlen_kv_, k.device
+ )
+ k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]]
+ # [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn]
+ k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
if use_fused_attention:
out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd(
is_training,
max_seqlen_q,
- max_seqlen_kv * num_kv_chunks,
+ max_seqlen_kv_,
cu_seqlens_q,
- cu_seqlens_kv * num_kv_chunks,
+ cu_seqlens_kv_per_step[i],
q_,
k_,
v_,
@@ -3133,8 +3139,8 @@ def forward(
attn_bias_type=attn_bias_type,
attn_bias=attn_bias,
cu_seqlens_q_padded=cu_seqlens_q_padded,
- cu_seqlens_kv_padded=cu_seqlens_kv_padded * num_kv_chunks,
- window_size=window_size,
+ cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
+ window_size=window_size_per_step[i],
)
else:
q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]]
@@ -3144,14 +3150,14 @@ def forward(
k_,
v_,
cu_seqlens_q,
- cu_seqlens_kv * num_kv_chunks,
+ cu_seqlens_kv_per_step[i],
max_seqlen_q,
- max_seqlen_kv * num_kv_chunks,
+ max_seqlen_kv_,
dropout_p,
softmax_scale,
- causal=True,
+ causal=causal,
return_softmax=False,
- window_size=window_size,
+ window_size=window_size_per_step[i],
**fa_optional_forward_kwargs,
)
)
@@ -3159,9 +3165,9 @@ def forward(
if i > 0:
with torch.cuda.stream(flash_attn_streams[i - 1]):
if qkv_format == "bshd":
- out[:, i - 1].copy_(out_per_step[i - 1].view_as(out[:, i - 1]))
+ out[:, i - 1].copy_(out_per_step[i - 1].view(out[:, i - 1].shape))
elif qkv_format == "sbhd":
- out[i - 1].copy_(out_per_step[i - 1].view_as(out[i - 1]))
+ out[i - 1].copy_(out_per_step[i - 1].view(out[i - 1].shape))
torch.cuda.current_stream().wait_stream(cp_stream)
@@ -3178,26 +3184,24 @@ def forward(
k,
v,
cu_seqlens_q,
- cu_seqlens_kv,
cu_seqlens_q_padded,
- cu_seqlens_kv_padded,
- *chunk_ids_to_kv_ag_per_step,
+ *cu_seqlens_kv_per_step,
*out_per_step,
*softmax_lse_per_step,
*rng_states,
)
+ ctx.kv_seq_range_per_step = kv_seq_range_per_step
+ ctx.window_size_per_step = window_size_per_step
ctx.cp_group = cp_group
ctx.cp_stream = cp_stream
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
- ctx.max_seqlen_kv = max_seqlen_kv
ctx.softmax_scale = softmax_scale
ctx.qkv_format = qkv_format
- ctx.attn_mask_type = attn_mask_type
ctx.attn_bias_type = attn_bias_type
+ ctx.attn_mask_type = attn_mask_type
ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention
- ctx.window_size = window_size
return out
@staticmethod
@@ -3205,21 +3209,20 @@ def backward(ctx, dout):
cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group)
- (q, k, v, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded) = (
- ctx.saved_tensors[:7]
- )
- chunk_ids_to_kv_ag_per_step = ctx.saved_tensors[7:9]
- out_per_step = ctx.saved_tensors[9:11]
- softmax_lse_per_step = ctx.saved_tensors[11:13]
- rng_states = ctx.saved_tensors[13:15]
+ (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = ctx.saved_tensors[:5]
+ cu_seqlens_kv_per_step = ctx.saved_tensors[5:7]
+ out_per_step = ctx.saved_tensors[7:9]
+ softmax_lse_per_step = ctx.saved_tensors[9:11]
+ rng_states = ctx.saved_tensors[11:13]
+ kv_seq_range_per_step = ctx.kv_seq_range_per_step
+ window_size_per_step = ctx.window_size_per_step
+ seq_dim = ctx.qkv_format.index("s")
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
- dout = dout.view_as(q)
+ dout = dout.view(q.shape)
dq = torch.empty_like(q)
- dk = torch.zeros(
- (2 * cp_size, k.shape[0] // 2, *k.shape[1:]), dtype=k.dtype, device=k.device
- )
+ dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device)
dv = torch.zeros_like(dk)
dq_per_step = [None, None]
dk_per_step = [None, None]
@@ -3230,11 +3233,20 @@ def backward(ctx, dout):
# synchronize dkv update across steps
dkv_update_done = torch.cuda.Event()
+ # [s, b, np, hn] -> [cp, s, b, np, hn]
k_ag, _ = gather_along_first_dim(k, ctx.cp_group)
v_ag, _ = gather_along_first_dim(v, ctx.cp_group)
- ctx.cp_stream.wait_stream(torch.cuda.current_stream())
+
+ # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn]
k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:])
v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:])
+ chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True)
+ k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag)
+ v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag)
+ # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
+ k_ag = k_ag.view(-1, *k.shape[1:])
+ v_ag = v_ag.view(-1, *v.shape[1:])
+ ctx.cp_stream.wait_stream(torch.cuda.current_stream())
local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1]
@@ -3247,66 +3259,46 @@ def backward(ctx, dout):
for i in range(len(local_seq_chunk_ids) + 1):
if i < len(local_seq_chunk_ids):
with torch.cuda.stream(flash_attn_streams[i]):
- chunk_ids_to_kv_ag = chunk_ids_to_kv_ag_per_step[i]
- num_kv_chunks = chunk_ids_to_kv_ag.numel()
+ # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
+ q_ = q.select(seq_dim, i).contiguous()
+ seq_start_idx, seq_end_idx = (
+ kv_seq_range_per_step[i][0],
+ kv_seq_range_per_step[i][1],
+ )
+ max_seqlen_kv = seq_end_idx - seq_start_idx
+ k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]]
+ # [cp*s, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn]
+ k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
out_ = out_per_step[i]
- if ctx.qkv_format == "bshd":
- # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
- q_ = q[:, i].contiguous()
- # [num_kv_chunks, sq//2, b, np, hn] -> [b, num_kv_chunks*sq//2, np, hn]
- k_ = (
- torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag)
- .movedim(2, 0)
- .contiguous()
- .view(k.shape[1], -1, *k.shape[-2:])
- )
- v_ = (
- torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag)
- .movedim(2, 0)
- .contiguous()
- .view(v.shape[1], -1, *v.shape[-2:])
- )
- dout_ = dout[:, i].contiguous().view_as(out_)
- elif ctx.qkv_format == "sbhd":
- # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
- q_ = q[i].contiguous()
- # [num_kv_chunks, sq//2, b, np, hn] -> [num_kv_chunks*sq//2, b, np, hn]
- k_ = torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag).view(
- -1, *k.shape[-3:]
- )
- v_ = torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag).view(
- -1, *v.shape[-3:]
- )
- dout_ = dout[i].contiguous().view_as(out_)
+ dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape)
if ctx.use_fused_attention:
- dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
- torch.empty_like(x) for x in [q_, k_, v_]
- ]
aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]]
dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd(
ctx.max_seqlen_q,
- ctx.max_seqlen_kv * num_kv_chunks,
+ max_seqlen_kv,
cu_seqlens_q,
- cu_seqlens_kv * num_kv_chunks,
+ cu_seqlens_kv_per_step[i],
q_,
k_,
v_,
out_,
dout_,
TE_DType[q.dtype],
- TE_DType[k.dtype],
+ TE_DType[dout.dtype],
aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
cu_seqlens_q_padded=cu_seqlens_q_padded,
- cu_seqlens_kv_padded=cu_seqlens_kv_padded * num_kv_chunks,
+ cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p,
qkv_layout=qkv_layout,
attn_mask_type=ctx.attn_mask_type,
attn_bias_type=ctx.attn_bias_type,
- window_size=ctx.window_size,
+ window_size=window_size_per_step[i],
+ deterministic=ctx.deterministic,
)
else:
+ batch_size = k_.shape[0]
q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]]
dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
torch.empty_like(x) for x in [q_, k_, v_]
@@ -3322,65 +3314,601 @@ def backward(ctx, dout):
dk_per_step[i],
dv_per_step[i],
cu_seqlens_q,
- cu_seqlens_kv * num_kv_chunks,
+ cu_seqlens_kv_per_step[i],
ctx.max_seqlen_q,
- ctx.max_seqlen_kv * num_kv_chunks,
+ max_seqlen_kv,
ctx.dropout_p,
ctx.softmax_scale,
- True,
- window_size=ctx.window_size,
+ "causal" in ctx.attn_mask_type,
+ window_size=window_size_per_step[i],
rng_state=rng_states[i],
**fa_optional_backward_kwargs,
)
+ # [b*sq//2, np, hn] -> [b, sq//2, np, hn]
+ dq_per_step[i] = dq_per_step[i].view(dq[:, i].shape)
+ # [b*s_range, np, hn] -> [b, s_range, np, hn]
+ dk_per_step[i], dv_per_step[i] = [
+ x.view(batch_size, -1, *x.shape[-2:])
+ for x in [dk_per_step[i], dv_per_step[i]]
+ ]
if i > 0:
with torch.cuda.stream(flash_attn_streams[i - 1]):
- chunk_ids_to_kv_ag = chunk_ids_to_kv_ag_per_step[i - 1]
- num_kv_chunks = chunk_ids_to_kv_ag.numel()
if ctx.qkv_format == "bshd":
- dq[:, i - 1].copy_(dq_per_step[i - 1].view_as(dq[:, i - 1]))
- dk_per_step[i - 1] = (
- dk_per_step[i - 1]
- .view(k.shape[1], num_kv_chunks, -1, *k.shape[-2:])
- .movedim(0, 2)
- .contiguous()
- )
- dv_per_step[i - 1] = (
- dv_per_step[i - 1]
- .view(v.shape[1], num_kv_chunks, -1, *v.shape[-2:])
- .movedim(0, 2)
- .contiguous()
- )
+ dq[:, i - 1].copy_(dq_per_step[i - 1])
elif ctx.qkv_format == "sbhd":
- dq[i - 1].copy_(dq_per_step[i - 1].view_as(dq[i - 1]))
- dk_per_step[i - 1] = dk_per_step[i - 1].view(
- num_kv_chunks, -1, *k.shape[-3:]
- )
- dv_per_step[i - 1] = dv_per_step[i - 1].view(
- num_kv_chunks, -1, *v.shape[-3:]
- )
-
+ dq[i - 1].copy_(dq_per_step[i - 1])
+ # [b, s_range, np, hn] or [s_range, b, np, hn] -> [s_range, b, np, hn]
+ dk_per_step[i - 1], dv_per_step[i - 1] = [
+ x.movedim(seq_dim, 0).contiguous()
+ for x in [dk_per_step[i - 1], dv_per_step[i - 1]]
+ ]
# wait until dkv update of last step is done
if i > 1:
flash_attn_streams[i - 1].wait_event(dkv_update_done)
- dk.index_add_(0, chunk_ids_to_kv_ag, dk_per_step[i - 1])
- dv.index_add_(0, chunk_ids_to_kv_ag, dv_per_step[i - 1])
+ seq_start_idx, seq_end_idx = (
+ kv_seq_range_per_step[i - 1][0],
+ kv_seq_range_per_step[i - 1][1],
+ )
+ dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1])
+ dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1])
if i < len(local_seq_chunk_ids):
flash_attn_streams[i - 1].record_event(dkv_update_done)
torch.cuda.current_stream().wait_stream(ctx.cp_stream)
+ # [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn]
+ dk = dk.view(2 * cp_size, -1, *dk.shape[-3:])
+ dv = dv.view(2 * cp_size, -1, *dv.shape[-3:])
+ chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, dk.device, False)
+ dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag)
+ dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag)
+ # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn]
dk = dk.view(-1, *dk.shape[-3:])
dv = dv.view(-1, *dv.shape[-3:])
dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group)
dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group)
+ dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :])
+ dk = dk.movedim(0, seq_dim).contiguous()
+ dv = dv.movedim(0, seq_dim).contiguous()
+
+ return (
+ None,
+ dq,
+ dk,
+ dv,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+@torch.compile
+def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn):
+ """Reorder sequence chunk for A2A communication."""
+ if before_attn:
+ # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn]
+ x = x.movedim(0, seq_dim).contiguous()
+ # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
+ x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :])
+ # reorder the sequence chunks
+ x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a)
+ else:
+ # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
+ x = x.movedim(seq_dim, 0).contiguous()
+ # reorder the sequence chunks
+ x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a)
+ # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn]
+ x = x.view(cp_size, 2, *x.shape[1:])
+ return x
+
+
+def flash_attn_a2a_communicate(
+ a2a_inputs: Union[torch.Tensor, List[torch.Tensor]],
+ chunk_ids_for_a2a: torch.Tensor,
+ seq_dim: int,
+ cp_size: int,
+ cp_group: dist_group_type,
+ cp_stream: torch.cuda.Stream,
+ before_attn: bool,
+) -> Union[torch.Tensor, List[torch.Tensor]]:
+ """A2A communication for context parallelism."""
+ a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs
+ a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs)
+ if before_attn:
+ for i in range(len(a2a_inputs) + 2):
+ if 0 < i < len(a2a_inputs) + 1:
+ a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
+ a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
+ a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
+ )
+ if i > 1:
+ with torch.cuda.stream(cp_stream):
+ a2a_reqs[i - 2].wait()
+ x = a2a_outputs[i - 2]
+ # reorder the sequence chunks
+ x = reorder_seq_chunks_for_a2a(
+ x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn
+ )
+ # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn]
+ a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :])
+ if i < len(a2a_inputs):
+ x = a2a_inputs[i]
+ # [b, s, np, hn] -> [b, s, cp, np//cp, hn] or [s, b, np, hn] -> [s, b, cp, np//cp, hn]
+ x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1])
+ # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn]
+ a2a_inputs[i] = x.movedim(-3, 0).contiguous()
+ else:
+ for i in range(len(a2a_inputs) + 2):
+ if 0 < i < len(a2a_inputs) + 1:
+ a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
+ a2a_reqs[i - 1] = torch.distributed.all_to_all_single(
+ a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True
+ )
+ if i < len(a2a_inputs):
+ x = a2a_inputs[i]
+ # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn]
+ x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :])
+ # reorder the sequence chunks
+ a2a_inputs[i] = reorder_seq_chunks_for_a2a(
+ x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn
+ )
+ if i > 1:
+ with torch.cuda.stream(cp_stream):
+ a2a_reqs[i - 2].wait()
+ x = a2a_outputs[i - 2]
+ # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn]
+ x = x.movedim(0, -3).movedim(0, seq_dim).contiguous()
+ # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn]
+ a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1])
+ torch.cuda.current_stream().wait_stream(cp_stream)
+ return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs
+
+
+class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
+ """
+ Attention implementation with context parallelism. Like Ulysses, applying A2A to QKVO.
+ Refer the paper `DeepSpeed Ulysses `_.
+ """
+
+ @staticmethod
+ def forward(
+ ctx,
+ is_training,
+ q,
+ k,
+ v,
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ max_seqlen_q,
+ max_seqlen_kv,
+ cu_seqlens_q_padded,
+ cu_seqlens_kv_padded,
+ dropout_p,
+ softmax_scale,
+ qkv_format,
+ attn_mask_type,
+ attn_bias_type,
+ attn_bias,
+ deterministic,
+ use_fused_attention,
+ window_size,
+ fp8,
+ fp8_meta,
+ cp_group,
+ cp_stream,
+ ):
+ if softmax_scale is None:
+ softmax_scale = q.shape[-1] ** (-0.5)
+
+ cp_size = get_distributed_world_size(cp_group)
+
+ causal = "causal" in attn_mask_type
+ padding = "padding" in attn_mask_type
+ assert not padding, f"{attn_mask_type} mask type is not supported!"
+ assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
+ assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
+ assert (
+ window_size == (-1, 0)
+ or window_size == (-1, -1)
+ or use_fused_attention
+ or _flash_attn_2_3_plus
+ ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"
+ fa_optional_forward_kwargs = {}
+ if _flash_attn_2_3_plus:
+ fa_optional_forward_kwargs["window_size"] = window_size
+ if _flash_attn_2_4_plus:
+ fa_optional_forward_kwargs["alibi_slopes"] = None
+ if _flash_attn_2_5_7_plus:
+ fa_optional_forward_kwargs["block_table"] = None
+
+ assert (
+ q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0
+ ), "The number of attention heads needs to be divisible by CP size!"
+
+ assert qkv_format != "thd", f"{qkv_format} format is not supported!"
+ qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
+
+ batch_dim = qkv_format.index("b")
+ seq_dim = qkv_format.index("s")
+ assert (
+ q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
+ ), "Sequence length per GPU needs to be divisible by 2!"
+
+ if fp8:
+ if use_fused_attention:
+ fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
+ fused_attn_qkv_dtype = fp8_dtype_forward
+ fused_attn_backend = FusedAttnBackend["FP8"]
+ if fp8_meta["recipe"].fp8_mha:
+ assert (
+ isinstance(q, Float8Tensor)
+ and isinstance(k, Float8Tensor)
+ and isinstance(v, Float8Tensor)
+ ), "q/k/v must be Float8Tensors for FP8 MHA!"
+ fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
+ q_fp8, k_fp8, v_fp8 = q, k, v
+ q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
+ elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
+ q_f16, k_f16, v_f16 = q, k, v
+ q, k, v = [
+ cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
+ for x in [q_f16, k_f16, v_f16]
+ ]
+ fp8_meta_kwargs = {}
+ fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv
+ fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV
+ fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv
+ fp8_meta_kwargs["d_scale_s_offset"] = META_S
+ fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale
+ fp8_meta_kwargs["q_scale_s_offset"] = META_S
+ fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale
+ fp8_meta_kwargs["q_scale_o_offset"] = META_O
+ fp8_meta_kwargs["amax_s"] = fp8_meta["scaling_fwd"].amax_history
+ fp8_meta_kwargs["amax_s_offset"] = META_S
+ fp8_meta_kwargs["amax_o"] = fp8_meta["scaling_fwd"].amax_history
+ fp8_meta_kwargs["amax_o_offset"] = META_O
+ else:
+ assert False, "FP8 is only supported with Fused Attention!"
+ else:
+ if use_fused_attention:
+ fp8_meta_kwargs = {}
+ fused_attn_qkv_dtype = TE_DType[q.dtype]
+ fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
+
+ chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, True)
+ q, k, v = flash_attn_a2a_communicate(
+ [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True
+ )
+
+ if fp8 and not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
+ q_f16, k_f16, v_f16 = q, k, v
+ q, k, v = [
+ cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
+ for x in [q_f16, k_f16, v_f16]
+ ]
+
+ batch_size = q.shape[batch_dim]
+ if use_fused_attention:
+ out, aux_ctx_tensors = fused_attn_fwd(
+ is_training,
+ max_seqlen_q,
+ max_seqlen_kv,
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ q,
+ k,
+ v,
+ fused_attn_qkv_dtype,
+ fused_attn_backend,
+ attn_scale=softmax_scale,
+ dropout=dropout_p,
+ qkv_layout=qkv_layout,
+ attn_mask_type=attn_mask_type,
+ attn_bias_type=attn_bias_type,
+ attn_bias=attn_bias,
+ cu_seqlens_q_padded=cu_seqlens_q_padded,
+ cu_seqlens_kv_padded=cu_seqlens_kv_padded,
+ window_size=window_size,
+ **fp8_meta_kwargs,
+ )
+ else:
+ # [b, cp*s, np//cp, hn] -> [b*cp*s, np//cp, hn]
+ q, k, v = [x.view(-1, *x.shape[-2:]) for x in [q, k, v]]
+ (
+ _,
+ _,
+ _,
+ _,
+ out,
+ softmax_lse,
+ _,
+ rng_state,
+ ) = _flash_attn_forward(
+ q,
+ k,
+ v,
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ max_seqlen_q,
+ max_seqlen_kv,
+ dropout_p,
+ softmax_scale,
+ causal=causal,
+ return_softmax=False,
+ **fa_optional_forward_kwargs,
+ )
+ aux_ctx_tensors = [softmax_lse, rng_state]
+ # [b*cp*s, np//cp, hn] -> [b, cp*s, np//cp, hn]
+ out = out.view(batch_size, -1, *out.shape[-2:])
+
+ chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, False)
+ out = flash_attn_a2a_communicate(
+ out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False
+ )
+
+ if use_fused_attention:
+ if qkv_format == "bshd":
+ # [b*s, np, hn] -> [b, s, np, hn]
+ out = out.view(batch_size, -1, *out.shape[-2:])
+ elif qkv_format == "sbhd":
+ # [s*b, np, hn] -> [s, b, np, hn]
+ out = out.view(-1, batch_size, *out.shape[-2:])
+
+ if fp8:
+ if fp8_meta["recipe"].fp8_mha:
+ out_fp8 = Float8Tensor(
+ data=out,
+ fp8_meta=fp8_meta,
+ fp8_meta_forward=True,
+ fp8_meta_index=META_O,
+ fp8_dtype=fp8_dtype_forward,
+ dtype=q_fp8.dtype,
+ )
+ out = out_fp8._data
+ out_ret = out_fp8
+ else:
+ out_f16 = cast_from_fp8(
+ out,
+ fp8_meta["scaling_fwd"],
+ META_O,
+ fp8_dtype_forward,
+ TE_DType[q_f16.dtype],
+ )
+ out_ret = out_f16
+ else:
+ out_ret = out
+
+ if fp8:
+ if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
+ q_save, k_save, v_save, out_save = q, k, v, out
+ elif fp8_meta["recipe"].fp8_mha:
+ q_fp8, k_fp8, v_fp8 = [
+ Float8Tensor(
+ data=x,
+ fp8_meta=fp8_meta,
+ fp8_meta_forward=True,
+ fp8_meta_index=META_QKV,
+ fp8_dtype=fp8_dtype_forward,
+ dtype=out_fp8.dtype,
+ )
+ for x in [q, k, v]
+ ]
+ q_save, k_save, v_save, out_save = q_fp8, k_fp8, v_fp8, out_fp8
+ else:
+ q_save, k_save, v_save, out_save = q_f16, k_f16, v_f16, out_f16
+ else:
+ q_save, k_save, v_save, out_save = q, k, v, out
+
+ if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
+ fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone()
+ fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone()
+ else:
+ fp8_fwd_scales, fp8_fwd_scale_invs = None, None
+
+ ctx.save_for_backward(
+ q_save,
+ k_save,
+ v_save,
+ out_save,
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ cu_seqlens_q_padded,
+ cu_seqlens_kv_padded,
+ fp8_fwd_scales,
+ fp8_fwd_scale_invs,
+ *aux_ctx_tensors,
+ )
+ ctx.batch_size = batch_size
+ ctx.cp_group = cp_group
+ ctx.cp_stream = cp_stream
+ ctx.dropout_p = dropout_p
+ ctx.max_seqlen_q = max_seqlen_q
+ ctx.max_seqlen_kv = max_seqlen_kv
+ ctx.softmax_scale = softmax_scale
+ ctx.qkv_format = qkv_format
+ ctx.attn_mask_type = attn_mask_type
+ ctx.attn_bias_type = attn_bias_type
+ ctx.deterministic = deterministic
+ ctx.window_size = window_size
+ ctx.use_fused_attention = use_fused_attention
+ ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
+ ctx.fp8_meta = fp8_meta
+ return out_ret
+
+ @staticmethod
+ def backward(ctx, dout):
+ cp_size = get_distributed_world_size(ctx.cp_group)
+
+ q, k, v, out = ctx.saved_tensors[:4]
+ cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = ctx.saved_tensors[
+ 4:8
+ ]
+ fp8_fwd_scales, fp8_fwd_scale_invs = ctx.saved_tensors[8:10]
+ aux_ctx_tensors = ctx.saved_tensors[10:]
+
+ qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
+ causal = "causal" in ctx.attn_mask_type
+ seq_dim = ctx.qkv_format.index("s")
+
+ if ctx.fp8:
+ if ctx.use_fused_attention:
+ fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
+ fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
+ fused_attn_qkv_dtype = fp8_dtype_forward
+ fused_attn_dqkv_dtype = fp8_dtype_backward
+ fused_attn_backend = FusedAttnBackend["FP8"]
+ if ctx.fp8_meta["recipe"].fp8_mha:
+ assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
+ ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv
+ dout_fp8 = dout
+ dout = dout_fp8._data
+ else:
+ dout_f16 = dout
+ dout = cast_to_fp8(
+ dout_f16, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
+ )
+ fp8_meta_kwargs = {}
+ fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV]
+ fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S]
+ fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O]
+ fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO]
+ fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP]
+ fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S]
+ fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP]
+ fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV]
+ fp8_meta_kwargs["amax_dp"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP]
+ fp8_meta_kwargs["amax_dqkv"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][
+ META_DQKV
+ ]
+ else:
+ assert False, "FP8 is only supported with Fused Attention!"
+ else:
+ if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha:
+ assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
+ q, k, v, out, dout = [x.from_float8(x.dtype) for x in [q, k, v, out, dout]]
+ if ctx.use_fused_attention:
+ fp8_meta_kwargs = {}
+ fused_attn_qkv_dtype = TE_DType[q.dtype]
+ fused_attn_dqkv_dtype = TE_DType[dout.dtype]
+ fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
+
+ if not ctx.use_fused_attention:
+ out = out.view(ctx.batch_size, -1, *out.shape[-2:])
+ dout = dout.view(*out.shape)
+
+ chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, True)
+ out, dout = flash_attn_a2a_communicate(
+ [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True
+ )
+
+ fa_optional_backward_kwargs = {}
+ if _flash_attn_2_3_plus:
+ fa_optional_backward_kwargs["window_size"] = ctx.window_size
+ if _flash_attn_2_4_plus:
+ fa_optional_backward_kwargs["alibi_slopes"] = None
+ if _flash_attn_2_4_1_plus:
+ fa_optional_backward_kwargs["deterministic"] = ctx.deterministic
+
+ if ctx.use_fused_attention:
+ dq, dk, dv, _ = fused_attn_bwd(
+ ctx.max_seqlen_q,
+ ctx.max_seqlen_kv,
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ q,
+ k,
+ v,
+ out,
+ dout,
+ fused_attn_qkv_dtype,
+ fused_attn_dqkv_dtype,
+ aux_ctx_tensors,
+ fused_attn_backend,
+ cu_seqlens_q_padded=cu_seqlens_q_padded,
+ cu_seqlens_kv_padded=cu_seqlens_kv_padded,
+ attn_scale=ctx.softmax_scale,
+ dropout=ctx.dropout_p,
+ qkv_layout=qkv_layout,
+ attn_mask_type=ctx.attn_mask_type,
+ attn_bias_type=ctx.attn_bias_type,
+ window_size=ctx.window_size,
+ deterministic=ctx.deterministic,
+ **fp8_meta_kwargs,
+ )
+ else:
+ softmax_lse, rng_state = aux_ctx_tensors
+ out, dout = [x.view(-1, *x.shape[-2:]) for x in [out, dout]]
+ dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]]
+ _flash_attn_backward(
+ dout,
+ q,
+ k,
+ v,
+ out,
+ softmax_lse,
+ dq,
+ dk,
+ dv,
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ ctx.max_seqlen_q,
+ ctx.max_seqlen_kv,
+ ctx.dropout_p,
+ ctx.softmax_scale,
+ causal,
+ rng_state=rng_state,
+ **fa_optional_backward_kwargs,
+ )
+ dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
+
+ chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, False)
+ dq, dk, dv = flash_attn_a2a_communicate(
+ [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False
+ )
+
if ctx.qkv_format == "bshd":
- dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
- dk = dk.transpose(0, 1).contiguous()
- dv = dv.transpose(0, 1).contiguous()
+ dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
elif ctx.qkv_format == "sbhd":
- dq = dq.view(-1, *dq.shape[-3:])
+ dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]]
+
+ if ctx.fp8:
+ if ctx.fp8_meta["recipe"].fp8_mha:
+ dq, dk, dv = [
+ Float8Tensor(
+ data=x,
+ fp8_meta=ctx.fp8_meta,
+ fp8_meta_forward=False,
+ fp8_meta_index=META_DQKV,
+ fp8_dtype=fp8_dtype_backward,
+ dtype=dout_fp8.dtype,
+ )
+ for x in [dq, dk, dv]
+ ]
+ else:
+ dq, dk, dv = [
+ cast_from_fp8(
+ x,
+ ctx.fp8_meta["scaling_bwd"],
+ META_DQKV,
+ fp8_dtype_backward,
+ TE_DType[dout_f16.dtype],
+ )
+ for x in [dq, dk, dv]
+ ]
return (
None,
@@ -3404,6 +3932,9 @@ def backward(ctx, dout):
None,
None,
None,
+ None,
+ None,
+ None,
)
@@ -3465,57 +3996,44 @@ def attn_forward_func_with_cp(
sliding_window_attn = (
window_size is not None and window_size != (-1, 0) and window_size != (-1, -1)
)
+ assert (
+ not sliding_window_attn
+ or cp_comm_type == "a2a"
+ or (cp_comm_type == "all_gather" and not use_fused_attention)
+ ), "The context parallel running configs cannot support sliding window attetnion!"
- if sliding_window_attn or cp_comm_type == "all_gather":
- out = AttnFuncWithCPAndKVAllGather.apply(
- is_training,
- q,
- k,
- v,
- cu_seqlens_q,
- cu_seqlens_kv,
- max_seqlen_q,
- max_seqlen_kv,
- cu_seqlens_q_padded,
- cu_seqlens_kv_padded,
- dropout_p,
- cp_group,
- cp_stream,
- softmax_scale,
- qkv_format,
- attn_mask_type,
- attn_bias_type,
- attn_bias,
- deterministic,
- use_fused_attention,
- window_size,
- )
- elif cp_comm_type == "p2p":
- out = AttnFuncWithCPAndKVP2P.apply(
- is_training,
- q,
- k,
- v,
- cu_seqlens_q,
- cu_seqlens_kv,
- max_seqlen_q,
- max_seqlen_kv,
- cu_seqlens_q_padded,
- cu_seqlens_kv_padded,
- dropout_p,
- cp_group,
- cp_global_ranks,
- cp_stream,
- softmax_scale,
- qkv_format,
- attn_mask_type,
- attn_bias_type,
- attn_bias,
- deterministic,
- use_fused_attention,
- fp8,
- fp8_meta,
- )
+ args = [
+ is_training,
+ q,
+ k,
+ v,
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ max_seqlen_q,
+ max_seqlen_kv,
+ cu_seqlens_q_padded,
+ cu_seqlens_kv_padded,
+ dropout_p,
+ softmax_scale,
+ qkv_format,
+ attn_mask_type,
+ attn_bias_type,
+ attn_bias,
+ deterministic,
+ use_fused_attention,
+ ]
+
+ if cp_comm_type == "p2p":
+ args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream]
+ out = AttnFuncWithCPAndKVP2P.apply(*args)
+ elif cp_comm_type == "all_gather":
+ args.pop(5)
+ args.pop(8)
+ args += [window_size, cp_group, cp_stream]
+ out = AttnFuncWithCPAndKVAllGather.apply(*args)
+ elif cp_comm_type == "a2a":
+ args += [window_size, fp8, fp8_meta, cp_group, cp_stream]
+ out = AttnFuncWithCPAndQKVOA2A.apply(*args)
else:
raise ValueError(f"Unsupported communication type: {cp_comm_type}!")
@@ -6416,7 +6934,13 @@ class DotProductAttention(TransformerEngineBaseModule):
can overlap two flash attention kernels.
cp_comm_type : str
inter-gpu communication type for context parallelism.
- Can be "p2p" or "all_gather".
+ Can be "p2p" or "all_gather" or "a2a".
+ "p2p": Exchange KV chunks with P2P communications in ring topology.
+ P2P is async and can be overlapped with attention compute.
+ "all_gather": All-gather to get full sequence of KV before attention.
+ The all-gather is not async, and cannot be overlapped.
+ "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
+ group, and gather to get full sequence of QKV.
"""
def __init__(
@@ -6608,7 +7132,13 @@ def set_context_parallel_group(
cuda stream for context parallel execution.
cp_comm_type : str
inter-gpu communication type for context parallelism.
- Can be "p2p" or "all_gather".
+ Can be "p2p" or "all_gather" or "a2a".
+ "p2p": Exchange KV chunks with P2P communications in ring topology.
+ P2P is async and can be overlapped with attention compute.
+ "all_gather": All-gather to get full sequence of KV before attention.
+ The all-gather is not async, and cannot be overlapped.
+ "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
+ group, and gather to get full sequence of QKV.
"""
self.cp_group = cp_group
self.cp_global_ranks = cp_global_ranks
@@ -7633,7 +8163,13 @@ def set_context_parallel_group(
cuda stream for context parallel execution.
cp_comm_type : str
inter-gpu communication type for context parallelism.
- Can be "p2p" or "all_gather".
+ Can be "p2p" or "all_gather" or "a2a".
+ "p2p": Exchange KV chunks with P2P communications in ring topology.
+ P2P is async and can be overlapped with attention compute.
+ "all_gather": All-gather to get full sequence of KV before attention.
+ The all-gather is not async, and cannot be overlapped.
+ "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
+ group, and gather to get full sequence of QKV.
"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py
index 8502f70491..fd1eb4a810 100644
--- a/transformer_engine/pytorch/cpp_extensions/gemm.py
+++ b/transformer_engine/pytorch/cpp_extensions/gemm.py
@@ -11,7 +11,12 @@
from ..utils import assert_dim_for_fp8_exec
-__all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"]
+__all__ = [
+ "gemm",
+ "fp8_gemm",
+ "grouped_gemm",
+ "fp8_grouped_gemm",
+]
@functools.lru_cache(maxsize=None)
@@ -313,7 +318,7 @@ def grouped_gemm(
layout: str = "TN",
bias: Optional[List[torch.Tensor]] = None,
use_bias: bool = False,
-) -> Tuple[Union[List[torch.Tensor], None], ...]:
+) -> Tuple[List[torch.Tensor], ...]:
"""Non FP8 Grouped GEMM."""
assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
@@ -380,7 +385,7 @@ def grouped_gemm(
def fp8_grouped_gemm(
A: List[torch.Tensor],
- A_scale_inv: torch.Tensor,
+ A_scale_inv: List[torch.Tensor],
A_fp8_tensor_offset: int,
A_dtype: tex.DType,
B: List[torch.Tensor],
@@ -390,6 +395,7 @@ def fp8_grouped_gemm(
out: List[torch.Tensor],
out_dtype: torch.dtype,
workspaces: List[torch.Tensor],
+ m_splits: Optional[List[int]] = None,
out_offset: Optional[int] = None,
fp8_meta_tensor: tex.FP8TensorMeta = None,
gelu: bool = False,
@@ -398,16 +404,25 @@ def fp8_grouped_gemm(
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
-) -> Tuple[Union[List[torch.Tensor], None], ...]:
+) -> Tuple[List[torch.Tensor], ...]:
"""
TN layout Grouped GEMM with fp8 inputs.
- This method assumes the scale/scale_inv/amax of A/B/out is contiguous in the meta tensor.
- scale: [ ...A_scale... | ...B_scale... | ...out_scale...]
- scale_inv: [ ...A_scale_inv... | ...B_scale_inv... | ...out_scale_inv...]
- amax: [ ...A_amax... | ...B_amax... | ...out_amax...]
+ Input requirements:
+ 1. If len(A_scale_inv) == num_gemms, len(out) must be 1, and m_splits is not None.
+ This is used for the calculation of output (fwd) and dgrad (bwd).
+ 2. if len(A_scale_inv) == 1, len(out) must be num_gemms. This is used for the
+ calculation of wgrad.
"""
-
num_gemms = len(A)
+ if num_gemms > 1 and len(A_scale_inv) == num_gemms:
+ assert len(out) == 1 and m_splits is not None
+ elif num_gemms > 1 and len(A_scale_inv) == 1:
+ assert len(out) == num_gemms
+ elif num_gemms == 1:
+ assert len(A_scale_inv) == 1 and len(out) == 1
+ else:
+ raise ValueError("Invalid input combinations of A_scale_inv and out.")
+
empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
@@ -420,41 +435,71 @@ def fp8_grouped_gemm(
# Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype
- if gelu:
- gelu_input = [
- torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
- for o in out
- ]
- else:
- gelu_input = empty_tensors
bias_dtype = TE_DType[bias_dtype]
-
+ gelu_input = empty_tensors
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype
- torch.ops.tex_ts.te_grouped_gemm_ts(
- A,
- A_scale_inv,
- A_fp8_tensor_offset,
- A_dtype,
- True, # transa
- B,
- B_scale_inv,
- B_fp8_tensor_offset,
- B_dtype,
- False, # transb
- out,
- 0 if out_offset is None else out_offset,
- empty_tensor if out_offset is None else fp8_meta_tensor.scale,
- out_dtype,
- empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
- bias if use_bias else empty_tensors,
- bias_dtype,
- gelu_input, # this is pre_gelu_out
- False, # grad
- workspaces,
- workspaces[0].shape[0],
- accumulate,
- use_split_accumulator,
- )
+ if len(A_scale_inv) == 1:
+ if gelu:
+ gelu_input = [
+ torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
+ for o in out
+ ]
+
+ torch.ops.tex_ts.te_grouped_gemm_ts(
+ A,
+ A_scale_inv[0],
+ A_fp8_tensor_offset,
+ A_dtype,
+ True, # transa
+ B,
+ B_scale_inv,
+ B_fp8_tensor_offset,
+ B_dtype,
+ False, # transb
+ out,
+ 0 if out_offset is None else out_offset,
+ empty_tensor if out_offset is None else fp8_meta_tensor.scale,
+ out_dtype,
+ empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
+ bias if use_bias else empty_tensors,
+ bias_dtype,
+ gelu_input, # this is pre_gelu_out
+ False, # grad
+ workspaces,
+ workspaces[0].shape[0],
+ accumulate,
+ use_split_accumulator,
+ )
+ else:
+ if gelu:
+ gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits]
+
+ torch.ops.tex_ts.te_grouped_gemm_single_output_ts(
+ A,
+ A_scale_inv,
+ A_fp8_tensor_offset,
+ A_dtype,
+ True, # transa
+ B,
+ B_scale_inv,
+ B_fp8_tensor_offset,
+ B_dtype,
+ False, # transb
+ m_splits,
+ out[0],
+ 0 if out_offset is None else out_offset,
+ empty_tensor if out_offset is None else fp8_meta_tensor.scale,
+ out_dtype,
+ empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
+ bias if use_bias else empty_tensors,
+ bias_dtype,
+ gelu_input, # this is pre_gelu_out
+ False, # grad
+ workspaces,
+ workspaces[0].shape[0],
+ accumulate,
+ use_split_accumulator,
+ )
return out, gelu_input
diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py
index 37a1b59da2..ddc3b67e9e 100644
--- a/transformer_engine/pytorch/cpp_extensions/transpose.py
+++ b/transformer_engine/pytorch/cpp_extensions/transpose.py
@@ -175,6 +175,7 @@ def fp8_multi_cast_transpose_fused(
amax_indices: List[int],
scale_inv_indices: List[int],
otype: tex.DType,
+ scale_inv: Optional[torch.Tensor] = None,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Cast + Transpose with FP8 output"""
@@ -182,7 +183,7 @@ def fp8_multi_cast_transpose_fused(
input_list,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
- fp8_meta_tensor.scale_inv,
+ scale_inv if scale_inv is not None else fp8_meta_tensor.scale_inv,
scale_indices,
amax_indices,
scale_inv_indices,
diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h
index 31103cbe8e..c797208e06 100644
--- a/transformer_engine/pytorch/csrc/extensions.h
+++ b/transformer_engine/pytorch/csrc/extensions.h
@@ -165,6 +165,16 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int
std::vector workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);
+void te_grouped_gemm_single_output(
+ std::vector A, std::vector A_scale_inverse, int A_offset,
+ transformer_engine::DType A_type, bool transa, std::vector B,
+ at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb,
+ std::vector m_splits, at::Tensor D, int D_offset, at::Tensor D_scale,
+ transformer_engine::DType D_type, at::Tensor D_amax, std::vector bias,
+ transformer_engine::DType bias_type, std::vector pre_gelu_out, bool grad,
+ std::vector workspace, size_t workspaceSize, bool accumulate,
+ bool use_split_accumulator, int math_sm_count);
+
/***************************************************************************************************
* Transpose
**************************************************************************************************/
diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu
index 7405914a0e..ba9851e7e8 100644
--- a/transformer_engine/pytorch/csrc/extensions/gemm.cu
+++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu
@@ -151,3 +151,64 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int
te_workspace.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
}
+
+void te_grouped_gemm_single_output(
+ std::vector A, std::vector A_scale_inverse, int A_offset,
+ transformer_engine::DType A_type, bool transa, std::vector B,
+ at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb,
+ std::vector m_splits, at::Tensor D, int D_offset, at::Tensor D_scale,
+ transformer_engine::DType D_type, at::Tensor D_amax, std::vector bias,
+ transformer_engine::DType bias_type, std::vector pre_gelu_out, bool grad,
+ std::vector workspace, size_t workspaceSize, bool accumulate,
+ bool use_split_accumulator, int math_sm_count) {
+ using namespace transformer_engine;
+ std::vector te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace;
+ std::vector tensor_wrappers;
+ auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape,
+ transformer_engine::DType dtype, void* amax_dptr,
+ void* scale_dptr, void* scale_inv_dptr) -> NVTETensor {
+ tensor_wrappers.emplace_back(
+ makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr));
+ return tensor_wrappers.back().data();
+ };
+ NVTE_CHECK(D.is_contiguous(), "D must be contiguous.");
+ void* d_i_ptr = reinterpret_cast(D.data_ptr());
+ for (size_t i = 0; i < A.size(); i++) {
+ if (m_splits[i] == 0) continue;
+ NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous.");
+ NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous.");
+ te_A.emplace_back(make_tensor(
+ A[i].data_ptr(), {static_cast(A[i].size(0)), static_cast(A[i].size(1))},
+ A_type, nullptr, nullptr, getDataPtr(A_scale_inverse[i], A_offset)));
+ te_B.emplace_back(make_tensor(
+ B[i].data_ptr(), {static_cast(B[i].size(0)), static_cast(B[i].size(1))},
+ B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i)));
+ te_D.emplace_back(make_tensor(
+ d_i_ptr, {static_cast(m_splits[i]), static_cast(A[i].size(0))}, D_type,
+ getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr));
+ te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast(bias[i].size(0))},
+ bias_type, nullptr, nullptr, nullptr));
+
+ const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr
+ ? std::vector{static_cast(pre_gelu_out[i].size(0))}
+ : std::vector{static_cast(pre_gelu_out[i].size(0)),
+ static_cast(pre_gelu_out[i].size(1))};
+ te_pre_gelu_out.emplace_back(make_tensor(
+ pre_gelu_out[i].data_ptr(), gelu_shape,
+ GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr));
+ // Move the D pointer to the next split.
+ char* char_ptr = reinterpret_cast(d_i_ptr);
+ char_ptr += m_splits[i] * A[i].size(0) * D.element_size();
+ d_i_ptr = reinterpret_cast(char_ptr);
+ }
+ for (size_t i = 0; i < workspace.size(); i++) {
+ te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte,
+ nullptr, nullptr, nullptr));
+ }
+
+ // For now, we only have multi-stream cublas backend.
+ nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(),
+ te_pre_gelu_out.data(), te_A.size(), transa, transb, grad,
+ te_workspace.data(), accumulate, use_split_accumulator,
+ math_sm_count, at::cuda::getCurrentCUDAStream());
+}
diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp
index 8c480e8343..9f31dba669 100644
--- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp
+++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp
@@ -305,6 +305,41 @@ std::vector te_grouped_gemm_ts(
return D;
}
+at::Tensor te_grouped_gemm_single_output_ts(
+ std::vector A, std::vector A_scale_inverse, int64_t A_offset,
+ int64_t A_type, int64_t transa, std::vector B, at::Tensor B_scale_inverse,
+ int64_t B_offset, int64_t B_type, int64_t transb, std::vector m_splits, at::Tensor D,
+ int64_t D_offset, at::Tensor D_scale, int64_t D_type, at::Tensor D_amax,
+ std::vector bias, int64_t bias_type, std::vector pre_gelu_out,
+ int64_t grad, std::vector workspace, int64_t workspaceSize, int64_t accumulate,
+ int64_t use_split_accumulator) {
+ // cast inputs to types accepted by te_gemm
+ transformer_engine::DType A_type_arg = reverse_map_dtype(A_type);
+ bool transa_arg = static_cast(transa);
+ transformer_engine::DType B_type_arg = reverse_map_dtype(B_type);
+ bool transb_arg = static_cast(transb);
+ transformer_engine::DType D_type_arg = reverse_map_dtype(D_type);
+ transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type);
+ bool grad_arg = static_cast(grad);
+ size_t workspaceSize_arg = static_cast(workspaceSize);
+ bool accumulate_arg = static_cast(accumulate);
+ bool use_split_accumulator_arg = static_cast(use_split_accumulator);
+
+ // Set an external SM Margin to all the GEMMs.
+ // This comes in handy when DP is overlapped with GEMMs
+
+ const int device_id = at::cuda::current_device();
+ const int sm_count = transformer_engine::cuda::sm_count(device_id);
+ int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count);
+
+ te_grouped_gemm_single_output(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B,
+ B_scale_inverse, B_offset, B_type_arg, transb_arg, m_splits, D,
+ D_offset, D_scale, D_type_arg, D_amax, bias, bias_type_arg,
+ pre_gelu_out, grad_arg, workspace, workspaceSize_arg,
+ accumulate_arg, use_split_accumulator_arg, num_math_sms);
+ return D;
+}
+
at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight,
const at::Tensor &bias, double eps, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor,
@@ -371,6 +406,7 @@ TORCH_LIBRARY(tex_ts, m) {
m.def("srelu_ts", &srelu_ts);
m.def("te_gemm_ts", &te_gemm_ts);
m.def("te_grouped_gemm_ts", &te_grouped_gemm_ts);
+ m.def("te_grouped_gemm_single_output_ts", &te_grouped_gemm_single_output_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts);
diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py
index a91ff5c361..ca100392c7 100644
--- a/transformer_engine/pytorch/module/grouped_linear.py
+++ b/transformer_engine/pytorch/module/grouped_linear.py
@@ -42,6 +42,7 @@
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor
+from ..export import is_in_onnx_export_mode
__all__ = ["GroupedLinear"]
@@ -102,10 +103,12 @@ def forward(
inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats]
inputmats = []
inputmats_t = []
+ inputmat_scale_inv = None
global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
+ inputmat_scale_inv = torch.empty([num_gemms], dtype=torch.float32, device=inp.device)
if (
not fp8_meta["recipe"].override_linear_precision.wgrad
and is_grad_enabled
@@ -121,6 +124,7 @@ def forward(
indices, # amax_indices
indices, # scale_inv_indices
fp8_dtype_forward,
+ scale_inv=inputmat_scale_inv,
)
else:
# FP8 input for forward
@@ -130,9 +134,22 @@ def forward(
fp8_meta["scaling_fwd"],
_GEMM_INPUT + i,
fp8_dtype_forward,
+ scale_inv=inputmat_scale_inv,
)
for i in range(num_gemms)
]
+
+ # Hack for ONNX export
+ # Note: ONNX models are represented as a graph of tensor
+ # operations, so the in-place scale-inv update doesn't fit
+ # very well. We work around this by making it look like
+ # the scale-inv tensor is initialized with a copy.
+ # Note: ONNX export expects FP8 scales can be represented
+ # with constant ops. However, copying into a buffer
+ # involves an expand op for array broadcasting. We work
+ # around this by filling the buffer instead.
+ if is_in_onnx_export_mode():
+ inputmat_scale_inv.fill_(inputmat_scale_inv.item())
else:
inputmats = inputmats_no_fp8
@@ -153,16 +170,17 @@ def forward(
_ = fp8_grouped_gemm(
[w._data for w in weights_fp8],
- fp8_meta["scaling_fwd"].scale_inv,
- _GEMM_WEIGHT,
+ [w._scale_inv for w in weights_fp8],
+ 0, # weight offset is 0 for the newly created _scale_inv
fp8_dtype_forward,
inputmats,
- fp8_meta["scaling_fwd"].scale_inv,
- _GEMM_INPUT,
+ inputmat_scale_inv,
+ 0,
fp8_dtype_forward,
- torch.split(out, m_splits),
+ [out],
activation_dtype,
get_multi_stream_cublas_workspace(),
+ m_splits=m_splits,
bias=biases,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
@@ -230,7 +248,7 @@ def forward(
t.activation_offloading = True
ctx.save_for_backward(
- fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
+ inputmat_scale_inv,
*saved_inputmats,
*saved_inputmats_t,
*weights,
@@ -270,7 +288,7 @@ def forward(
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
with torch.cuda.nvtx.range("_GroupedLinear_backward"):
(
- fwd_scale_inverses,
+ inputmat_scale_inv,
*saved_tensors,
) = ctx.saved_tensors
inputmats = saved_tensors[: ctx.num_gemms]
@@ -342,18 +360,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
)
fp8_grouped_gemm(
[w.transpose_2d() for w in weights_fp8],
- torch.cat(
- [w._scale_inv for w in weights_fp8]
- ), # avoiding torch.cat requires another interface
+ [w._scale_inv for w in weights_fp8],
0, # weight offset is 0 for the newly created _scale_inv
weights_fp8[0]._fp8_dtype,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
_GRAD_OUTPUT,
fp8_dtype_backward,
- torch.split(dgrad, ctx.m_splits),
+ [dgrad],
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
+ m_splits=ctx.m_splits,
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
@@ -396,8 +413,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
inp._data if isinstance(inp, Float8Tensor) else inp
for inp in inputmats_t
],
- fwd_scale_inverses,
- _GEMM_INPUT,
+ [inputmat_scale_inv],
+ 0,
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv,
diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py
index bd6e27594d..958c7019ba 100644
--- a/transformer_engine/pytorch/transformer.py
+++ b/transformer_engine/pytorch/transformer.py
@@ -503,7 +503,13 @@ def set_context_parallel_group(
cuda stream for context parallel execution.
cp_comm_type : str
inter-gpu communication type for context parallelism.
- Can be "p2p" or "all_gather".
+ Can be "p2p" or "all_gather" or "a2a".
+ "p2p": Exchange KV chunks with P2P communications in ring topology.
+ P2P is async and can be overlapped with attention compute.
+ "all_gather": All-gather to get full sequence of KV before attention.
+ The all-gather is not async, and cannot be overlapped.
+ "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
+ group, and gather to get full sequence of QKV.
"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):