Skip to content

Commit

Permalink
Merge branch 'main' into te_llama_tutorial_enhancement
Browse files Browse the repository at this point in the history
  • Loading branch information
sudhakarsingh27 authored Sep 16, 2024
2 parents 674e499 + af5daa0 commit bc9d706
Show file tree
Hide file tree
Showing 26 changed files with 2,264 additions and 1,544 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install pip -y
pip install torch
pip install torch numpy
export PYTHON_ONLY=1
export TE_PATH=.
bash ./qa/L0_pytorch_lint/test.sh
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/trigger-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jobs:
|| github.actor == 'phu0ngng'
|| github.actor == 'xrennvidia'
|| github.actor == 'yaox12'
|| github.actor == 'huanghua1994'
)
steps:
- name: Check if comment is issued by authorized person
Expand Down
63 changes: 53 additions & 10 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
get_qkv_format,
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
from transformer_engine.transformer_engine_jax import (
NVTE_Fused_Attn_Backend,
get_cudnn_version,
)

from utils import assert_allclose

Expand Down Expand Up @@ -230,7 +233,14 @@ def customcall_fused_dpa(
kwargs.pop("max_segments_per_seq")
return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
return fused_attn_thd(
qkv_args, bias, seqlens_q, seqlens_kv, offsets_q, offsets_kv, dropout_rng, **kwargs
qkv_args,
bias,
seqlens_q,
seqlens_kv,
offsets_q,
offsets_kv,
dropout_rng,
**kwargs,
).astype(query.dtype)


Expand Down Expand Up @@ -265,6 +275,15 @@ class FusedAttnRunner:
qkv_layout: QKVLayout
bias_shape: BiasShape

# See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
def _get_max_segments_per_sequence(self):
if 90400 <= get_cudnn_version() < 90500:
return self.num_segments_per_seq
else:
# +1 for testing runtime_segments < max_segments
return self.num_segments_per_seq + 1

def _check_configs(self):
# TODO(rewang): probably adds this in is_fused_attn_available
if get_qkv_format(self.qkv_layout) == QKVFormat.THD and not self.attn_mask_type in [
Expand Down Expand Up @@ -299,7 +318,10 @@ def _check_configs(self):
self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
and self.bias_shape != BiasShape.BIAS_1HSS
):
if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
if self.attn_mask_type not in [
AttnMaskType.NO_MASK,
AttnMaskType.CAUSAL_MASK,
]:
pytest.skip(
"B1SS, BHSS and 11SS bias shapes are only supported for "
"AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK."
Expand All @@ -316,7 +338,12 @@ def _setup_inputs(self):
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)

q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim)
k_shape = v_shape = (
self.batch_size,
self.max_seqlen_kv,
self.num_heads_kv,
self.head_dim,
)

if self.attn_bias_type == AttnBiasType.NO_BIAS:
bias_shape = None
Expand All @@ -325,7 +352,12 @@ def _setup_inputs(self):
elif self.bias_shape == BiasShape.BIAS_B1SS:
bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape.BIAS_BHSS:
bias_shape = (self.batch_size, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
bias_shape = (
self.batch_size,
self.num_heads_q,
self.max_seqlen_q,
self.max_seqlen_kv,
)
elif self.bias_shape == BiasShape.BIAS_11SS:
bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
else:
Expand Down Expand Up @@ -405,7 +437,10 @@ def generate_random_segment_ids(
self.segment_pad_kv = self.segment_pad_q
else:
self.token_kv, self.segment_pad_kv = generate_random_segment_ids(
self.batch_size, self.max_seqlen_kv, self.num_segments_per_seq, seed=2024
self.batch_size,
self.max_seqlen_kv,
self.num_segments_per_seq,
seed=2024,
)
self.pad_q = self.segment_pad_q
self.pad_kv = self.segment_pad_kv
Expand Down Expand Up @@ -464,8 +499,7 @@ def test_forward(self):
"dropout_probability": self.dropout_prob,
"is_training": self.is_training,
"qkv_layout": self.qkv_layout,
# +1 for testing runtime_segments < max_segments
"max_segments_per_seq": self.num_segments_per_seq + 1,
"max_segments_per_seq": self._get_max_segments_per_sequence(),
}

# Convert the outputs to float32 for the elementwise comparison
Expand Down Expand Up @@ -522,7 +556,7 @@ def grad_func(func, *args, **kwargs):
"dropout_probability": self.dropout_prob,
"is_training": self.is_training,
"qkv_layout": self.qkv_layout,
"max_segments_per_seq": self.num_segments_per_seq + 1,
"max_segments_per_seq": self._get_max_segments_per_sequence(),
}

