From 9ba2fab2748c06f34816185118bb79b71544e790 Mon Sep 17 00:00:00 2001 From: Matthew Wong Date: Thu, 19 Dec 2024 01:07:21 +0000 Subject: [PATCH 1/7] Ingest FP8 attn scales and use them in Triton FA, if present --- vllm/attention/backends/abstract.py | 2 +- vllm/attention/backends/blocksparse_attn.py | 2 +- vllm/attention/backends/flash_attn.py | 2 +- vllm/attention/backends/flashinfer.py | 2 +- vllm/attention/backends/hpu_attn.py | 2 +- vllm/attention/backends/ipex_attn.py | 2 +- vllm/attention/backends/pallas.py | 2 +- vllm/attention/backends/rocm_flash_attn.py | 10 ++++- vllm/attention/backends/torch_sdpa.py | 2 +- vllm/attention/backends/xformers.py | 2 +- vllm/attention/layer.py | 14 ++++--- vllm/envs.py | 11 +++--- .../quantization/compressed_tensors/utils.py | 4 ++ .../layers/quantization/kv_cache.py | 39 +++++++++++++++++-- vllm/model_executor/models/aria.py | 4 +- vllm/model_executor/models/exaone.py | 4 +- vllm/model_executor/models/granite.py | 4 +- vllm/model_executor/models/llama.py | 24 ++++++++---- vllm/model_executor/models/solar.py | 4 +- 19 files changed, 99 insertions(+), 37 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 3127285df6c63..de34c0da55b4f 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -252,6 +252,6 @@ def forward( v_scale: torch.Tensor, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 1869dbab0cbf5..e6f63a1ff0fc9 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -363,7 +363,7 @@ def forward( v_scale: torch.Tensor, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c640de998149e..94a7478f8eb51 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -642,7 +642,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 2285f20c1c8af..5301474833163 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -777,7 +777,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: # TODO: directly write to output tensor diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index b3065495ab396..efb3ac980aaba 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -154,7 +154,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index d02fbcb6ca0ae..0225554b0ceb1 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -174,7 +174,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index a2612c97ca23b..7c702291513d2 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -152,7 +152,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index efaa74f67bafd..335bae56f6ab4 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -551,7 +551,7 @@ def forward( v_scale: torch.Tensor, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: torch.Tensor = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -601,6 +601,8 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + q_scale, prob_scale, fp8_out_scale = fp8_comp_scales or (None, None, + None) query = query.view(-1, self.num_heads, self.head_size) if key is not None: @@ -681,6 +683,10 @@ def forward( query.dtype, seq_lens, make_attn_mask=False) # type: ignore + full_scales = ( + 1.0 / q_scale.item(), 1.0 / k_scale.item(), + 1.0 / v_scale.item(), 1.0 / prob_scale.item(), + fp8_out_scale.item()) if fp8_comp_scales else None out, _ = self.attn_func( query, key, @@ -694,7 +700,7 @@ def forward( self.scale, attn_masks[0][None] if attn_masks is not None else None, - None, + full_scales, ) elif self.use_naive_attn: if self.num_kv_heads != self.num_heads: diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 8e586c00024d5..660408f3ea477 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -434,7 +434,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index dcd60da43d520..d0f1034dc966f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -420,7 +420,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 800a1ff2f4f65..d0d981e04db15 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -1,5 +1,5 @@ """Attention layer.""" -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -42,6 +42,7 @@ def __init__( logits_soft_cap: Optional[float] = None, per_layer_sliding_window: Optional[int] = None, prefix: str = "", + use_fp8: bool = False, ) -> None: super().__init__() if per_layer_sliding_window is not None: @@ -73,8 +74,11 @@ def __init__( # with the model weights. self.kv_cache_dtype = kv_cache_dtype self.calculate_kv_scales = calculate_kv_scales + self.use_fp8 = use_fp8 self._k_scale = torch.tensor(1.0, dtype=torch.float32) self._v_scale = torch.tensor(1.0, dtype=torch.float32) + self._q_scale = torch.tensor(1.0, dtype=torch.float32) + self._prob_scale = torch.tensor(1.0, dtype=torch.float32) quant_method = quant_config.get_quant_method( self, prefix=prefix) if quant_config else None if quant_method is not None: @@ -106,11 +110,11 @@ def __init__( self.num_kv_heads = num_kv_heads self.backend = backend_name_to_enum(attn_backend.get_name()) - # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how + # For cuda and cpu platforms, we control how # torch.compile works by registering the attention as one giant # opaque custom op. For other platforms, we directly call them # and let torch.compile handle them. - self.use_direct_call = not current_platform.is_cuda_alike( + self.use_direct_call = not current_platform.is_cuda( ) and not current_platform.is_cpu() # For some attention backends, we allocate an output tensor before @@ -135,7 +139,7 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, attn_type: str = AttentionType.DECODER, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: if self.calculate_kv_scales and \ attn_metadata.enable_kv_scales_calculation: @@ -150,7 +154,7 @@ def forward( self._k_scale, self._v_scale, attn_type=attn_type, - fp8_out_scale=fp8_out_scale) + fp8_comp_scales=fp8_comp_scales) elif self.use_output: output = torch.empty_like(query) hidden_size = query.size(-1) diff --git a/vllm/envs.py b/vllm/envs.py index 19e520691e436..03646192e12d7 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -16,7 +16,7 @@ VLLM_USE_TRITON_FLASH_ATTN: bool = True VLLM_USE_ROCM_SKINNY_GEMM: bool = True VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True - VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = True + VLLM_USE_ROCM_FP8_ATTN: bool = True RANK: int = 0 LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None @@ -242,13 +242,12 @@ def get_default_config_root(): # custom paged attention implemented for MI3* cards "VLLM_USE_ROCM_CUSTOM_PAGED_ATTN": lambda: (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in - ("true", "1") != "0"), + ("true", "1")), # have custom paged attention implemented for MI3* cards write out fp8 - "VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT": - lambda: - (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT", "True").lower() in - ("true", "1") != "0"), + "VLLM_USE_ROCM_FP8_ATTN": + lambda: (os.getenv("VLLM_USE_ROCM_FP8_ATTN", "True").lower() in + ("true", "1")), # rank of the process in the distributed setting, used to determine # the driver worker diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index a74eaef5efdee..c474dcd0c5246 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -146,6 +146,10 @@ def get_compressed_tensors_cache_scale(name: str) -> Optional[str]: return name.replace(".k_proj.output_scale", ".attn.k_scale") if name.endswith(".output_scale") and ".v_proj" in name: return name.replace(".v_proj.output_scale", ".attn.v_scale") + if name.endswith(".output_scale") and ".q_proj" in name: + return name.replace(".q_proj.output_scale", ".attn.q_scale") + if name.endswith("self_attn.prob_output_scale"): + return name.replace(".prob_output_scale", ".attn.prob_scale") # If no matches, return None return None diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index b386a9f309639..4502c8bea47c4 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -31,17 +31,20 @@ def create_weights(self, layer: torch.nn.Module): requires_grad=False) layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + # Initialize Q and P = softmax(QK^T) scales + layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) + layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError( f"{self.__class__.__name__}.apply should not be called.") def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 - # regardless whether the kv-scale is available in the checkpoint. # No need to process kv scales after loading if we are going to # calculate them on the fly. - if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: + if not layer.calculate_kv_scales: if layer.k_scale > 0.0 and layer.v_scale > 0.0: # We prefer to use separate k_scale and v_scale if present k_scale = layer.k_scale.to("cpu").tolist() @@ -75,11 +78,41 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer._k_scale.copy_(k_scale) layer._v_scale.copy_(v_scale) if (k_scale == 1.0 and v_scale == 1.0 + and (layer.kv_cache_dtype != "auto" or layer.use_fp8) and "e5m2" not in layer.kv_cache_dtype): print_warning_once( "Using KV cache scaling factor 1.0 for fp8_e4m3. This " "may cause accuracy issues. Please make sure k/v_scale " "scaling factors are available in the fp8 checkpoint.") + if layer.q_scale > 0.0 and layer.prob_scale > 0.0: + q_scale = layer.q_scale.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + q_scale *= 2 + else: + q_scale = 1.0 + if layer.prob_scale > 0.0: + prob_scale = layer.prob_scale.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + prob_scale *= 2 + else: + prob_scale = 1.0 + + if not isinstance(q_scale, float) or not isinstance(prob_scale, float): + raise ValueError("Only support per-tensor scaling factor" + "for fp8-quantized Q/prob") + + # These are used in the final Attention.forward() + layer._q_scale.copy_(q_scale) + layer._prob_scale.copy_(prob_scale) + if (q_scale == 1.0 or prob_scale == 1.0) and layer.use_fp8: + print_warning_once( + f"Using Q scale {q_scale} and prob scale {prob_scale} " + "with fp8 attention. This may cause accuracy issues. " + "Please make sure Q/prob scaling factors are " + "available in the fp8 checkpoint.") + del layer.k_scale del layer.v_scale + del layer.q_scale + del layer.prob_scale diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index dd4b0c75cb84d..684e7f5382277 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -395,7 +395,9 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] + if loaded_weight.shape: + # scalar shape is torch.Size([1]), not torch.Size([]) + loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 6926de007bc84..ca85f418d5762 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -543,7 +543,9 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] + if loaded_weight.shape: + # scalar shape is torch.Size([1]), not torch.Size([]) + loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 6353a7703d6cb..3c5dfc2e7f45d 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -485,7 +485,9 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] + if loaded_weight.shape: + # scalar shape is torch.Size([1]), not torch.Size([]) + loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 514ba946af6f7..05b237c955369 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -197,6 +197,14 @@ def __init__( else: sliding_window = None + # For CUDA devices and Navi4x, attn_fp8 will be set to false. + self.attn_fp8 = envs.VLLM_USE_ROCM_FP8_ATTN \ + and current_platform.is_rocm() \ + and not is_navi() \ + and isinstance(quant_config, Fp8Config) \ + and hasattr(self.o_proj, "input_scale") \ + and self.o_proj.input_scale is not None + self.attn = Attention( self.num_heads, self.head_dim, @@ -206,12 +214,8 @@ def __init__( quant_config=quant_config, per_layer_sliding_window=sliding_window, prefix=f"{prefix}.attn", + use_fp8=self.attn_fp8, ) - # For CUDA devices and Navi4x, attn_fp8_out will be set to false. - self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \ - and current_platform.is_rocm() \ - and not is_navi() \ - and isinstance(quant_config, Fp8Config) def forward( self, @@ -228,8 +232,10 @@ def forward( v, kv_cache, attn_metadata, - fp8_out_scale=self.o_proj.input_scale - if self.attn_fp8_out else None) + fp8_comp_scales=(self.attn._q_scale, + self.attn._prob_scale, + self.o_proj.input_scale) + if self.attn_fp8 else None) output, _ = self.o_proj(attn_output) return output @@ -428,7 +434,9 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] + if loaded_weight.shape: + # scalar shape is torch.Size([1]), not torch.Size([]) + loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 087697bc45e61..4f3cdbbcee9f4 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -502,7 +502,9 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] + if loaded_weight.shape: + # scalar shape is torch.Size([1]), not torch.Size([]) + loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue From 37f37d1e73c68abbb90d84deb5d6f284e986594e Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Thu, 19 Dec 2024 14:51:11 -0600 Subject: [PATCH 2/7] Disabling calc_kv_scales if the checkoint has them. Enabling fp8 attention for dynamic quantization --- vllm/attention/backends/rocm_flash_attn.py | 2 +- vllm/attention/layer.py | 6 +- .../layers/quantization/kv_cache.py | 80 +++++++++---------- vllm/model_executor/models/llama.py | 4 +- 4 files changed, 45 insertions(+), 47 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 335bae56f6ab4..f1bb90550a045 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -686,7 +686,7 @@ def forward( full_scales = ( 1.0 / q_scale.item(), 1.0 / k_scale.item(), 1.0 / v_scale.item(), 1.0 / prob_scale.item(), - fp8_out_scale.item()) if fp8_comp_scales else None + fp8_out_scale.item()) if fp8_out_scale else None out, _ = self.attn_func( query, key, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index d0d981e04db15..179f7aefcbcc7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -143,8 +143,7 @@ def forward( ) -> torch.Tensor: if self.calculate_kv_scales and \ attn_metadata.enable_kv_scales_calculation: - self.calc_kv_scales(key, value) - + self.calc_kv_scales(query, key, value) if self.use_direct_call: return self.impl.forward(query, key, @@ -176,7 +175,8 @@ def forward( kv_cache, attn_type, self.layer_name) - def calc_kv_scales(self, key, value): + def calc_kv_scales(self, query, key, value): + self._q_scale.copy_(torch.abs(query).max() / self.k_range) self._k_scale.copy_(torch.abs(key).max() / self.k_range) self._v_scale.copy_(torch.abs(value).max() / self.v_range) # We only calculate the scales once diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 4502c8bea47c4..c96192b250d2c 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -42,53 +42,53 @@ def apply(self, layer: torch.nn.Module) -> torch.Tensor: f"{self.__class__.__name__}.apply should not be called.") def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # No need to process kv scales after loading if we are going to - # calculate them on the fly. - if not layer.calculate_kv_scales: - if layer.k_scale > 0.0 and layer.v_scale > 0.0: - # We prefer to use separate k_scale and v_scale if present - k_scale = layer.k_scale.to("cpu").tolist() - v_scale = layer.v_scale.to("cpu").tolist() - if current_platform.is_rocm() and not is_navi(): - k_scale *= 2 - v_scale *= 2 - elif layer.k_scale < 0.0 and layer.v_scale < 0.0: - # If no scales were loaded (both scales are invalid negative - # values), use the default value of 1.0 - k_scale = 1.0 - v_scale = 1.0 - else: - # If we find a single kv_scale in the checkpoint, we remap - # kv_scale to k_scale during weight loading, and duplicate - # k_scale to v_scale here - assert layer.k_scale > 0.0 - scale_to_duplicate = max(layer.k_scale, layer.v_scale) - k_scale = scale_to_duplicate.to("cpu").tolist() - v_scale = scale_to_duplicate.to("cpu").tolist() - if current_platform.is_rocm() and not is_navi(): - k_scale *= 2 - v_scale *= 2 + if layer.k_scale > 0.0 and layer.v_scale > 0.0: + # We prefer to use separate k_scale and v_scale if present + k_scale = layer.k_scale.to("cpu").tolist() + v_scale = layer.v_scale.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + k_scale *= 2 + v_scale *= 2 + layer.calculate_kv_scales = False + elif layer.k_scale < 0.0 and layer.v_scale < 0.0: + # If no scales were loaded (both scales are invalid negative + # values), use the default value of 1.0 + k_scale = 1.0 + v_scale = 1.0 + else: + # If we find a single kv_scale in the checkpoint, we remap + # kv_scale to k_scale during weight loading, and duplicate + # k_scale to v_scale here + assert layer.k_scale > 0.0 + scale_to_duplicate = max(layer.k_scale, layer.v_scale) + k_scale = scale_to_duplicate.to("cpu").tolist() + v_scale = scale_to_duplicate.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + k_scale *= 2 + v_scale *= 2 + layer.calculate_kv_scales = False - if not isinstance(k_scale, float) or not isinstance( - v_scale, float): - raise ValueError("Only support per-tensor scaling factor " - "for fp8 KV cache") + if not isinstance(k_scale, float) or not isinstance( + v_scale, float): + raise ValueError("Only support per-tensor scaling factor " + "for fp8 KV cache") - # These are used in the final Attention.forward() - layer._k_scale.copy_(k_scale) - layer._v_scale.copy_(v_scale) - if (k_scale == 1.0 and v_scale == 1.0 - and (layer.kv_cache_dtype != "auto" or layer.use_fp8) - and "e5m2" not in layer.kv_cache_dtype): - print_warning_once( - "Using KV cache scaling factor 1.0 for fp8_e4m3. This " - "may cause accuracy issues. Please make sure k/v_scale " - "scaling factors are available in the fp8 checkpoint.") + # These are used in the final Attention.forward() + layer._k_scale.copy_(k_scale) + layer._v_scale.copy_(v_scale) + if (k_scale == 1.0 and v_scale == 1.0 + and (layer.kv_cache_dtype != "auto" or layer.use_fp8) + and "e5m2" not in layer.kv_cache_dtype): + print_warning_once( + "Using KV cache scaling factor 1.0 for fp8_e4m3. This " + "may cause accuracy issues. Please make sure k/v_scale " + "scaling factors are available in the fp8 checkpoint.") if layer.q_scale > 0.0 and layer.prob_scale > 0.0: q_scale = layer.q_scale.to("cpu").tolist() if current_platform.is_rocm() and not is_navi(): q_scale *= 2 + layer.calculate_kv_scales = False else: q_scale = 1.0 if layer.prob_scale > 0.0: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 05b237c955369..9e7bf84ae912e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -201,9 +201,7 @@ def __init__( self.attn_fp8 = envs.VLLM_USE_ROCM_FP8_ATTN \ and current_platform.is_rocm() \ and not is_navi() \ - and isinstance(quant_config, Fp8Config) \ - and hasattr(self.o_proj, "input_scale") \ - and self.o_proj.input_scale is not None + and isinstance(quant_config, Fp8Config) self.attn = Attention( self.num_heads, From a283f40db5f8ac068353d42f51cfbef6c5b66317 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Thu, 19 Dec 2024 22:07:46 +0000 Subject: [PATCH 3/7] q_range as an env --- vllm/attention/layer.py | 3 ++- vllm/envs.py | 17 +++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 179f7aefcbcc7..65e0b714aad55 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -128,6 +128,7 @@ def __init__( compilation_config.static_forward_context[prefix] = self self.layer_name = prefix + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) @@ -176,7 +177,7 @@ def forward( self.layer_name) def calc_kv_scales(self, query, key, value): - self._q_scale.copy_(torch.abs(query).max() / self.k_range) + self._q_scale.copy_(torch.abs(query).max() / self.q_range) self._k_scale.copy_(torch.abs(key).max() / self.k_range) self._v_scale.copy_(torch.abs(value).max() / self.v_range) # We only calculate the scales once diff --git a/vllm/envs.py b/vllm/envs.py index 03646192e12d7..a0d4817c36ae3 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -83,8 +83,9 @@ VLLM_FP8_PADDING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 - K_SCALE_CONSTANT: int = 200 - V_SCALE_CONSTANT: int = 100 + Q_SCALE_CONSTANT: int = 20 + K_SCALE_CONSTANT: int = 20 + V_SCALE_CONSTANT: int = 10 def get_default_cache_root(): @@ -529,13 +530,17 @@ def get_default_config_root(): "VLLM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_FP8_PADDING", "1"))), - # Divisor for dynamic key scale factor calculation for FP8 KV Cache + # Divisor for dynamic query scale factor calculation for FP8 attention + "Q_SCALE_CONSTANT": + lambda: int(os.getenv("Q_SCALE_CONSTANT", "20")), + + # Divisor for dynamic key scale factor calculation for FP8 KV Cache and attention "K_SCALE_CONSTANT": - lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), + lambda: int(os.getenv("K_SCALE_CONSTANT", "20")), - # Divisor for dynamic value scale factor calculation for FP8 KV Cache + # Divisor for dynamic value scale factor calculation for FP8 KV Cache and attention "V_SCALE_CONSTANT": - lambda: int(os.getenv("V_SCALE_CONSTANT", "100")), + lambda: int(os.getenv("V_SCALE_CONSTANT", "10")), # If set, enable multiprocessing in LLM for the V1 code path. "VLLM_ENABLE_V1_MULTIPROCESSING": From 7908e9b60f59564a02b225932da1472808dd5c51 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Thu, 19 Dec 2024 22:09:00 +0000 Subject: [PATCH 4/7] format --- vllm/envs.py | 6 ++++-- vllm/model_executor/layers/quantization/kv_cache.py | 5 ++--- vllm/model_executor/models/llama.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index a0d4817c36ae3..70b189be974ab 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -534,11 +534,13 @@ def get_default_config_root(): "Q_SCALE_CONSTANT": lambda: int(os.getenv("Q_SCALE_CONSTANT", "20")), - # Divisor for dynamic key scale factor calculation for FP8 KV Cache and attention + # Divisor for dynamic key scale factor calculation + # for FP8 KV Cache and attention "K_SCALE_CONSTANT": lambda: int(os.getenv("K_SCALE_CONSTANT", "20")), - # Divisor for dynamic value scale factor calculation for FP8 KV Cache and attention + # Divisor for dynamic value scale factor calculation + # for FP8 KV Cache and attention "V_SCALE_CONSTANT": lambda: int(os.getenv("V_SCALE_CONSTANT", "10")), diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index c96192b250d2c..bcb28193a4056 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -68,10 +68,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: v_scale *= 2 layer.calculate_kv_scales = False - if not isinstance(k_scale, float) or not isinstance( - v_scale, float): + if not isinstance(k_scale, float) or not isinstance(v_scale, float): raise ValueError("Only support per-tensor scaling factor " - "for fp8 KV cache") + "for fp8 KV cache") # These are used in the final Attention.forward() layer._k_scale.copy_(k_scale) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 9e7bf84ae912e..69d63216d379d 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -201,7 +201,7 @@ def __init__( self.attn_fp8 = envs.VLLM_USE_ROCM_FP8_ATTN \ and current_platform.is_rocm() \ and not is_navi() \ - and isinstance(quant_config, Fp8Config) + and isinstance(quant_config, Fp8Config) self.attn = Attention( self.num_heads, From 1ed13896679e62e324bc1ae6d6f2bfb3ebdcee88 Mon Sep 17 00:00:00 2001 From: Matthew Wong Date: Thu, 19 Dec 2024 23:12:03 +0000 Subject: [PATCH 5/7] Dedupe FA/PA attn toggles, set FA off by default --- vllm/attention/backends/rocm_flash_attn.py | 4 +++- vllm/attention/layer.py | 2 -- vllm/attention/ops/triton_flash_attention.py | 2 +- vllm/envs.py | 13 ++++++++++--- vllm/model_executor/layers/quantization/kv_cache.py | 10 ++++++---- vllm/model_executor/models/llama.py | 7 +++---- 6 files changed, 23 insertions(+), 15 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index f1bb90550a045..5a146940765bb 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -686,7 +686,9 @@ def forward( full_scales = ( 1.0 / q_scale.item(), 1.0 / k_scale.item(), 1.0 / v_scale.item(), 1.0 / prob_scale.item(), - fp8_out_scale.item()) if fp8_out_scale else None + fp8_out_scale.item()) if ( + fp8_out_scale + and envs.VLLM_USE_ROCM_FP8_FLASH_ATTN) else None out, _ = self.attn_func( query, key, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 65e0b714aad55..392736137aa12 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -42,7 +42,6 @@ def __init__( logits_soft_cap: Optional[float] = None, per_layer_sliding_window: Optional[int] = None, prefix: str = "", - use_fp8: bool = False, ) -> None: super().__init__() if per_layer_sliding_window is not None: @@ -74,7 +73,6 @@ def __init__( # with the model weights. self.kv_cache_dtype = kv_cache_dtype self.calculate_kv_scales = calculate_kv_scales - self.use_fp8 = use_fp8 self._k_scale = torch.tensor(1.0, dtype=torch.float32) self._v_scale = torch.tensor(1.0, dtype=torch.float32) self._q_scale = torch.tensor(1.0, dtype=torch.float32) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 2e647a8e52278..828bdc2905957 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -742,7 +742,7 @@ def attn_fwd( mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) out_ptrs_mask = (mask_m_offsets[:, None] >= out_mask_boundary[None, :]) - z = 0.0 + z = tl.zeros((1, ), tl.float32) acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m diff --git a/vllm/envs.py b/vllm/envs.py index 70b189be974ab..eddfa174ed55a 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -16,7 +16,8 @@ VLLM_USE_TRITON_FLASH_ATTN: bool = True VLLM_USE_ROCM_SKINNY_GEMM: bool = True VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True - VLLM_USE_ROCM_FP8_ATTN: bool = True + VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = True + VLLM_USE_ROCM_FP8_FLASH_ATTN: bool = False RANK: int = 0 LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None @@ -246,8 +247,14 @@ def get_default_config_root(): ("true", "1")), # have custom paged attention implemented for MI3* cards write out fp8 - "VLLM_USE_ROCM_FP8_ATTN": - lambda: (os.getenv("VLLM_USE_ROCM_FP8_ATTN", "True").lower() in + "VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT": + lambda: + (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT", "True").lower() in + ("true", "1")), + + # use quantized q,k,v,softmax(qk^T), attn output during prefill + "VLLM_USE_ROCM_FP8_FLASH_ATTN": + lambda: (os.getenv("VLLM_USE_ROCM_FP8_FLASH_ATTN", "False").lower() in ("true", "1")), # rank of the process in the distributed setting, used to determine diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index bcb28193a4056..c8f9e86a9cc7c 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -1,5 +1,6 @@ import torch +import vllm.envs as envs from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.platforms import current_platform @@ -76,18 +77,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer._k_scale.copy_(k_scale) layer._v_scale.copy_(v_scale) if (k_scale == 1.0 and v_scale == 1.0 - and (layer.kv_cache_dtype != "auto" or layer.use_fp8) + and (layer.kv_cache_dtype != "auto" + or envs.VLLM_USE_ROCM_FP8_FLASH_ATTN) and "e5m2" not in layer.kv_cache_dtype): print_warning_once( "Using KV cache scaling factor 1.0 for fp8_e4m3. This " "may cause accuracy issues. Please make sure k/v_scale " "scaling factors are available in the fp8 checkpoint.") - if layer.q_scale > 0.0 and layer.prob_scale > 0.0: + if layer.q_scale > 0.0: q_scale = layer.q_scale.to("cpu").tolist() if current_platform.is_rocm() and not is_navi(): q_scale *= 2 - layer.calculate_kv_scales = False else: q_scale = 1.0 if layer.prob_scale > 0.0: @@ -104,7 +105,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # These are used in the final Attention.forward() layer._q_scale.copy_(q_scale) layer._prob_scale.copy_(prob_scale) - if (q_scale == 1.0 or prob_scale == 1.0) and layer.use_fp8: + if (q_scale == 1.0 + or prob_scale == 1.0) and envs.VLLM_USE_ROCM_FP8_FLASH_ATTN: print_warning_once( f"Using Q scale {q_scale} and prob scale {prob_scale} " "with fp8 attention. This may cause accuracy issues. " diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 69d63216d379d..b268375ce8a4b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -198,7 +198,7 @@ def __init__( sliding_window = None # For CUDA devices and Navi4x, attn_fp8 will be set to false. - self.attn_fp8 = envs.VLLM_USE_ROCM_FP8_ATTN \ + self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \ and current_platform.is_rocm() \ and not is_navi() \ and isinstance(quant_config, Fp8Config) @@ -212,7 +212,6 @@ def __init__( quant_config=quant_config, per_layer_sliding_window=sliding_window, prefix=f"{prefix}.attn", - use_fp8=self.attn_fp8, ) def forward( @@ -232,8 +231,8 @@ def forward( attn_metadata, fp8_comp_scales=(self.attn._q_scale, self.attn._prob_scale, - self.o_proj.input_scale) - if self.attn_fp8 else None) + self.o_proj.input_scale + if self.attn_fp8_out else None)) output, _ = self.o_proj(attn_output) return output From 06f53ba9be4abb25cbf614fde16b45ff8e14ca64 Mon Sep 17 00:00:00 2001 From: Matthew Wong Date: Thu, 19 Dec 2024 23:21:33 +0000 Subject: [PATCH 6/7] Lint again, to fixed point --- vllm/model_executor/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index b268375ce8a4b..ee8dc07a756b7 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -231,8 +231,8 @@ def forward( attn_metadata, fp8_comp_scales=(self.attn._q_scale, self.attn._prob_scale, - self.o_proj.input_scale - if self.attn_fp8_out else None)) + self.o_proj.input_scale if + self.attn_fp8_out else None)) output, _ = self.o_proj(attn_output) return output From 0bd414aa418d8b8a89222e84bde3cb628b2612de Mon Sep 17 00:00:00 2001 From: Matthew Wong Date: Thu, 19 Dec 2024 23:30:26 +0000 Subject: [PATCH 7/7] Don't calculate KV scales dynamically if Q scale is included --- vllm/model_executor/layers/quantization/kv_cache.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index c8f9e86a9cc7c..388d3a91228c8 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -89,6 +89,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: q_scale = layer.q_scale.to("cpu").tolist() if current_platform.is_rocm() and not is_navi(): q_scale *= 2 + layer.calculate_kv_scales = False else: q_scale = 1.0 if layer.prob_scale > 0.0: