Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-applying G42 bias triton fix on 0.4.3 #41

Merged
merged 3 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 79 additions & 8 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,62 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
)
return self._cached_decode_metadata

def _make_alibi_bias(
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
seq_lens: List[int],
) -> List[torch.Tensor]:
attn_biases = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]

num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat((num_heads, 1, 1)).to(alibi_slopes.device)
bias.mul_(alibi_slopes[:, None, None])
inf_mask = torch.empty(
(1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(alibi_slopes.device)
attn_biases.append((bias + inf_mask).to(dtype))

return attn_biases


def _make_alibi_bias_v2(
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
seq_lens: List[int],
make_attn_mask: bool = True
) -> List[torch.Tensor]:
attn_biases = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]

num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat((num_heads, 1, 1)).to(alibi_slopes.device)
bias.mul_(alibi_slopes[:, None, None])
if make_attn_mask:
inf_mask = torch.empty(
(1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(alibi_slopes.device)
attn_biases.append((bias + inf_mask).to(dtype))
else:
attn_biases.append(bias.to(dtype))

return attn_biases



class ROCmFlashAttentionImpl(AttentionImpl):
"""
Expand Down Expand Up @@ -324,7 +380,12 @@ def forward(
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
att_masks = None
if self.use_triton_flash_attn:
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias_v2(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens, make_attn_mask=False) # type: ignore
out, _ = self.attn_func(
query,
key,
Expand All @@ -336,8 +397,13 @@ def forward(
prefill_meta.max_prefill_seq_len,
True,
self.scale,
att_masks[0][None] if att_masks is not None else None,
)
elif self.use_naive_attn:
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias_v2(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens, make_attn_mask=True) # type: ignore
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
key = self.repeat_kv(key, self.num_queries_per_kv)
Expand All @@ -348,6 +414,7 @@ def forward(
value,
prefill_meta.seq_lens,
self.scale,
att_masks
)
else:
out = self.attn_func(
Expand Down Expand Up @@ -408,16 +475,18 @@ def _naive_attention(
value: torch.Tensor,
seq_lens: List[int],
scale: float,
attn_masks: Optional[List[torch.Tensor]],
) -> torch.Tensor:
output = torch.empty_like(query)
start = 0
for _, seq_len in enumerate(seq_lens):
for i, seq_len in enumerate(seq_lens):
end = start + seq_len
out = _naive_masked_attention(
query[start:end],
key[start:end],
value[start:end],
scale,
attn_masks[i],
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out)
Expand All @@ -431,16 +500,18 @@ def _naive_masked_attention(
key: torch.Tensor,
value: torch.Tensor,
scale: float,
attn_mask: Optional[torch.Tensor],
) -> torch.Tensor:
seq_len, head_size, head_dim = query.shape
attn_mask = torch.triu(torch.ones(seq_len,
seq_len,
dtype=query.dtype,
device=query.device),
diagonal=1)
attn_mask = attn_mask * torch.finfo(query.dtype).min
if attn_mask is None:
attn_mask = torch.triu(torch.ones(seq_len,
seq_len,
dtype=query.dtype,
device=query.device),
diagonal=1)
attn_mask = attn_mask * torch.finfo(query.dtype).min
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
return out
return out
Loading
Loading