Skip to content

Commit

Permalink
Merge branch 'NVIDIA:main' into mgoldfarb-nvidia/context_parallel_att…
Browse files Browse the repository at this point in the history
…ention_with_all_gather
  • Loading branch information
mgoldfarb-nvidia committed Sep 9, 2024
2 parents d3c9d06 + 40dda92 commit 5c6bc6c
Show file tree
Hide file tree
Showing 11 changed files with 1,110 additions and 356 deletions.
15 changes: 12 additions & 3 deletions docs/_templates/layout.html
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
{% extends "!layout.html" %}

{% block extrahead %}

<script src="https://assets.adobedtm.com/5d4962a43b79/c1061d2c5e7b/launch-191c2462b890.min.js"></script>

{% endblock %}

{% block sidebartitle %} {{ super() }}

<style>
Expand Down Expand Up @@ -83,8 +90,10 @@
}
</style>

{%- if nvidia_analytics_id %}
<script type="text/javascript">_satellite.pageBottom();</script>
{%- endif %}
{% endblock %}

{% block footer %}

<script type="text/javascript">if (typeof _satellite !== “undefined”){ _satellite.pageBottom();}</script>

{% endblock %}
110 changes: 64 additions & 46 deletions tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,16 @@
"cp_1_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_1_3": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(
2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA
"cp_2_3": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # GQA
}

Expand All @@ -45,31 +51,32 @@ def get_bash_arguments(**kwargs):
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
config = model_configs_flash_attn[model]
if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip(
f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
)
if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
" type yet!"
)
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if cp_comm_type == "a2a" and qkv_format == "thd":
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
" type yet!"
)
if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip(
f"CP implementation with KV P2P does not support window size {config.window_size} yet!"
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)

subprocess.run(
get_bash_arguments(
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention"
dtype=dtype,
model=model,
qkv_format=qkv_format,
kernel_backend="FlashAttention",
cp_comm_type=cp_comm_type,
),
check=True,
)
Expand All @@ -81,10 +88,16 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
"cp_1_4": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
"cp_2_4": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA
}


Expand All @@ -93,48 +106,53 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+.")
pytest.skip("THD format is only supported on sm90+!")
if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0):
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0")
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")

config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
pytest.skip(f"{qkv_format} format does not support QGA/MQA yet!")
pytest.skip("THD format does not support QGA/MQA yet!")
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip(f"{qkv_format} format does not support {config.attn_bias_type} bias type yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip(
f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
)
if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
" type yet!"
)
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
" type yet!"
)
if config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if qkv_format == "thd" and cp_comm_type == "a2a":
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a":
pytest.skip(
"Fused attention does not support sliding window attention + context parallelism yet!"
"Sliding window attention only can be supported with the implementation of QKVO A2A!"
)
if cp_comm_type == "all_gather" and dtype == "fp8":
if dtype == "fp8" and cp_comm_type == "all_gather":
pytest.skip(
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!"
)
if dtype == "fp8" and qkv_format == "thd":
pytest.skip("FP8 attention cannot work with THD format yet!")
if dtype == "fp8" and config.attn_bias_type != "no_bias":
pytest.skip("FP8 attention cannot work with bias yet!")
if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("FP8 attention cannot work with sliding window yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
pytest.skip(
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)

subprocess.run(
get_bash_arguments(
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention"
dtype=dtype,
model=model,
qkv_format=qkv_format,
kernel_backend="FusedAttention",
cp_comm_type=cp_comm_type,
),
check=True,
)
31 changes: 23 additions & 8 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,12 +1266,15 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False
)
inp_hidden_states.retain_grad()

m = config.seq_len // 16
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
m_splits = m_splits * 16
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
if num_gemms > 1:
m = config.seq_len // 16
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
m_splits = m_splits * 16
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
else:
m_splits = torch.tensor([config.seq_len])

with fp8_autocast(enabled=fp8):
if isinstance(block, GroupedLinear):
Expand Down Expand Up @@ -1353,7 +1356,7 @@ def test_grouped_linear_accuracy(

@pytest.mark.parametrize("parallel_mode", ["column", "row"])
def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
"""Split the tests to reduce CI time"""
"""Split the tests to save CI time"""
test_grouped_linear_accuracy(
dtype=torch.float32,
num_gemms=6,
Expand All @@ -1365,6 +1368,18 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
)


def test_grouped_linear_accuracy_single_gemm():
"""Split the tests to save CI time"""
test_grouped_linear_accuracy(
dtype=torch.float32,
num_gemms=1,
bs=2,
model=list(model_configs.keys())[0],
fp8=True,
fp8_model_params=True,
)


def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False):

def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
Expand Down Expand Up @@ -2034,7 +2049,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):

fp8_grouped_gemm(
A_fp8,
scale_inv,
[scale_inv],
0, # A_offset
tex.DType.kFloat8E4M3,
B_fp8,
Expand Down
Loading

0 comments on commit 5c6bc6c

Please sign in to comment.