Skip to content

Commit

Permalink
Implementation of context parallel fused attention using all-gather.
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Goldfarb <[email protected]>
  • Loading branch information
mgoldfarb-nvidia committed Sep 5, 2024
1 parent 4ddb0a7 commit ba88ede
Show file tree
Hide file tree
Showing 6 changed files with 942 additions and 262 deletions.
24 changes: 24 additions & 0 deletions tests/jax/distributed_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import operator
import re
from functools import reduce
from itertools import product
import pytest

import jax
from jax.experimental.pjit import pjit, _UNSPECIFIED
Expand All @@ -29,6 +31,28 @@ def generate_configs():
return configs


def generate_context_parallel_configs():
configs = []

DP_sizes = (1, 2)
CP_sizes = (1, 2, 4, 8)
TP_sizes = (1, 2)
for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes):
ndev = cp * tp * dp
if is_devices_enough(ndev):
configs.append(
pytest.param(
ndev,
(dp, cp, tp),
("dp", "cp", "tp"),
MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp"),
id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}",
)
)

return configs


COLL_AR_KEY = "all-reduce"
COLL_AG_KEY = "all-gather"
COLL_OTHER_KEY = "other"
Expand Down
245 changes: 241 additions & 4 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,32 @@
# See LICENSE for license information.

import pytest
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
from flax.linen import dot_product_attention
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs, generate_collectives_count, compare_ops
from utils import make_causal_mask, make_self_mask
from distributed_test_base import (
generate_configs,
generate_context_parallel_configs,
generate_collectives_count,
compare_ops,
)
from utils import make_causal_mask, make_self_mask, assert_tree_like_allclose, assert_allclose
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
fused_attn,
AttnBiasType,
AttnMaskType,
QKVLayout,
QKVFormat,
get_qkv_format,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
)


Expand Down Expand Up @@ -263,7 +273,8 @@ def target_func(q, kv, mask):
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
)
),
dtype=jnp.float32,
)

def ref_func(query, kv, mask):
Expand All @@ -284,7 +295,7 @@ def ref_func(query, kv, mask):
dtype=jnp.float32,
)

return jnp.mean(output).astype(dtype)
return jnp.mean(output, dtype=jnp.float32)

(q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, attn_mask_type, dtype
Expand All @@ -310,3 +321,229 @@ def ref_func(query, kv, mask):
in_shardings=(q_pspec, kv_pspec, mask_pspec),
out_shardings=(None, (q_pspec, kv_pspec)),
)


class TestDistributedContexParallelSelfAttn:

def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype):
batch, seqlen, heads, hidden = shape
qkey, kkey, vkey = random.split(random.PRNGKey(1124), 3)
q = random.normal(qkey, shape, dtype=dtype)
k = random.normal(kkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype)
v = random.normal(vkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype)

mask = None
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
mask = make_causal_mask(batch, seqlen)

return q, k, v, mask

def qkv_to_layout(self, q, k, v, qkv_layout):
qkv_args = ()
match qkv_layout:
case QKVLayout.BSHD_BS2HD:
k, v = map(partial(jnp.expand_dims, axis=-3), [k, v])
kv = jnp.concatenate((k, v), axis=-3)
qkv_args = (q, kv)
case QKVLayout.BSHD_BSHD_BSHD:
qkv_args = (q, k, v)
case _:
raise ValueError(f"Unsupported {qkv_layout=}")
return qkv_args

@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
"data_shape",
[
pytest.param([2, 512, 12, 128], id="2-512-12-128"),
pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
],
)
@pytest.mark.parametrize(
"load_balanced", [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")]
)
def test_contex_parallel_self_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
):
attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0
is_training = True
scaling_factor = 1.0
dp_size, cp_size, tp_size = mesh_shape
qkv_format = get_qkv_format(qkv_layout)

_, seqlen, num_head, hidden = data_shape
num_kv_heads = num_head // kv_groups

# make sure the mesh evently divides cp and tp axis
if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")

def target_func(q, k, v, mask):
return jnp.mean(
fused_attn(
self.qkv_to_layout(q, k, v, qkv_layout),
bias=None,
mask=mask,
seed=None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
context_parallel_causal_load_balanced=load_balanced,
),
).astype(dtype)

