Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ingest FP8 attn scales and use them in ROCm FlashAttention #338

Merged
merged 7 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,6 @@ def forward(
v_scale: torch.Tensor,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
raise NotImplementedError
2 changes: 1 addition & 1 deletion vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def forward(
v_scale: torch.Tensor,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.

Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.

Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:

# TODO: directly write to output tensor
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.

Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.

Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.

Expand Down
10 changes: 8 additions & 2 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def forward(
v_scale: torch.Tensor,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: torch.Tensor = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.

Expand Down Expand Up @@ -601,6 +601,8 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""
q_scale, prob_scale, fp8_out_scale = fp8_comp_scales or (None, None,
None)

query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
Expand Down Expand Up @@ -681,6 +683,10 @@ def forward(
query.dtype,
seq_lens,
make_attn_mask=False) # type: ignore
full_scales = (
1.0 / q_scale.item(), 1.0 / k_scale.item(),
1.0 / v_scale.item(), 1.0 / prob_scale.item(),
fp8_out_scale.item()) if fp8_comp_scales else None
out, _ = self.attn_func(
query,
key,
Expand All @@ -694,7 +700,7 @@ def forward(
self.scale,
attn_masks[0][None]
if attn_masks is not None else None,
None,
full_scales,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.

Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.

Expand Down
14 changes: 9 additions & 5 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Attention layer."""
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -42,6 +42,7 @@ def __init__(
logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None,
prefix: str = "",
use_fp8: bool = False,
) -> None:
super().__init__()
if per_layer_sliding_window is not None:
Expand Down Expand Up @@ -73,8 +74,11 @@ def __init__(
# with the model weights.
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
self.use_fp8 = use_fp8
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
if quant_method is not None:
Expand Down Expand Up @@ -106,11 +110,11 @@ def __init__(
self.num_kv_heads = num_kv_heads
self.backend = backend_name_to_enum(attn_backend.get_name())

# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# For cuda and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = not current_platform.is_cuda_alike(
self.use_direct_call = not current_platform.is_cuda(
) and not current_platform.is_cpu()

# For some attention backends, we allocate an output tensor before
Expand All @@ -135,7 +139,7 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_type: str = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
if self.calculate_kv_scales and \
attn_metadata.enable_kv_scales_calculation:
Expand All @@ -150,7 +154,7 @@ def forward(
self._k_scale,
self._v_scale,
attn_type=attn_type,
fp8_out_scale=fp8_out_scale)
fp8_comp_scales=fp8_comp_scales)
elif self.use_output:
output = torch.empty_like(query)
hidden_size = query.size(-1)
Expand Down
11 changes: 5 additions & 6 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
VLLM_USE_TRITON_FLASH_ATTN: bool = True
VLLM_USE_ROCM_SKINNY_GEMM: bool = True
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = True
VLLM_USE_ROCM_FP8_ATTN: bool = True
RANK: int = 0
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
Expand Down Expand Up @@ -242,13 +242,12 @@ def get_default_config_root():
# custom paged attention implemented for MI3* cards
"VLLM_USE_ROCM_CUSTOM_PAGED_ATTN":
lambda: (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in
("true", "1") != "0"),
("true", "1")),

# have custom paged attention implemented for MI3* cards write out fp8
"VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT":
lambda:
(os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT", "True").lower() in
("true", "1") != "0"),
"VLLM_USE_ROCM_FP8_ATTN":
lambda: (os.getenv("VLLM_USE_ROCM_FP8_ATTN", "True").lower() in
("true", "1")),

# rank of the process in the distributed setting, used to determine
# the driver worker
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
if name.endswith(".output_scale") and ".q_proj" in name:
return name.replace(".q_proj.output_scale", ".attn.q_scale")
if name.endswith("self_attn.prob_output_scale"):
return name.replace(".prob_output_scale", ".attn.prob_scale")
# If no matches, return None
return None

Expand Down
39 changes: 36 additions & 3 deletions vllm/model_executor/layers/quantization/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,20 @@ def create_weights(self, layer: torch.nn.Module):
requires_grad=False)
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)
# Initialize Q and P = softmax(QK^T) scales
layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)
layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)

def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError(
f"{self.__class__.__name__}.apply should not be called.")

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
# No need to process kv scales after loading if we are going to
# calculate them on the fly.
if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
if not layer.calculate_kv_scales:
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
# We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist()
Expand Down Expand Up @@ -75,11 +78,41 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer._k_scale.copy_(k_scale)
layer._v_scale.copy_(v_scale)
if (k_scale == 1.0 and v_scale == 1.0
and (layer.kv_cache_dtype != "auto" or layer.use_fp8)
and "e5m2" not in layer.kv_cache_dtype):
print_warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint.")

if layer.q_scale > 0.0 and layer.prob_scale > 0.0:
q_scale = layer.q_scale.to("cpu").tolist()
if current_platform.is_rocm() and not is_navi():
q_scale *= 2
else:
q_scale = 1.0
if layer.prob_scale > 0.0:
prob_scale = layer.prob_scale.to("cpu").tolist()
if current_platform.is_rocm() and not is_navi():
prob_scale *= 2
else:
prob_scale = 1.0

if not isinstance(q_scale, float) or not isinstance(prob_scale, float):
raise ValueError("Only support per-tensor scaling factor"
"for fp8-quantized Q/prob")

# These are used in the final Attention.forward()
layer._q_scale.copy_(q_scale)
layer._prob_scale.copy_(prob_scale)
if (q_scale == 1.0 or prob_scale == 1.0) and layer.use_fp8:
print_warning_once(
f"Using Q scale {q_scale} and prob scale {prob_scale} "
"with fp8 attention. This may cause accuracy issues. "
"Please make sure Q/prob scaling factors are "
"available in the fp8 checkpoint.")

del layer.k_scale
del layer.v_scale
del layer.q_scale
del layer.prob_scale
4 changes: 3 additions & 1 deletion vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
if loaded_weight.shape:
# scalar shape is torch.Size([1]), not torch.Size([])
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/exaone.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
if loaded_weight.shape:
# scalar shape is torch.Size([1]), not torch.Size([])
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
if loaded_weight.shape:
# scalar shape is torch.Size([1]), not torch.Size([])
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down
24 changes: 16 additions & 8 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,14 @@ def __init__(
else:
sliding_window = None

# For CUDA devices and Navi4x, attn_fp8 will be set to false.
self.attn_fp8 = envs.VLLM_USE_ROCM_FP8_ATTN \
and current_platform.is_rocm() \
and not is_navi() \
and isinstance(quant_config, Fp8Config) \
and hasattr(self.o_proj, "input_scale") \
and self.o_proj.input_scale is not None

self.attn = Attention(
self.num_heads,
self.head_dim,
Expand All @@ -206,12 +214,8 @@ def __init__(
quant_config=quant_config,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn",
use_fp8=self.attn_fp8,
)
# For CUDA devices and Navi4x, attn_fp8_out will be set to false.
self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \
and current_platform.is_rocm() \
and not is_navi() \
and isinstance(quant_config, Fp8Config)

def forward(
self,
Expand All @@ -228,8 +232,10 @@ def forward(
v,
kv_cache,
attn_metadata,
fp8_out_scale=self.o_proj.input_scale
if self.attn_fp8_out else None)
fp8_comp_scales=(self.attn._q_scale,
self.attn._prob_scale,
self.o_proj.input_scale)
if self.attn_fp8 else None)
output, _ = self.o_proj(attn_output)
return output

Expand Down Expand Up @@ -428,7 +434,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
if loaded_weight.shape:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for mllama.py?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I targeted only the models which unconditionally do this logic

# scalar shape is torch.Size([1]), not torch.Size([])
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/solar.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
if loaded_weight.shape:
# scalar shape is torch.Size([1]), not torch.Size([])
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down
Loading