Skip to content

Commit

Permalink
[Bugfix] Get available quantization methods from quantization registry (
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin authored Apr 18, 2024
1 parent 66ded03 commit 53b018e
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 13 deletions.
3 changes: 2 additions & 1 deletion benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tqdm import tqdm

from vllm import LLM, SamplingParams
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS


def main(args: argparse.Namespace):
Expand Down Expand Up @@ -101,7 +102,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', 'gptq', 'squeezellm', None],
choices=[*QUANTIZATION_METHODS, None],
default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32)
Expand Down
4 changes: 3 additions & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)

from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS


def sample_requests(
dataset_path: str,
Expand Down Expand Up @@ -267,7 +269,7 @@ def main(args: argparse.Namespace):
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', 'gptq', 'squeezellm', None],
choices=[*QUANTIZATION_METHODS, None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",
Expand Down
7 changes: 3 additions & 4 deletions tests/models/test_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@
import pytest
import torch

from vllm.model_executor.layers.quantization import (
_QUANTIZATION_CONFIG_REGISTRY)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS

capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
marlin_not_supported = (
capability < _QUANTIZATION_CONFIG_REGISTRY["marlin"].get_min_capability())
marlin_not_supported = (capability <
QUANTIZATION_METHODS["marlin"].get_min_capability())


@dataclass
Expand Down
7 changes: 4 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from transformers import PretrainedConfig

from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
is_neuron)
Expand Down Expand Up @@ -118,8 +119,8 @@ def _verify_tokenizer_mode(self) -> None:
self.tokenizer_mode = tokenizer_mode

def _verify_quantization(self) -> None:
supported_quantization = ["awq", "gptq", "squeezellm", "marlin"]
rocm_not_supported_quantization = ["awq", "marlin"]
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["gptq", "squeezellm"]
if self.quantization is not None:
self.quantization = self.quantization.lower()

Expand Down Expand Up @@ -155,7 +156,7 @@ def _verify_quantization(self) -> None:
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.")
if is_hip(
) and self.quantization in rocm_not_supported_quantization:
) and self.quantization not in rocm_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in ROCm.")
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
TokenizerPoolConfig, VisionLanguageConfig)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import str_to_int_tuple


Expand Down Expand Up @@ -286,7 +287,7 @@ def add_cli_args(
parser.add_argument('--quantization',
'-q',
type=str,
choices=['awq', 'gptq', 'squeezellm', None],
choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.quantization,
help='Method used to quantize the weights. If '
'None, we first check the `quantization_config` '
Expand Down
7 changes: 4 additions & 3 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig

_QUANTIZATION_CONFIG_REGISTRY = {
QUANTIZATION_METHODS = {
"awq": AWQConfig,
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
Expand All @@ -16,12 +16,13 @@


def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization not in _QUANTIZATION_CONFIG_REGISTRY:
if quantization not in QUANTIZATION_METHODS:
raise ValueError(f"Invalid quantization method: {quantization}")
return _QUANTIZATION_CONFIG_REGISTRY[quantization]
return QUANTIZATION_METHODS[quantization]


__all__ = [
"QuantizationConfig",
"get_quantization_config",
"QUANTIZATION_METHODS",
]

0 comments on commit 53b018e

Please sign in to comment.