Skip to content

Commit

Permalink
[Bugfix] Allow fallback to AWQ from AWQMarlin at per-layer granularity (
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin authored Feb 12, 2025
1 parent 36a0863 commit 09972e7
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 29 deletions.
35 changes: 19 additions & 16 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,29 +290,30 @@ def __init__(self,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[list[int]] = None,
prefix: str = ""):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)

self.gather_output = gather_output

# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
assert self.quant_method is not None
self.output_size_per_partition = divide(self.output_size, tp_size)
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, tp_size)
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]

super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)

self.gather_output = gather_output

if output_sizes is None:
output_sizes = [output_size]

assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
Expand Down Expand Up @@ -1044,22 +1045,24 @@ def __init__(self,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
# Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]

super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)

self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results

# Divide the weight matrix along the last dimension.
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
assert self.quant_method is not None

self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=[self.output_size],
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
Expand Down
28 changes: 18 additions & 10 deletions vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
marlin_permute_scales, moe_awq_to_marlin_zero_points,
verify_marlin_supported, verify_marlin_supports_shape)
check_marlin_supports_layer, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_moe_permute_scales, marlin_permute_scales,
moe_awq_to_marlin_zero_points, verify_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
Expand All @@ -40,18 +42,17 @@ class AWQMarlinConfig(QuantizationConfig):
8: scalar_types.uint8,
}

def __init__(self,
weight_bits: int,
group_size: int,
zero_point: bool,
def __init__(self, weight_bits: int, group_size: int, zero_point: bool,
lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]] = None) -> None:
modules_to_not_convert: Optional[List[str]],
full_config: Dict[str, Any]) -> None:
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.zero_point = zero_point
self.lm_head_quantized = lm_head_quantized
self.weight_bits = weight_bits
self.modules_to_not_convert = modules_to_not_convert or []
self.full_config = full_config

if self.weight_bits not in self.TYPE_MAP:
raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
Expand Down Expand Up @@ -96,7 +97,7 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig":
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
return cls(weight_bits, group_size, zero_point, lm_head_quantized,
modules_to_not_convert)
modules_to_not_convert, config)

@classmethod
def override_quantization_method(cls, hf_quant_cfg,
Expand Down Expand Up @@ -124,6 +125,13 @@ def get_quant_method(self, layer: torch.nn.Module,
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
# Check if the layer is supported by AWQMarlin.
if not check_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
f"Layer '{prefix}' is not supported by AWQMarlin. "
"Falling back to unoptimized AWQ kernels.")
return AWQConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self)
Expand Down
9 changes: 6 additions & 3 deletions vllm/model_executor/layers/quantization/moe_wna16.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supports_layer)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform

Expand Down Expand Up @@ -87,8 +89,8 @@ def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config":
modules_to_not_convert = []
elif linear_quant_method == "awq":
has_zp = cls.get_from_keys(config, ["zero_point"])
modules_to_not_convert = cls.get_from_keys(
config, ["modules_to_not_convert"])
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
else:
raise ValueError("moe_wna16 only support gptq and awq.")

Expand Down Expand Up @@ -135,7 +137,8 @@ def get_quant_method(self, layer: torch.nn.Module,
return GPTQConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
elif self.linear_quant_method == "awq":
if self.use_marlin:
if self.use_marlin and check_marlin_supports_layer(
layer, self.group_size):
return AWQMarlinConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
Expand Down
15 changes: 15 additions & 0 deletions vllm/model_executor/layers/quantization/utils/marlin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types

Expand Down Expand Up @@ -135,6 +136,20 @@ def check_marlin_supports_shape(output_size_per_partition: int,
return True, None


def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
-> bool:
output_size_per_partition = getattr(layer, "output_size_per_partition",
None) or layer.output_size
input_size_per_partition = getattr(layer, "input_size_per_partition",
None) or layer.input_size

return check_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=layer.input_size,
group_size=group_size)[0]


def marlin_make_workspace(output_size_per_partition: int,
device: torch.device) -> torch.Tensor:
max_workspace_size = (output_size_per_partition //
Expand Down

0 comments on commit 09972e7

Please sign in to comment.