Skip to content

Commit

Permalink
feat(pytorch): Allow TransformerLayer and MultiheadAttention to accep…
Browse files Browse the repository at this point in the history
…t sequence length parameters (#1066)

* Added ability for seqlen for transformer and mha layer

Signed-off-by: Lukasz Pierscieniewski <[email protected]>

* Documentation for new parameters

Signed-off-by: Lukasz Pierscieniewski <[email protected]>

* Add tests for THD layout, assert for THD layout with KV-Cache

Signed-off-by: Lukasz Pierscieniewski <[email protected]>

* Fixed tests

Signed-off-by: Lukasz Pierscieniewski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Move THD logic in shape calculation, add missing optional in params

Signed-off-by: Lukasz Pierscieniewski <[email protected]>

* Skip the THD test on GPUs older than Ampere

Signed-off-by: Przemek Tredak <[email protected]>

---------

Signed-off-by: Lukasz Pierscieniewski <[email protected]>
Signed-off-by: Przemek Tredak <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Przemek Tredak <[email protected]>
  • Loading branch information
4 people committed Aug 20, 2024
1 parent ee541e8 commit 5d5fe81
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 12 deletions.
47 changes: 45 additions & 2 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability
import transformer_engine_torch as tex

# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()

sm_80plus = get_device_compute_capability() >= (8, 0)

seed = 1234
torch.manual_seed(seed)
Expand Down Expand Up @@ -1548,8 +1550,29 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
attn_input_format="bshd",
)

for (n1, p1), (n2, p2) in zip(block_bshd.named_parameters(), block_sbhd.named_parameters()):
assert torch.all(torch.eq(p1, p2)), f"{n1}, {n2} not identical"
torch.manual_seed(0)
block_thd = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
kv_channels=config.embed,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
device="cuda",
attn_input_format="thd",
self_attn_mask_type="padding_causal",
)

for (n1, p1), (n2, p2), (n3, p3) in zip(
block_bshd.named_parameters(), block_sbhd.named_parameters(), block_thd.named_parameters()
):
assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical"

x_sbhd = torch.randn(
(config.seq_len, bs, config.hidden_size),
Expand All @@ -1559,6 +1582,8 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
)

x_bshd = x_sbhd.transpose(0, 1).contiguous()
x_thd = x_bshd.reshape(bs * config.seq_len, config.hidden_size).contiguous()
x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.seq_len

# To make sure forward is also identical (just in case some module decides
# to act fancy)
Expand All @@ -1576,6 +1601,24 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
y_sbhd.transpose(0, 1).contiguous(),
)

# THD is not supported in float32 and on GPUs older than Ampere, skip the test here
if dtype != torch.float32 and sm_80plus:
# To make sure forward is also identical (just in case some module decides
# to act fancy)
torch.manual_seed(0)
y_thd = block_thd(
x_thd,
cu_seqlens_q=x_thd_cumsum,
cu_seqlens_kv=x_thd_cumsum,
max_seqlen_q=config.seq_len,
max_seqlen_kv=config.seq_len,
)

torch.testing.assert_close(
y_bshd,
y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(),
)


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
Expand Down
44 changes: 35 additions & 9 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7048,6 +7048,10 @@ def forward(
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
fast_zero_fill: bool = True,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""
Expand Down Expand Up @@ -7113,6 +7117,18 @@ def forward(
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
to the attention score of query i and key j.
cu_seqlens_q: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided.
max_seqlen_kv: Optional[int], default = `None`
Maximum sequence length in `key_layer` and `value_layer`.
Calculated from `cu_seqlens_kv` if not provided.
fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use.
"""
Expand All @@ -7139,6 +7155,9 @@ def forward(
# =================================================

if inference_params and self.layer_number is not None:
assert (
self.qkv_format != "thd"
), "qkv_format == thd is not supported for an inference with KV-cache!"
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size
Expand Down Expand Up @@ -7221,13 +7240,18 @@ def forward(
dim=split_dim,
)

# query: -> [sq, b, np, hn]
# key, value: -> [sq, b, ng, hn]
query_layer, key_layer, value_layer = (
x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head)
for x in (query_layer, key_layer, value_layer)
)

if self.qkv_format == "thd":
query_layer, key_layer, value_layer = (
x.reshape(x.size(0), -1, self.hidden_size_per_attention_head)
for x in (query_layer, key_layer, value_layer)
)
else:
# query: -> [sq, b, np, hn]
# key, value: -> [sq, b, ng, hn]
query_layer, key_layer, value_layer = (
x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head)
for x in (query_layer, key_layer, value_layer)
)
elif self.attention_type == "cross":
# Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)]
mixed_kv_layer = self.key_value(
Expand Down Expand Up @@ -7341,8 +7365,10 @@ def forward(
key_layer,
value_layer,
qkv_format=self.qkv_format,
cu_seqlens_q=None,
cu_seqlens_kv=None,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
attention_mask=attention_mask,
attn_mask_type=attn_mask_type,
window_size=window_size,
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .base import (
get_workspace,
_ub_communicators,
get_ub,
TransformerEngineBaseModule,
_2X_ACC_FPROP,
Expand Down Expand Up @@ -1297,7 +1298,7 @@ def __init__(
self.gemm_gelu_fusion = (
bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0")))
and self.activation == "gelu"
and not get_ub("fc1_fprop").is_atomic_gemm()
and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm()))
)

if tp_group is None:
Expand Down
20 changes: 20 additions & 0 deletions transformer_engine/pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,10 @@ def forward(
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
fast_zero_fill: bool = True,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -604,6 +608,18 @@ def forward(
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
to the attention score of query i and key j.
cu_seqlens_q: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided.
max_seqlen_kv: Optional[int], default = `None`
Maximum sequence length in `key_layer` and `value_layer`.
Calculated from `cu_seqlens_kv` if not provided.
fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use.
inference_params: InferenceParams, default = None
Expand Down Expand Up @@ -664,6 +680,10 @@ def forward(
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
fast_zero_fill=fast_zero_fill,
)

Expand Down

0 comments on commit 5d5fe81

Please sign in to comment.