diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 335bae56f6ab4..f1bb90550a045 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -686,7 +686,7 @@ def forward( full_scales = ( 1.0 / q_scale.item(), 1.0 / k_scale.item(), 1.0 / v_scale.item(), 1.0 / prob_scale.item(), - fp8_out_scale.item()) if fp8_comp_scales else None + fp8_out_scale.item()) if fp8_out_scale else None out, _ = self.attn_func( query, key, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index d0d981e04db15..65e0b714aad55 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -128,6 +128,7 @@ def __init__( compilation_config.static_forward_context[prefix] = self self.layer_name = prefix + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) @@ -143,8 +144,7 @@ def forward( ) -> torch.Tensor: if self.calculate_kv_scales and \ attn_metadata.enable_kv_scales_calculation: - self.calc_kv_scales(key, value) - + self.calc_kv_scales(query, key, value) if self.use_direct_call: return self.impl.forward(query, key, @@ -176,7 +176,8 @@ def forward( kv_cache, attn_type, self.layer_name) - def calc_kv_scales(self, key, value): + def calc_kv_scales(self, query, key, value): + self._q_scale.copy_(torch.abs(query).max() / self.q_range) self._k_scale.copy_(torch.abs(key).max() / self.k_range) self._v_scale.copy_(torch.abs(value).max() / self.v_range) # We only calculate the scales once diff --git a/vllm/envs.py b/vllm/envs.py index 03646192e12d7..70b189be974ab 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -83,8 +83,9 @@ VLLM_FP8_PADDING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 - K_SCALE_CONSTANT: int = 200 - V_SCALE_CONSTANT: int = 100 + Q_SCALE_CONSTANT: int = 20 + K_SCALE_CONSTANT: int = 20 + V_SCALE_CONSTANT: int = 10 def get_default_cache_root(): @@ -529,13 +530,19 @@ def get_default_config_root(): "VLLM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_FP8_PADDING", "1"))), - # Divisor for dynamic key scale factor calculation for FP8 KV Cache + # Divisor for dynamic query scale factor calculation for FP8 attention + "Q_SCALE_CONSTANT": + lambda: int(os.getenv("Q_SCALE_CONSTANT", "20")), + + # Divisor for dynamic key scale factor calculation + # for FP8 KV Cache and attention "K_SCALE_CONSTANT": - lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), + lambda: int(os.getenv("K_SCALE_CONSTANT", "20")), - # Divisor for dynamic value scale factor calculation for FP8 KV Cache + # Divisor for dynamic value scale factor calculation + # for FP8 KV Cache and attention "V_SCALE_CONSTANT": - lambda: int(os.getenv("V_SCALE_CONSTANT", "100")), + lambda: int(os.getenv("V_SCALE_CONSTANT", "10")), # If set, enable multiprocessing in LLM for the V1 code path. "VLLM_ENABLE_V1_MULTIPROCESSING": diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 4502c8bea47c4..bcb28193a4056 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -42,53 +42,52 @@ def apply(self, layer: torch.nn.Module) -> torch.Tensor: f"{self.__class__.__name__}.apply should not be called.") def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # No need to process kv scales after loading if we are going to - # calculate them on the fly. - if not layer.calculate_kv_scales: - if layer.k_scale > 0.0 and layer.v_scale > 0.0: - # We prefer to use separate k_scale and v_scale if present - k_scale = layer.k_scale.to("cpu").tolist() - v_scale = layer.v_scale.to("cpu").tolist() - if current_platform.is_rocm() and not is_navi(): - k_scale *= 2 - v_scale *= 2 - elif layer.k_scale < 0.0 and layer.v_scale < 0.0: - # If no scales were loaded (both scales are invalid negative - # values), use the default value of 1.0 - k_scale = 1.0 - v_scale = 1.0 - else: - # If we find a single kv_scale in the checkpoint, we remap - # kv_scale to k_scale during weight loading, and duplicate - # k_scale to v_scale here - assert layer.k_scale > 0.0 - scale_to_duplicate = max(layer.k_scale, layer.v_scale) - k_scale = scale_to_duplicate.to("cpu").tolist() - v_scale = scale_to_duplicate.to("cpu").tolist() - if current_platform.is_rocm() and not is_navi(): - k_scale *= 2 - v_scale *= 2 + if layer.k_scale > 0.0 and layer.v_scale > 0.0: + # We prefer to use separate k_scale and v_scale if present + k_scale = layer.k_scale.to("cpu").tolist() + v_scale = layer.v_scale.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + k_scale *= 2 + v_scale *= 2 + layer.calculate_kv_scales = False + elif layer.k_scale < 0.0 and layer.v_scale < 0.0: + # If no scales were loaded (both scales are invalid negative + # values), use the default value of 1.0 + k_scale = 1.0 + v_scale = 1.0 + else: + # If we find a single kv_scale in the checkpoint, we remap + # kv_scale to k_scale during weight loading, and duplicate + # k_scale to v_scale here + assert layer.k_scale > 0.0 + scale_to_duplicate = max(layer.k_scale, layer.v_scale) + k_scale = scale_to_duplicate.to("cpu").tolist() + v_scale = scale_to_duplicate.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + k_scale *= 2 + v_scale *= 2 + layer.calculate_kv_scales = False - if not isinstance(k_scale, float) or not isinstance( - v_scale, float): - raise ValueError("Only support per-tensor scaling factor " - "for fp8 KV cache") + if not isinstance(k_scale, float) or not isinstance(v_scale, float): + raise ValueError("Only support per-tensor scaling factor " + "for fp8 KV cache") - # These are used in the final Attention.forward() - layer._k_scale.copy_(k_scale) - layer._v_scale.copy_(v_scale) - if (k_scale == 1.0 and v_scale == 1.0 - and (layer.kv_cache_dtype != "auto" or layer.use_fp8) - and "e5m2" not in layer.kv_cache_dtype): - print_warning_once( - "Using KV cache scaling factor 1.0 for fp8_e4m3. This " - "may cause accuracy issues. Please make sure k/v_scale " - "scaling factors are available in the fp8 checkpoint.") + # These are used in the final Attention.forward() + layer._k_scale.copy_(k_scale) + layer._v_scale.copy_(v_scale) + if (k_scale == 1.0 and v_scale == 1.0 + and (layer.kv_cache_dtype != "auto" or layer.use_fp8) + and "e5m2" not in layer.kv_cache_dtype): + print_warning_once( + "Using KV cache scaling factor 1.0 for fp8_e4m3. This " + "may cause accuracy issues. Please make sure k/v_scale " + "scaling factors are available in the fp8 checkpoint.") if layer.q_scale > 0.0 and layer.prob_scale > 0.0: q_scale = layer.q_scale.to("cpu").tolist() if current_platform.is_rocm() and not is_navi(): q_scale *= 2 + layer.calculate_kv_scales = False else: q_scale = 1.0 if layer.prob_scale > 0.0: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 05b237c955369..69d63216d379d 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -201,9 +201,7 @@ def __init__( self.attn_fp8 = envs.VLLM_USE_ROCM_FP8_ATTN \ and current_platform.is_rocm() \ and not is_navi() \ - and isinstance(quant_config, Fp8Config) \ - and hasattr(self.o_proj, "input_scale") \ - and self.o_proj.input_scale is not None + and isinstance(quant_config, Fp8Config) self.attn = Attention( self.num_heads,