Skip to content

Commit

Permalink
[4/N] make quant config first-class citizen (vllm-project#9978)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Richard Liu <[email protected]>
  • Loading branch information
youkaichao authored and richardsliu committed Nov 4, 2024
1 parent 9ddd35a commit 6eb2ed0
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 31 deletions.
38 changes: 38 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@
from ray.util.placement_group import PlacementGroup

from vllm.executor.executor_base import ExecutorBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.loader import BaseModelLoader
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
else:
QuantizationConfig = None

logger = init_logger(__name__)

Expand Down Expand Up @@ -1966,6 +1970,35 @@ class VllmConfig:
decoding_config: Optional[DecodingConfig] = None
observability_config: Optional[ObservabilityConfig] = None
prompt_adapter_config: Optional[PromptAdapterConfig] = None
quant_config: Optional[QuantizationConfig] = None

@staticmethod
def _get_quantization_config(
model_config: ModelConfig,
load_config: LoadConfig) -> Optional[QuantizationConfig]:
"""Get the quantization config."""
if model_config.quantization is not None:
from vllm.model_executor.model_loader.weight_utils import (
get_quant_config)
quant_config = get_quant_config(model_config, load_config)
capability_tuple = current_platform.get_device_capability()

if capability_tuple is not None:
capability = capability_tuple.to_int()
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} "
"is not supported for the current GPU. Minimum "
f"capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}.")
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}")
return quant_config
return None

def __post_init__(self):
"""Verify configs are valid & consistent with each other.
Expand All @@ -1983,3 +2016,8 @@ def __post_init__(self):
if self.prompt_adapter_config:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)

if self.quant_config is None and \
self.model_config is not None and self.load_config is not None:
self.quant_config = VllmConfig._get_quantization_config(
self.model_config, self.load_config)
34 changes: 3 additions & 31 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator,
get_gguf_extra_tensor_names, gguf_quant_weights_iterator,
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
safetensors_weights_iterator)
from vllm.model_executor.models import (has_inner_state, supports_lora,
Expand Down Expand Up @@ -93,32 +93,6 @@ def device_loading_context(module: torch.nn.Module,
logger = init_logger(__name__)


def _get_quantization_config(
model_config: ModelConfig,
load_config: LoadConfig) -> Optional[QuantizationConfig]:
"""Get the quantization config."""
if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config)
capability_tuple = current_platform.get_device_capability()

if capability_tuple is not None:
capability = capability_tuple.to_int()
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} "
"is not supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}.")
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}")
return quant_config
return None


def _get_model_initialization_kwargs(
model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig],
Expand Down Expand Up @@ -185,15 +159,14 @@ def _initialize_model(vllm_config: VllmConfig) -> nn.Module:
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
cache_config = vllm_config.cache_config
load_config = vllm_config.load_config
model_class, _ = get_model_architecture(model_config)

return build_model(
model_class,
vllm_config,
model_config.hf_config,
cache_config=cache_config,
quant_config=_get_quantization_config(model_config, load_config),
quant_config=vllm_config.quant_config,
lora_config=lora_config,
multimodal_config=model_config.multimodal_config,
scheduler_config=scheduler_config,
Expand Down Expand Up @@ -518,8 +491,7 @@ def _load_model_serialized(
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model_class = get_model_architecture(model_config)[0]
quant_config = _get_quantization_config(
model_config, self.load_config)
quant_config = vllm_config.quant_config
extra_kwargs = _get_model_initialization_kwargs(
model_class, lora_config, model_config.multimodal_config)
extra_kwargs["quant_config"] = quant_config
Expand Down

0 comments on commit 6eb2ed0

Please sign in to comment.