Skip to content

Commit

Permalink
bash format.sh
Browse files Browse the repository at this point in the history
  • Loading branch information
cennn committed Dec 30, 2024
1 parent eb6b394 commit fbb425c
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 99 deletions.
39 changes: 18 additions & 21 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
# yapf: disable
from vllm.model_executor.parameter import (BasevLLMParameter,
BlockQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
PerTensorScaleParameter,
RowvLLMParameter, has_any_param_feature, Features)
from vllm.model_executor.parameter import Features, has_any_param_feature
# yapf: enable
from vllm.model_executor.utils import set_weight_attrs

Expand Down Expand Up @@ -574,7 +569,7 @@ def weight_loader(self,
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
def _load_fused_module_from_checkpoint(self, param: Parameter,
loaded_weight: torch.Tensor):
"""
Handle special case for models where MLP layers are already
Expand All @@ -596,8 +591,11 @@ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
)) and param.packed_dim == param.output_dim:
# if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
# )) and param.packed_dim == param.output_dim:
if has_any_param_feature(param,
[Features.PackedColumn, Features.Packed]) \
and param.packed_dim == param.output_dim:
shard_size, shard_offset = \
param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset)
Expand All @@ -608,15 +606,15 @@ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
self.weight_loader_v2(param, loaded_weight_shard, shard_id)

def weight_loader_v2(self,
param: BasevLLMParameter,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
if loaded_shard_id is None:
if isinstance(param, PerTensorScaleParameter):
if has_any_param_feature(param, Features.PerTensorScale):
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=0)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
elif has_any_param_feature(param, [Features.Row, Features.Base]):
param.load_merged_column_weight(loaded_weight=loaded_weight)
return
# TODO: @dsikka - move to parameter.py
Expand Down Expand Up @@ -738,7 +736,7 @@ def _get_shard_size_mapping(self, loaded_shard_id: str):
}
return shard_size_mapping.get(loaded_shard_id)

def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
def _load_fused_module_from_checkpoint(self, param: Parameter,
loaded_weight: torch.Tensor):
"""
Handle special case for models where QKV layers are already
Expand All @@ -763,8 +761,9 @@ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
)) and param.packed_dim == param.output_dim:
if has_any_param_feature(param,
[Features.PackedColumn, Features.Packed]) \
and param.packed_dim == param.output_dim:
shard_size, shard_offset = \
param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset)
Expand All @@ -775,14 +774,14 @@ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
self.weight_loader_v2(param, loaded_weight_shard, shard_id)

def weight_loader_v2(self,
param: BasevLLMParameter,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter):
if has_any_param_feature(param, Features.PerTensorScale):
param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
elif has_any_param_feature(param, [Features.Row, Features.Base]):
param.load_qkv_weight(loaded_weight=loaded_weight)
return
# TODO: @dsikka - move to parameter.py
Expand Down Expand Up @@ -1087,9 +1086,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

def weight_loader_v2(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):

def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/layers/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Dict, List, Optional

import torch
from torch.nn.parameter import Parameter

from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/layers/quantization/gptq_marlin_24.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Dict, List, Optional

import torch
from torch.nn.parameter import Parameter

from vllm import _custom_ops as ops
from vllm.logger import init_logger
Expand Down
81 changes: 53 additions & 28 deletions vllm/model_executor/layers/quantization/hqq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace)
from vllm.model_executor.layers.quantization.utils.quant_utils import gptq_pack
from vllm.model_executor.parameter import (wrap_base_vllm_parameter,
from vllm.model_executor.parameter import (Features, add_param_feature,
wrap_base_vllm_parameter,
wrap_column_vllm_parameter,
wrap_row_vllm_parameter,
add_param_feature,
Features, wrap_packed_vllm_parameter, has_any_param_feature,
wrap_group_quant_scale_parameter)

wrap_group_quant_scale_parameter,
wrap_packed_vllm_parameter,
wrap_row_vllm_parameter)
from vllm.scalar_type import scalar_types

logger = init_logger(__name__)
Expand Down Expand Up @@ -119,8 +118,10 @@ def HQQweightParameter(data: torch.Tensor, **kwargs) -> Parameter:
return param


