From fbb425cfcf171a0d9e3cbff76a2d542895970a68 Mon Sep 17 00:00:00 2001 From: cenzhiyao <2523403608@qq.com> Date: Mon, 30 Dec 2024 17:36:51 +0800 Subject: [PATCH] bash format.sh --- vllm/model_executor/layers/linear.py | 39 ++++---- .../layers/quantization/gptq.py | 1 - .../layers/quantization/gptq_marlin_24.py | 1 - .../layers/quantization/hqq_marlin.py | 81 ++++++++++------ .../layers/quantization/kernels/exllama.py | 6 +- .../layers/quantization/kernels/machete.py | 3 +- .../layers/quantization/kernels/marlin.py | 6 +- .../layers/vocab_parallel_embedding.py | 4 +- vllm/model_executor/parameter.py | 96 +++++++++++-------- 9 files changed, 138 insertions(+), 99 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 8313c12b9b85f..ed6877e6fb968 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -15,12 +15,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) # yapf: disable -from vllm.model_executor.parameter import (BasevLLMParameter, - BlockQuantScaleParameter, - PackedColumnParameter, - PackedvLLMParameter, - PerTensorScaleParameter, - RowvLLMParameter, has_any_param_feature, Features) +from vllm.model_executor.parameter import Features, has_any_param_feature # yapf: enable from vllm.model_executor.utils import set_weight_attrs @@ -574,7 +569,7 @@ def weight_loader(self, assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, + def _load_fused_module_from_checkpoint(self, param: Parameter, loaded_weight: torch.Tensor): """ Handle special case for models where MLP layers are already @@ -596,8 +591,11 @@ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, # Special case for Quantization. # If quantized, we need to adjust the offset and size to account # for the packing. - if isinstance(param, (PackedColumnParameter, PackedvLLMParameter - )) and param.packed_dim == param.output_dim: + # if isinstance(param, (PackedColumnParameter, PackedvLLMParameter + # )) and param.packed_dim == param.output_dim: + if has_any_param_feature(param, + [Features.PackedColumn, Features.Packed]) \ + and param.packed_dim == param.output_dim: shard_size, shard_offset = \ param.adjust_shard_indexes_for_packing( shard_size=shard_size, shard_offset=shard_offset) @@ -608,15 +606,15 @@ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, self.weight_loader_v2(param, loaded_weight_shard, shard_id) def weight_loader_v2(self, - param: BasevLLMParameter, + param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[int] = None): if loaded_shard_id is None: - if isinstance(param, PerTensorScaleParameter): + if has_any_param_feature(param, Features.PerTensorScale): param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) return - elif type(param) in (RowvLLMParameter, BasevLLMParameter): + elif has_any_param_feature(param, [Features.Row, Features.Base]): param.load_merged_column_weight(loaded_weight=loaded_weight) return # TODO: @dsikka - move to parameter.py @@ -738,7 +736,7 @@ def _get_shard_size_mapping(self, loaded_shard_id: str): } return shard_size_mapping.get(loaded_shard_id) - def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, + def _load_fused_module_from_checkpoint(self, param: Parameter, loaded_weight: torch.Tensor): """ Handle special case for models where QKV layers are already @@ -763,8 +761,9 @@ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, # Special case for Quantization. # If quantized, we need to adjust the offset and size to account # for the packing. - if isinstance(param, (PackedColumnParameter, PackedvLLMParameter - )) and param.packed_dim == param.output_dim: + if has_any_param_feature(param, + [Features.PackedColumn, Features.Packed]) \ + and param.packed_dim == param.output_dim: shard_size, shard_offset = \ param.adjust_shard_indexes_for_packing( shard_size=shard_size, shard_offset=shard_offset) @@ -775,14 +774,14 @@ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, self.weight_loader_v2(param, loaded_weight_shard, shard_id) def weight_loader_v2(self, - param: BasevLLMParameter, + param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[str] = None): if loaded_shard_id is None: # special case for certain models - if isinstance(param, PerTensorScaleParameter): + if has_any_param_feature(param, Features.PerTensorScale): param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0) return - elif type(param) in (RowvLLMParameter, BasevLLMParameter): + elif has_any_param_feature(param, [Features.Row, Features.Base]): param.load_qkv_weight(loaded_weight=loaded_weight) return # TODO: @dsikka - move to parameter.py @@ -1087,9 +1086,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def weight_loader_v2(self, param: BasevLLMParameter, - loaded_weight: torch.Tensor): - + def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). if len(loaded_weight.shape) == 0: diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 05e9edd59246c..24b61b79b0bbd 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -4,7 +4,6 @@ from typing import Any, Dict, List, Optional import torch -from torch.nn.parameter import Parameter from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index 8db7e175650a9..73c862d076015 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -1,7 +1,6 @@ from typing import Any, Dict, List, Optional import torch -from torch.nn.parameter import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index ca187296b7084..fcc5f0ee3c79b 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -15,13 +15,12 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( MarlinWorkspace) from vllm.model_executor.layers.quantization.utils.quant_utils import gptq_pack -from vllm.model_executor.parameter import (wrap_base_vllm_parameter, +from vllm.model_executor.parameter import (Features, add_param_feature, + wrap_base_vllm_parameter, wrap_column_vllm_parameter, - wrap_row_vllm_parameter, - add_param_feature, - Features, wrap_packed_vllm_parameter, has_any_param_feature, - wrap_group_quant_scale_parameter) - + wrap_group_quant_scale_parameter, + wrap_packed_vllm_parameter, + wrap_row_vllm_parameter) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -119,8 +118,10 @@ def HQQweightParameter(data: torch.Tensor, **kwargs) -> Parameter: return param -def wrap_hqq_weight_parameter(param: Parameter, packed_factor: int, packed_dim: int, weight_bits: int, +def wrap_hqq_weight_parameter(param: Parameter, packed_factor: int, + packed_dim: int, weight_bits: int, **kwargs) -> None: + def unpack_4bit_u8(param: Parameter, W_q: torch.Tensor) -> torch.Tensor: assert param.weight_bits == 4, "Unsupported quant bitsize (must be 4)" @@ -133,7 +134,9 @@ def unpack_4bit_u8(param: Parameter, W_q: torch.Tensor) -> torch.Tensor: tmp[step:] = W_q & 0b00001111 return tmp - def load_merged_column_weight(param: Parameter, loaded_weight: torch.Tensor, **kwargs) -> None: + def load_merged_column_weight(param: Parameter, + loaded_weight: torch.Tensor, + **kwargs) -> None: loaded_weight = param.unpack_4bit_u8(loaded_weight) loaded_weight = loaded_weight.reshape(-1, param.input_shape).transpose( 1, 0) @@ -143,7 +146,8 @@ def load_merged_column_weight(param: Parameter, loaded_weight: torch.Tensor, **k # load_merged_column_weight from wrap_column_vllm_parameter param.load_merged_column_weight_(loaded_weight, **kwargs) - def load_row_parallel_weight(param: Parameter, loaded_weight: torch.Tensor) -> None: + def load_row_parallel_weight(param: Parameter, + loaded_weight: torch.Tensor) -> None: loaded_weight = param.unpack_4bit_u8(loaded_weight) loaded_weight = loaded_weight.reshape(param.output_shape, -1).transpose(1, 0) @@ -153,10 +157,12 @@ def load_row_parallel_weight(param: Parameter, loaded_weight: torch.Tensor) -> N # load_row_parallel_weight from wrap_row_vllm_parameter param.load_row_parallel_weight_(loaded_weight) - def load_qkv_weight(param: Parameter, loaded_weight: torch.Tensor, **kwargs) -> None: + def load_qkv_weight(param: Parameter, loaded_weight: torch.Tensor, + **kwargs) -> None: loaded_weight = param.unpack_4bit_u8(loaded_weight) - loaded_weight = (loaded_weight.reshape(-1, param.input_shape) - .transpose(1, 0)) + loaded_weight = (loaded_weight.reshape(-1, + param.input_shape).transpose( + 1, 0)) loaded_weight = gptq_pack(loaded_weight, param.weight_bits, loaded_weight.shape[0], loaded_weight.shape[1]) @@ -167,13 +173,20 @@ def load_qkv_weight(param: Parameter, loaded_weight: torch.Tensor, **kwargs) -> param.input_shape = param.shape[param.input_dim] * packed_factor param.output_shape = param.shape[param.output_dim] param.unpack_4bit_u8 = lambda W_q: unpack_4bit_u8(param, W_q) - param.load_merged_column_weight_ = param.load_merged_column_weight # save the original method - param.load_merged_column_weight = lambda loaded_weight, **kwargs: load_merged_column_weight(param, loaded_weight, - **kwargs) - param.load_row_parallel_weight_ = param.load_row_parallel_weight # save the original method - param.load_row_parallel_weight = lambda loaded_weight: load_row_parallel_weight(param, loaded_weight) + # save the original method + param.load_merged_column_weight_ = param.load_merged_column_weight + param.load_merged_column_weight = \ + lambda loaded_weight, **kwargs: \ + load_merged_column_weight(param, loaded_weight, **kwargs) + # save the original method + param.load_row_parallel_weight_ = param.load_row_parallel_weight + param.load_row_parallel_weight = \ + lambda loaded_weight: \ + load_row_parallel_weight(param, loaded_weight) param.load_qkv_weight_ = param.load_qkv_weight # save the original method - param.load_qkv_weight = lambda loaded_weight, **kwargs: load_qkv_weight(param, loaded_weight, **kwargs) + param.load_qkv_weight = \ + lambda loaded_weight, **kwargs: \ + load_qkv_weight(param, loaded_weight, **kwargs) # Zero points and scales in HQQ must also be reshaped to correspond to W_q's @@ -189,25 +202,37 @@ def HQQZeroScaleParameter(data: torch.Tensor, **kwargs) -> Parameter: def wrap_hqq_zero_scale_parameter(param: Parameter, **kwargs) -> None: - def load_merged_column_weight(param: Parameter, loaded_weight: torch.Tensor, **kwargs) -> None: + + def load_merged_column_weight(param: Parameter, + loaded_weight: torch.Tensor, + **kwargs) -> None: loaded_weight = loaded_weight.reshape(-1, param.shape[1]) param.load_merged_column_weight_(loaded_weight, **kwargs) - def load_row_parallel_weight(param: Parameter, loaded_weight: torch.Tensor) -> None: + def load_row_parallel_weight(param: Parameter, + loaded_weight: torch.Tensor) -> None: loaded_weight = loaded_weight.reshape(param.shape[0], -1) param.load_row_parallel_weight_(loaded_weight) - def load_qkv_weight(param: Parameter, loaded_weight: torch.Tensor, **kwargs) -> None: + def load_qkv_weight(param: Parameter, loaded_weight: torch.Tensor, + **kwargs) -> None: loaded_weight = loaded_weight.reshape(-1, param.shape[1]) param.load_qkv_weight_(loaded_weight, **kwargs) - param.load_merged_column_weight_ = param.load_merged_column_weight # save the original method - param.load_merged_column_weight = lambda loaded_weight, **kwargs: load_merged_column_weight(param, loaded_weight, - **kwargs) - param.load_row_parallel_weight_ = param.load_row_parallel_weight # save the original method - param.load_row_parallel_weight = lambda loaded_weight: load_row_parallel_weight(param, loaded_weight) - param.load_qkv_weight_ = param.load_qkv_weight # save the original method - param.load_qkv_weight = lambda loaded_weight, **kwargs: load_qkv_weight(param, loaded_weight, **kwargs) + # save the original method + param.load_merged_column_weight_ = param.load_merged_column_weight + param.load_merged_column_weight = \ + lambda loaded_weight, **kwargs: \ + (load_merged_column_weight(param, loaded_weight, **kwargs)) + # save the original method + param.load_row_parallel_weight_ = param.load_row_parallel_weight + param.load_row_parallel_weight = \ + lambda loaded_weight: load_row_parallel_weight(param, loaded_weight) + # save the original method + param.load_qkv_weight_ = param.load_qkv_weight + param.load_qkv_weight = \ + lambda loaded_weight, **kwargs: \ + (load_qkv_weight(param, loaded_weight, **kwargs)) class HQQMarlinMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/kernels/exllama.py b/vllm/model_executor/layers/quantization/kernels/exllama.py index 1d85d62ec83ee..7a35915d24eb4 100644 --- a/vllm/model_executor/layers/quantization/kernels/exllama.py +++ b/vllm/model_executor/layers/quantization/kernels/exllama.py @@ -5,7 +5,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.quant_utils import ( pack_quantized_values_into_int32) -from vllm.model_executor.parameter import (BasevLLMParameter, +from vllm.model_executor.parameter import (Features, has_any_param_feature, permute_param_layout_) from vllm.scalar_type import scalar_types @@ -100,7 +100,7 @@ def transform_w_g_idx(x): setattr(layer, self.w_gidx_name, empty_g_idx) def transform_w_q(x): - assert isinstance(x, BasevLLMParameter) + assert has_any_param_feature(x, Features.Base) assert self.w_gidx_name is not None g_idx = getattr(layer, self.w_gidx_name) @@ -110,7 +110,7 @@ def transform_w_q(x): return x_cont def transform_w_s(x): - assert isinstance(x, BasevLLMParameter) + assert has_any_param_feature(x, Features.Base) permute_param_layout_(x, input_dim=0, output_dim=1) x.data = x.data.contiguous() return x.to(dtype=c.act_type) diff --git a/vllm/model_executor/layers/quantization/kernels/machete.py b/vllm/model_executor/layers/quantization/kernels/machete.py index 873e0985c57f9..ae97cb0b4d2fb 100644 --- a/vllm/model_executor/layers/quantization/kernels/machete.py +++ b/vllm/model_executor/layers/quantization/kernels/machete.py @@ -9,7 +9,8 @@ query_machete_supported_quant_types) from vllm.model_executor.layers.quantization.utils.quant_utils import ( pack_quantized_values_into_int32, unpack_quantized_values_into_int32) -from vllm.model_executor.parameter import (permute_param_layout_, has_any_param_feature, Features) +from vllm.model_executor.parameter import (Features, has_any_param_feature, + permute_param_layout_) from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig diff --git a/vllm/model_executor/layers/quantization/kernels/marlin.py b/vllm/model_executor/layers/quantization/kernels/marlin.py index 6969583d6d473..9d55aab1da94b 100644 --- a/vllm/model_executor/layers/quantization/kernels/marlin.py +++ b/vllm/model_executor/layers/quantization/kernels/marlin.py @@ -8,7 +8,7 @@ check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx, query_marlin_supported_quant_types) -from vllm.model_executor.parameter import (BasevLLMParameter, +from vllm.model_executor.parameter import (Features, has_any_param_feature, permute_param_layout_) from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig @@ -88,7 +88,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) def transform_w_q(x): - assert isinstance(x, BasevLLMParameter) + assert has_any_param_feature(x, Features.Base) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) x.data = ops.gptq_marlin_repack(x.data.contiguous(), perm=layer.g_idx_sort_indices, @@ -98,7 +98,7 @@ def transform_w_q(x): return x def transform_w_s(x): - assert isinstance(x, BasevLLMParameter) + assert has_any_param_feature(x, Features.Base) permute_param_layout_(x, input_dim=0, output_dim=1) x.data = marlin_permute_scales(x.data.contiguous(), size_k=c.partition_weight_shape[0], diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 30548e656c557..b733b07c2ea98 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -10,7 +10,6 @@ tensor_model_parallel_all_reduce) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) -from vllm.model_executor.parameter import BasevLLMParameter from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -372,8 +371,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): # If param packed on the same dim we are sharding on, then # need to adjust offsets of loaded weight by pack_factor. if packed_dim is not None and packed_dim == output_dim: - packed_factor = param.packed_factor if isinstance( - param, BasevLLMParameter) else param.pack_factor + packed_factor = param.packed_factor assert loaded_weight.shape[output_dim] == (self.org_vocab_size // param.packed_factor) start_idx = start_idx // packed_factor diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 455e74342ee8b..780d696332477 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -1,3 +1,4 @@ +from enum import Enum from fractions import Fraction from typing import Callable, Optional, Union @@ -7,8 +8,6 @@ from vllm.distributed import get_tensor_model_parallel_rank from vllm.logger import init_logger -from enum import Enum - __all__ = [ "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", "ModelWeightParameter", "ChannelQuantScaleParameter", @@ -69,7 +68,8 @@ def BasevLLMParameter(data: torch.Tensor, **kwargs) -> Parameter: return param -def wrap_base_vllm_parameter(param: Parameter, weight_loader: Callable, **kwargs): +def wrap_base_vllm_parameter(param: Parameter, weight_loader: Callable, + **kwargs): """ Add basic functionality for vLLM linear layer parameters. """ @@ -79,17 +79,20 @@ def _assert_and_load(param: Parameter, loaded_weight: torch.Tensor): param.data.copy_(loaded_weight) param.weight_loader = weight_loader - param.load_column_parallel_weight = lambda loaded_weight: _assert_and_load(param, loaded_weight) - param.load_row_parallel_weight = lambda loaded_weight: _assert_and_load(param, loaded_weight) - param.load_merged_column_weight = lambda loaded_weight, **kwargs: _assert_and_load(param, loaded_weight) - param.load_qkv_weight = lambda loaded_weight, **kwargs: _assert_and_load(param, loaded_weight) + param.load_column_parallel_weight = lambda loaded_weight: _assert_and_load( + param, loaded_weight) + param.load_row_parallel_weight = lambda loaded_weight: _assert_and_load( + param, loaded_weight) + param.load_merged_column_weight = \ + lambda loaded_weight, **kwargs: _assert_and_load( + param, loaded_weight) + param.load_qkv_weight = lambda loaded_weight, **kwargs: _assert_and_load( + param, loaded_weight) add_param_feature(param, Features.Base) -def wrap_column_vllm_parameter(param: Parameter, - output_dim: int, - **kwargs - ) -> None: +def wrap_column_vllm_parameter(param: Parameter, output_dim: int, + **kwargs) -> None: """ Add functionality to the parameter for loading weights into linear layers with column parallelism. This includes QKV and MLP @@ -98,7 +101,8 @@ def wrap_column_vllm_parameter(param: Parameter, of the column parallel linear layers. """ - def load_column_parallel_weight(param: Parameter, loaded_weight: torch.Tensor): + def load_column_parallel_weight(param: Parameter, + loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() shard_size = param.data.shape[param.output_dim] loaded_weight = loaded_weight.narrow(param.output_dim, @@ -106,10 +110,12 @@ def load_column_parallel_weight(param: Parameter, loaded_weight: torch.Tensor): assert param.data.shape == loaded_weight.shape param.data.copy_(loaded_weight) - def load_merged_column_weight(param: Parameter, loaded_weight: torch.Tensor, **kwargs): + def load_merged_column_weight(param: Parameter, + loaded_weight: torch.Tensor, **kwargs): shard_offset = kwargs.get("shard_offset") shard_size = kwargs.get("shard_size") - if (has_any_param_feature(param, [Features.PackedColumn, Features.Packed]) + if (has_any_param_feature(param, + [Features.PackedColumn, Features.Packed]) and param.output_dim == param.packed_dim): shard_size, shard_offset = param.adjust_shard_indexes_for_packing( shard_offset=shard_offset, shard_size=shard_size) @@ -124,13 +130,15 @@ def load_merged_column_weight(param: Parameter, loaded_weight: torch.Tensor, **k assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def load_qkv_weight(param: Parameter, loaded_weight: torch.Tensor, **kwargs): + def load_qkv_weight(param: Parameter, loaded_weight: torch.Tensor, + **kwargs): shard_offset = kwargs.get("shard_offset") shard_size = kwargs.get("shard_size") shard_id = kwargs.get("shard_id") num_heads = kwargs.get("num_heads") - if (has_any_param_feature(param, [Features.PackedColumn, Features.Packed]) + if (has_any_param_feature(param, + [Features.PackedColumn, Features.Packed]) and output_dim == param.packed_dim): shard_size, shard_offset = param.adjust_shard_indexes_for_packing( shard_offset=shard_offset, shard_size=shard_size) @@ -147,10 +155,12 @@ def load_qkv_weight(param: Parameter, loaded_weight: torch.Tensor, **kwargs): param_data.copy_(loaded_weight) param.output_dim = output_dim - param.load_column_parallel_weight = lambda loaded_weight: load_column_parallel_weight(param, loaded_weight) - param.load_merged_column_weight = lambda loaded_weight, **kwargs: load_merged_column_weight(param, loaded_weight, - **kwargs) - param.load_qkv_weight = lambda loaded_weight, **kwargs: load_qkv_weight(param, loaded_weight, **kwargs) + param.load_column_parallel_weight = lambda loaded_weight: ( + load_column_parallel_weight(param, loaded_weight)) + param.load_merged_column_weight = lambda loaded_weight, **kwargs: ( + load_merged_column_weight(param, loaded_weight, **kwargs)) + param.load_qkv_weight = lambda loaded_weight, **kwargs: (load_qkv_weight( + param, loaded_weight, **kwargs)) add_param_feature(param, Features.Column) @@ -161,10 +171,8 @@ def RowvLLMParameter(data: torch.Tensor, **kwargs) -> Parameter: return param -def wrap_row_vllm_parameter(param: Parameter, - input_dim: int, - **kwargs - ) -> None: +def wrap_row_vllm_parameter(param: Parameter, input_dim: int, + **kwargs) -> None: """ Add functionality to the parameter for loading weights into linear layers with row parallelism. This includes layers @@ -173,15 +181,18 @@ def wrap_row_vllm_parameter(param: Parameter, row parallel linear layers. """ - def load_row_parallel_weight(param: Parameter, loaded_weight: torch.Tensor): + def load_row_parallel_weight(param: Parameter, + loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() shard_size = param.data.shape[input_dim] - loaded_weight = loaded_weight.narrow(input_dim, tp_rank * shard_size, shard_size) + loaded_weight = loaded_weight.narrow(input_dim, tp_rank * shard_size, + shard_size) assert param.data.shape == loaded_weight.shape param.data.copy_(loaded_weight) param.input_dim = input_dim - param.load_row_parallel_weight = lambda loaded_weight: load_row_parallel_weight(param, loaded_weight) + param.load_row_parallel_weight = lambda loaded_weight: ( + load_row_parallel_weight(param, loaded_weight)) add_param_feature(param, Features.Row) @@ -254,7 +265,8 @@ def shard_id_as_int(shard_id: Union[str, int]) -> int: assert shard_id in param.qkv_idxs return param.qkv_idxs[shard_id] - def load_into_shard_id(param: Parameter, loaded_weight: torch.Tensor, shard_id: int, **kwargs): + def load_into_shard_id(param: Parameter, loaded_weight: torch.Tensor, + shard_id: int, **kwargs): param_data = param.data shard_id = param.shard_id_as_int(shard_id) # AutoFP8 scales do not have a shape @@ -266,7 +278,8 @@ def load_into_shard_id(param: Parameter, loaded_weight: torch.Tensor, shard_id: assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def load_row_parallel_weight(param: Parameter, loaded_weight: torch.Tensor): + def load_row_parallel_weight(param: Parameter, + loaded_weight: torch.Tensor): assert param.data.shape == loaded_weight.shape param.data.copy_(loaded_weight) @@ -276,17 +289,21 @@ def load_merged_column_weight(param: Parameter, **kwargs): def load_qkv_weight(param: Parameter, **kwargs): param.load_into_shard_id(param, **kwargs) - def load_column_parallel_weight(param: Parameter, loaded_weight: torch.Tensor): + def load_column_parallel_weight(param: Parameter, + loaded_weight: torch.Tensor): assert param.data.shape == loaded_weight.shape param.data.copy_(loaded_weight) param.qkv_idxs = {"q": 0, "k": 1, "v": 2} param.shard_id_as_int = shard_id_as_int param.load_into_shard_id = load_into_shard_id - param.load_row_parallel_weight = lambda loaded_weight: load_row_parallel_weight(param, loaded_weight) - param.load_merged_column_weight = lambda **kwargs: load_merged_column_weight(param, **kwargs) - param.load_qkv_weight = lambda **kwargs: load_qkv_weight(param, **kwargs) - param.load_column_parallel_weight = lambda loaded_weight: load_column_parallel_weight(param, loaded_weight) + param.load_row_parallel_weight = lambda loaded_weight: ( + load_row_parallel_weight(param, loaded_weight)) + param.load_merged_column_weight = lambda **kwargs: ( + load_merged_column_weight(param, **kwargs)) + param.load_qkv_weight = lambda **kwargs: (load_qkv_weight(param, **kwargs)) + param.load_column_parallel_weight = lambda loaded_weight: ( + load_column_parallel_weight(param, loaded_weight)) add_param_feature(param, Features.PerTensorScale) @@ -298,8 +315,11 @@ def PackedColumnParameter(data: torch.Tensor, **kwargs) -> Parameter: return param -def wrap_packed_column_parameter(param: Parameter, packed_factor: Union[int, Fraction], - packed_dim: int, marlin_tile_size: Optional[int] = None, **kwargs) -> None: +def wrap_packed_column_parameter(param: Parameter, + packed_factor: Union[int, Fraction], + packed_dim: int, + marlin_tile_size: Optional[int] = None, + **kwargs) -> None: """ Add properties and methods for parameters which are packed on disk and support column parallelism only. See PackedvLLMParameter @@ -367,8 +387,8 @@ def BlockQuantScaleParameter(data: torch.Tensor, **kwargs) -> Parameter: return param -def permute_param_layout_(param: BasevLLMParameter, input_dim: int, - output_dim: int, **kwargs) -> BasevLLMParameter: +def permute_param_layout_(param: Parameter, input_dim: int, output_dim: int, + **kwargs) -> Parameter: """ Permute a parameter's layout to the specified input and output dimensions, useful for forcing the parameter into a known layout, for example, if I need