Skip to content

Commit

Permalink
[platform] Add verify_quantization in platform. (#10757)
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <[email protected]>
  • Loading branch information
wangxiyuan authored Nov 29, 2024
1 parent 3132aac commit 661175b
Show file tree
Hide file tree
Showing 10 changed files with 38 additions and 27 deletions.
28 changes: 1 addition & 27 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,17 +393,11 @@ def _parse_quant_hf_config(self):

def _verify_quantization(self) -> None:
supported_quantization = QUANTIZATION_METHODS
rocm_supported_quantization = [
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
"fbgemm_fp8", "gguf"
]
optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
"compressed-tensors", "experts_int8"
]
tpu_supported_quantization = ["tpu_int8"]
neuron_supported_quantization = ["neuron_quant"]
if self.quantization is not None:
self.quantization = self.quantization.lower()

Expand Down Expand Up @@ -438,32 +432,12 @@ def _verify_quantization(self) -> None:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.")
if current_platform.is_rocm(
) and self.quantization not in rocm_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in ROCm.")
if current_platform.is_tpu(
) and self.quantization not in tpu_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in TPU Backend.")
current_platform.verify_quantization(self.quantization)
if self.quantization not in optimized_quantization_methods:
logger.warning(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.", self.quantization)
if (self.quantization == "awq" and current_platform.is_rocm()
and not envs.VLLM_USE_TRITON_AWQ):
logger.warning(
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.")
envs.VLLM_USE_TRITON_AWQ = True
if current_platform.is_neuron(
) and self.quantization not in neuron_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in Neuron Backend.")

def _verify_cuda_graph(self) -> None:
if self.max_seq_len_to_capture is None:
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

class CpuPlatform(Platform):
_enum = PlatformEnum.CPU
device_name: str = "cpu"
device_type: str = "cpu"
dispatch_key: str = "CPU"

Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:

class CudaPlatformBase(Platform):
_enum = PlatformEnum.CUDA
device_name: str = "cuda"
device_type: str = "cuda"
dispatch_key: str = "CUDA"

Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

class HpuPlatform(Platform):
_enum = PlatformEnum.HPU
device_name: str = "hpu"
device_type: str = "hpu"
dispatch_key: str = "HPU"

Expand Down
13 changes: 13 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,13 @@ def to_int(self) -> int:

class Platform:
_enum: PlatformEnum
device_name: str
device_type: str
# available dispatch keys:
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
# use "CPU" as a fallback for platforms not registered in PyTorch
dispatch_key: str = "CPU"
supported_quantization: list[str] = []

def is_cuda(self) -> bool:
return self._enum == PlatformEnum.CUDA
Expand Down Expand Up @@ -171,6 +173,17 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"""
pass

@classmethod
def verify_quantization(cls, quant: str) -> None:
"""
Verify whether the quantization is supported by the current platform.
"""
if cls.supported_quantization and \
quant not in cls.supported_quantization:
raise ValueError(
f"{quant} quantization is currently not supported in "
f"{cls.device_name}.")


class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
Expand Down
2 changes: 2 additions & 0 deletions vllm/platforms/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

class NeuronPlatform(Platform):
_enum = PlatformEnum.NEURON
device_name: str = "neuron"
device_type: str = "neuron"
supported_quantization: list[str] = ["neuron_quant"]

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

class OpenVinoPlatform(Platform):
_enum = PlatformEnum.OPENVINO
device_name: str = "openvino"
device_type: str = "openvino"
dispatch_key: str = "CPU"

Expand Down
15 changes: 15 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

import vllm.envs as envs
from vllm.logger import init_logger

from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
Expand Down Expand Up @@ -35,8 +36,13 @@

class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
device_name: str = "rocm"
device_type: str = "cuda"
dispatch_key: str = "CUDA"
supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
"fbgemm_fp8", "gguf"
]

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
Expand Down Expand Up @@ -79,3 +85,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"

@classmethod
def verify_quantization(cls, quant: str) -> None:
super().verify_quantization(quant)
if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
logger.warning(
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.")
envs.VLLM_USE_TRITON_AWQ = True
2 changes: 2 additions & 0 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

class TpuPlatform(Platform):
_enum = PlatformEnum.TPU
device_name: str = "tpu"
device_type: str = "tpu"
dispatch_key: str = "XLA"
supported_quantization: list[str] = ["tpu_int8"]

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

class XPUPlatform(Platform):
_enum = PlatformEnum.XPU
device_name: str = "xpu"
device_type: str = "xpu"
dispatch_key: str = "XPU"

Expand Down

0 comments on commit 661175b

Please sign in to comment.