def wrap_hqq_weight_parameter(param: Parameter, packed_factor: int, packed_dim: int, weight_bits: int,
def wrap_hqq_weight_parameter(param: Parameter, packed_factor: int,
packed_dim: int, weight_bits: int,
**kwargs) -> None:

def unpack_4bit_u8(param: Parameter, W_q: torch.Tensor) -> torch.Tensor:
assert param.weight_bits == 4, "Unsupported quant bitsize (must be 4)"

Expand All @@ -133,7 +134,9 @@ def unpack_4bit_u8(param: Parameter, W_q: torch.Tensor) -> torch.Tensor:
tmp[step:] = W_q & 0b00001111
return tmp

def load_merged_column_weight(param: Parameter, loaded_weight: torch.Tensor, **kwargs) -> None:
def load_merged_column_weight(param: Parameter,
loaded_weight: torch.Tensor,
**kwargs) -> None:
loaded_weight = param.unpack_4bit_u8(loaded_weight)
loaded_weight = loaded_weight.reshape(-1, param.input_shape).transpose(
1, 0)
Expand All @@ -143,7 +146,8 @@ def load_merged_column_weight(param: Parameter, loaded_weight: torch.Tensor, **k
# load_merged_column_weight from wrap_column_vllm_parameter
param.load_merged_column_weight_(loaded_weight, **kwargs)

def load_row_parallel_weight(param: Parameter, loaded_weight: torch.Tensor) -> None:
def load_row_parallel_weight(param: Parameter,
loaded_weight: torch.Tensor) -> None:
loaded_weight = param.unpack_4bit_u8(loaded_weight)
loaded_weight = loaded_weight.reshape(param.output_shape,
-1).transpose(1, 0)
Expand All @@ -153,10 +157,12 @@ def load_row_parallel_weight(param: Parameter, loaded_weight: torch.Tensor) -> N
# load_row_parallel_weight from wrap_row_vllm_parameter
param.load_row_parallel_weight_(loaded_weight)

def load_qkv_weight(param: Parameter, loaded_weight: torch.Tensor, **kwargs) -> None:
def load_qkv_weight(param: Parameter, loaded_weight: torch.Tensor,
**kwargs) -> None:
loaded_weight = param.unpack_4bit_u8(loaded_weight)
loaded_weight = (loaded_weight.reshape(-1, param.input_shape)
.transpose(1, 0))
loaded_weight = (loaded_weight.reshape(-1,
param.input_shape).transpose(
1, 0))
loaded_weight = gptq_pack(loaded_weight, param.weight_bits,
loaded_weight.shape[0],
loaded_weight.shape[1])
Expand All @@ -167,13 +173,20 @@ def load_qkv_weight(param: Parameter, loaded_weight: torch.Tensor, **kwargs) ->
param.input_shape = param.shape[param.input_dim] * packed_factor
param.output_shape = param.shape[param.output_dim]
param.unpack_4bit_u8 = lambda W_q: unpack_4bit_u8(param, W_q)
param.load_merged_column_weight_ = param.load_merged_column_weight # save the original method
param.load_merged_column_weight = lambda loaded_weight, **kwargs: load_merged_column_weight(param, loaded_weight,
**kwargs)
param.load_row_parallel_weight_ = param.load_row_parallel_weight # save the original method
param.load_row_parallel_weight = lambda loaded_weight: load_row_parallel_weight(param, loaded_weight)
# save the original method
param.load_merged_column_weight_ = param.load_merged_column_weight
param.load_merged_column_weight = \
lambda loaded_weight, **kwargs: \
load_merged_column_weight(param, loaded_weight, **kwargs)
# save the original method
param.load_row_parallel_weight_ = param.load_row_parallel_weight
param.load_row_parallel_weight = \
lambda loaded_weight: \
load_row_parallel_weight(param, loaded_weight)
param.load_qkv_weight_ = param.load_qkv_weight # save the original method
param.load_qkv_weight = lambda loaded_weight, **kwargs: load_qkv_weight(param, loaded_weight, **kwargs)
param.load_qkv_weight = \
lambda loaded_weight, **kwargs: \
load_qkv_weight(param, loaded_weight, **kwargs)


# Zero points and scales in HQQ must also be reshaped to correspond to W_q's
Expand All @@ -189,25 +202,37 @@ def HQQZeroScaleParameter(data: torch.Tensor, **kwargs) -> Parameter:


def wrap_hqq_zero_scale_parameter(param: Parameter, **kwargs) -> None:
def load_merged_column_weight(param: Parameter, loaded_weight: torch.Tensor, **kwargs) -> None:

def load_merged_column_weight(param: Parameter,
loaded_weight: torch.Tensor,
**kwargs) -> None:
loaded_weight = loaded_weight.reshape(-1, param.shape[1])
param.load_merged_column_weight_(loaded_weight, **kwargs)

def load_row_parallel_weight(param: Parameter, loaded_weight: torch.Tensor) -> None:
def load_row_parallel_weight(param: Parameter,
loaded_weight: torch.Tensor) -> None:
loaded_weight = loaded_weight.reshape(param.shape[0], -1)
param.load_row_parallel_weight_(loaded_weight)

def load_qkv_weight(param: Parameter, loaded_weight: torch.Tensor, **kwargs) -> None:
def load_qkv_weight(param: Parameter, loaded_weight: torch.Tensor,
**kwargs) -> None:
loaded_weight = loaded_weight.reshape(-1, param.shape[1])
param.load_qkv_weight_(loaded_weight, **kwargs)

param.load_merged_column_weight_ = param.load_merged_column_weight # save the original method
param.load_merged_column_weight = lambda loaded_weight, **kwargs: load_merged_column_weight(param, loaded_weight,
**kwargs)
param.load_row_parallel_weight_ = param.load_row_parallel_weight # save the original method
param.load_row_parallel_weight = lambda loaded_weight: load_row_parallel_weight(param, loaded_weight)
param.load_qkv_weight_ = param.load_qkv_weight # save the original method
param.load_qkv_weight = lambda loaded_weight, **kwargs: load_qkv_weight(param, loaded_weight, **kwargs)
# save the original method
param.load_merged_column_weight_ = param.load_merged_column_weight
param.load_merged_column_weight = \
lambda loaded_weight, **kwargs: \
(load_merged_column_weight(param, loaded_weight, **kwargs))
# save the original method
param.load_row_parallel_weight_ = param.load_row_parallel_weight
param.load_row_parallel_weight = \
lambda loaded_weight: load_row_parallel_weight(param, loaded_weight)
# save the original method
param.load_qkv_weight_ = param.load_qkv_weight
param.load_qkv_weight = \
lambda loaded_weight, **kwargs: \
(load_qkv_weight(param, loaded_weight, **kwargs))


class HQQMarlinMethod(LinearMethodBase):
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/quantization/kernels/exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_quantized_values_into_int32)
from vllm.model_executor.parameter import (BasevLLMParameter,
from vllm.model_executor.parameter import (Features, has_any_param_feature,
permute_param_layout_)
from vllm.scalar_type import scalar_types

Expand Down Expand Up @@ -100,7 +100,7 @@ def transform_w_g_idx(x):
setattr(layer, self.w_gidx_name, empty_g_idx)

def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
assert has_any_param_feature(x, Features.Base)
assert self.w_gidx_name is not None
g_idx = getattr(layer, self.w_gidx_name)

Expand All @@ -110,7 +110,7 @@ def transform_w_q(x):
return x_cont

def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
assert has_any_param_feature(x, Features.Base)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous()
return x.to(dtype=c.act_type)
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/kernels/machete.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
query_machete_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
from vllm.model_executor.parameter import (permute_param_layout_, has_any_param_feature, Features)
from vllm.model_executor.parameter import (Features, has_any_param_feature,
permute_param_layout_)

from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig

Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/quantization/kernels/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
query_marlin_supported_quant_types)
from vllm.model_executor.parameter import (BasevLLMParameter,
from vllm.model_executor.parameter import (Features, has_any_param_feature,
permute_param_layout_)

from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
Expand Down Expand Up @@ -88,7 +88,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))

def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
assert has_any_param_feature(x, Features.Base)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.gptq_marlin_repack(x.data.contiguous(),
perm=layer.g_idx_sort_indices,
Expand All @@ -98,7 +98,7 @@ def transform_w_q(x):
return x

def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
assert has_any_param_feature(x, Features.Base)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = marlin_permute_scales(x.data.contiguous(),
size_k=c.partition_weight_shape[0],
Expand Down
4 changes: 1 addition & 3 deletions vllm/model_executor/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.parameter import BasevLLMParameter
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform

Expand Down Expand Up @@ -372,8 +371,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# If param packed on the same dim we are sharding on, then
# need to adjust offsets of loaded weight by pack_factor.
if packed_dim is not None and packed_dim == output_dim:
packed_factor = param.packed_factor if isinstance(
param, BasevLLMParameter) else param.pack_factor
packed_factor = param.packed_factor
assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
param.packed_factor)
start_idx = start_idx // packed_factor
Expand Down
Loading

0 comments on commit fbb425c

Please sign in to comment.