Skip to content

Commit

Permalink
revert flash attn hacks
Browse files Browse the repository at this point in the history
  • Loading branch information
tlrmchlsmth committed Nov 1, 2024
1 parent 1f832ba commit b5a161f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 6 deletions.
25 changes: 22 additions & 3 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def forward(
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")

output = unified_flash_attention(
output = torch.ops.vllm.unified_flash_attention(
query,
key,
value,
Expand All @@ -595,8 +595,8 @@ def forward(
return output


#@torch.library.custom_op("vllm::unified_flash_attention",
# mutates_args=["kv_cache"])
@torch.library.custom_op("vllm::unified_flash_attention",
mutates_args=["kv_cache"])
def unified_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -754,3 +754,22 @@ def unified_flash_attention(
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)


@unified_flash_attention.register_fake
def _(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
24 changes: 21 additions & 3 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def forward(
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")

output = unified_flash_attention(
output = torch.ops.vllm.unified_flash_attention(
query,
key,
value,
Expand All @@ -152,8 +152,8 @@ def forward(
return output


#@torch.library.custom_op("vllm::unified_flash_attention",
# mutates_args=["kv_cache"])
@torch.library.custom_op("vllm::unified_flash_attention",
mutates_args=["kv_cache"])
def unified_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -217,3 +217,21 @@ def unified_flash_attention(
return output.view(num_tokens, hidden_size)


@unified_flash_attention.register_fake
def _(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)

0 comments on commit b5a161f

Please sign in to comment.