def ref_func(q, k, v, mask, kv_groups):
q = jnp.squeeze(q)
k = jnp.squeeze(jnp.repeat(k, kv_groups, axis=2))
v = jnp.squeeze(jnp.repeat(v, kv_groups, axis=2))
output = dot_product_attention(
q,
k,
v,
bias=None,
mask=mask,
deterministic=is_training,
dropout_rate=dropout_prob,
dropout_rng=None,
dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype)

q, k, v, mask = self.generate_inputs(data_shape, kv_groups, attn_mask_type, dtype)

# Single GPU (reference)
ref_func_jit = jax.jit(jax.value_and_grad(ref_func, argnums=[0, 1, 2]), static_argnums=[4])
ref_fwd, ref_grads = ref_func_jit(q, k, v, mask, kv_groups)

# Multi GPU (function under test)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
qkv_ps = PartitionSpec(
mesh_resource.dp_resource,
mesh_resource.cp_resource,
mesh_resource.tp_resource,
None,
)
qkv_sharding = NamedSharding(mesh, qkv_ps)

mask_ps = PartitionSpec(
mesh_resource.dp_resource, None, mesh_resource.cp_resource, None
)
mask_sharding = NamedSharding(mesh, mask_ps)

reorder = partial(
reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format
)
inverse_reorder = partial(
inverse_reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format
)

if load_balanced:
q, k, v = jax.tree.map(reorder, (q, k, v))

q_, k_, v_ = map(partial(jax.device_put, device=qkv_sharding), [q, k, v])
mask_ = jax.device_put(mask, device=mask_sharding)

target_func_jit = jax.jit(
jax.value_and_grad(target_func, argnums=[0, 1, 2]),
in_shardings=[qkv_sharding, qkv_sharding, qkv_sharding, mask_sharding],
out_shardings=(None, (qkv_sharding, qkv_sharding, qkv_sharding)),
)

target_fwd, target_grads = target_func_jit(q_, k_, v_, mask_)

if load_balanced:
target_dq, target_dk, target_dv = jax.tree.map(inverse_reorder, target_grads[0:3])
target_grads = (target_dq, target_dk, target_dv, *target_grads[3:])

def _print_diffs(target, ref):
print("min: ", jnp.min(target), jnp.min(ref))
print("max: ", jnp.max(target), jnp.max(ref))
print("mean: ", jnp.mean(target), jnp.mean(ref))
print("median: ", jnp.median(target), jnp.median(ref))
print("std: ", jnp.std(target), jnp.std(ref))
print("var: ", jnp.var(target), jnp.var(ref))
print("max diff: ", jnp.max(jnp.abs(target - ref)))

has_diffs = False

try:
assert_allclose(target_fwd, ref_fwd, dtype=dtype)
except AssertionError as e:
has_diffs = True
print(f"target_fwd v. ref_fwd")
_print_diffs(target_fwd, ref_fwd)

for i in range(len(target_grads)):
if ref_grads[i] is None or target_grads[i] is None:
# expect both none if one is
assert target_grads[i] is None and ref_grads[i] is None
else:
try:
assert_allclose(target_grads[i], ref_grads[i])
except AssertionError as e:
has_diffs = True
print(f"target_grads[{i}] v. ref_grads[{i}]")
_print_diffs(target_grads[i], ref_grads[i])

assert has_diffs == False, "has_diffs != False"


class TestReorderCausalLoadBalancing:
@pytest.mark.parametrize("cp_size", [2, 4, 8])
@pytest.mark.parametrize(
"shape",
[
pytest.param([1, 16, 1, 1], id="1-16-1-1"),
pytest.param([4, 32, 12, 32], id="4-32-12-32"),
pytest.param([3, 32, 8, 64], id="3-32-8-64"),
],
)
@pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
def test(self, cp_size, shape, qkv_format):
tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
if qkv_format == QKVFormat.SBHD:
tensor = tensor.swapaxes(0, 1)

ref = tensor.copy()

reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2])
inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2])

reordered = reorder(tensor, cp_size, qkv_format)
inversed = inverse(reordered, cp_size, qkv_format)

assert jnp.array_equal(inversed, ref)
Loading

0 comments on commit ba88ede

Please sign in to comment.