Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Kernel ] Fp8 Channelwise Weight Support #6487

Merged
3 changes: 2 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ def _verify_quantization(self) -> None:
f"{self.quantization} quantization is currently not "
f"supported in ROCm.")
if (self.quantization
not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin")):
not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin",
"compressed-tesnors")):
robertgshaw2-neuralmagic marked this conversation as resolved.
Show resolved Hide resolved
logger.warning(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,11 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel,
# Confirm weight scheme is supported.
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_per_tensor_weight = (
weight_quant.strategy == QuantizationStrategy.TENSOR)
is_per_tensor_or_channel_weight = (
weight_quant.strategy == QuantizationStrategy.TENSOR
or weight_quant.strategy == QuantizationStrategy.CHANNEL)
robertgshaw2-neuralmagic marked this conversation as resolved.
Show resolved Hide resolved
if not (is_symmetric_weight and is_static_weight
and is_per_tensor_weight):
and is_per_tensor_or_channel_weight):
return False

# Dynamic quantization is always supported if weights supported.
Expand Down Expand Up @@ -167,6 +168,7 @@ def _is_wNa16_group_channel(self, weight_quant: BaseModel,
def _get_schema(self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme":

# Detect If Mixed Precision
if self._is_wNa16_group_channel(weight_quant, input_quant):
self._check_gptq_and_marlin_can_run()
if (self.quant_format == CompressionFormat.marlin_24.value
Expand All @@ -182,11 +184,15 @@ def _get_schema(self, weight_quant: BaseModel,
strategy=weight_quant.strategy,
group_size=weight_quant.group_size)

if (self.quant_format == CompressionFormat.int_quantized.value or
self.quant_format == CompressionFormat.float_quantized.value):
# Detect If Activation Quantization.
if (self.quant_format == CompressionFormat.naive_quantized.value
or self.quant_format == CompressionFormat.int_quantized.value
or self.quant_format
== CompressionFormat.float_quantized.value):
robertgshaw2-neuralmagic marked this conversation as resolved.
Show resolved Hide resolved
if self._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8(
input_dynamic=input_quant.dynamic)
strategy=weight_quant.strategy,
is_static_input_scheme=(not input_quant.dynamic))

if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from typing import Callable, List, Optional

import torch
from torch.nn import Parameter

from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_tensor_scale_param, cutlass_fp8_supported,
apply_fp8_linear, create_per_channel_scale_param,
create_per_tensor_scale_param, cutlass_fp8_supported,
requantize_with_max_scale)
from vllm.model_executor.utils import set_weight_attrs

Expand All @@ -14,39 +18,56 @@

class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):

def __init__(self, input_dynamic: bool):
self.input_dynamic = input_dynamic
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported()

# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), we requantize with a single scale.
# On Lovelace, fall ba
robertgshaw2-neuralmagic marked this conversation as resolved.
Show resolved Hide resolved
if (not self.cutlass_fp8_supported
and self.strategy == QuantizationStrategy.CHANNEL):
raise ValueError(
"Channelwise fp8 quantization requires vLLM's custom "
"cutlass kernels, which are not supported on your device."
"Consider quantizing with per tensor scales or upgrading "
"to Hopper.")

def process_weights_after_loading(self, layer) -> None:
# Dequant -> Quant with max scale.
max_w_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)

# Update layer with new values.
layer.weight = torch.nn.Parameter(weight.t(), requires_grad=False)
layer.weight_scale = torch.nn.Parameter(max_w_scale,
requires_grad=False)
if self.input_dynamic:
layer.input_scale = None
# If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor with torch._scaled_mm
if (self.strategy == QuantizationStrategy.TENSOR
or not self.cutlass_fp8_supported):
max_w_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)

layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)

# If channelwise, scales are already lined up, so just transpose.
elif self.strategy == QuantizationStrategy.CHANNEL:
assert self.cutlass_fp8_supported
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)

else:
raise ValueError(f"Unknown quantization strategy {self.strategy}")

# INPUT SCALE
if self.is_static_input_scheme:
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
else:
layer.input_scale = torch.nn.Parameter(layer.input_scale.max(),
requires_grad=False)
layer.input_scale = None

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

del params_dtype

output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes

Expand All @@ -63,12 +84,17 @@ def create_weights(self, layer: torch.nn.Module,
})

# WEIGHT SCALE
weight_scale = create_per_tensor_scale_param(
output_partition_sizes, weight_loader=weight_loader)
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = create_per_channel_scale_param(
output_partition_sizes, weight_loader=weight_loader)
else:
assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = create_per_tensor_scale_param(
output_partition_sizes, weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)

# INPUT SCALE
if not self.input_dynamic:
if self.is_static_input_scheme:
input_scale = create_per_tensor_scale_param(
output_partition_sizes, weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@

class CompressionFormat(Enum):
dense = "dense"
# Sparsity
sparse_bitmask = "sparse-bitmask"
# For Activation Quantization
naive_quantized = "naive-quantized"
float_quantized = "float-quantized"
int_quantized = "int-quantized"
# For Marlin
pack_quantized = "pack-quantized"
# For Marlin 2:4
marlin_24 = "marlin-24"


Expand Down
Loading