Skip to content

Commit

Permalink
Ingest FP8 attn scales and use them in Triton FA, if present
Browse files Browse the repository at this point in the history
  • Loading branch information
mawong-amd committed Dec 19, 2024
1 parent d08b78b commit 9ba2fab
Show file tree
Hide file tree
Showing 19 changed files with 99 additions and 37 deletions.
2 changes: 1 addition & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,6 @@ def forward(
v_scale: torch.Tensor,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
raise NotImplementedError
2 changes: 1 addition & 1 deletion vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def forward(
v_scale: torch.Tensor,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:

# TODO: directly write to output tensor
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Expand Down
10 changes: 8 additions & 2 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def forward(
v_scale: torch.Tensor,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: torch.Tensor = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Expand Down Expand Up @@ -601,6 +601,8 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""
q_scale, prob_scale, fp8_out_scale = fp8_comp_scales or (None, None,
None)

query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
Expand Down Expand Up @@ -681,6 +683,10 @@ def forward(
query.dtype,
seq_lens,
make_attn_mask=False) # type: ignore
full_scales = (
1.0 / q_scale.item(), 1.0 / k_scale.item(),
1.0 / v_scale.item(), 1.0 / prob_scale.item(),
fp8_out_scale.item()) if fp8_comp_scales else None
out, _ = self.attn_func(
query,
key,
Expand All @@ -694,7 +700,7 @@ def forward(
self.scale,
attn_masks[0][None]
if attn_masks is not None else None,
None,
full_scales,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand Down
14 changes: 9 additions & 5 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Attention layer."""
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -42,6 +42,7 @@ def __init__(
logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None,
prefix: str = "",
use_fp8: bool = False,
) -> None:
super().__init__()
if per_layer_sliding_window is not None:
Expand Down Expand Up @@ -73,8 +74,11 @@ def __init__(
# with the model weights.
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
self.use_fp8 = use_fp8
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
if quant_method is not None:
Expand Down Expand Up @@ -106,11 +110,11 @@ def __init__(
self.num_kv_heads = num_kv_heads
self.backend = backend_name_to_enum(attn_backend.get_name())

# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# For cuda and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = not current_platform.is_cuda_alike(
self.use_direct_call = not current_platform.is_cuda(
) and not current_platform.is_cpu()

# For some attention backends, we allocate an output tensor before
Expand All @@ -135,7 +139,7 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_type: str = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
if self.calculate_kv_scales and \
attn_metadata.enable_kv_scales_calculation:
Expand All @@ -150,7 +154,7 @@ def forward(
self._k_scale,
self._v_scale,
attn_type=attn_type,
fp8_out_scale=fp8_out_scale)
fp8_comp_scales=fp8_comp_scales)
elif self.use_output:
output = torch.empty_like(query)
hidden_size = query.size(-1)
Expand Down
11 changes: 5 additions & 6 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
VLLM_USE_TRITON_FLASH_ATTN: bool = True
VLLM_USE_ROCM_SKINNY_GEMM: bool = True
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = True
VLLM_USE_ROCM_FP8_ATTN: bool = True
RANK: int = 0
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
Expand Down Expand Up @@ -242,13 +242,12 @@ def get_default_config_root():
# custom paged attention implemented for MI3* cards
"VLLM_USE_ROCM_CUSTOM_PAGED_ATTN":
lambda: (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in
("true", "1") != "0"),
("true", "1")),

# have custom paged attention implemented for MI3* cards write out fp8
"VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT":
lambda:
(os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT", "True").lower() in
("true", "1") != "0"),
"VLLM_USE_ROCM_FP8_ATTN":
lambda: (os.getenv("VLLM_USE_ROCM_FP8_ATTN", "True").lower() in
("true", "1")),

# rank of the process in the distributed setting, used to determine
# the driver worker
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
if name.endswith(".output_scale") and ".q_proj" in name:
return name.replace(".q_proj.output_scale", ".attn.q_scale")
if name.endswith("self_attn.prob_output_scale"):
return name.replace(".prob_output_scale", ".attn.prob_scale")
# If no matches, return None
return None

Expand Down
39 changes: 36 additions & 3 deletions vllm/model_executor/layers/quantization/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,20 @@ def create_weights(self, layer: torch.nn.Module):
requires_grad=False)
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)
# Initialize Q and P = softmax(QK^T) scales
layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)
layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)

def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError(
f"{self.__class__.__name__}.apply should not be called.")

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
# No need to process kv scales after loading if we are going to
# calculate them on the fly.
if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
if not layer.calculate_kv_scales:
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
# We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist()
Expand Down Expand Up @@ -75,11 +78,41 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer._k_scale.copy_(k_scale)
layer._v_scale.copy_(v_scale)
if (k_scale == 1.0 and v_scale == 1.0
and (layer.kv_cache_dtype != "auto" or layer.use_fp8)
and "e5m2" not in layer.kv_cache_dtype):
print_warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint.")

if layer.q_scale > 0.0 and layer.prob_scale > 0.0:
q_scale = layer.q_scale.to("cpu").tolist()
if current_platform.is_rocm() and not is_navi():
q_scale *= 2
else:
q_scale = 1.0
if layer.prob_scale > 0.0:
prob_scale = layer.prob_scale.to("cpu").tolist()
if current_platform.is_rocm() and not is_navi():
prob_scale *= 2
else:
prob_scale = 1.0

if not isinstance(q_scale, float) or not isinstance(prob_scale, float):
raise ValueError("Only support per-tensor scaling factor"
"for fp8-quantized Q/prob")

# These are used in the final Attention.forward()
layer._q_scale.copy_(q_scale)
layer._prob_scale.copy_(prob_scale)
if (q_scale == 1.0 or prob_scale == 1.0) and layer.use_fp8:
print_warning_once(
f"Using Q scale {q_scale} and prob scale {prob_scale} "
"with fp8 attention. This may cause accuracy issues. "
"Please make sure Q/prob scaling factors are "
"available in the fp8 checkpoint.")

del layer.k_scale
del layer.v_scale
del layer.q_scale
del layer.prob_scale
4 changes: 3 additions & 1 deletion vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
if loaded_weight.shape:
# scalar shape is torch.Size([1]), not torch.Size([])
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/exaone.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
if loaded_weight.shape:
# scalar shape is torch.Size([1]), not torch.Size([])
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
if loaded_weight.shape:
# scalar shape is torch.Size([1]), not torch.Size([])
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down
24 changes: 16 additions & 8 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,14 @@ def __init__(
else:
sliding_window = None

# For CUDA devices and Navi4x, attn_fp8 will be set to false.
self.attn_fp8 = envs.VLLM_USE_ROCM_FP8_ATTN \
and current_platform.is_rocm() \
and not is_navi() \
and isinstance(quant_config, Fp8Config) \
and hasattr(self.o_proj, "input_scale") \
and self.o_proj.input_scale is not None

self.attn = Attention(
self.num_heads,
self.head_dim,
Expand All @@ -206,12 +214,8 @@ def __init__(
quant_config=quant_config,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn",
use_fp8=self.attn_fp8,
)
# For CUDA devices and Navi4x, attn_fp8_out will be set to false.
self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \
and current_platform.is_rocm() \
and not is_navi() \
and isinstance(quant_config, Fp8Config)

def forward(
self,
Expand All @@ -228,8 +232,10 @@ def forward(
v,
kv_cache,
attn_metadata,
fp8_out_scale=self.o_proj.input_scale
if self.attn_fp8_out else None)
fp8_comp_scales=(self.attn._q_scale,
self.attn._prob_scale,
self.o_proj.input_scale)
if self.attn_fp8 else None)
output, _ = self.o_proj(attn_output)
return output

Expand Down Expand Up @@ -428,7 +434,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
if loaded_weight.shape:
# scalar shape is torch.Size([1]), not torch.Size([])
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/solar.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
if loaded_weight.shape:
# scalar shape is torch.Size([1]), not torch.Size([])
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down

0 comments on commit 9ba2fab

Please sign in to comment.