# We can compute dBias only for the [1, h, s, s] layout
Expand Down Expand Up @@ -635,7 +669,16 @@ def check_dqkv(primitive, reference, pad):
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
pytest.param(4, 512, 128, 16, 16, 64, jnp.bfloat16, id="4-512-128-16-16-64-BF16-CROSS"),
pytest.param(2, 2048, 1024, 12, 12, 64, jnp.bfloat16, id="2-2048-1048-12-12-64-BF16-CROSS"),
pytest.param(
2,
2048,
1024,
12,
12,
64,
jnp.bfloat16,
id="2-2048-1048-12-12-64-BF16-CROSS",
),
pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"),
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
],
Expand Down
6 changes: 3 additions & 3 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@ def test_dot_product_attention(
"""Test DotProductAttention module"""

# Get configs
tols = dict(atol=5e-3, rtol=5e-3)
tols = dict(atol=1e-3, rtol=1e-3)
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
tols = dict(atol=1.5e-2, rtol=1.5e-2)
config = model_configs[model]
is_mla = config.head_dim_qk != config.head_dim_v
if qkv_layout is None:
Expand Down Expand Up @@ -1035,7 +1035,7 @@ def test_transformer_layer(

# Get configs
config = model_configs[model]
tols = dict(atol=5e-1, rtol=5e-2)
tols = dict(atol=5e-2, rtol=5e-2)
workspace_opt = True

# Test backend availability
Expand Down
110 changes: 64 additions & 46 deletions tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,16 @@
"cp_1_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_1_3": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(
2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA
"cp_2_3": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # GQA
}

Expand All @@ -45,31 +51,32 @@ def get_bash_arguments(**kwargs):
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
config = model_configs_flash_attn[model]
if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip(
f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
)
if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
" type yet!"
)
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if cp_comm_type == "a2a" and qkv_format == "thd":
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
" type yet!"
)
if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip(
f"CP implementation with KV P2P does not support window size {config.window_size} yet!"
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)

subprocess.run(
get_bash_arguments(
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention"
dtype=dtype,
model=model,
qkv_format=qkv_format,
kernel_backend="FlashAttention",
cp_comm_type=cp_comm_type,
),
check=True,
)
Expand All @@ -81,10 +88,16 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
"cp_1_4": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
"cp_2_4": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA
}


Expand All @@ -93,48 +106,53 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+.")
pytest.skip("THD format is only supported on sm90+!")
if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0):
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0")
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")

config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
pytest.skip(f"{qkv_format} format does not support QGA/MQA yet!")
pytest.skip("THD format does not support QGA/MQA yet!")
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip(f"{qkv_format} format does not support {config.attn_bias_type} bias type yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip(
f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
)
if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
" type yet!"
)
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
" type yet!"
)
if config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if qkv_format == "thd" and cp_comm_type == "a2a":
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a":
pytest.skip(
"Fused attention does not support sliding window attention + context parallelism yet!"
"Sliding window attention only can be supported with the implementation of QKVO A2A!"
)
if cp_comm_type == "all_gather" and dtype == "fp8":
if dtype == "fp8" and cp_comm_type == "all_gather":
pytest.skip(
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!"
)
if dtype == "fp8" and qkv_format == "thd":
pytest.skip("FP8 attention cannot work with THD format yet!")
if dtype == "fp8" and config.attn_bias_type != "no_bias":
pytest.skip("FP8 attention cannot work with bias yet!")
if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("FP8 attention cannot work with sliding window yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
pytest.skip(
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)

subprocess.run(
get_bash_arguments(
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention"
dtype=dtype,
model=model,
qkv_format=qkv_format,
kernel_backend="FusedAttention",
cp_comm_type=cp_comm_type,
),
check=True,
)
9 changes: 4 additions & 5 deletions tests/pytorch/test_float8tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def test_transpose(
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8_t, x, **tols)

# Caching test.
# Caching test
assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching."
x_fp8 += 0.5
x = x_fp8.from_float8()
Expand All @@ -302,14 +302,13 @@ def test_transpose(
torch.testing.assert_close(x_fp8_t, x_t, **tols)
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."

# Inplace update test.
# Inplace update test
x_fp8 += 0.5
assert x_fp8._transpose_invalid, "Transpose cache not invalidated properly."
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
x = x_fp8.from_float8()
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(fill_cache=True))
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8._transpose)
x_t = x.transpose(0, 1)
torch.testing.assert_close(x_fp8_t, x_t, **tols)
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."

def test_serialization(
self,
Expand Down
5 changes: 1 addition & 4 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,7 @@ def make_reference_and_test_tensors(
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8:
test = Float8Tensor.to_float8(test)
test._transpose = test._data.reshape(-1, test.size(-1)).transpose(0, 1)
test._transpose = test._transpose.contiguous()
test._transpose_invalid = False
test = Float8Tensor.to_float8(test, with_transpose_cache=True)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
ref.copy_(test)
Expand Down
Loading

0 comments on commit bc9d706

Please sign in to comment.