From 5194eb4b1618e377c2f0348a6acc0666c35380c8 Mon Sep 17 00:00:00 2001 From: Michael Goldfarb Date: Mon, 12 Aug 2024 16:42:52 +0000 Subject: [PATCH] Implementation of context parallel fused attention using all-gather. Signed-off-by: Michael Goldfarb --- tests/jax/distributed_test_base.py | 24 + tests/jax/test_distributed_fused_attn.py | 245 +++++- transformer_engine/jax/attention.py | 119 ++- .../jax/cpp_extensions/attention.py | 772 ++++++++++++------ .../jax/csrc/extensions/pybind.cpp | 5 +- transformer_engine/jax/sharding.py | 39 +- 6 files changed, 942 insertions(+), 262 deletions(-) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 3a7fe33378..bbd54ecce5 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -4,6 +4,8 @@ import operator import re from functools import reduce +from itertools import product +import pytest import jax from jax.experimental.pjit import pjit, _UNSPECIFIED @@ -29,6 +31,28 @@ def generate_configs(): return configs +def generate_context_parallel_configs(): + configs = [] + + DP_sizes = (1, 2) + CP_sizes = (1, 2, 4, 8) + TP_sizes = (1, 2) + for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes): + ndev = cp * tp * dp + if is_devices_enough(ndev): + configs.append( + pytest.param( + ndev, + (dp, cp, tp), + ("dp", "cp", "tp"), + MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp"), + id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}", + ) + ) + + return configs + + COLL_AR_KEY = "all-reduce" COLL_AG_KEY = "all-gather" COLL_OTHER_KEY = "other" diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 15676dd270..61d68aacae 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -3,6 +3,7 @@ # See LICENSE for license information. import pytest +from functools import partial import jax import jax.numpy as jnp @@ -10,8 +11,13 @@ from flax.linen import dot_product_attention from jax import random from jax.sharding import Mesh, NamedSharding, PartitionSpec -from distributed_test_base import generate_configs, generate_collectives_count, compare_ops -from utils import make_causal_mask, make_self_mask +from distributed_test_base import ( + generate_configs, + generate_context_parallel_configs, + generate_collectives_count, + compare_ops, +) +from utils import make_causal_mask, make_self_mask, assert_tree_like_allclose, assert_allclose from transformer_engine.jax import fp8_autocast from transformer_engine.jax.attention import ( is_fused_attn_kernel_available, @@ -19,6 +25,10 @@ AttnBiasType, AttnMaskType, QKVLayout, + QKVFormat, + get_qkv_format, + reorder_causal_load_balancing, + inverse_reorder_causal_load_balancing, ) @@ -263,7 +273,8 @@ def target_func(q, kv, mask): scaling_factor=scaling_factor, dropout_probability=dropout_prob, is_training=is_training, - ) + ), + dtype=jnp.float32, ) def ref_func(query, kv, mask): @@ -284,7 +295,7 @@ def ref_func(query, kv, mask): dtype=jnp.float32, ) - return jnp.mean(output).astype(dtype) + return jnp.mean(output, dtype=jnp.float32) (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs( data_shape, mesh_resource, attn_mask_type, dtype @@ -310,3 +321,229 @@ def ref_func(query, kv, mask): in_shardings=(q_pspec, kv_pspec, mask_pspec), out_shardings=(None, (q_pspec, kv_pspec)), ) + + +class TestDistributedContexParallelSelfAttn: + + def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype): + batch, seqlen, heads, hidden = shape + qkey, kkey, vkey = random.split(random.PRNGKey(1124), 3) + q = random.normal(qkey, shape, dtype=dtype) + k = random.normal(kkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype) + v = random.normal(vkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype) + + mask = None + if attn_mask_type == AttnMaskType.CAUSAL_MASK: + mask = make_causal_mask(batch, seqlen) + + return q, k, v, mask + + def qkv_to_layout(self, q, k, v, qkv_layout): + qkv_args = () + match qkv_layout: + case QKVLayout.BSHD_BS2HD: + k, v = map(partial(jnp.expand_dims, axis=-3), [k, v]) + kv = jnp.concatenate((k, v), axis=-3) + qkv_args = (q, kv) + case QKVLayout.BSHD_BSHD_BSHD: + qkv_args = (q, k, v) + case _: + raise ValueError(f"Unsupported {qkv_layout=}") + return qkv_args + + @pytest.mark.parametrize( + "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() + ) + @pytest.mark.parametrize( + "data_shape", + [ + pytest.param([2, 512, 12, 128], id="2-512-12-128"), + pytest.param([4, 1024, 16, 64], id="4-1024-16-64"), + ], + ) + @pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16]) + @pytest.mark.parametrize( + "attn_mask_type", + [ + pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"), + pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), + ], + ) + @pytest.mark.parametrize("dtype", [jnp.bfloat16]) + @pytest.mark.parametrize( + "qkv_layout", + [ + pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"), + pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"), + ], + ) + @pytest.mark.parametrize( + "load_balanced", [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")] + ) + def test_contex_parallel_self_attn( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + kv_groups, + attn_mask_type, + dtype, + qkv_layout, + load_balanced, + ): + attn_bias_type = AttnBiasType.NO_BIAS + dropout_prob = 0.0 + is_training = True + scaling_factor = 1.0 + dp_size, cp_size, tp_size = mesh_shape + qkv_format = get_qkv_format(qkv_layout) + + _, seqlen, num_head, hidden = data_shape + num_kv_heads = num_head // kv_groups + + # make sure the mesh evently divides cp and tp axis + if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0: + pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}") + + def target_func(q, k, v, mask): + return jnp.mean( + fused_attn( + self.qkv_to_layout(q, k, v, qkv_layout), + bias=None, + mask=mask, + seed=None, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_prob, + is_training=is_training, + context_parallel_causal_load_balanced=load_balanced, + ), + ).astype(dtype) + + def ref_func(q, k, v, mask, kv_groups): + q = jnp.squeeze(q) + k = jnp.squeeze(jnp.repeat(k, kv_groups, axis=2)) + v = jnp.squeeze(jnp.repeat(v, kv_groups, axis=2)) + output = dot_product_attention( + q, + k, + v, + bias=None, + mask=mask, + deterministic=is_training, + dropout_rate=dropout_prob, + dropout_rng=None, + dtype=jnp.float32, + ) + return jnp.mean(output).astype(dtype) + + q, k, v, mask = self.generate_inputs(data_shape, kv_groups, attn_mask_type, dtype) + + # Single GPU (reference) + ref_func_jit = jax.jit(jax.value_and_grad(ref_func, argnums=[0, 1, 2]), static_argnums=[4]) + ref_fwd, ref_grads = ref_func_jit(q, k, v, mask, kv_groups) + + # Multi GPU (function under test) + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + with mesh, fp8_autocast(mesh_resource=mesh_resource): + qkv_ps = PartitionSpec( + mesh_resource.dp_resource, + mesh_resource.cp_resource, + mesh_resource.tp_resource, + None, + ) + qkv_sharding = NamedSharding(mesh, qkv_ps) + + mask_ps = PartitionSpec( + mesh_resource.dp_resource, None, mesh_resource.cp_resource, None + ) + mask_sharding = NamedSharding(mesh, mask_ps) + + reorder = partial( + reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format + ) + inverse_reorder = partial( + inverse_reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format + ) + + if load_balanced: + q, k, v = jax.tree.map(reorder, (q, k, v)) + + q_, k_, v_ = map(partial(jax.device_put, device=qkv_sharding), [q, k, v]) + mask_ = jax.device_put(mask, device=mask_sharding) + + target_func_jit = jax.jit( + jax.value_and_grad(target_func, argnums=[0, 1, 2]), + in_shardings=[qkv_sharding, qkv_sharding, qkv_sharding, mask_sharding], + out_shardings=(None, (qkv_sharding, qkv_sharding, qkv_sharding)), + ) + + target_fwd, target_grads = target_func_jit(q_, k_, v_, mask_) + + if load_balanced: + target_dq, target_dk, target_dv = jax.tree.map(inverse_reorder, target_grads[0:3]) + target_grads = (target_dq, target_dk, target_dv, *target_grads[3:]) + + def _print_diffs(target, ref): + print("min: ", jnp.min(target), jnp.min(ref)) + print("max: ", jnp.max(target), jnp.max(ref)) + print("mean: ", jnp.mean(target), jnp.mean(ref)) + print("median: ", jnp.median(target), jnp.median(ref)) + print("std: ", jnp.std(target), jnp.std(ref)) + print("var: ", jnp.var(target), jnp.var(ref)) + print("max diff: ", jnp.max(jnp.abs(target - ref))) + + has_diffs = False + + try: + assert_allclose(target_fwd, ref_fwd, dtype=dtype) + except AssertionError as e: + has_diffs = True + print(f"target_fwd v. ref_fwd") + _print_diffs(target_fwd, ref_fwd) + + for i in range(len(target_grads)): + if ref_grads[i] is None or target_grads[i] is None: + # expect both none if one is + assert target_grads[i] is None and ref_grads[i] is None + else: + try: + assert_allclose(target_grads[i], ref_grads[i]) + except AssertionError as e: + has_diffs = True + print(f"target_grads[{i}] v. ref_grads[{i}]") + _print_diffs(target_grads[i], ref_grads[i]) + + assert has_diffs == False, "has_diffs != False" + + +class TestReorderCausalLoadBalancing: + @pytest.mark.parametrize("cp_size", [2, 4, 8]) + @pytest.mark.parametrize( + "shape", + [ + pytest.param([1, 16, 1, 1], id="1-16-1-1"), + pytest.param([4, 32, 12, 32], id="4-32-12-32"), + pytest.param([3, 32, 8, 64], id="3-32-8-64"), + ], + ) + @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD]) + def test(self, cp_size, shape, qkv_format): + tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16) + if qkv_format == QKVFormat.SBHD: + tensor = tensor.swapaxes(0, 1) + + ref = tensor.copy() + + reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2]) + inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2]) + + reordered = reorder(tensor, cp_size, qkv_format) + inversed = inverse(reordered, cp_size, qkv_format) + + assert jnp.array_equal(inversed, ref) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index dcd860c3a4..9b8279be25 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -43,6 +43,8 @@ class AttnMaskType(Enum): PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK + CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK + PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK class QKVLayout(Enum): @@ -97,11 +99,21 @@ def canonicalize_attn_mask_type(attn_mask_type: str): return AttnMaskType.PADDING_MASK case "causal": return AttnMaskType.CAUSAL_MASK + case "causal_bottom_right" | "bottom_right_causal": + return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK case "padding_causal" | "causal_padding": return AttnMaskType.PADDING_CAUSAL_MASK + case ( + "padding_causal_bottom_right" + | "causal_padding_bottom_right" + | "bottom_right_causal_padding" + | "bottom_right_padding_causal" + ): + return AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK raise ValueError( - f"Unsupported {attn_mask_type=}, supported attn_mask_type=" - "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}" + f"Unsupported {attn_mask_type=}, supported attn_mask_type={{'no_mask', 'padding', 'causal'," + " 'padding_causal', 'causal_padding', 'causal_bottom_right'," + " 'padding_causal_bottom_right'}" ) @@ -155,6 +167,75 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout): return batch, q_max_seqlen, kv_max_seqlen +def _reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat, inverse: bool): + match tensor_format: + case QKVFormat.SBHD: + seq_dim = 0 + case QKVFormat.BSHD: + seq_dim = 1 + case _: + raise ValueError(f"{tensor_format=} is not supported for causal load balancing.") + + if cp_size == 1: + return tensor + + if cp_size % 2 != 0: + raise ValueError(f"{cp_size=} must be a multiple of 2.") + + # Need to ensure we have 2 pairs to swap for balancing between cp ranks + if tensor.shape[seq_dim] % (cp_size * 2) != 0: + raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}") + + # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D] + # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D] + ori_tensor_shape = tensor.shape + tensor = tensor.reshape( + ( + *ori_tensor_shape[:seq_dim], + 2 * cp_size, + ori_tensor_shape[seq_dim] // (2 * cp_size), + *ori_tensor_shape[seq_dim + 1 :], + ) + ) + + parts = [] + if not inverse: + for cp_rank in range(cp_size): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + else: + for cp_rank in range(cp_size // 2): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + base = 4 * cp_rank + index = jnp.array([base, base + 2]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + for cp_rank in range(cp_size // 2): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + base = 2 * cp_size - 1 - 4 * cp_rank + index = jnp.array([base, base - 2]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] + combined = jnp.stack(parts, axis=seq_dim) + + return combined.reshape(ori_tensor_shape) + + +def reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat): + """Reorders a tensor for load balancing the compute of causal attention.""" + return _reorder_causal_load_balancing(tensor, cp_size, tensor_format, False) + + +def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat): + """Inverse operation of `reorder_causal_load_balancing`.""" + return _reorder_causal_load_balancing(tensor, cp_size, tensor_format, True) + + def fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], @@ -166,6 +247,8 @@ def fused_attn( scaling_factor: float, dropout_probability: float, is_training: bool, + context_parallel_causal_load_balanced: bool = False, + context_parallel_axis: str = "", ): """ Perform non-THD (non-packed) cuDNN fused attention. @@ -192,6 +275,9 @@ def fused_attn( scaling_factor (float): Scaling factor for the attention scores. dropout_probability (float): Dropout probability to apply during attention. is_training (bool): Flag indicating whether the model is in training mode. + context_parallel_causal_load_balanced (bool): + Indicates the sequences are ordered for causal mask load balancing when running context parallelism. + context_parallel_axis (str): The name of the context parallel axis. Returns: (jnp.ndarray): The output tensor from the fused attention. """ @@ -213,7 +299,11 @@ def fused_attn( ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" # convert the mask to seqlens, mask doesn't support ragged offsets - if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: + if attn_mask_type in [ + AttnMaskType.NO_MASK, + AttnMaskType.CAUSAL_MASK, + AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, + ]: batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout) q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32) kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32) @@ -242,6 +332,8 @@ def fused_attn( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=1, + context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, + context_parallel_axis=context_parallel_axis, ) return output @@ -262,6 +354,8 @@ def fused_attn_thd( dropout_probability: float, is_training: bool, max_segments_per_seq: int = 1, + context_parallel_causal_load_balanced: bool = False, + context_parallel_axis: str = "", ): """ (Experimental) Perform THD (packed) cuDNN fused attention. @@ -300,6 +394,9 @@ def fused_attn_thd( Indicating the maximum number of segments inside a sequence. This parameter is to constrain the limit usage and need to be static during the e2e training. The XLA compile time and memory consumption is proportional to `max_segments_per_seq`. + context_parallel_causal_load_balanced (bool): + Indicates the sequences are ordered for causal mask load balancing when running context parallelism. + context_parallel_axis (str): The name of the context parallel axis. Returns: (jnp.ndarray): The output tensor from the fused attention. @@ -354,12 +451,14 @@ def fused_attn_thd( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, + context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, + context_parallel_axis=context_parallel_axis, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15)) def _fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], @@ -375,6 +474,8 @@ def _fused_attn( dropout_probability: float, is_training: bool, max_segments_per_seq: int, + context_parallel_causal_load_balanced: bool, + context_parallel_axis: str, ): output, _ = _fused_attn_fwd_rule( qkv, @@ -391,6 +492,8 @@ def _fused_attn( dropout_probability, is_training, max_segments_per_seq, + context_parallel_causal_load_balanced, + context_parallel_axis, ) return output @@ -410,6 +513,8 @@ def _fused_attn_fwd_rule( dropout_probability, is_training, max_segments_per_seq, + context_parallel_causal_load_balanced, + context_parallel_axis, ): output, softmax_aux, rng_state = tex.fused_attn_fwd( qkv, @@ -426,6 +531,8 @@ def _fused_attn_fwd_rule( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, + context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, + context_parallel_axis=context_parallel_axis, ) output = checkpoint_name(output, "context") softmax_aux = checkpoint_name(softmax_aux, "context") @@ -451,6 +558,8 @@ def _fused_attn_bwd_rule( dropout_probability, is_training, max_segments_per_seq, + context_parallel_causal_load_balanced, + context_parallel_axis, ctx, dz, ): @@ -483,6 +592,8 @@ def _fused_attn_bwd_rule( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, + context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, + context_parallel_axis=context_parallel_axis, ) if attn_bias_type == AttnBiasType.NO_BIAS: grad_bias = None diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 0cbf847dcd..d5b901c107 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -9,8 +9,9 @@ from typing import Optional, Tuple import warnings +import jax import jax.numpy as jnp -from jax import dtypes +from jax import dtypes, lax from jax.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding @@ -34,7 +35,11 @@ get_cudnn_version, ) from ..sharding import ( + global_mesh_resource, + lax_paral_op, all_reduce_sum_along_dp_fsdp, + get_mesh_axis_size, + get_mesh_axis_rank, get_all_mesh_axes, num_of_devices, ) @@ -47,6 +52,38 @@ ] +@partial( + jax.tree_util.register_dataclass, + data_fields=[], + meta_fields=[ + "attn_bias_type", + "attn_mask_type", + "qkv_layout", + "scaling_factor", + "dropout_probability", + "is_training", + "max_segments_per_seq", + "context_parallel_load_balanced", + "cp_axis", + ], +) +@dataclass(frozen=True) +class _FusedAttnConfig: + """ + Passes static configuration properties of fused attention. + """ + + attn_bias_type: NVTE_Bias_Type + attn_mask_type: NVTE_Mask_Type + qkv_layout: NVTE_QKV_Layout + scaling_factor: float + dropout_probability: float + is_training: bool + max_segments_per_seq: int + context_parallel_load_balanced: bool + cp_axis: str + + @dataclass(frozen=True) class FusedAttnHelper: """ @@ -178,7 +215,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): name = "te_fused_attn_forward" multiple_results = True - impl_static_args = (9, 10, 11, 12, 13, 14, 15) + impl_static_args = (9,) inner_primitive = None outer_primitive = None @@ -194,13 +231,7 @@ def abstract( _k_seq_offsets, seed_aval, *, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, + config: _FusedAttnConfig, ): """ Fused attention fwd abstract @@ -213,7 +244,7 @@ def abstract( assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( - FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) + FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) ) output_shape = (*batch_shape, q_max_seqlen, attn_heads, head_dim) @@ -223,10 +254,10 @@ def abstract( backend = FusedAttnHelper( q_dtype, k_dtype, - qkv_layout, - attn_bias_type, - attn_mask_type, - dropout_probability, + config.qkv_layout, + config.attn_bias_type, + config.attn_mask_type, + config.dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, @@ -238,7 +269,7 @@ def abstract( softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) softmax_dtype = q_dtype elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: - softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, max_segments_per_seq) + softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, config.max_segments_per_seq) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: raise ValueError(f"Unsupported {backend=}") @@ -252,7 +283,7 @@ def abstract( rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: bias_batch = bias_heads = 0 else: *bias_batch_shape, bias_heads, _, _ = bias_aval.shape @@ -270,14 +301,14 @@ def abstract( num_gqa_groups, bias_heads, head_dim, - scaling_factor, - dropout_probability, - attn_bias_type, - attn_mask_type, - qkv_layout, + config.scaling_factor, + config.dropout_probability, + config.attn_bias_type, + config.attn_mask_type, + config.qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), - is_training, - max_segments_per_seq, + config.is_training, + config.max_segments_per_seq, ) wkspace_aval = q_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) @@ -308,28 +339,12 @@ def lowering( k_seq_offsets, seed, *, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, + config: _FusedAttnConfig, ): """ Fused attention fwd lowering rules """ - operands = [ - q, - k, - v, - bias, - q_cu_seqlen, - kv_cu_seqlen, - q_seq_offsets, - k_seq_offsets, - seed, - ] + operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, seed] operand_shapes = map(lambda x: x.type.shape, operands) out_types = [ ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) @@ -340,12 +355,12 @@ def lowering( q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( - FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) + FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) ) input_batch = reduce(operator.mul, batch_shape) - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: bias_batch = bias_heads = 0 else: *bias_batch_shape, bias_heads, _, _ = bias_aval.shape @@ -362,16 +377,16 @@ def lowering( num_gqa_groups, bias_heads, head_dim, - max_segments_per_seq, + config.max_segments_per_seq, wkspace_aval.size, - scaling_factor, - dropout_probability, - attn_bias_type, - attn_mask_type, - qkv_layout, + config.scaling_factor, + config.dropout_probability, + config.attn_bias_type, + config.attn_mask_type, + config.qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - is_training, + config.is_training, not FusedAttnHelper.is_non_deterministic_allowed(), ) @@ -390,17 +405,11 @@ def impl( q_seq_offsets, k_seq_offsets, seed, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, + config: _FusedAttnConfig, ): assert FusedAttnFwdPrimitive.inner_primitive is not None - if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD: + if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD: def _fix_len_take(x, condition, fill_value=-1): x_shape = x.shape @@ -418,7 +427,7 @@ def convert_to_2d(offsets, batch, max_seqlen): ) return offsets_2d - match qkv_layout: + match config.qkv_layout: case NVTE_QKV_Layout.NVTE_T3HD: kv_max_seqlen = q_max_seqlen = q.shape[-4] kv_batch = q_batch = reduce(operator.mul, q.shape[:-4]) @@ -472,66 +481,27 @@ def convert_to_2d(offsets, batch, max_seqlen): q_seq_offsets, k_seq_offsets, seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, + config=config, ) return output, softmax_aux, rng_state @staticmethod - def batcher( - batched_args, - batch_dims, - *, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, - ): + def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) assert FusedAttnFwdPrimitive.outer_primitive is not None q_bdim, *_, seed_bdim = batch_dims out_bdims = q_bdim, q_bdim, seed_bdim return ( - FusedAttnFwdPrimitive.outer_primitive.bind( - *batched_args, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, - ), + FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config), out_bdims, ) @staticmethod - def infer_sharding_from_operands( - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, - mesh, - arg_infos, - result_infos, - ): - del attn_bias_type, attn_mask_type, scaling_factor - del dropout_probability, is_training, max_segments_per_seq, result_infos + def infer_sharding_from_operands(config, mesh, arg_infos, result_infos): + del result_infos q_spec = get_padded_spec(arg_infos[0]) - k_spec = get_padded_spec(arg_infos[1]) - match qkv_layout: + match config.qkv_layout: case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD: # q_spec = (...batch, q_seqlen, head, hidden) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:])) @@ -543,33 +513,22 @@ def infer_sharding_from_operands( # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-4]) + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) ) case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD: # q_spec = (...batch, q_seqlen, head, hidden) # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3]) + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) ) case _: - raise ValueError(f"Unsupported {qkv_layout=}") + raise ValueError(f"Unsupported {config.qkv_layout=}") rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) return (out_sharding, softmax_aux_sharding, rng_state_sharding) @staticmethod - def partition( - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, - mesh, - arg_infos, - result_infos, - ): + def partition(config, mesh, arg_infos, result_infos): out_sharding = result_infos[0].sharding softmax_aux_sharding = result_infos[1].sharding rng_state_sharding = seed_sharding = NamedSharding( @@ -577,16 +536,7 @@ def partition( ) arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) - impl = partial( - FusedAttnFwdPrimitive.impl, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, - ) + impl = partial(FusedAttnFwdPrimitive.impl, config=config) return mesh, impl, out_shardings, arg_shardings @@ -600,7 +550,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): name = "te_fused_attn_backward" multiple_results = True - impl_static_args = (12, 13, 14, 15, 16, 17, 18) + impl_static_args = (12,) inner_primitive = None outer_primitive = None @@ -619,13 +569,7 @@ def abstract( _q_seq_offsets, _k_seq_offsets, *, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, + config, ): """ Fused attention bwd abstract @@ -641,10 +585,10 @@ def abstract( assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( - FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) + FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) ) - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: bias_batch = bias_heads = 0 else: *bias_batch_shape, bias_heads, _, _ = bias_aval.shape @@ -662,15 +606,15 @@ def abstract( num_gqa_groups, bias_heads, head_dim, - scaling_factor, - dropout_probability, - attn_bias_type, - attn_mask_type, - qkv_layout, + config.scaling_factor, + config.dropout_probability, + config.attn_bias_type, + config.attn_mask_type, + config.qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), - is_training, + config.is_training, deterministic, - max_segments_per_seq, + config.max_segments_per_seq, ) dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) @@ -707,13 +651,7 @@ def lowering( q_seq_offsets, k_seq_offsets, *, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, + config, ): """ Fused attention bwd lowering rules @@ -743,12 +681,12 @@ def lowering( q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( - FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) + FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) ) input_batch = reduce(operator.mul, batch_shape) - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: bias_batch = bias_heads = 0 else: *bias_batch_shape, bias_heads, _, _ = bias_aval.shape @@ -765,16 +703,16 @@ def lowering( num_gqa_groups, bias_heads, head_dim, - max_segments_per_seq, + config.max_segments_per_seq, wkspace_aval.size, - scaling_factor, - dropout_probability, - attn_bias_type, - attn_mask_type, - qkv_layout, + config.scaling_factor, + config.dropout_probability, + config.attn_bias_type, + config.attn_mask_type, + config.qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - is_training, + config.is_training, not FusedAttnHelper.is_non_deterministic_allowed(), ) @@ -796,17 +734,11 @@ def impl( kv_seqlen, q_seq_offsets, k_seq_offsets, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, + config, ): assert FusedAttnBwdPrimitive.inner_primitive is not None - if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD: + if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD: def _fix_len_take(x, condition, fill_value=-1): x_shape = x.shape @@ -825,7 +757,7 @@ def convert_to_2d(offsets, batch, max_seqlen): ) return offsets_2d - match qkv_layout: + match config.qkv_layout: case NVTE_QKV_Layout.NVTE_T3HD: kv_max_seqlen = q_max_seqlen = q.shape[-4] kv_batch = q_batch = reduce(operator.mul, q.shape[:-4]) @@ -882,63 +814,25 @@ def convert_to_2d(offsets, batch, max_seqlen): kv_cu_seqlen, q_seq_offsets, k_seq_offsets, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, + config=config, ) return dq, dk, dv, dbias @staticmethod - def batcher( - batched_args, - batch_dims, - *, - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, - ): + def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) assert FusedAttnBwdPrimitive.outer_primitive is not None q_bdim, k_bdim, v_bdim, *_ = batch_dims out_bdims = q_bdim, k_bdim, v_bdim, q_bdim return ( - FusedAttnBwdPrimitive.outer_primitive.bind( - *batched_args, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, - ), + FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config), out_bdims, ) @staticmethod - def infer_sharding_from_operands( - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, - mesh, - arg_infos, - result_infos, - ): - del attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, max_segments_per_seq - del dropout_probability, is_training, result_infos + def infer_sharding_from_operands(config, mesh, arg_infos, result_infos): + del config, result_infos q_spec = get_padded_spec(arg_infos[0]) k_spec = get_padded_spec(arg_infos[1]) v_spec = get_padded_spec(arg_infos[2]) @@ -950,18 +844,7 @@ def infer_sharding_from_operands( return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) @staticmethod - def partition( - attn_bias_type, - attn_mask_type, - qkv_layout, - scaling_factor, - dropout_probability, - is_training, - max_segments_per_seq, - mesh, - arg_infos, - result_infos, - ): + def partition(config, mesh, arg_infos, result_infos): del result_infos q_spec = get_padded_spec(arg_infos[0]) k_spec = get_padded_spec(arg_infos[1]) @@ -1001,16 +884,10 @@ def sharded_impl( kv_cu_seqlen, q_seq_offsets, k_seq_offsets, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, + config=config, ) global_dbias = local_dbias - if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) return local_dq, local_dk, local_dv, global_dbias @@ -1020,6 +897,378 @@ def sharded_impl( register_primitive(FusedAttnBwdPrimitive) +@dataclass(frozen=True) +class _FusedAttnCPWithAllGatherHelper: + """Helper class to assist with running the all-gather strategy for CP attention.""" + + mesh: jax.sharding.Mesh + config: _FusedAttnConfig + + def check_supported(self): + """Checks if the context parallel implementation is supported by the given arguments.""" + header = "Context parallel fused attention" + + allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD] + assert self.config.qkv_layout in allowed_layouts, ( + f"{header} only supports layouts: {','.join([str(x) for x in allowed_layouts])} got:" + f" {self.config.qkv_layout}" + ) + + assert ( + self.config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS + ), f"{header} does not support bias got: {self.config.attn_bias_type}" + + allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK] + assert self.config.attn_mask_type in allowed_masks, ( + f"{header} only supports masking types: " + f" {','.join([str(x) for x in allowed_masks])} got: {self.config.attn_mask_type}" + ) + + assert self.config.max_segments_per_seq == 1, ( + f"{header} only supports max_segments_per_seq == 1 got:" + f" {self.config.max_segments_per_seq}" + ) + assert self.config.dropout_probability == 0.0, f"{header} does not support dropout" + + def get_adjusted_mask(self): + """Converts the mask for context parallelism.""" + if self.config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK: + return NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK + return self.config.attn_mask_type + + def all_gather_kv(self, k, v): + """Performs a all-gather of k and v over context parallel ranks.""" + + def ag(x): + return lax_paral_op( + x, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True + ) + + match self.config.qkv_layout: + case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + return ag(k), v + case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + return ag(k), ag(v) + + return k, v # fall through + + def reduce_scatter_dkv(self, dk, dv): + """Performs a reduce-scatter of dk and dv over context parallel ranks.""" + + def rs(x): + return lax_paral_op( + x, + lax.psum_scatter, + self.config.cp_axis, + mesh=self.mesh, + scatter_dimension=1, + tiled=True, + ) + + match self.config.qkv_layout: + case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + return rs(dk), dv + case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + return rs(dk), rs(dv) + + return dk, dv # fall through + + def kv_seqlens_for_rank(self, cp_rank, kv_max_seqlen, kv_seqlen_per_subrank): + """Returns sequence lengths of KV to use for each sub rank of the given cp_rank. + + Example: CP=4, MaxLen = 1024, Unbalanced + cp_rank 0: [128, 256] + cp_rank 1: [384, 512] + cp_rank 2: [640, 768] + cp_rank 3: [896, 1024] + + Example: CP=4, MaxLen = 1024, Balanced + cp_rank 0: [128, 1024] + cp_rank 1: [256, 896] + cp_rank 2: [384, 768] + cp_rank 3: [512, 640] + """ + if self.config.context_parallel_load_balanced: + kv_seq_this_rank = [ + (cp_rank + 1) * kv_seqlen_per_subrank, + kv_max_seqlen - cp_rank * kv_seqlen_per_subrank, + ] + else: + kv_seq_this_rank = [ + (cp_rank * 2 + 1) * kv_seqlen_per_subrank, + (cp_rank * 2 + 2) * kv_seqlen_per_subrank, + ] + return kv_seq_this_rank + + def slice_kv(self, k, v, slice_seq_len): + """Slices k and v tensors to a sequence length of slice_seq_len.""" + + def sliced(x): + return lax.dynamic_slice_in_dim(x, 0, slice_seq_len, axis=1) + + match self.config.qkv_layout: + case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + return sliced(k), v + case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + return sliced(k), sliced(v) + + return k, v # fall through + + def pad_kv(self, dk, dv, pad_seq_len): + """Pads dk and dv tensors to a sequence length of pad_seq_len.""" + + def pad(x, npad): + return jnp.pad(x, npad, "constant", constant_values=0.0) + + match self.config.qkv_layout: + case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0], [0, 0]] + return pad(dk, npad), dv + case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0]] + return pad(dk, npad), pad(dv, npad) + + return dk, dv # fall through + + +class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): + """ + Fused Attention Forward with Context Parallelism Primitive + + This context parallel implementation uses all-gather to collect KV inputs from context parallel ranks. + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + # Call base implementation for non-context parallel mesh to avoid unecessary work. + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + if not is_context_parallel: + return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + helper = _FusedAttnCPWithAllGatherHelper(mesh, config) + helper.check_supported() + + out_sharding = result_infos[0].sharding + softmax_aux_sharding = result_infos[1].sharding + rng_state_sharding = seed_sharding = NamedSharding( + mesh, PartitionSpec(get_all_mesh_axes(), None) + ) + arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) + out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) + + def impl(q, k, v, bias, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, seed): + + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) + + # cuDNN does not support right-aligned masking with dynamic sequence length padding. + # Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch + # to select the appropriate computation. Each case generates a [..., SEQ/CP, ..] tensor + # meeting the expectation of the SPMD model. + # TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding + # mask/sequence length tensor to avoid this unrolled loop. + def _cross_attn(idx, q, k, v, bias, q_seqlen, kv_seqlen, seed): + kv_max_seqlen = k.shape[1] + kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2) + assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size" + + q_split = jnp.split(q, 2, axis=1) + + kv_seqlens_for_rank = helper.kv_seqlens_for_rank( + idx, kv_max_seqlen, kv_seqlen_per_subrank + ) + + results = [] + for sub_idx in range(2): + if config.attn_mask_type == NVTE_Mask_Type.NVTE_NO_MASK: + k_unmasked, v_unmasked = k, v # full kv used for unmasked + else: + k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx]) + + q_seqlen_for_step = q_seqlen / (cp_size * 2) + num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx] + kv_seqlen_for_step = (kv_seqlen / (cp_size * 2)) * num_kv_chunks + + output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( + q_split[sub_idx], + k_unmasked, + v_unmasked, + bias, + q_seqlen_for_step, + kv_seqlen_for_step, + q_seq_offsets, + k_seq_offsets, + seed, + config=config, + ) + results.append((output, softmax_aux, rng_state)) + + output = jnp.concatenate((results[0][0], results[1][0]), axis=1) + softmax_aux = jnp.concatenate((results[0][1], results[1][1]), axis=2) + rng_state = results[1][2] # Use the final RNG state + + return output, softmax_aux, rng_state + + k_ag, v_ag = helper.all_gather_kv(k, v) + + functions = [ + partial(_cross_attn, idx, q, k_ag, v_ag, bias, q_seqlen, kv_seqlen, seed) + for idx in range(cp_size) + ] + + return lax.switch(cp_rank, functions) + + return mesh, impl, out_shardings, arg_shardings + + +register_primitive(FusedAttnCPWithAllGatherFwdPrimitive) + + +class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): + """ + Fused Attention Backward with Context Parallelism Primitive. + + This context parallel implementation uses all-gather to collect KV and dKV inputs from context parallel ranks. + The gradients are subsequently reduce-scattered back to each context parallel rank. + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + # Call base implementation for non-context parallel mesh to avoid unecessary work. + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + if not is_context_parallel: + return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + # Ensure we can support this configuration with context parallelism. + helper = _FusedAttnCPWithAllGatherHelper(mesh, config) + helper.check_supported() + + del result_infos + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + v_spec = get_padded_spec(arg_infos[2]) + bias_spec = get_padded_spec(arg_infos[3]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) + dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) + dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) + + def impl( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + ): + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) + + # See comment in FusedAttnCPFwdPrimitive.partition for why we define this function. + def _cross_attn_bwd( + idx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen + ): + kv_max_seqlen = k.shape[1] + kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2) + assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size" + + q_split = jnp.split(q, 2, axis=1) + output_split = jnp.split(output, 2, axis=1) + doutput_split = jnp.split(doutput, 2, axis=1) + softmax_aux_split = jnp.split(softmax_aux, 2, axis=2) + + kv_seqlens_for_rank = helper.kv_seqlens_for_rank( + idx, kv_max_seqlen, kv_seqlen_per_subrank + ) + + results = [] + for sub_idx in range(2): + if config.attn_mask_type == NVTE_Mask_Type.NVTE_NO_MASK: + k_unmasked, v_unmasked = k, v # full kv used for unmasked + else: + k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx]) + + q_seqlen_for_step = q_seqlen // (cp_size * 2) + num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx] + kv_seqlen_for_step = (kv_seqlen // (cp_size * 2)) * num_kv_chunks + + dq_local, dk_local, dv_local, dbias_local = FusedAttnBwdPrimitive.impl( + q_split[sub_idx], + k_unmasked, + v_unmasked, + bias, + softmax_aux_split[sub_idx], + rng_state, + output_split[sub_idx], + doutput_split[sub_idx], + q_seqlen_for_step, + kv_seqlen_for_step, + q_seq_offsets, + k_seq_offsets, + config=config, + ) + + # pad dk/dv to be unsliced shape so we can reduce scatter over all ranks. + if config.attn_mask_type != NVTE_Mask_Type.NVTE_NO_MASK: + pad_length = kv_max_seqlen - kv_seqlens_for_rank[sub_idx] + dk_local, dv_local = helper.pad_kv(dk_local, dv_local, pad_length) + + results.append((dq_local, dk_local, dv_local, dbias_local)) + + dq_local = jnp.concatenate((results[0][0], results[1][0]), axis=1) + dk_local_pad = results[0][1] + results[1][1] + dv_local_pad = results[0][2] + results[1][2] + return dq_local, dk_local_pad, dv_local_pad, results[1][3] + + k_ag, v_ag = helper.all_gather_kv(k, v) + + functions = [ + partial( + _cross_attn_bwd, + idx, + q, + k_ag, + v_ag, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + ) + for idx in range(cp_size) + ] + + dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions) + dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local) + + return dq, dk, dv, dbias + + return mesh, impl, out_shardings, arg_shardings + + +register_primitive(FusedAttnCPWithAllGatherBwdPrimitive) + + +def _maybe_context_parallel_axis(cp_axis: str): + if not cp_axis: + gmr = global_mesh_resource() + if gmr is not None: + cp_axis = gmr.cp_resource + else: + cp_axis = "" + return cp_axis + + def fused_attn_fwd( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], @@ -1035,6 +1284,8 @@ def fused_attn_fwd( dropout_probability: float, is_training: bool, max_segments_per_seq: int, + context_parallel_causal_load_balanced: bool = False, + context_parallel_axis: str = "", ) -> jnp.ndarray: """ Perform the forward pass of with cuDNN fused attention implementations. @@ -1063,6 +1314,9 @@ def fused_attn_fwd( scaling_factor (float): Scaling factor for the attention scores. dropout_probability (float): Dropout probability to apply during attention. is_training (bool): Flag indicating whether the model is in training mode. + context_parallel_causal_load_balanced (bool): + Indicates the sequences are ordered for causal mask load balancing when running context parallelism. + context_parallel_axis (str): The name of the context parallel axis. Returns: (jnp.ndarray): The output tensor from the fused attention. """ @@ -1094,14 +1348,7 @@ def fused_attn_fwd( assert bias is None bias = jnp.zeros(0, dtype=qkv[0].dtype) - return FusedAttnFwdPrimitive.outer_primitive.bind( - *qkv_for_primitive, - bias, - q_seqlen, - kv_seqlen, - q_seq_offsets if is_ragged else _not_used, - kv_seq_offsets if is_ragged else _not_used, - seed, + fused_config = _FusedAttnConfig( attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, qkv_layout=qkv_layout, @@ -1109,6 +1356,19 @@ def fused_attn_fwd( dropout_probability=dropout_probability, is_training=is_training, max_segments_per_seq=max_segments_per_seq, + context_parallel_load_balanced=context_parallel_causal_load_balanced, + cp_axis=_maybe_context_parallel_axis(context_parallel_axis), + ) + + return FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive.bind( + *qkv_for_primitive, + bias, + q_seqlen, + kv_seqlen, + q_seq_offsets if is_ragged else _not_used, + kv_seq_offsets if is_ragged else _not_used, + seed, + config=fused_config, ) @@ -1130,6 +1390,8 @@ def fused_attn_bwd( dropout_probability: float, is_training: bool, max_segments_per_seq: int, + context_parallel_causal_load_balanced: bool = False, + context_parallel_axis: str = "", ): """ Perform the backward pass of the cuDNN fused attention implementations. @@ -1159,7 +1421,9 @@ def fused_attn_bwd( scaling_factor (float): Scaling factor for the attention scores. dropout_probability (float): Dropout probability to apply during attention. is_training (bool): Flag indicating whether the model is in training mode. - + context_parallel_causal_load_balanced (bool): + Indicates the sequences are ordered for causal mask load balancing when running context parallelism. + context_parallel_axis (str): The name of the context parallel axis. Returns: Tuple[jnp.ndarray, ...], jnp.ndarray: - The first tuple contains the gradients with respect to the input `qkv` tensors in the @@ -1194,7 +1458,19 @@ def fused_attn_bwd( assert bias is None bias = jnp.zeros(0, dtype=qkv[0].dtype) - *qkv_grads, bias_grad = FusedAttnBwdPrimitive.outer_primitive.bind( + fused_config = _FusedAttnConfig( + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + max_segments_per_seq=max_segments_per_seq, + context_parallel_load_balanced=context_parallel_causal_load_balanced, + cp_axis=_maybe_context_parallel_axis(context_parallel_axis), + ) + + *qkv_grads, bias_grad = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive.bind( *qkv_for_primitive, bias, softmax_aux, @@ -1205,12 +1481,6 @@ def fused_attn_bwd( kv_seqlen, q_seq_offsets if is_ragged else _not_used, kv_seq_offsets if is_ragged else _not_used, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training, - max_segments_per_seq=max_segments_per_seq, + config=fused_config, ) return tuple(qkv_grads[: len(qkv)]), bias_grad diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 0a2172bb1b..14f449a76b 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -100,7 +100,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK); + .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) + .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) + .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", + NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 586e1a70c9..a14a8384cf 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -20,6 +20,7 @@ BATCH_AXES = "nvte_batch" SEQLEN_AXES = "nvte_seqlen" SEQLEN_TP_AXES = "nvte_seqlen_tp" +SEQLEN_CP_AXES = "nvte_seqlen_cp" HEAD_AXES = "nvte_head" HIDDEN_AXES = "nvte_hidden" HIDDEN_TP_AXES = "nvte_hidden_tp" @@ -65,6 +66,7 @@ def get_sharding_map_logic_axis_to_mesh_axis(): BATCH_AXES: batch_dim_rule, SEQLEN_AXES: None, SEQLEN_TP_AXES: gsr.tp_resource, + SEQLEN_CP_AXES: gsr.cp_resource, HEAD_AXES: gsr.tp_resource, HIDDEN_AXES: None, HIDDEN_TP_AXES: gsr.tp_resource, @@ -131,13 +133,15 @@ def get_padded_spec(spec, ndim): return spec + (None,) * (ndim - len(spec)) -def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh): +def lax_paral_op( + x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh, **kwargs +): """ A wrapper function to invoke lax.p* operations, like psum. """ if mesh_resource is not None: _, resource = _get_mesh_info(mesh_resource, mesh) - return ops(x, resource) + return ops(x, resource, **kwargs) return x @@ -148,6 +152,33 @@ def num_of_devices(): return len(jax.devices()) +def get_mesh_axis_size(axis, mesh=None): + """ + Get the axis size of the given mesh. + If the mesh is None, it would be replaced + by the global mesh. + """ + if mesh is None: + mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh + + if axis is None: + return 1 + + assert axis in mesh.shape, f"{axis} is not a axis of the given mesh {mesh.shape}" + return mesh.shape[axis] + + +def get_mesh_axis_rank(axis: str, mesh=None): + """ + Gets the local axis rank of the `axis` of the array. + If the mesh is None the rank is 0. + """ + if mesh is None: + return 0 + _, axis_name = _get_mesh_info(axis, mesh) + return jax.lax.axis_index(axis_name) + + @dataclass class MeshResource: """ @@ -168,12 +199,16 @@ class MeshResource: pp_resource : str, default = None The axis name in Mesh used to split model layers. along. If it is None, then pipeline parallelism is disabled. + cp_resource : str, default = None + The axis name in Mesh used to split sequence (context) dimensions along + in the attention. If it is None, then context parallelism is disabled. """ dp_resource: str = None tp_resource: str = None fsdp_resource: str = None pp_resource: str = None + cp_resource: str = None _GLOBAL_MESH_RESOURCE = MeshResource()