Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Context Parallel Attention with All-Gather #1106

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tests/jax/distributed_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import operator
import re
from functools import reduce
from itertools import product
import pytest

import jax
from jax.experimental.pjit import pjit, _UNSPECIFIED
Expand All @@ -29,6 +31,28 @@ def generate_configs():
return configs


def generate_context_parallel_configs():
configs = []

DP_sizes = (1, 2)
CP_sizes = (1, 2, 4, 8)
TP_sizes = (1, 2)
for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes):
ndev = cp * tp * dp
if is_devices_enough(ndev):
configs.append(
pytest.param(
ndev,
(dp, cp, tp),
("dp", "cp", "tp"),
MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp"),
id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}",
)
)

return configs


COLL_AR_KEY = "all-reduce"
COLL_AG_KEY = "all-gather"
COLL_OTHER_KEY = "other"
Expand Down
245 changes: 241 additions & 4 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,32 @@
# See LICENSE for license information.

import pytest
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
from flax.linen import dot_product_attention
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs, generate_collectives_count, compare_ops
from utils import make_causal_mask, make_self_mask
from distributed_test_base import (
generate_configs,
generate_context_parallel_configs,
generate_collectives_count,
compare_ops,
)
from utils import make_causal_mask, make_self_mask, assert_tree_like_allclose, assert_allclose
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
fused_attn,
AttnBiasType,
AttnMaskType,
QKVLayout,
QKVFormat,
get_qkv_format,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
)


Expand Down Expand Up @@ -263,7 +273,8 @@ def target_func(q, kv, mask):
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
)
),
dtype=jnp.float32,
)

def ref_func(query, kv, mask):
Expand All @@ -284,7 +295,7 @@ def ref_func(query, kv, mask):
dtype=jnp.float32,
)

return jnp.mean(output).astype(dtype)
return jnp.mean(output, dtype=jnp.float32)

(q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, attn_mask_type, dtype
Expand All @@ -310,3 +321,229 @@ def ref_func(query, kv, mask):
in_shardings=(q_pspec, kv_pspec, mask_pspec),
out_shardings=(None, (q_pspec, kv_pspec)),
)


class TestDistributedContexParallelSelfAttn:

def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype):
batch, seqlen, heads, hidden = shape
qkey, kkey, vkey = random.split(random.PRNGKey(1124), 3)
q = random.normal(qkey, shape, dtype=dtype)
k = random.normal(kkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype)
v = random.normal(vkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype)

mask = None
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
mask = make_causal_mask(batch, seqlen)

return q, k, v, mask

def qkv_to_layout(self, q, k, v, qkv_layout):
qkv_args = ()
match qkv_layout:
case QKVLayout.BSHD_BS2HD:
k, v = map(partial(jnp.expand_dims, axis=-3), [k, v])
kv = jnp.concatenate((k, v), axis=-3)
qkv_args = (q, kv)
case QKVLayout.BSHD_BSHD_BSHD:
qkv_args = (q, k, v)
case _:
raise ValueError(f"Unsupported {qkv_layout=}")
return qkv_args

@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
"data_shape",
[
pytest.param([2, 512, 12, 128], id="2-512-12-128"),
pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
],
)
@pytest.mark.parametrize(
"load_balanced", [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")]
)
def test_contex_parallel_self_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
):
attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0
is_training = True
scaling_factor = 1.0
dp_size, cp_size, tp_size = mesh_shape
qkv_format = get_qkv_format(qkv_layout)

_, seqlen, num_head, hidden = data_shape
num_kv_heads = num_head // kv_groups

# make sure the mesh evently divides cp and tp axis
if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")

def target_func(q, k, v, mask):
return jnp.mean(
fused_attn(
self.qkv_to_layout(q, k, v, qkv_layout),
bias=None,
mask=mask,
seed=None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
context_parallel_causal_load_balanced=load_balanced,
),
mgoldfarb-nvidia marked this conversation as resolved.
Show resolved Hide resolved
).astype(dtype)

def ref_func(q, k, v, mask, kv_groups):
q = jnp.squeeze(q)
k = jnp.squeeze(jnp.repeat(k, kv_groups, axis=2))
v = jnp.squeeze(jnp.repeat(v, kv_groups, axis=2))
output = dot_product_attention(
q,
k,
v,
bias=None,
mask=mask,
deterministic=is_training,
dropout_rate=dropout_prob,
dropout_rng=None,
dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype)

q, k, v, mask = self.generate_inputs(data_shape, kv_groups, attn_mask_type, dtype)

# Single GPU (reference)
ref_func_jit = jax.jit(jax.value_and_grad(ref_func, argnums=[0, 1, 2]), static_argnums=[4])
ref_fwd, ref_grads = ref_func_jit(q, k, v, mask, kv_groups)

# Multi GPU (function under test)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
qkv_ps = PartitionSpec(
mesh_resource.dp_resource,
mesh_resource.cp_resource,
mesh_resource.tp_resource,
None,
)
qkv_sharding = NamedSharding(mesh, qkv_ps)

mask_ps = PartitionSpec(
mesh_resource.dp_resource, None, mesh_resource.cp_resource, None
)
mask_sharding = NamedSharding(mesh, mask_ps)

reorder = partial(
reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format
)
inverse_reorder = partial(
inverse_reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format
)

if load_balanced:
q, k, v = jax.tree.map(reorder, (q, k, v))

q_, k_, v_ = map(partial(jax.device_put, device=qkv_sharding), [q, k, v])
mask_ = jax.device_put(mask, device=mask_sharding)

target_func_jit = jax.jit(
jax.value_and_grad(target_func, argnums=[0, 1, 2]),
in_shardings=[qkv_sharding, qkv_sharding, qkv_sharding, mask_sharding],
out_shardings=(None, (qkv_sharding, qkv_sharding, qkv_sharding)),
)

target_fwd, target_grads = target_func_jit(q_, k_, v_, mask_)

if load_balanced:
target_dq, target_dk, target_dv = jax.tree.map(inverse_reorder, target_grads[0:3])
target_grads = (target_dq, target_dk, target_dv, *target_grads[3:])

def _print_diffs(target, ref):
print("min: ", jnp.min(target), jnp.min(ref))
print("max: ", jnp.max(target), jnp.max(ref))
print("mean: ", jnp.mean(target), jnp.mean(ref))
print("median: ", jnp.median(target), jnp.median(ref))
print("std: ", jnp.std(target), jnp.std(ref))
print("var: ", jnp.var(target), jnp.var(ref))
print("max diff: ", jnp.max(jnp.abs(target - ref)))

has_diffs = False

try:
assert_allclose(target_fwd, ref_fwd, dtype=dtype)
except AssertionError as e:
has_diffs = True
print(f"target_fwd v. ref_fwd")
_print_diffs(target_fwd, ref_fwd)

for i in range(len(target_grads)):
if ref_grads[i] is None or target_grads[i] is None:
# expect both none if one is
assert target_grads[i] is None and ref_grads[i] is None
else:
try:
assert_allclose(target_grads[i], ref_grads[i])
except AssertionError as e:
has_diffs = True
print(f"target_grads[{i}] v. ref_grads[{i}]")
_print_diffs(target_grads[i], ref_grads[i])

assert has_diffs == False, "has_diffs != False"


class TestReorderCausalLoadBalancing:
@pytest.mark.parametrize("cp_size", [2, 4, 8])
@pytest.mark.parametrize(
"shape",
[
pytest.param([1, 16, 1, 1], id="1-16-1-1"),
pytest.param([4, 32, 12, 32], id="4-32-12-32"),
pytest.param([3, 32, 8, 64], id="3-32-8-64"),
],
)
@pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
def test(self, cp_size, shape, qkv_format):
tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
if qkv_format == QKVFormat.SBHD:
tensor = tensor.swapaxes(0, 1)

ref = tensor.copy()

reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2])
inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2])

reordered = reorder(tensor, cp_size, qkv_format)
inversed = inverse(reordered, cp_size, qkv_format)

assert jnp.array_equal(inversed, ref)
Loading
Loading