Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

roll back fmha/common.py #5

Merged
merged 3 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 3 additions & 128 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,47 +320,6 @@ def T(t):
return out.permute((0, 2, 1, 3))


# this interface assumes the tensor is in BMHK, but q and k/v might have different number of heads
def ref_attention_mqa(q, k, v, attn_bias=None, drop_mask=None, p=0.0, scale=None):
assert q.ndim == 4

B, M, Hq, K = q.shape
_, N, Hkv, Kv = v.shape
nhead_ratio_qk = Hq // Hkv

def attn_bias_head(head: int):
if isinstance(attn_bias, torch.Tensor):
assert attn_bias.ndim == 4
_, H, _, _ = attn_bias.shape
assert H == Hq
bias_bghmn = attn_bias.reshape(B, Hkv, nhead_ratio_qk, M, N)
return bias_bghmn[:, :, head]
if isinstance(attn_bias, fmha.attn_bias.LowerTriangularMaskWithTensorBias):
assert attn_bias._bias.ndim == 4
_, H, _, _ = attn_bias._bias.shape
assert H == Hq
bias_bghmn = attn_bias._bias.reshape(B, Hkv, nhead_ratio_qk, M, N)
return fmha.attn_bias.LowerTriangularMaskWithTensorBias(
bias_bghmn[:, :, head]
)
return attn_bias

q_bmghk = q.reshape((B, M, Hkv, nhead_ratio_qk, K))

return torch.stack(
[
ref_attention_bmhk(
q_bmghk[:, :, :, h],
k,
v,
attn_bias=attn_bias_head(h),
)
for h in range(q_bmghk.shape[3])
],
dim=3,
).reshape((B, M, Hq, Kv))


def _rand_partition(r: random.Random, total: int, n: int) -> List[int]:
# returns list of n nonnegative integers summing to total
idx = {0, total}
Expand Down Expand Up @@ -571,92 +530,6 @@ def test_forward(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, packed, fmt, **kwargs)
)


@rocm_only
@pytest.mark.parametrize("hdim_k,hdim_v", [(64, 64), (128, 128)])
@pytest.mark.parametrize("nhead_q,nhead_kv", [(8, 1), (8, 2), (12, 4), (4, 4)])
@pytest.mark.parametrize("seqlen_q,seqlen_kv", [(100, 128), (128, 100), (200, 1000)])
@pytest.mark.parametrize("batches", [100, 64, 1])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
"attn_bias_type", [type(None), torch.Tensor, fmha.attn_bias.LowerTriangularMask]
)
@pytest.mark.parametrize("op", [fmha.ck.FwOp])
def test_mqa_forward(
op,
attn_bias_type,
dtype,
batches: int,
seqlen_kv: int,
seqlen_q: int,
nhead_kv: int,
nhead_q: int,
hdim_v: int,
hdim_k: int,
):
B = batches
M = seqlen_q
N = seqlen_kv
Hq = nhead_q
Hkv = nhead_kv
K = hdim_k
Kv = hdim_v
nhead_ratio_qk = Hq // Hkv

device = torch.device("cuda")

torch.manual_seed(B * M + N * K + Hq * Hkv + Kv)

scale = 3
query = torch.randn((B, M, Hq, K), device=device, dtype=dtype).mul_(scale)
key = torch.randn((B, N, Hkv, K), device=device, dtype=dtype).mul_(scale)
value = torch.randn((B, N, Hkv, Kv), device=device, dtype=dtype).mul_(scale)

attn_bias = None
if attn_bias_type is not None:
attn_bias = create_attn_bias(
attn_bias_type,
batch_size=B,
num_heads=Hq,
num_heads_groups=nhead_ratio_qk,
q_len=M,
kv_len=N,
dtype=dtype,
device=device,
requires_grad=False,
fmt="BMHK",
op=op,
)

inputs = fmha.Inputs(query=query, key=key, value=value, attn_bias=attn_bias)
reasons = op.not_supported_reasons(inputs)
if reasons:
err_msg = f"{op.NAME}: unsupported ({'/'.join(reasons)})"
# Ensure we free memory to avoid OOMs
del query, key, value, attn_bias, inputs
assert False, err_msg

out = xformers.ops.memory_efficient_attention_forward(
query, key, value, attn_bias, op=op
)
assert not out.isnan().any(), ("Output has NaNs", attn_bias)
out2 = xformers.ops.memory_efficient_attention_forward(
query, key, value, attn_bias, op=op
)
assert torch.allclose(out, out2, atol=0.0, rtol=0.0), (
"Non-deterministic behavior",
attn_bias,
)

ref = ref_attention_mqa(query, key, value, attn_bias)
assert out.shape == ref.shape, out.shape
assert_allclose(
out.float(),
ref,
atol=op.ERROR_ATOL[dtype],
rtol=op.ERROR_RTOL.get(dtype, 1e-5),
)


@cuda_only
@pytest.mark.parametrize("k_len", [5, 6, 32])
@pytest.mark.parametrize("batch_size", [1, 4])
Expand Down Expand Up @@ -2328,7 +2201,9 @@ def test_forward_splitk(

@cuda_only
@pytest.mark.parametrize(
"op", [fmha.triton_splitk.FwOp, fmha.flash.FwOp], ids=lambda op: op.NAME
"op",
[fmha.triton_splitk.FwOp, fmha.flash.FwOp, fmha.ck.FwOp],
ids=lambda op: op.NAME,
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=str)
@pytest.mark.parametrize(
Expand Down
6 changes: 2 additions & 4 deletions xformers/ops/fmha/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,11 @@ def validate_inputs(self) -> None:
and self.value.shape == (B, Mkv, Kv)
)
H = self.query.shape[-2]
Hkv = self.key.shape[-2]
if self.query.ndim == 4: # BMHK
valid_shapes = (
self.query.shape == (B, Mq, H, K)
and self.key.shape == (B, Mkv, Hkv, key_embed_dim)
and self.value.shape == (B, Mkv, Hkv, Kv)
and H % Hkv == 0
and self.key.shape == (B, Mkv, H, key_embed_dim)
and self.value.shape == (B, Mkv, H, Kv)
)
G = self.query.shape[2]
if self.query.ndim == 5: # BMNHK
Expand Down
Loading