Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] improve cpu offloading implementation #10609

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
18 changes: 16 additions & 2 deletions tests/basic_correctness/test_cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,19 @@


def test_cpu_offload():
compare_two_settings("meta-llama/Llama-3.2-1B", [],
["--cpu-offload-gb", "1"])
compare_two_settings("meta-llama/Llama-3.1-8B", [],
["--cpu-offload-gb", "2"])


#
#
# def test_cpu_offload_gptq():
# # Test GPTQ Marlin
# compare_two_settings("Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4", [],
# ["--cpu-offload-gb", "1"],
# max_wait_seconds=480)
# # Test GPTQ
# compare_two_settings("Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4",
# ["--quantization", "gptq"],
# ["--quantization", "gptq", "--cpu-offload-gb", "1"],
# max_wait_seconds=480)
40 changes: 22 additions & 18 deletions vllm/model_executor/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
# from vllm.model_executor.parameter import (GroupQuantScaleParameter,
# PackedvLLMParameter)
from vllm.model_executor.parameter import (construct_group_quant_scale_parameter,
construct_packed_vllm_parameter)


class AWQConfig(QuantizationConfig):
Expand Down Expand Up @@ -111,7 +113,7 @@ def create_weights(self, layer: torch.nn.Module,
"tensor parallel size.")

weight_loader = extra_weight_attrs.get("weight_loader")
qweight = PackedvLLMParameter(
qweight = construct_packed_vllm_parameter(
data=torch.empty(
input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor,
Expand All @@ -123,7 +125,7 @@ def create_weights(self, layer: torch.nn.Module,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)

qzeros = PackedvLLMParameter(
qzeros = construct_packed_vllm_parameter(
data=torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition // self.quant_config.pack_factor,
Expand All @@ -135,26 +137,28 @@ def create_weights(self, layer: torch.nn.Module,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)

scales = GroupQuantScaleParameter(data=torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
dtype=params_dtype,
),
input_dim=0,
output_dim=1,
weight_loader=weight_loader)
scales = construct_group_quant_scale_parameter(
data=torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
dtype=params_dtype,
),
input_dim=0,
output_dim=1,
weight_loader=weight_loader)

layer.register_parameter("qweight", qweight)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.qweight = torch.nn.Parameter(layer.qweight.data,
requires_grad=False)
layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
requires_grad=False)
layer.scales = torch.nn.Parameter(layer.scales.data,
requires_grad=False)
# layer.qweight = torch.nn.Parameter(layer.qweight.data,
# requires_grad=False)
# layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
# requires_grad=False)
# layer.scales = torch.nn.Parameter(layer.scales.data,
# requires_grad=False)
pass

def apply(self,
layer: torch.nn.Module,
Expand Down
39 changes: 21 additions & 18 deletions vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
marlin_permute_scales, moe_awq_to_marlin_zero_points,
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
# from vllm.model_executor.parameter import (GroupQuantScaleParameter,
# PackedvLLMParameter)
from vllm.model_executor.parameter import (construct_group_quant_scale_parameter,
construct_packed_vllm_parameter)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types

Expand Down Expand Up @@ -189,7 +191,7 @@ def create_weights(
input_size=input_size,
group_size=group_size)

qweight = PackedvLLMParameter(
qweight = construct_packed_vllm_parameter(
data=torch.empty(
input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor,
Expand All @@ -203,7 +205,7 @@ def create_weights(

num_groups = input_size_per_partition // group_size

qzeros = PackedvLLMParameter(
qzeros = construct_packed_vllm_parameter(
data=torch.empty(
num_groups,
output_size_per_partition // self.quant_config.pack_factor,
Expand All @@ -215,14 +217,15 @@ def create_weights(
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)

scales = GroupQuantScaleParameter(data=torch.empty(
num_groups,
output_size_per_partition,
dtype=params_dtype,
),
input_dim=0,
output_dim=1,
weight_loader=weight_loader)
scales = construct_group_quant_scale_parameter(
data=torch.empty(
num_groups,
output_size_per_partition,
dtype=params_dtype,
),
input_dim=0,
output_dim=1,
weight_loader=weight_loader)

layer.register_parameter("qweight", qweight)
layer.register_parameter("qzeros", qzeros)
Expand All @@ -238,12 +241,12 @@ def create_weights(
# Here, we handle the repacking
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device
layer.qweight = torch.nn.Parameter(layer.qweight.data,
requires_grad=False)
layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
requires_grad=False)
layer.scales = torch.nn.Parameter(layer.scales.data,
requires_grad=False)
# layer.qweight = torch.nn.Parameter(layer.qweight.data,
# requires_grad=False)
# layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
# requires_grad=False)
# layer.scales = torch.nn.Parameter(layer.scales.data,
# requires_grad=False)

# Allocate marlin workspace
layer.workspace = marlin_make_workspace(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, sparse_cutlass_supported)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
# from vllm.model_executor.parameter import (BasevLLMParameter,
# ChannelQuantScaleParameter,
# ModelWeightParameter,
# PerTensorScaleParameter)
from vllm.model_executor.parameter import (construct_base_vllm_parameter,
construct_channel_quant_scale_parameter,
construct_model_weight_parameter,
construct_per_tensor_scale_parameter)


__all__ = ["CompressedTensors24"]

Expand Down Expand Up @@ -50,27 +55,27 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)

# parameter to store uncompressed weight
weight = ModelWeightParameter(data=torch.empty(
weight = construct_model_weight_parameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=self.weights_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

# Check if quantized, not just 2:4 Sparse
if self.quantized:
if (self.weight_quant and self.weight_quant.strategy
== QuantizationStrategy.CHANNEL.value):
weight_scale = ChannelQuantScaleParameter(
weight_scale = construct_channel_quant_scale_parameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
else:
assert (self.weight_quant and self.weight_quant.strategy
== QuantizationStrategy.TENSOR.value)
weight_scale = PerTensorScaleParameter(
weight_scale = construct_per_tensor_scale_parameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader)
Expand All @@ -82,9 +87,9 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
# register input quant scale
assert (self.input_quant.strategy ==
QuantizationStrategy.TENSOR.value)
input_scale = BasevLLMParameter(data=torch.empty(
1, dtype=torch.float32),
weight_loader=weight_loader)
input_scale = construct_base_vllm_parameter(
data=torch.empty(1, dtype=torch.float32),
weight_loader=weight_loader)

layer.register_parameter("input_scale", input_scale)

Expand Down Expand Up @@ -113,20 +118,20 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

"""
# torch.compile workaround
if hasattr(layer, "input_scale"):
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)
# if hasattr(layer, "input_scale"):
# layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
# requires_grad=False)

if self.weight_quant:
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
layer.weight_scale = torch.nn.Parameter(convert_to_channelwise(
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths),
requires_grad=False)
else:
# torch.compile workaround
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False)
# else:
# # torch.compile workaround
# layer.weight_scale = torch.nn.Parameter(
# layer.weight_scale.data, requires_grad=False)

w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data)
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
# from vllm.model_executor.parameter import (BasevLLMParameter,
# ChannelQuantScaleParameter,
# GroupQuantScaleParameter,
# PackedvLLMParameter)
from vllm.model_executor.parameter import (construct_base_vllm_parameter,
construct_channel_quant_scale_parameter,
construct_group_quant_scale_parameter,
construct_packed_vllm_parameter)
from vllm.scalar_type import scalar_types

__all__ = ["CompressedTensorsW4A16Sparse24"]
Expand Down Expand Up @@ -48,12 +52,13 @@ def get_min_capability(cls) -> int:
return 80

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# required by torch.compile to be torch.nn.Parameter
layer.weight_packed = Parameter(layer.weight_packed.data,
requires_grad=False)
layer.scale_packed = Parameter(layer.scale_packed.data,
requires_grad=False)
layer.meta = Parameter(layer.meta.data, requires_grad=False)
# # required by torch.compile to be torch.nn.Parameter
# layer.weight_packed = Parameter(layer.weight_packed.data,
# requires_grad=False)
# layer.scale_packed = Parameter(layer.scale_packed.data,
# requires_grad=False)
# layer.meta = Parameter(layer.meta.data, requires_grad=False)
pass

def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
Expand All @@ -64,17 +69,18 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
pack_factor = 32 // self.quant_type.size_bits
output_size_per_partition = sum(output_partition_sizes)

qweight = PackedvLLMParameter(data=torch.empty(
input_size_per_partition // self.tile_size // 2,
output_size_per_partition * self.tile_size // pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=pack_factor,
marlin_tile_size=self.tile_size,
weight_loader=weight_loader)
qweight = construct_packed_vllm_parameter(
data=torch.empty(
input_size_per_partition // self.tile_size // 2,
output_size_per_partition * self.tile_size // pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=pack_factor,
marlin_tile_size=self.tile_size,
weight_loader=weight_loader)

input_groups = (1 if self.group_size is None else
input_size_per_partition // self.group_size)
Expand All @@ -91,28 +97,31 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
}

if self.group_size is not None:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
scales = construct_group_quant_scale_parameter(
output_dim=1,
input_dim=0,
**weight_scale_args)
else:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)

weight_shape = BasevLLMParameter(data=torch.empty(2,
dtype=torch.int64),
weight_loader=weight_loader)

meta = PackedvLLMParameter(data=torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
dtype=torch.int16,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=1,
marlin_tile_size=2,
weight_loader=weight_loader)
scales = construct_channel_quant_scale_parameter(
output_dim=1,
**weight_scale_args)

weight_shape = construct_base_vllm_parameter(
data=torch.empty(2,dtype=torch.int64),
weight_loader=weight_loader)

meta = construct_packed_vllm_parameter(
data=torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
dtype=torch.int16,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=1,
marlin_tile_size=2,
weight_loader=weight_loader)

layer.register_parameter("weight_packed", qweight)
layer.register_parameter("weight_shape", weight_shape)
Expand Down
Loading