From 1dcd9fe8095941acf596b39bd0a075302fb3e960 Mon Sep 17 00:00:00 2001 From: Matt Wong <156021403+mawong-amd@users.noreply.github.com> Date: Thu, 19 Dec 2024 18:21:18 -0600 Subject: [PATCH] Ingest FP8 attn scales and use them in ROCm FlashAttention (#338) * Ingest FP8 attn scales and use them in Triton FA, if present * Disabling calc_kv_scales if the checkoint has them. Enabling fp8 attention for dynamic quantization * q_range as an env * format * Dedupe FA/PA attn toggles, set FA off by default * Lint again, to fixed point * Don't calculate KV scales dynamically if Q scale is included --------- Co-authored-by: Gregory Shtrasberg --- vllm/attention/backends/abstract.py | 2 +- vllm/attention/backends/blocksparse_attn.py | 2 +- vllm/attention/backends/flash_attn.py | 2 +- vllm/attention/backends/flashinfer.py | 2 +- vllm/attention/backends/hpu_attn.py | 2 +- vllm/attention/backends/ipex_attn.py | 2 +- vllm/attention/backends/pallas.py | 2 +- vllm/attention/backends/rocm_flash_attn.py | 12 +- vllm/attention/backends/torch_sdpa.py | 2 +- vllm/attention/backends/xformers.py | 2 +- vllm/attention/layer.py | 19 +-- vllm/attention/ops/triton_flash_attention.py | 2 +- vllm/envs.py | 29 +++-- .../quantization/compressed_tensors/utils.py | 4 + .../layers/quantization/kv_cache.py | 117 ++++++++++++------ vllm/model_executor/models/aria.py | 4 +- vllm/model_executor/models/exaone.py | 4 +- vllm/model_executor/models/granite.py | 4 +- vllm/model_executor/models/llama.py | 21 ++-- vllm/model_executor/models/solar.py | 4 +- 20 files changed, 157 insertions(+), 81 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 3127285df6c63..de34c0da55b4f 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -252,6 +252,6 @@ def forward( v_scale: torch.Tensor, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 1869dbab0cbf5..e6f63a1ff0fc9 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -363,7 +363,7 @@ def forward( v_scale: torch.Tensor, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c640de998149e..94a7478f8eb51 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -642,7 +642,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 2285f20c1c8af..5301474833163 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -777,7 +777,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: # TODO: directly write to output tensor diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index b3065495ab396..efb3ac980aaba 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -154,7 +154,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index d02fbcb6ca0ae..0225554b0ceb1 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -174,7 +174,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index a2612c97ca23b..7c702291513d2 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -152,7 +152,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index efaa74f67bafd..5a146940765bb 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -551,7 +551,7 @@ def forward( v_scale: torch.Tensor, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: torch.Tensor = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -601,6 +601,8 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + q_scale, prob_scale, fp8_out_scale = fp8_comp_scales or (None, None, + None) query = query.view(-1, self.num_heads, self.head_size) if key is not None: @@ -681,6 +683,12 @@ def forward( query.dtype, seq_lens, make_attn_mask=False) # type: ignore + full_scales = ( + 1.0 / q_scale.item(), 1.0 / k_scale.item(), + 1.0 / v_scale.item(), 1.0 / prob_scale.item(), + fp8_out_scale.item()) if ( + fp8_out_scale + and envs.VLLM_USE_ROCM_FP8_FLASH_ATTN) else None out, _ = self.attn_func( query, key, @@ -694,7 +702,7 @@ def forward( self.scale, attn_masks[0][None] if attn_masks is not None else None, - None, + full_scales, ) elif self.use_naive_attn: if self.num_kv_heads != self.num_heads: diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 8e586c00024d5..660408f3ea477 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -434,7 +434,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index dcd60da43d520..d0f1034dc966f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -420,7 +420,7 @@ def forward( v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 800a1ff2f4f65..392736137aa12 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -1,5 +1,5 @@ """Attention layer.""" -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -75,6 +75,8 @@ def __init__( self.calculate_kv_scales = calculate_kv_scales self._k_scale = torch.tensor(1.0, dtype=torch.float32) self._v_scale = torch.tensor(1.0, dtype=torch.float32) + self._q_scale = torch.tensor(1.0, dtype=torch.float32) + self._prob_scale = torch.tensor(1.0, dtype=torch.float32) quant_method = quant_config.get_quant_method( self, prefix=prefix) if quant_config else None if quant_method is not None: @@ -106,11 +108,11 @@ def __init__( self.num_kv_heads = num_kv_heads self.backend = backend_name_to_enum(attn_backend.get_name()) - # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how + # For cuda and cpu platforms, we control how # torch.compile works by registering the attention as one giant # opaque custom op. For other platforms, we directly call them # and let torch.compile handle them. - self.use_direct_call = not current_platform.is_cuda_alike( + self.use_direct_call = not current_platform.is_cuda( ) and not current_platform.is_cpu() # For some attention backends, we allocate an output tensor before @@ -124,6 +126,7 @@ def __init__( compilation_config.static_forward_context[prefix] = self self.layer_name = prefix + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) @@ -135,12 +138,11 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, attn_type: str = AttentionType.DECODER, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: if self.calculate_kv_scales and \ attn_metadata.enable_kv_scales_calculation: - self.calc_kv_scales(key, value) - + self.calc_kv_scales(query, key, value) if self.use_direct_call: return self.impl.forward(query, key, @@ -150,7 +152,7 @@ def forward( self._k_scale, self._v_scale, attn_type=attn_type, - fp8_out_scale=fp8_out_scale) + fp8_comp_scales=fp8_comp_scales) elif self.use_output: output = torch.empty_like(query) hidden_size = query.size(-1) @@ -172,7 +174,8 @@ def forward( kv_cache, attn_type, self.layer_name) - def calc_kv_scales(self, key, value): + def calc_kv_scales(self, query, key, value): + self._q_scale.copy_(torch.abs(query).max() / self.q_range) self._k_scale.copy_(torch.abs(key).max() / self.k_range) self._v_scale.copy_(torch.abs(value).max() / self.v_range) # We only calculate the scales once diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 2e647a8e52278..828bdc2905957 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -742,7 +742,7 @@ def attn_fwd( mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) out_ptrs_mask = (mask_m_offsets[:, None] >= out_mask_boundary[None, :]) - z = 0.0 + z = tl.zeros((1, ), tl.float32) acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m diff --git a/vllm/envs.py b/vllm/envs.py index 19e520691e436..eddfa174ed55a 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -17,6 +17,7 @@ VLLM_USE_ROCM_SKINNY_GEMM: bool = True VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = True + VLLM_USE_ROCM_FP8_FLASH_ATTN: bool = False RANK: int = 0 LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None @@ -83,8 +84,9 @@ VLLM_FP8_PADDING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 - K_SCALE_CONSTANT: int = 200 - V_SCALE_CONSTANT: int = 100 + Q_SCALE_CONSTANT: int = 20 + K_SCALE_CONSTANT: int = 20 + V_SCALE_CONSTANT: int = 10 def get_default_cache_root(): @@ -242,13 +244,18 @@ def get_default_config_root(): # custom paged attention implemented for MI3* cards "VLLM_USE_ROCM_CUSTOM_PAGED_ATTN": lambda: (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in - ("true", "1") != "0"), + ("true", "1")), # have custom paged attention implemented for MI3* cards write out fp8 "VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT": lambda: (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT", "True").lower() in - ("true", "1") != "0"), + ("true", "1")), + + # use quantized q,k,v,softmax(qk^T), attn output during prefill + "VLLM_USE_ROCM_FP8_FLASH_ATTN": + lambda: (os.getenv("VLLM_USE_ROCM_FP8_FLASH_ATTN", "False").lower() in + ("true", "1")), # rank of the process in the distributed setting, used to determine # the driver worker @@ -530,13 +537,19 @@ def get_default_config_root(): "VLLM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_FP8_PADDING", "1"))), - # Divisor for dynamic key scale factor calculation for FP8 KV Cache + # Divisor for dynamic query scale factor calculation for FP8 attention + "Q_SCALE_CONSTANT": + lambda: int(os.getenv("Q_SCALE_CONSTANT", "20")), + + # Divisor for dynamic key scale factor calculation + # for FP8 KV Cache and attention "K_SCALE_CONSTANT": - lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), + lambda: int(os.getenv("K_SCALE_CONSTANT", "20")), - # Divisor for dynamic value scale factor calculation for FP8 KV Cache + # Divisor for dynamic value scale factor calculation + # for FP8 KV Cache and attention "V_SCALE_CONSTANT": - lambda: int(os.getenv("V_SCALE_CONSTANT", "100")), + lambda: int(os.getenv("V_SCALE_CONSTANT", "10")), # If set, enable multiprocessing in LLM for the V1 code path. "VLLM_ENABLE_V1_MULTIPROCESSING": diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index a74eaef5efdee..c474dcd0c5246 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -146,6 +146,10 @@ def get_compressed_tensors_cache_scale(name: str) -> Optional[str]: return name.replace(".k_proj.output_scale", ".attn.k_scale") if name.endswith(".output_scale") and ".v_proj" in name: return name.replace(".v_proj.output_scale", ".attn.v_scale") + if name.endswith(".output_scale") and ".q_proj" in name: + return name.replace(".q_proj.output_scale", ".attn.q_scale") + if name.endswith("self_attn.prob_output_scale"): + return name.replace(".prob_output_scale", ".attn.prob_scale") # If no matches, return None return None diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index b386a9f309639..388d3a91228c8 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -1,5 +1,6 @@ import torch +import vllm.envs as envs from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.platforms import current_platform @@ -31,55 +32,89 @@ def create_weights(self, layer: torch.nn.Module): requires_grad=False) layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + # Initialize Q and P = softmax(QK^T) scales + layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) + layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError( f"{self.__class__.__name__}.apply should not be called.") def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 - # regardless whether the kv-scale is available in the checkpoint. - # No need to process kv scales after loading if we are going to - # calculate them on the fly. - if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: - if layer.k_scale > 0.0 and layer.v_scale > 0.0: - # We prefer to use separate k_scale and v_scale if present - k_scale = layer.k_scale.to("cpu").tolist() - v_scale = layer.v_scale.to("cpu").tolist() - if current_platform.is_rocm() and not is_navi(): - k_scale *= 2 - v_scale *= 2 - elif layer.k_scale < 0.0 and layer.v_scale < 0.0: - # If no scales were loaded (both scales are invalid negative - # values), use the default value of 1.0 - k_scale = 1.0 - v_scale = 1.0 - else: - # If we find a single kv_scale in the checkpoint, we remap - # kv_scale to k_scale during weight loading, and duplicate - # k_scale to v_scale here - assert layer.k_scale > 0.0 - scale_to_duplicate = max(layer.k_scale, layer.v_scale) - k_scale = scale_to_duplicate.to("cpu").tolist() - v_scale = scale_to_duplicate.to("cpu").tolist() - if current_platform.is_rocm() and not is_navi(): - k_scale *= 2 - v_scale *= 2 + if layer.k_scale > 0.0 and layer.v_scale > 0.0: + # We prefer to use separate k_scale and v_scale if present + k_scale = layer.k_scale.to("cpu").tolist() + v_scale = layer.v_scale.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + k_scale *= 2 + v_scale *= 2 + layer.calculate_kv_scales = False + elif layer.k_scale < 0.0 and layer.v_scale < 0.0: + # If no scales were loaded (both scales are invalid negative + # values), use the default value of 1.0 + k_scale = 1.0 + v_scale = 1.0 + else: + # If we find a single kv_scale in the checkpoint, we remap + # kv_scale to k_scale during weight loading, and duplicate + # k_scale to v_scale here + assert layer.k_scale > 0.0 + scale_to_duplicate = max(layer.k_scale, layer.v_scale) + k_scale = scale_to_duplicate.to("cpu").tolist() + v_scale = scale_to_duplicate.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + k_scale *= 2 + v_scale *= 2 + layer.calculate_kv_scales = False + + if not isinstance(k_scale, float) or not isinstance(v_scale, float): + raise ValueError("Only support per-tensor scaling factor " + "for fp8 KV cache") + + # These are used in the final Attention.forward() + layer._k_scale.copy_(k_scale) + layer._v_scale.copy_(v_scale) + if (k_scale == 1.0 and v_scale == 1.0 + and (layer.kv_cache_dtype != "auto" + or envs.VLLM_USE_ROCM_FP8_FLASH_ATTN) + and "e5m2" not in layer.kv_cache_dtype): + print_warning_once( + "Using KV cache scaling factor 1.0 for fp8_e4m3. This " + "may cause accuracy issues. Please make sure k/v_scale " + "scaling factors are available in the fp8 checkpoint.") + + if layer.q_scale > 0.0: + q_scale = layer.q_scale.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + q_scale *= 2 + layer.calculate_kv_scales = False + else: + q_scale = 1.0 + if layer.prob_scale > 0.0: + prob_scale = layer.prob_scale.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + prob_scale *= 2 + else: + prob_scale = 1.0 - if not isinstance(k_scale, float) or not isinstance( - v_scale, float): - raise ValueError("Only support per-tensor scaling factor " - "for fp8 KV cache") + if not isinstance(q_scale, float) or not isinstance(prob_scale, float): + raise ValueError("Only support per-tensor scaling factor" + "for fp8-quantized Q/prob") - # These are used in the final Attention.forward() - layer._k_scale.copy_(k_scale) - layer._v_scale.copy_(v_scale) - if (k_scale == 1.0 and v_scale == 1.0 - and "e5m2" not in layer.kv_cache_dtype): - print_warning_once( - "Using KV cache scaling factor 1.0 for fp8_e4m3. This " - "may cause accuracy issues. Please make sure k/v_scale " - "scaling factors are available in the fp8 checkpoint.") + # These are used in the final Attention.forward() + layer._q_scale.copy_(q_scale) + layer._prob_scale.copy_(prob_scale) + if (q_scale == 1.0 + or prob_scale == 1.0) and envs.VLLM_USE_ROCM_FP8_FLASH_ATTN: + print_warning_once( + f"Using Q scale {q_scale} and prob scale {prob_scale} " + "with fp8 attention. This may cause accuracy issues. " + "Please make sure Q/prob scaling factors are " + "available in the fp8 checkpoint.") del layer.k_scale del layer.v_scale + del layer.q_scale + del layer.prob_scale diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index dd4b0c75cb84d..684e7f5382277 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -395,7 +395,9 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] + if loaded_weight.shape: + # scalar shape is torch.Size([1]), not torch.Size([]) + loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 6926de007bc84..ca85f418d5762 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -543,7 +543,9 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] + if loaded_weight.shape: + # scalar shape is torch.Size([1]), not torch.Size([]) + loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 6353a7703d6cb..3c5dfc2e7f45d 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -485,7 +485,9 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] + if loaded_weight.shape: + # scalar shape is torch.Size([1]), not torch.Size([]) + loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 514ba946af6f7..ee8dc07a756b7 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -197,6 +197,12 @@ def __init__( else: sliding_window = None + # For CUDA devices and Navi4x, attn_fp8 will be set to false. + self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \ + and current_platform.is_rocm() \ + and not is_navi() \ + and isinstance(quant_config, Fp8Config) + self.attn = Attention( self.num_heads, self.head_dim, @@ -207,11 +213,6 @@ def __init__( per_layer_sliding_window=sliding_window, prefix=f"{prefix}.attn", ) - # For CUDA devices and Navi4x, attn_fp8_out will be set to false. - self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \ - and current_platform.is_rocm() \ - and not is_navi() \ - and isinstance(quant_config, Fp8Config) def forward( self, @@ -228,8 +229,10 @@ def forward( v, kv_cache, attn_metadata, - fp8_out_scale=self.o_proj.input_scale - if self.attn_fp8_out else None) + fp8_comp_scales=(self.attn._q_scale, + self.attn._prob_scale, + self.o_proj.input_scale if + self.attn_fp8_out else None)) output, _ = self.o_proj(attn_output) return output @@ -428,7 +431,9 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] + if loaded_weight.shape: + # scalar shape is torch.Size([1]), not torch.Size([]) + loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 087697bc45e61..4f3cdbbcee9f4 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -502,7 +502,9 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = loaded_weight[0] + if loaded_weight.shape: + # scalar shape is torch.Size([1]), not torch.Size([]) + loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue