From 9639307ce58f8732a944f2a6826a15743112e48e Mon Sep 17 00:00:00 2001 From: Matthew Wong Date: Thu, 19 Dec 2024 23:12:03 +0000 Subject: [PATCH] Dedupe FA/PA attn toggles, set FA off by default --- vllm/attention/backends/rocm_flash_attn.py | 4 +++- vllm/attention/layer.py | 2 -- vllm/attention/ops/triton_flash_attention.py | 2 +- vllm/envs.py | 13 ++++++++++--- vllm/model_executor/layers/quantization/kv_cache.py | 10 ++++++---- vllm/model_executor/models/llama.py | 6 +++--- 6 files changed, 23 insertions(+), 14 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index f1bb90550a045..5a146940765bb 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -686,7 +686,9 @@ def forward( full_scales = ( 1.0 / q_scale.item(), 1.0 / k_scale.item(), 1.0 / v_scale.item(), 1.0 / prob_scale.item(), - fp8_out_scale.item()) if fp8_out_scale else None + fp8_out_scale.item()) if ( + fp8_out_scale + and envs.VLLM_USE_ROCM_FP8_FLASH_ATTN) else None out, _ = self.attn_func( query, key, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 65e0b714aad55..392736137aa12 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -42,7 +42,6 @@ def __init__( logits_soft_cap: Optional[float] = None, per_layer_sliding_window: Optional[int] = None, prefix: str = "", - use_fp8: bool = False, ) -> None: super().__init__() if per_layer_sliding_window is not None: @@ -74,7 +73,6 @@ def __init__( # with the model weights. self.kv_cache_dtype = kv_cache_dtype self.calculate_kv_scales = calculate_kv_scales - self.use_fp8 = use_fp8 self._k_scale = torch.tensor(1.0, dtype=torch.float32) self._v_scale = torch.tensor(1.0, dtype=torch.float32) self._q_scale = torch.tensor(1.0, dtype=torch.float32) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 2e647a8e52278..828bdc2905957 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -742,7 +742,7 @@ def attn_fwd( mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) out_ptrs_mask = (mask_m_offsets[:, None] >= out_mask_boundary[None, :]) - z = 0.0 + z = tl.zeros((1, ), tl.float32) acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m diff --git a/vllm/envs.py b/vllm/envs.py index 70b189be974ab..eddfa174ed55a 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -16,7 +16,8 @@ VLLM_USE_TRITON_FLASH_ATTN: bool = True VLLM_USE_ROCM_SKINNY_GEMM: bool = True VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True - VLLM_USE_ROCM_FP8_ATTN: bool = True + VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = True + VLLM_USE_ROCM_FP8_FLASH_ATTN: bool = False RANK: int = 0 LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None @@ -246,8 +247,14 @@ def get_default_config_root(): ("true", "1")), # have custom paged attention implemented for MI3* cards write out fp8 - "VLLM_USE_ROCM_FP8_ATTN": - lambda: (os.getenv("VLLM_USE_ROCM_FP8_ATTN", "True").lower() in + "VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT": + lambda: + (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT", "True").lower() in + ("true", "1")), + + # use quantized q,k,v,softmax(qk^T), attn output during prefill + "VLLM_USE_ROCM_FP8_FLASH_ATTN": + lambda: (os.getenv("VLLM_USE_ROCM_FP8_FLASH_ATTN", "False").lower() in ("true", "1")), # rank of the process in the distributed setting, used to determine diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index bcb28193a4056..c8f9e86a9cc7c 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -1,5 +1,6 @@ import torch +import vllm.envs as envs from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.platforms import current_platform @@ -76,18 +77,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer._k_scale.copy_(k_scale) layer._v_scale.copy_(v_scale) if (k_scale == 1.0 and v_scale == 1.0 - and (layer.kv_cache_dtype != "auto" or layer.use_fp8) + and (layer.kv_cache_dtype != "auto" + or envs.VLLM_USE_ROCM_FP8_FLASH_ATTN) and "e5m2" not in layer.kv_cache_dtype): print_warning_once( "Using KV cache scaling factor 1.0 for fp8_e4m3. This " "may cause accuracy issues. Please make sure k/v_scale " "scaling factors are available in the fp8 checkpoint.") - if layer.q_scale > 0.0 and layer.prob_scale > 0.0: + if layer.q_scale > 0.0: q_scale = layer.q_scale.to("cpu").tolist() if current_platform.is_rocm() and not is_navi(): q_scale *= 2 - layer.calculate_kv_scales = False else: q_scale = 1.0 if layer.prob_scale > 0.0: @@ -104,7 +105,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # These are used in the final Attention.forward() layer._q_scale.copy_(q_scale) layer._prob_scale.copy_(prob_scale) - if (q_scale == 1.0 or prob_scale == 1.0) and layer.use_fp8: + if (q_scale == 1.0 + or prob_scale == 1.0) and envs.VLLM_USE_ROCM_FP8_FLASH_ATTN: print_warning_once( f"Using Q scale {q_scale} and prob scale {prob_scale} " "with fp8 attention. This may cause accuracy issues. " diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 69d63216d379d..760b6eecfa3bd 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -198,7 +198,7 @@ def __init__( sliding_window = None # For CUDA devices and Navi4x, attn_fp8 will be set to false. - self.attn_fp8 = envs.VLLM_USE_ROCM_FP8_ATTN \ + self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \ and current_platform.is_rocm() \ and not is_navi() \ and isinstance(quant_config, Fp8Config) @@ -232,8 +232,8 @@ def forward( attn_metadata, fp8_comp_scales=(self.attn._q_scale, self.attn._prob_scale, - self.o_proj.input_scale) - if self.attn_fp8 else None) + self.o_proj.input_scale + if self.attn_fp8_out else None)) output, _ = self.o_proj(attn_output) return output