Skip to content

Commit

Permalink
Dedupe FA/PA attn toggles, set FA off by default
Browse files Browse the repository at this point in the history
  • Loading branch information
mawong-amd committed Dec 19, 2024
1 parent 7908e9b commit 1ed1389
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 15 deletions.
4 changes: 3 additions & 1 deletion vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,9 @@ def forward(
full_scales = (
1.0 / q_scale.item(), 1.0 / k_scale.item(),
1.0 / v_scale.item(), 1.0 / prob_scale.item(),
fp8_out_scale.item()) if fp8_out_scale else None
fp8_out_scale.item()) if (
fp8_out_scale
and envs.VLLM_USE_ROCM_FP8_FLASH_ATTN) else None
out, _ = self.attn_func(
query,
key,
Expand Down
2 changes: 0 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def __init__(
logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None,
prefix: str = "",
use_fp8: bool = False,
) -> None:
super().__init__()
if per_layer_sliding_window is not None:
Expand Down Expand Up @@ -74,7 +73,6 @@ def __init__(
# with the model weights.
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
self.use_fp8 = use_fp8
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/ops/triton_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def attn_fwd(
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = (mask_m_offsets[:, None] >=
out_mask_boundary[None, :])
z = 0.0
z = tl.zeros((1, ), tl.float32)
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
Expand Down
13 changes: 10 additions & 3 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
VLLM_USE_TRITON_FLASH_ATTN: bool = True
VLLM_USE_ROCM_SKINNY_GEMM: bool = True
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True
VLLM_USE_ROCM_FP8_ATTN: bool = True
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = True
VLLM_USE_ROCM_FP8_FLASH_ATTN: bool = False
RANK: int = 0
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
Expand Down Expand Up @@ -246,8 +247,14 @@ def get_default_config_root():
("true", "1")),

# have custom paged attention implemented for MI3* cards write out fp8
"VLLM_USE_ROCM_FP8_ATTN":
lambda: (os.getenv("VLLM_USE_ROCM_FP8_ATTN", "True").lower() in
"VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT":
lambda:
(os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT", "True").lower() in
("true", "1")),

# use quantized q,k,v,softmax(qk^T), attn output during prefill
"VLLM_USE_ROCM_FP8_FLASH_ATTN":
lambda: (os.getenv("VLLM_USE_ROCM_FP8_FLASH_ATTN", "False").lower() in
("true", "1")),

# rank of the process in the distributed setting, used to determine
Expand Down
10 changes: 6 additions & 4 deletions vllm/model_executor/layers/quantization/kv_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch

import vllm.envs as envs
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.platforms import current_platform
Expand Down Expand Up @@ -76,18 +77,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer._k_scale.copy_(k_scale)
layer._v_scale.copy_(v_scale)
if (k_scale == 1.0 and v_scale == 1.0
and (layer.kv_cache_dtype != "auto" or layer.use_fp8)
and (layer.kv_cache_dtype != "auto"
or envs.VLLM_USE_ROCM_FP8_FLASH_ATTN)
and "e5m2" not in layer.kv_cache_dtype):
print_warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint.")

if layer.q_scale > 0.0 and layer.prob_scale > 0.0:
if layer.q_scale > 0.0:
q_scale = layer.q_scale.to("cpu").tolist()
if current_platform.is_rocm() and not is_navi():
q_scale *= 2
layer.calculate_kv_scales = False
else:
q_scale = 1.0
if layer.prob_scale > 0.0:
Expand All @@ -104,7 +105,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# These are used in the final Attention.forward()
layer._q_scale.copy_(q_scale)
layer._prob_scale.copy_(prob_scale)
if (q_scale == 1.0 or prob_scale == 1.0) and layer.use_fp8:
if (q_scale == 1.0
or prob_scale == 1.0) and envs.VLLM_USE_ROCM_FP8_FLASH_ATTN:
print_warning_once(
f"Using Q scale {q_scale} and prob scale {prob_scale} "
"with fp8 attention. This may cause accuracy issues. "
Expand Down
7 changes: 3 additions & 4 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def __init__(
sliding_window = None

# For CUDA devices and Navi4x, attn_fp8 will be set to false.
self.attn_fp8 = envs.VLLM_USE_ROCM_FP8_ATTN \
self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \
and current_platform.is_rocm() \
and not is_navi() \
and isinstance(quant_config, Fp8Config)
Expand All @@ -212,7 +212,6 @@ def __init__(
quant_config=quant_config,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn",
use_fp8=self.attn_fp8,
)

def forward(
Expand All @@ -232,8 +231,8 @@ def forward(
attn_metadata,
fp8_comp_scales=(self.attn._q_scale,
self.attn._prob_scale,
self.o_proj.input_scale)
if self.attn_fp8 else None)
self.o_proj.input_scale
if self.attn_fp8_out else None))

Check failure on line 235 in vllm/model_executor/models/llama.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/models/llama.py:235:81: E501 Line too long (81 > 80)
output, _ = self.o_proj(attn_output)
return output

Expand Down

0 comments on commit 1ed1389

Please sign in to comment.