From 0f7052bc7e7c3301588705abf7c7fadf3db293a6 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 7 Aug 2024 12:17:58 -0400 Subject: [PATCH] [Misc] Refactor linear layer weight loading; introduce `BasevLLMParameter` and `weight_loader_v2` (#5874) --- tests/quantization/test_compressed_tensors.py | 20 +- vllm/model_executor/__init__.py | 4 + vllm/model_executor/layers/linear.py | 153 +++++++++- .../compressed_tensors/compressed_tensors.py | 21 +- .../schemes/compressed_tensors_unquantized.py | 20 +- .../schemes/compressed_tensors_w4a16_24.py | 112 ++++--- .../schemes/compressed_tensors_w8a16_fp8.py | 51 ++-- .../schemes/compressed_tensors_w8a8_fp8.py | 51 ++-- .../schemes/compressed_tensors_w8a8_int8.py | 47 +-- .../schemes/compressed_tensors_wNa16.py | 102 +++---- vllm/model_executor/parameter.py | 277 ++++++++++++++++++ 11 files changed, 655 insertions(+), 203 deletions(-) create mode 100644 vllm/model_executor/parameter.py diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index bd79da84a7764..2ea340779b819 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, - CompressedTensorsWNA16) + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( QuantizationType) @@ -109,7 +109,7 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): assert qkv_proj.weight_packed.dtype is torch.int32 assert qkv_proj.weight_scale.dtype is torch.float16 - assert qkv_proj.weight_packed.pack_factor == pack_factor + assert qkv_proj.scheme.pack_factor == pack_factor output = llm.generate_greedy("Hello my name is", max_tokens=20) assert output @@ -140,13 +140,17 @@ def test_compressed_tensors_fp8(vllm_runner): qkv_proj = layer.self_attn.qkv_proj assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) - assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8) - assert qkv_proj.weight.dtype is torch.float8_e4m3fn + assert isinstance( + qkv_proj.scheme, + (CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8)) + assert qkv_proj.input_scale.dtype is torch.float32 - assert qkv_proj.weight_scale.dtype is torch.float32 - # should be scalars after processing - assert len(qkv_proj.input_scale.shape) == 0 - assert len(qkv_proj.weight_scale.shape) == 0 + + if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8): + assert len(qkv_proj.input_scale.shape) == 0 + assert qkv_proj.weight.dtype is torch.float8_e4m3fn + assert qkv_proj.weight_scale.dtype is torch.float32 + assert len(qkv_proj.weight_scale.shape) == 0 output = llm.generate_greedy("Hello my name is", max_tokens=20) assert output diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index fb98f4a6b46f4..5c767e22de4d0 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -1,7 +1,11 @@ +from vllm.model_executor.parameter import (BasevLLMParameter, + PackedvLLMParameter) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed __all__ = [ "SamplingMetadata", "set_random_seed", + "BasevLLMParameter", + "PackedvLLMParameter", ] diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index cd53c2b916211..646839ff303ee 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -13,10 +13,14 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.parameter import (BasevLLMParameter, + PackedvLLMParameter) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) +WEIGHT_LOADER_V2_SUPPORTED = ["CompressedTensorsLinearMethod"] + def adjust_marlin_shard(param, shard_size, shard_offset): marlin_tile_size = getattr(param, "marlin_tile_size", None) @@ -288,6 +292,7 @@ def __init__(self, if output_sizes is None: output_sizes = [output_size] + self.quant_method.create_weights( layer=self, input_size_per_partition=self.input_size, @@ -295,7 +300,9 @@ def __init__(self, input_size=self.input_size, output_size=self.output_size, params_dtype=self.params_dtype, - weight_loader=self.weight_loader, + weight_loader=( + self.weight_loader_v2 if self.quant_method.__class__.__name__ + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader), prefix=prefix) if bias: self.bias = Parameter( @@ -337,6 +344,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) + def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): + param.load_column_parallel_weight(loaded_weight=loaded_weight) + def forward(self, input_): bias = self.bias if not self.skip_bias_add else None @@ -527,6 +537,62 @@ def weight_loader(self, assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) + def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, + loaded_weight: torch.Tensor): + """ + Handle special case for models where MLP layers are already + fused on disk. In this case, we have no shard id. This function + determmines the shard id by splitting these layers and then calls + the weight loader using the shard id. + + An example of a model with these fused layers: + https://huggingface.co/microsoft/Phi-3-mini-4k-instruct + """ + + current_shard_offset = 0 + shard_offsets: List[Tuple[int, int, int]] = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if isinstance(param, PackedvLLMParameter + ) and param.packed_dim == param.output_dim: + param.adjust_shard_indexes_for_packing( + shard_size=shard_size, shard_offset=shard_offset) + + loaded_weight_shard = loaded_weight.narrow(param.output_dim, + shard_offset, + shard_size) + self.weight_loader_v2(param, loaded_weight_shard, shard_id) + + def weight_loader_v2(self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None): + param_data = param.data + if loaded_shard_id is None: + if param.output_dim is None: + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + self._load_fused_module_from_checkpoint(param, loaded_weight) + return + + assert loaded_shard_id < len(self.output_sizes) + + tp_size = get_tensor_model_parallel_world_size() + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + + param.load_merged_column_weight(loaded_weight=loaded_weight, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size) + class QKVParallelLinear(ColumnParallelLinear): """Linear layers for the attention's QKV transformation. @@ -598,6 +664,82 @@ def __init__(self, quant_config=quant_config, prefix=prefix) + def _get_shard_offset_mapping(self, loaded_shard_id: str): + shard_offset_mapping = { + "q": 0, + "k": self.num_heads * self.head_size, + "v": (self.num_heads + self.num_kv_heads) * self.head_size, + "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size + } + return shard_offset_mapping.get(loaded_shard_id) + + def _get_shard_size_mapping(self, loaded_shard_id: str): + shard_size_mapping = { + "q": self.num_heads * self.head_size, + "k": self.num_kv_heads * self.head_size, + "v": self.num_kv_heads * self.head_size, + } + return shard_size_mapping.get(loaded_shard_id) + + def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, + loaded_weight: torch.Tensor): + """ + Handle special case for models where QKV layers are already + fused on disk. In this case, we have no shard id. This function + determmines the shard id by splitting these layers and then calls + the weight loader using the shard id. + + An example of a model with these fused layers: + https://huggingface.co/microsoft/Phi-3-mini-4k-instruct + """ + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ("k", self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size), + ("v", + (self.total_num_heads + self.total_num_kv_heads) * self.head_size, + self.total_num_kv_heads * self.head_size), + ] + + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if isinstance(param, PackedvLLMParameter + ) and param.packed_dim == param.output_dim: + param.adjust_shard_indexes_for_packing( + shard_size=shard_size, shard_offset=shard_offset) + + loaded_weight_shard = loaded_weight.narrow(param.output_dim, + shard_offset, + shard_size) + self.weight_loader_v2(param, loaded_weight_shard, shard_id) + + def weight_loader_v2(self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + param_data = param.data + if loaded_shard_id is None: # special case for certain models + if param.output_dim is None: + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + self._load_fused_module_from_checkpoint(param, loaded_weight) + return + + assert loaded_shard_id in ["q", "k", "v"] + + shard_offset = self._get_shard_offset_mapping(loaded_shard_id) + shard_size = self._get_shard_size_mapping(loaded_shard_id) + + param.load_qkv_weight(loaded_weight=loaded_weight, + num_heads=self.num_kv_head_replicas, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size) + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, @@ -798,6 +940,7 @@ def __init__(self, self.tp_size = get_tensor_model_parallel_world_size() self.input_size_per_partition = divide(input_size, self.tp_size) assert self.quant_method is not None + self.quant_method.create_weights( layer=self, input_size_per_partition=self.input_size_per_partition, @@ -805,7 +948,9 @@ def __init__(self, input_size=self.input_size, output_size=self.output_size, params_dtype=self.params_dtype, - weight_loader=self.weight_loader, + weight_loader=( + self.weight_loader_v2 if self.quant_method.__class__.__name__ + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader), prefix=prefix) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " @@ -850,6 +995,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) + def weight_loader_v2(self, param: BasevLLMParameter, + loaded_weight: torch.Tensor): + param.load_row_parallel_weight(loaded_weight=loaded_weight) + def forward(self, input_): if self.input_is_parallel: input_parallel = input_ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 39d00bd5733ff..ae75781927381 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -19,6 +19,8 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.platforms import current_platform +__all__ = ["CompressedTensorsLinearMethod"] + class CompressedTensorsConfig(QuantizationConfig): @@ -146,18 +148,15 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, if weight_quant is None or input_quant is None: return False - # Confirm we have floating points. - if not (weight_quant.type == QuantizationType.FLOAT - and input_quant.type == QuantizationType.FLOAT): - return False - # Confirm weight scheme is supported. + is_floating_point = (weight_quant.type == QuantizationType.FLOAT + and input_quant.type == QuantizationType.FLOAT) is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic is_per_tensor_or_channel_weight = (weight_quant.strategy in [ QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL ]) - if not (is_symmetric_weight and is_static_weight + if not (is_floating_point and is_symmetric_weight and is_static_weight and is_per_tensor_or_channel_weight): return False @@ -169,11 +168,7 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, is_symmetric_activation = input_quant.symmetric is_per_tensor_activation = ( input_quant.strategy == QuantizationStrategy.TENSOR) - if not (is_symmetric_activation and is_per_tensor_activation): - return False - - # All conditions satisfied. - return True + return is_symmetric_activation and is_per_tensor_activation def _is_fp8_w8a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: @@ -230,6 +225,7 @@ def _get_scheme_from_parts( group_size=weight_quant.group_size) # Detect If Activation Quantization. + # TODO @dsikka: clean-up conditions if is_activation_quantization_format(self.quant_format): if self._is_fp8_w8a8(weight_quant, input_quant): is_fp8_w8a8_supported = self._check_scheme_supported( @@ -237,7 +233,8 @@ def _get_scheme_from_parts( if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( strategy=weight_quant.strategy, - is_static_input_scheme=(not input_quant.dynamic)) + is_static_input_scheme=(input_quant + and not input_quant.dynamic)) else: return CompressedTensorsW8A16Fp8( strategy=weight_quant.strategy, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py index b7ba29ddc9840..2e8d520eacc81 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py @@ -2,11 +2,10 @@ import torch import torch.nn.functional as F -from torch.nn import Parameter from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.parameter import ModelWeightParameter __all__ = ["CompressedTensorsUnquantized"] @@ -24,7 +23,9 @@ def get_min_capability(cls) -> int: return 70 def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - pass + # required by torch.compile to be torch.nn.Parameter + layer.weight = torch.nn.Parameter(layer.weight.data, + requires_grad=False) def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], @@ -32,14 +33,15 @@ def create_weights(self, layer: torch.nn.Module, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - requires_grad=False) + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) layer.register_parameter("weight", weight) - set_weight_attrs(weight, {"weight_loader": weight_loader}) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index c1adfdb2980b6..9ad61a64e406c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -8,7 +8,10 @@ CompressedTensorsScheme) from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N) -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) from vllm.scalar_type import scalar_types __all__ = ["CompressedTensorsW4A16Sparse24"] @@ -45,7 +48,12 @@ def get_min_capability(cls) -> int: return 80 def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - pass + # required by torch.compile to be torch.nn.Parameter + layer.weight_packed = Parameter(layer.weight_packed.data, + requires_grad=False) + layer.scale_packed = Parameter(layer.scale_packed.data, + requires_grad=False) + layer.meta = Parameter(layer.meta.data, requires_grad=False) def create_weights(self, layer: torch.nn.Module, input_size: int, output_partition_sizes: List[int], @@ -56,79 +64,65 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, pack_factor = 32 // self.quant_type.size_bits output_size_per_partition = sum(output_partition_sizes) - qweight = Parameter( - torch.empty( - input_size_per_partition // self.tile_size // 2, - output_size_per_partition * self.tile_size // pack_factor, - dtype=torch.int32, - ), - requires_grad=False, - ) - set_weight_attrs( - qweight, - { - "input_dim": 0, - "output_dim": 1, - "packed_dim": 1, - "pack_factor": pack_factor, - "marlin_tile_size": self.tile_size, - "weight_loader": weight_loader - }, - ) - - layer.register_parameter("weight_packed", qweight) + qweight = PackedvLLMParameter(data=torch.empty( + input_size_per_partition // self.tile_size // 2, + output_size_per_partition * self.tile_size // pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=pack_factor, + marlin_tile_size=self.tile_size, + weight_loader=weight_loader) input_groups = (1 if self.group_size is None else input_size_per_partition // self.group_size) - scales = Parameter( + weight_scale_args = { + "data": torch.empty( input_groups, output_size_per_partition, dtype=params_dtype, ), - requires_grad=False, - ) - set_weight_attrs( - scales, - { - "output_dim": 1, - "input_dim": None if input_groups == 1 else 0, - "weight_loader": weight_loader - }, - ) - layer.register_parameter("scale_packed", scales) - - weight_shape = Parameter(torch.empty(2, dtype=torch.int64), - requires_grad=False) + "weight_loader": + weight_loader + } + + if self.group_size is not None: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + else: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + + weight_shape = BasevLLMParameter(data=torch.empty(2, + dtype=torch.int64), + weight_loader=weight_loader) + + meta = PackedvLLMParameter(data=torch.empty( + input_size_per_partition // 8 // 2 // 2, + output_size_per_partition * 2, + dtype=torch.int16, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=1, + marlin_tile_size=2, + weight_loader=weight_loader) + layer.register_parameter("weight_packed", qweight) layer.register_parameter("weight_shape", weight_shape) - set_weight_attrs(weight_shape, {"weight_loader": weight_loader}) - - meta = Parameter( - torch.empty( - input_size_per_partition // 8 // 2 // 2, - output_size_per_partition * 2, - dtype=torch.int16, - ), - requires_grad=False, - ) - set_weight_attrs( - meta, - { - "input_dim": 0, - "packed_dim": 1, - "pack_factor": 1, - "output_dim": 1, - "marlin_tile_size": 2, - "weight_loader": weight_loader - }, - ) + layer.register_parameter("scale_packed", scales) layer.register_parameter("meta", meta) max_workspace_size = ( output_size_per_partition // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL + workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int), requires_grad=False) layer.workspace = workspace diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py index eeb7c042e1d1f..3d55d55cc390d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py @@ -9,9 +9,10 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - convert_to_channelwise, create_per_channel_scale_param, - create_per_tensor_scale_param) -from vllm.model_executor.utils import set_weight_attrs + convert_to_channelwise) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) __all__ = ["CompressedTensorsW8A16Fp8"] @@ -40,11 +41,19 @@ def process_weights_after_loading(self, layer) -> None: layer.logical_widths) layer.weight_scale = torch.nn.Parameter(ws_channelwise, requires_grad=False) + else: + # required by torch.compile to be torch.nn.Parameter + layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, + requires_grad=False) # Weights must be transposed for marlin layer.weight = torch.nn.Parameter(layer.weight.t(), requires_grad=False) + if self.is_static_input_scheme: + # required by torch.compile to be torch.nn.Parameter + layer.input_scale = torch.nn.Parameter(layer.input_scale.data, + requires_grad=False) prepare_fp8_layer_for_marlin(layer, strategy="channel") def create_weights(self, layer: torch.nn.Module, input_size: int, @@ -60,35 +69,39 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, layer.orig_dtype = params_dtype # WEIGHT - weight = torch.nn.Parameter(torch.empty(output_size_per_partition, - input_size_per_partition, - dtype=torch.float8_e4m3fn), - requires_grad=False) + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) layer.register_parameter("weight", weight) - set_weight_attrs(weight, { - "input_dim": 1, - "output_dim": 0, - "weight_loader": weight_loader, - }) # WEIGHT SCALE - layer_kwargs = {"weight_loader": weight_loader} if self.strategy == QuantizationStrategy.CHANNEL: - weight_scale = create_per_channel_scale_param( - output_partition_sizes, **layer_kwargs) + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) elif self.strategy == QuantizationStrategy.TENSOR: - weight_scale = create_per_tensor_scale_param( - output_partition_sizes, **layer_kwargs) + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) else: raise ValueError( f"Unsupported weight strategy={self.strategy}, " f"supported strategies are {SUPPORTED_STRATEGIES}") + + weight_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE (to deal with converted checkpoints) if self.is_static_input_scheme: - input_scale = create_per_tensor_scale_param( - output_partition_sizes, **layer_kwargs) + input_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) layer.register_parameter("input_scale", input_scale) def apply_weights(self, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index cc9d71db140c2..8a3d24e2fd258 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -8,10 +8,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( QuantizationStrategy) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, create_per_channel_scale_param, - create_per_tensor_scale_param, cutlass_fp8_supported, - requantize_with_max_scale) -from vllm.model_executor.utils import set_weight_attrs + apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) __all__ = ["CompressedTensorsW8A8Fp8"] @@ -46,6 +46,9 @@ def process_weights_after_loading(self, layer) -> None: elif self.strategy == QuantizationStrategy.CHANNEL: weight = layer.weight layer.weight = Parameter(weight.t(), requires_grad=False) + # required by torch.compile to be torch.nn.Parameter + layer.weight_scale = Parameter(layer.weight_scale.data, + requires_grad=False) else: raise ValueError(f"Unknown quantization strategy {self.strategy}") @@ -66,32 +69,40 @@ def create_weights(self, layer: torch.nn.Module, layer.logical_widths = output_partition_sizes # WEIGHT - weight = torch.nn.Parameter(torch.empty(output_size_per_partition, - input_size_per_partition, - dtype=torch.float8_e4m3fn), - requires_grad=False) + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) layer.register_parameter("weight", weight) - set_weight_attrs(weight, { - "input_dim": 1, - "output_dim": 0, - "weight_loader": weight_loader, - }) # WEIGHT SCALE - layer_kwargs = {"weight_loader": weight_loader} + # TODO: update create_xxx_parameter functions to return + # the newly added parameters if self.strategy == QuantizationStrategy.CHANNEL: - weight_scale = create_per_channel_scale_param( - output_partition_sizes, **layer_kwargs) + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) else: assert self.strategy == QuantizationStrategy.TENSOR - weight_scale = create_per_tensor_scale_param( - output_partition_sizes, **layer_kwargs) + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + + # min requirement for fp8 kernels + weight_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE if self.is_static_input_scheme: - input_scale = create_per_tensor_scale_param( - output_partition_sizes, **layer_kwargs) + input_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + input_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", input_scale) def apply_weights(self, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 3a80863d3abbe..078380f159291 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -8,9 +8,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( QuantizationStrategy) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_int8_linear, convert_to_channelwise, create_per_channel_scale_param, - create_per_tensor_scale_param) -from vllm.model_executor.utils import set_weight_attrs + apply_int8_linear, convert_to_channelwise) +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) class CompressedTensorsW8A8Int8(CompressedTensorsScheme): @@ -39,7 +41,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ws_channelwise = convert_to_channelwise(layer.weight_scale, self.logical_widths) layer.weight_scale = Parameter(ws_channelwise, requires_grad=False) - + else: + layer.weight_scale = Parameter(layer.weight_scale.data, + requires_grad=False) # INPUT SCALE if self.is_static_input_scheme: layer.input_scale = Parameter(layer.input_scale.max(), @@ -55,32 +59,35 @@ def create_weights(self, layer: torch.nn.Module, self.logical_widths = output_partition_sizes # WEIGHT - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=torch.int8), - requires_grad=False) + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) - set_weight_attrs(weight, { - "input_dim": 1, - "output_dim": 0, - "weight_loader": weight_loader, - }) # WEIGHT SCALE - layer_kwargs = {"weight_loader": weight_loader} if self.strategy == QuantizationStrategy.CHANNEL: - weight_scale = create_per_channel_scale_param( - output_partition_sizes, **layer_kwargs) + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) else: assert self.strategy == QuantizationStrategy.TENSOR - weight_scale = create_per_tensor_scale_param( - output_partition_sizes, **layer_kwargs) + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE if self.is_static_input_scheme: - input_scale = create_per_tensor_scale_param( - output_partition_sizes, **layer_kwargs) + input_scale = BasevLLMParameter(data=torch.empty( + 1, dtype=torch.float32), + weight_loader=weight_loader) layer.register_parameter("input_scale", input_scale) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index b8880f7ac136f..94699c27d5cee 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -1,7 +1,6 @@ from typing import Callable, List, Optional import torch -from torch.nn import Parameter from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( @@ -10,7 +9,10 @@ apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) from vllm.scalar_type import scalar_types __all__ = ["CompressedTensorsWNA16"] @@ -30,17 +32,12 @@ def __init__(self, self.pack_factor = 32 // num_bits self.strategy = strategy + self.group_size = -1 if group_size is None else group_size - self.group_size: int - if group_size is None: - if self.strategy != "channel": - raise ValueError( - "Marlin kernels require group quantization or " - "channelwise quantization, but found no group " - "size and strategy is not channelwise.") - self.group_size = -1 - else: - self.group_size = group_size + if self.group_size == -1 and self.strategy != "channel": + raise ValueError("Marlin kernels require group quantization or " + "channelwise quantization, but found no group " + "size and strategy is not channelwise.") if num_bits not in WNA16_SUPPORTED_TYPES_MAP: raise ValueError( @@ -63,11 +60,12 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): + output_size_per_partition = sum(output_partition_sizes) # If group_size is -1, we are in channelwise case. channelwise = (self.group_size == -1) - group_size = input_size if channelwise else self.group_size + group_size = self.group_size if self.group_size != -1 else input_size row_parallel = (input_size != input_size_per_partition) # In the case of channelwise quantization, we need to replicate the # scales across all gpus. @@ -79,60 +77,51 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, input_size=input_size, group_size=group_size) - weight_scale_dim = None scales_and_zp_size = input_size // group_size if partition_scales: assert input_size_per_partition % group_size == 0 - weight_scale_dim = 1 scales_and_zp_size = input_size_per_partition // group_size - weight = Parameter( - torch.empty( - output_size_per_partition, - input_size_per_partition // self.pack_factor, - dtype=torch.int32, - ), - requires_grad=False, - ) - - set_weight_attrs( - weight, { - "input_dim": 1, - "output_dim": 0, - "packed_dim": 1, - "pack_factor": self.pack_factor, - "weight_loader": weight_loader - }) - layer.register_parameter("weight_packed", weight) - - weight_scale = Parameter( + weight = PackedvLLMParameter(input_dim=1, + output_dim=0, + weight_loader=weight_loader, + packed_factor=self.pack_factor, + packed_dim=1, + data=torch.empty( + output_size_per_partition, + input_size_per_partition // + self.pack_factor, + dtype=torch.int32, + )) + + weight_scale_args = { + "weight_loader": + weight_loader, + "data": torch.empty( output_size_per_partition, scales_and_zp_size, dtype=params_dtype, - ), - requires_grad=False, - ) - - set_weight_attrs( - weight_scale, { - "weight_loader": weight_loader, - "input_dim": weight_scale_dim, - "output_dim": 0 - }) - layer.register_parameter("weight_scale", weight_scale) + ) + } + if self.group_size == -1: + weight_scale = ChannelQuantScaleParameter(output_dim=0, + **weight_scale_args) + else: + weight_scale = GroupQuantScaleParameter(output_dim=0, + input_dim=1, + **weight_scale_args) # A 2D array defining the original shape of the weights # before packing - weight_shape = Parameter(torch.empty(2, dtype=torch.int64), - requires_grad=False) + weight_shape = BasevLLMParameter(data=torch.empty(2, + dtype=torch.int64), + weight_loader=weight_loader) + layer.register_parameter("weight_packed", weight) + layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_shape", weight_shape) - set_weight_attrs(weight_shape, { - "weight_loader": weight_loader, - "ignore_warning": True, - }) layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition @@ -154,10 +143,15 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # No zero-point layer.weight_zp = marlin_make_empty_g_idx(device) + # Update for kernel + layer.weight_packed = torch.nn.Parameter( + layer.weight_packed.t().contiguous(), requires_grad=False) + layer.weight_scale = torch.nn.Parameter( + layer.weight_scale.squeeze().t().contiguous(), requires_grad=False) # Repack weights from compressed-tensors format to marlin format. marlin_qweight = ops.gptq_marlin_repack( - layer.weight_packed.t().contiguous(), + layer.weight_packed, perm=layer.g_idx_sort_indices, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, @@ -166,7 +160,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Permute scales from compressed-tensors format to marlin format. marlin_scales = marlin_permute_scales( - layer.weight_scale.squeeze().t().contiguous(), + layer.weight_scale, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, group_size=layer.group_size) diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py new file mode 100644 index 0000000000000..10239843b3222 --- /dev/null +++ b/vllm/model_executor/parameter.py @@ -0,0 +1,277 @@ +from typing import Callable, Optional, Union + +import torch +from torch.nn import Parameter + +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.logger import init_logger + +__all__ = [ + "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", + "ModelWeightParameter", "ChannelQuantScaleParameter", + "GroupQuantScaleParameter" +] + +logger = init_logger(__name__) + + +class BasevLLMParameter(Parameter): + """ + Base parameter for vLLM linear layers. Extends the torch.nn.parameter + by taking in a linear weight loader. Will copy the loaded weight + into the parameter when the provided weight loader is called. + """ + + def __new__(cls, data: torch.Tensor, **kwargs): + + return super().__new__(cls, data=data, requires_grad=False) + + def __init__(self, data: torch.Tensor, weight_loader: Callable): + """ + Initialize the BasevLLMParameter + + :param data: torch tensor with the parameter data + :param weight_loader: weight loader callable + + :returns: a torch.nn.parameter + """ + + self._weight_loader = weight_loader + + @property + def weight_loader(self): + return self._weight_loader + + def _assert_and_load(self, loaded_weight: torch.Tensor): + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + def load_column_parallel_weight(self, loaded_weight: torch.Tensor): + self._assert_and_load(loaded_weight) + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + self._assert_and_load(loaded_weight) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + self._assert_and_load(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + self._assert_and_load(loaded_weight) + + +class _ColumnvLLMParameter(BasevLLMParameter): + """ + Private class defining weight loading functionality + (load_merged_column_weight, load_qkv_weight) + for parameters being loaded into linear layers with column + parallelism. This includes QKV and MLP layers which are + not already fused on disk. Requires an output dimension + to be defined. Called within the weight loader of + each of the column parallel linear layers. + """ + + def __init__(self, output_dim: int, **kwargs): + self._output_dim = output_dim + super().__init__(**kwargs) + + @property + def output_dim(self): + return self._output_dim + + def load_column_parallel_weight(self, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.data.shape[self.output_dim] + loaded_weight = loaded_weight.narrow(self.output_dim, + tp_rank * shard_size, shard_size) + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + + shard_offset = kwargs.get("shard_offset") + shard_size = kwargs.get("shard_size") + if isinstance( + self, + PackedvLLMParameter) and self.packed_dim == self.output_dim: + shard_size, shard_offset = self.adjust_shard_indexes_for_packing( + shard_offset=shard_offset, shard_size=shard_size) + + param_data = self.data + + tp_rank = get_tensor_model_parallel_rank() + param_data = param_data.narrow(self.output_dim, shard_offset, + shard_size) + loaded_weight = loaded_weight.narrow(self.output_dim, + tp_rank * shard_size, shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + + shard_offset = kwargs.get("shard_offset") + shard_size = kwargs.get("shard_size") + shard_id = kwargs.get("shard_id") + num_heads = kwargs.get("num_heads") + + if isinstance( + self, + PackedvLLMParameter) and self.output_dim == self.packed_dim: + shard_size, shard_offset = self.adjust_shard_indexes_for_packing( + shard_offset=shard_offset, shard_size=shard_size) + + param_data = self.data + tp_rank = get_tensor_model_parallel_rank() + shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads + param_data = param_data.narrow(self.output_dim, shard_offset, + shard_size) + loaded_weight = loaded_weight.narrow(self.output_dim, + shard_id * shard_size, shard_size) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class ModelWeightParameter(_ColumnvLLMParameter): + """ + Parameter class for linear layer weights. Extends the + _ColumnvLLMParameter by adding loading functionality + for linear layers with row parallel functionality. + Requires an input dimension to be defined. + """ + + def __init__(self, input_dim: int, **kwargs): + self._input_dim = input_dim + super().__init__(**kwargs) + + @property + def input_dim(self): + return self._input_dim + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.data.shape[self.input_dim] + loaded_weight = loaded_weight.narrow(self.input_dim, + tp_rank * shard_size, shard_size) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + +class GroupQuantScaleParameter(ModelWeightParameter): + """ + Parameter class for weight scales loaded for weights with + grouped quantization. Equivalent to ModelWeightParameter. + """ + pass + + +class ChannelQuantScaleParameter(_ColumnvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + channel-wise quantization. Equivalent to _ColumnvLLMParameter. + """ + pass + + +class PerTensorScaleParameter(BasevLLMParameter): + """ + Parameter class for scales where the number of scales is + equivalent to the number of logical matrices in fused linear + layers (e.g. for QKV, there are 3 scales loaded from disk). + This is relevant to weights with per-tensor quantization. + Adds functionality to map the scalers to a shard during + weight loading. + + Note: additional parameter manipulation may be handled + for each quantization config specifically, within + process_weights_after_loading + """ + + def __init__(self, **kwargs): + self.qkv_idxs = {"q": 0, "k": 1, "v": 2} + super().__init__(**kwargs) + + def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: + if isinstance(shard_id, int): + return shard_id + + assert isinstance(shard_id, str) + assert shard_id in self.qkv_idxs + return self.qkv_idxs[shard_id] + + def load_merged_column_weight(self, *args, **kwargs): + self._load_into_shard_id(*args, **kwargs) + + def load_qkv_weight(self, *args, **kwargs): + self._load_into_shard_id(*args, **kwargs) + + def load_column_parallel_weight(self, *args, **kwargs): + self._load_into_shard_id(*args, **kwargs) + + def _load_into_shard_id(self, loaded_weight: torch.Tensor, + shard_id: Union[str, int], **kwargs): + """ + Slice the parameter data based on the shard id for + loading. + """ + + param_data = self.data + shard_id = self._shard_id_as_int(shard_id) + + # AutoFP8 scales do not have a shape + # compressed-tensors scales do have a shape + if len(loaded_weight.shape) != 0: + assert loaded_weight.shape[0] == 1 + loaded_weight = loaded_weight[0] + + param_data = param_data[shard_id] + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class PackedvLLMParameter(ModelWeightParameter): + """ + Parameter for model weights which are packed on disk. + Example: GPTQ Marlin weights are int4 or int8, packed into int32. + Extends the ModelWeightParameter to take in the + packed factor, the packed dimension, and optionally, marlin + tile size for marlin kernels. Adjusts the shard_size and + shard_offset for fused linear layers model weight loading + by accounting for packing and optionally, marlin tile size. + """ + + def __init__(self, + packed_factor: int, + packed_dim: int, + marlin_tile_size: Optional[int] = None, + **kwargs): + self._packed_factor = packed_factor + self._packed_dim = packed_dim + self._marlin_tile = marlin_tile_size + super().__init__(**kwargs) + + @property + def packed_dim(self): + return self._packed_dim + + @property + def packed_factor(self): + return self._packed_factor + + @property + def marlin_tile(self): + return self._marlin_tile + + def _adjust_shard_indexes_for_marlin(self, shard_size, shard_offset): + return shard_size * self.marlin_tile, shard_offset * self.marlin_tile + + def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): + shard_size = shard_size // self.packed_factor + shard_offset = shard_offset // self.packed_factor + if self.marlin_tile is not None: + return self._adjust_shard_indexes_for_marlin( + shard_size, shard_offset) + return shard_size, shard_offset