Skip to content

Commit

Permalink
Add a CP implementation variant with KV all-gather. (#1060)
Browse files Browse the repository at this point in the history
* add window_size to AttnFuncWithCP

Signed-off-by: Xiaowei Ren <[email protected]>

* add seq_offsets_qkvo for cudnn thd

Signed-off-by: Xiaowei Ren <[email protected]>

* add seq_offsets_qkvo to AttnFuncWithCP

Signed-off-by: Xiaowei Ren <[email protected]>

* fix seq_offsets calculation of cudnn thd

Signed-off-by: Xiaowei Ren <[email protected]>

* remove a thd assert

Signed-off-by: Xiaowei Ren <[email protected]>

* fix bias for thd test

Signed-off-by: Xiaowei Ren <[email protected]>

* add thd test for cudnn FA with CP

Signed-off-by: Xiaowei Ren <[email protected]>

* skip GQA/MQA test for cuDNN THD

Signed-off-by: Xiaowei Ren <[email protected]>

* make sure seq_offsets are computed with qkv_group of hd_hd_hd while CP>1

Signed-off-by: Xiaowei Ren <[email protected]>

* fix seq_offsets inputs

Signed-off-by: Xiaowei Ren <[email protected]>

* remove two comments

Signed-off-by: Xiaowei Ren <[email protected]>

* fix attn mask type for cudnn thd with cp

Signed-off-by: Xiaowei Ren <[email protected]>

* fix attn_mask_type check

Signed-off-by: Xiaowei Ren <[email protected]>

* fix attn_mask_type for cudnn fa with thd

Signed-off-by: Xiaowei Ren <[email protected]>

* fix a typo

Signed-off-by: Xiaowei Ren <[email protected]>

* fix out dout in bwd

Signed-off-by: Xiaowei Ren <[email protected]>

* assert cudnn+thd does not support attn bias

Signed-off-by: Xiaowei Ren <[email protected]>

* check if attn_mask_type has padding

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* change cp test batch size to 2

Signed-off-by: Xiaowei Ren <[email protected]>

* fix code format

Signed-off-by: Xiaowei Ren <[email protected]>

* fix two assert info

Signed-off-by: Xiaowei Ren <[email protected]>

* fix assert comment

Signed-off-by: Xiaowei Ren <[email protected]>

* fix assert comments

Signed-off-by: Xiaowei Ren <[email protected]>

* minor fix

Signed-off-by: Xiaowei Ren <[email protected]>

* fix assert comments

Signed-off-by: Xiaowei Ren <[email protected]>

* assert swa+CP cannot work with thd format

Signed-off-by: Xiaowei Ren <[email protected]>

* add a new CP function for swa

Signed-off-by: Xiaowei Ren <[email protected]>

* add a missing dgrads

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* add draft fwd function for swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* enable flash attention for swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* remove an assert of swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* call SWAFuncWithCP for swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* use 2hd layout

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change qkv_format check

Signed-off-by: Xiaowei Ren <[email protected]>

* add a code comment

Signed-off-by: Xiaowei Ren <[email protected]>

* tensor shape bug fix

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tensor shape fix

Signed-off-by: Xiaowei Ren <[email protected]>

* add function to compute cu_seqlens of a cp rank

Signed-off-by: Xiaowei Ren <[email protected]>

* add cu_seqlens and cu_seqlens_padded to context parallelism

Signed-off-by: Xiaowei Ren <[email protected]>

* typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* fix FlashAttention output sequence length

Signed-off-by: Xiaowei Ren <[email protected]>

* fix cu_seqlens_kv_per_step calculation

Signed-off-by: Xiaowei Ren <[email protected]>

* zero dQKV for ending padded tokens

Signed-off-by: Xiaowei Ren <[email protected]>

* zero dQKV tensors of FlashAttention

Signed-off-by: Xiaowei Ren <[email protected]>

* fix softmax_lse correction

Signed-off-by: Xiaowei Ren <[email protected]>

* remove padded tokens of KV to save comounication

Signed-off-by: Xiaowei Ren <[email protected]>

* do not need to zero dkv for FlashAttention any mroe

Signed-off-by: Xiaowei Ren <[email protected]>

* zero out tensors

Signed-off-by: Xiaowei Ren <[email protected]>

* remove redundant code

Signed-off-by: Xiaowei Ren <[email protected]>

* fix CP unit test

Signed-off-by: Xiaowei Ren <[email protected]>

* fix kv shape of cp test with thd format

Signed-off-by: Xiaowei Ren <[email protected]>

* update cp unit test

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add simple code framework

Signed-off-by: Xiaowei Ren <[email protected]>

* try not to have a separate CP function for SWA

Signed-off-by: Xiaowei Ren <[email protected]>

* backup some code change

Signed-off-by: Xiaowei Ren <[email protected]>

* back up code

Signed-off-by: Xiaowei Ren <[email protected]>

* clean up fwd implementation of SWAFuncWithCP

Signed-off-by: Xiaowei Ren <[email protected]>

* remove redundant code

Signed-off-by: Xiaowei Ren <[email protected]>

* code cleaning

Signed-off-by: Xiaowei Ren <[email protected]>

* fix assert info

Signed-off-by: Xiaowei Ren <[email protected]>

* reduce kv chunk concat overheads

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* make AttnFuncWithCP and SWAFuncWithCP have same API

Signed-off-by: Xiaowei Ren <[email protected]>

* add a docstring

Signed-off-by: Xiaowei Ren <[email protected]>

* preliminary implementation of SWAFuncWithCP forward seems working

Signed-off-by: Xiaowei Ren <[email protected]>

* fix output shape of SWAFuncWithCP

Signed-off-by: Xiaowei Ren <[email protected]>

* code refactoring for FlashAttention and add a code placeholder for bwd

Signed-off-by: Xiaowei Ren <[email protected]>

* use gather_along_first_dim

Signed-off-by: Xiaowei Ren <[email protected]>

* finish the preliminary implementation of bwd

Signed-off-by: Xiaowei Ren <[email protected]>

* remove redundant code

Signed-off-by: Xiaowei Ren <[email protected]>

* fix assert condition

Signed-off-by: Xiaowei Ren <[email protected]>

* add draft implementation of SWA+CP with FusedAttention

Signed-off-by: Xiaowei Ren <[email protected]>

* fix attention mask type of swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* code cleaning

Signed-off-by: Xiaowei Ren <[email protected]>

* add qkv_layout

Signed-off-by: Xiaowei Ren <[email protected]>

* add missing window_size argument

Signed-off-by: Xiaowei Ren <[email protected]>

* typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* fix kv shape of swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* bug and typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* fix dout shape

Signed-off-by: Xiaowei Ren <[email protected]>

* add multi stream in fwd of swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* save chunk_ids_to_kv_ag in fwd

Signed-off-by: Xiaowei Ren <[email protected]>

* add multi stream in bwd of swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* minor fix to cp stream sync

Signed-off-by: Xiaowei Ren <[email protected]>

* rename AttnFuncWithCP

Signed-off-by: Xiaowei Ren <[email protected]>

* check if window size is None

Signed-off-by: Xiaowei Ren <[email protected]>

* fix docstring of AttnFuncWithCP

Signed-off-by: Xiaowei Ren <[email protected]>

* minor fix

Signed-off-by: Xiaowei Ren <[email protected]>

* add env var for users to choose KV ag or KV p2p

Signed-off-by: Xiaowei Ren <[email protected]>

* update cp tests

Signed-off-by: Xiaowei Ren <[email protected]>

* fix window size in cp unit test

Signed-off-by: Xiaowei Ren <[email protected]>

* fix pytest skip messages

Signed-off-by: Xiaowei Ren <[email protected]>

* add cp_comm_type into API

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* code cleaning

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* assert sequence length divisible requirements

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add support table of context parallelism

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* typo and code format fix

Signed-off-by: Xiaowei Ren <[email protected]>

* do not print multiple disabling messages

Signed-off-by: Xiaowei Ren <[email protected]>

* bug fix

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix device in torch.arange and adjust code for the PR of MLA

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typos and clean asserts

Signed-off-by: Xiaowei Ren <[email protected]>

---------

Signed-off-by: Xiaowei Ren <[email protected]>
Co-authored-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Xiaowei Ren <[email protected]>
  • Loading branch information
4 people authored Aug 16, 2024
1 parent 941364d commit 3040785
Show file tree
Hide file tree
Showing 4 changed files with 736 additions and 87 deletions.
68 changes: 41 additions & 27 deletions tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16}


def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention"):
def run_dpa_with_cp(
dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention", cp_comm_type="p2p"
):
"""Test DotProductAttention module with context parallelism"""

os.environ["NVTE_FLASH_ATTN"] = "0"
Expand All @@ -24,10 +26,16 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
if qkv_format == "thd" and (
config.num_heads != config.num_gqa_groups or config.attn_bias_type == "post_scale_bias"
):
return

assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type} is an unsupported attention mask type!"
if kernel_backend == "FusedAttention" and qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"

rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
Expand All @@ -49,73 +57,77 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")

assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type} is an unsupported attention mask type!"

if kernel_backend == "FusedAttention" and qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"

# instantiate core attn module
core_attn = DotProductAttention(
config.num_heads,
config.head_dim,
config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
)
core_attn = core_attn.cuda()

# create flash attn inputs
if qkv_format == "bshd":
q_input_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim)
q_input_shape = (
config.batch_size,
config.max_seqlen_q,
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
config.batch_size,
config.max_seqlen_kv,
config.num_gqa_groups,
config.head_dim,
config.head_dim_qk,
)
attn_output_shape = (
config.batch_size,
config.max_seqlen_q,
config.num_heads * config.head_dim,
config.num_heads * config.head_dim_qk,
)
cu_seqlens_q = None
cu_seqlens_kv = None
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None
elif qkv_format == "sbhd":
q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim)
q_input_shape = (
config.max_seqlen_q,
config.batch_size,
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
config.max_seqlen_kv,
config.batch_size,
config.num_gqa_groups,
config.head_dim,
config.head_dim_qk,
)
attn_output_shape = (
config.max_seqlen_q,
config.batch_size,
config.num_heads * config.head_dim,
config.num_heads * config.head_dim_qk,
)
cu_seqlens_q = None
cu_seqlens_kv = None
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None
elif qkv_format == "thd":
q_input_shape = (config.batch_size * config.max_seqlen_q, config.num_heads, config.head_dim)
q_input_shape = (
config.batch_size * config.max_seqlen_q,
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
config.batch_size * config.max_seqlen_q,
config.num_gqa_groups,
config.head_dim,
config.head_dim_qk,
)
attn_output_shape = (
config.batch_size * config.max_seqlen_q,
config.num_heads * config.head_dim,
config.num_heads * config.head_dim_qk,
)
seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
Expand Down Expand Up @@ -211,7 +223,9 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
)
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream())
core_attn.set_context_parallel_group(
cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type
)
out_ = core_attn(
q_,
k_,
Expand Down
63 changes: 59 additions & 4 deletions tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,17 @@
)

model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(
2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA
}


Expand All @@ -39,7 +45,28 @@ def get_bash_arguments(**kwargs):
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
def test_cp_with_flash_attention(dtype, model, qkv_format):
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
config = model_configs_flash_attn[model]
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip(
f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
)
if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
" type yet!"
)
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
" type yet!"
)
if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip(
f"CP implementation with KV P2P does not support window size {config.window_size} yet!"
)

subprocess.run(
get_bash_arguments(
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention"
Expand All @@ -49,7 +76,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format):


model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
Expand All @@ -66,9 +93,37 @@ def test_cp_with_flash_attention(dtype, model, qkv_format):
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
def test_cp_with_fused_attention(dtype, model, qkv_format):
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+.")
if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0):
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0")

config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
pytest.skip(f"{qkv_format} format does not support QGA/MQA yet!")
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip(f"{qkv_format} format does not support {config.attn_bias_type} bias type yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip(
f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
)
if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
" type yet!"
)
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
" type yet!"
)
if config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip(
f"Fused attention does not support sliding window attention + context parallelism yet!"
)

subprocess.run(
get_bash_arguments(
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention"
Expand Down
Loading

0 comments on commit 3040785

Please sign in to comment.