diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 3e3c0668198ad..df68de908eb31 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -15,6 +15,8 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask + from vllm.attention.backends.xformers import _make_alibi_bias + FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer @@ -337,20 +339,26 @@ def ref_multi_query_kv_attention( key: torch.Tensor, value: torch.Tensor, scale: float, + alibi_bias: Optional[List[torch.Tensor]], dtype: torch.dtype, ) -> torch.Tensor: num_seqs = len(cu_seq_lens) - 1 ref_outputs: List[torch.Tensor] = [] + if alibi_bias: + assert len(alibi_bias) == num_seqs for i in range(num_seqs): start_idx = cu_seq_lens[i] end_idx = cu_seq_lens[i + 1] seq_len = end_idx - start_idx - # Create attention mask. - attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), - diagonal=1) - attn_mask = attn_mask * torch.finfo(dtype).min - attn_mask = attn_mask.to(dtype=dtype) + # Create attention mask. ALiBi already includes a tril causal mask. + if alibi_bias: + attn_mask = alibi_bias[i] + else: + attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype) ref_output = ref_masked_attention( query[start_idx:end_idx], @@ -364,10 +372,10 @@ def ref_multi_query_kv_attention( return torch.cat(ref_outputs, dim=0) -# TODO(woosuk): Add tests for USE_ALIBI=True. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", USE_ALIBI) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -378,6 +386,7 @@ def test_multi_query_kv_attention( num_seqs: int, num_heads: Tuple[int, int], head_size: int, + use_alibi: bool, dtype: torch.dtype, seed: int, device: str, @@ -406,16 +415,40 @@ def test_multi_query_kv_attention( # Handle MQA and GQA key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) - attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) - output = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - attn_bias=attn_bias, - p=0.0, - scale=scale, - ) - output = output.squeeze(0) + alibi_bias = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) + attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, + seq_lens) + output = torch.empty_like(query) + start = 0 + # Dynamic sequence length not supported with custom attn_bias. + for i, seq_len in enumerate(seq_lens): + end = start + seq_len + out = xops.memory_efficient_attention_forward( + query[None, start:end], + key[None, start:end], + value[None, start:end], + attn_bias=attn_bias[i], + p=0.0, + scale=scale) + output[start:end].copy_(out.view_as(query[start:end])) + start += seq_len + # xformers.AttentionBias to Tensor for use in reference impl. + alibi_bias = [ + b.materialize(b.shape, device=device).squeeze() for b in attn_bias + ] + else: + attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) + output = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + ) + output = output.squeeze(0) cu_seq_lens = [0] for seq_len in seq_lens: @@ -426,6 +459,7 @@ def test_multi_query_kv_attention( key, value, scale, + alibi_bias, dtype, ) atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 3e59b3603d2c6..951b01b2e4237 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -779,8 +779,6 @@ def _make_alibi_bias( dtype=dtype, )[:, :, :, :seq_len].copy_(bias) bias.mul_(alibi_slopes[:, None, None]) - if num_heads != num_kv_heads: - bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) attn_biases.append(LowerTriangularMaskWithTensorBias(bias)) return attn_biases