Skip to content

Commit af9ad46

Browse files
robertgshaw2-redhatRobert Shaw
and
Robert Shaw
authored
[ Misc ] Refactor w8a8 to use process_weights_after_load (Simplify Weight Loading) (vllm-project#5940)
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
1 parent 7836fdc commit af9ad46

10 files changed

+153
-158
lines changed

tests/quantization/test_compressed_tensors.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,18 @@
1111
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
1212
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor,
1313
CompressedTensorsWNA16)
14+
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
15+
QuantizationType)
1416

1517

1618
@pytest.mark.parametrize("model_args", [
17-
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor"),
18-
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel"),
19+
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor",
20+
QuantizationType.INT, 2560),
21+
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel",
22+
QuantizationType.INT, 2560),
1923
])
2024
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
21-
model_path, strategy = model_args
25+
model_path, strategy, quant_type, shape_0 = model_args
2226
with vllm_runner(model_path, enforce_eager=True) as llm:
2327
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
2428
layer = model.model.layers[0]
@@ -34,17 +38,23 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
3438
CompressedTensorsLinearMethod)
3539
assert isinstance(down_proj.quant_method,
3640
CompressedTensorsLinearMethod)
37-
3841
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor)
3942

4043
assert qkv_proj.scheme.strategy == strategy
41-
assert qkv_proj.weight.dtype is torch.int8
42-
assert o_proj.weight.dtype is torch.int8
43-
assert gate_up_proj.weight.dtype is torch.int8
44+
expected_type = (torch.int8 if quant_type == QuantizationType.INT else
45+
torch.float8_e4m3fn)
46+
47+
assert qkv_proj.weight.dtype is expected_type
48+
assert o_proj.weight.dtype is expected_type
49+
assert gate_up_proj.weight.dtype is expected_type
4450

4551
if qkv_proj.scheme.strategy == "tensor":
46-
assert qkv_proj.weight_scale.shard_splitter is not None
47-
assert qkv_proj.weight_scale.logical_widths is not None
52+
# Make sure it is a channelwise buffer
53+
# After running process_weights_after_loading
54+
assert len(qkv_proj.weight_scale.shape) == 2
55+
assert qkv_proj.weight_scale.shape[0] == shape_0
56+
assert qkv_proj.weight_scale.shape[1] == 1
57+
assert qkv_proj.weight_scale.dtype is torch.float32
4858
assert qkv_proj.input_scale.dtype is torch.float32
4959

5060

tests/quantization/test_fp8.py

+17
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,23 @@
99
from vllm._custom_ops import scaled_fp8_quant
1010
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
1111

12+
MODELS = [
13+
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8",
14+
"nm-testing/Phi-3-mini-128k-instruct-FP8",
15+
]
16+
17+
18+
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
19+
reason="FP8 is not supported on this GPU type.")
20+
@pytest.mark.parametrize("model", MODELS)
21+
def test_model_load_and_run(vllm_runner, model: str):
22+
with vllm_runner(model) as llm:
23+
# note: this does not test accuracy, just that we can run through
24+
# see lm-eval tests for accuracy
25+
outputs = llm.generate_greedy(prompts=["Hello my name is"],
26+
max_tokens=10)
27+
print(outputs[0][1])
28+
1229

1330
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
1431
reason="FP8 is not supported on this GPU type.")

vllm/model_executor/layers/linear.py

+39-71
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,29 @@ def adjust_bitsandbytes_shard(param: Parameter,
4141
return quantized_size, quantized_offset
4242

4343

44+
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
45+
"""For fused modules (QKV and MLP) we have an array of length
46+
N that holds 1 scale for each "logical" matrix. So the param
47+
is an array of length N. The loaded_weight corresponds to
48+
one of the shards on disk. Here, we slice the param based on
49+
the shard_id for loading.
50+
"""
51+
qkv_idxs = {"q": 0, "k": 1, "v": 2}
52+
53+
if isinstance(shard_id, str):
54+
shard_id = qkv_idxs[shard_id]
55+
elif not isinstance(shard_id, int):
56+
raise ValueError(f"Unknown Shard Id {shard_id}")
57+
58+
# AutoFP8 scales do not have a shape
59+
# compressed-tensors scales do have a shape
60+
if len(loaded_weight.shape) != 0:
61+
assert loaded_weight.shape[0] == 1
62+
loaded_weight = loaded_weight[0]
63+
64+
return param[shard_id], loaded_weight
65+
66+
4467
class LinearMethodBase(QuantizeMethodBase):
4568
"""Base class for different (maybe quantized) linear methods."""
4669

@@ -358,37 +381,15 @@ def weight_loader(self,
358381
output_dim = getattr(param, "output_dim", None)
359382
# Special case for AQLM codebooks.
360383
is_metadata = getattr(param, "is_metadata", False)
361-
362-
param_shard_splitter = getattr(param, "shard_splitter", None)
363-
364-
if output_dim is not None and param_shard_splitter is not None:
365-
raise NotImplementedError(
366-
"We do not currently support output_dim != None and "
367-
"shard_splitter != None for a parameter. Please open an issue."
368-
)
369-
# If a parameter has defined a shard_splitter to be used for
370-
# the weight, it should be applied before the weight is
371-
# loaded/copied to the parameter. The shard_splitter applies
372-
# logic by using the loaded_shard_id to ensure that the loaded
373-
# param is loaded to the correct location
374-
# within the parameter defined by the linear method.
375-
if loaded_shard_id is None and param_shard_splitter is not None:
376-
raise NotImplementedError(
377-
"We do not currently support loaded_shard_id == None and "
378-
"shard_splitter != None for a parameter. Please open an issue."
379-
)
380-
381-
# Special case for Fp8 scales.
382-
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
383-
None)
384+
# Special case for per-tensor scale to load scalar into fused array.
385+
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
384386

385387
if loaded_shard_id is None:
386388
# Loaded weight is already fused on disk (qkv/mlp).
387389
if output_dim is None:
388-
# If fp8 + scale, need to send to each shard.
389-
if fp8_scales_shard_indexer is not None:
390-
param_data, loaded_weight = fp8_scales_shard_indexer(
391-
param_data, loaded_weight, loaded_shard_id)
390+
if needs_scalar_to_array is not None:
391+
param_data, loaded_weight = adjust_scalar_to_fused_array(
392+
param_data, loaded_weight, 0)
392393

393394
assert param_data.shape == loaded_weight.shape
394395
param_data.copy_(loaded_weight)
@@ -450,15 +451,9 @@ def weight_loader(self,
450451
shard_offset = loaded_shard_id * shard_size
451452
param_data = param_data.narrow(0, shard_offset, shard_size)
452453

453-
# If a param_shard_splitter is defined by the LinearMethod, use it.
454-
elif param_shard_splitter is not None:
455-
logical_widths = getattr(param, "logical_widths", None)
456-
param_data, loaded_weight = param_shard_splitter(
457-
param_data, loaded_weight, loaded_shard_id, logical_widths)
458-
459-
# Special case for Fp8 scales.
460-
elif fp8_scales_shard_indexer is not None:
461-
param_data, loaded_weight = fp8_scales_shard_indexer(
454+
# Special case for per-tensor scales in fused case.
455+
elif needs_scalar_to_array:
456+
param_data, loaded_weight = adjust_scalar_to_fused_array(
462457
param_data, loaded_weight, loaded_shard_id)
463458

464459
else:
@@ -548,36 +543,15 @@ def weight_loader(self,
548543
# Special case for AQLM codebooks.
549544
is_metadata = getattr(param, "is_metadata", False)
550545

551-
param_shard_splitter = getattr(param, "shard_splitter", None)
552-
553-
if output_dim is not None and param_shard_splitter is not None:
554-
raise NotImplementedError(
555-
"We do not currently support output_dim != None and "
556-
"shard_splitter != None for a parameter. Please open an issue."
557-
)
558-
# If a parameter has defined a shard_splitter to be used for
559-
# the weight, it should be applied before the weight is
560-
# loaded/copied to the parameter. The shard_splitter applies
561-
# logic by using the loaded_shard_id to ensure that the loaded
562-
# param is loaded to the correct location
563-
# within the parameter defined by the linear method.
564-
if loaded_shard_id is None and param_shard_splitter is not None:
565-
raise NotImplementedError(
566-
"We do not currently support loaded_shard_id == None and "
567-
"shard_splitter != None for a parameter. Please open an issue."
568-
)
569-
570-
# Special case for Fp8 scales.
571-
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
572-
None)
546+
# Special case for per-tensor scales in fused case.
547+
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
573548

574549
if loaded_shard_id is None:
575550
# Loaded weight is already fused on disk (qkv/mlp).
576551
if output_dim is None:
577-
# If fp8 + scale, need to send to each shard.
578-
if fp8_scales_shard_indexer is not None:
579-
param_data, loaded_weight = fp8_scales_shard_indexer(
580-
param_data, loaded_weight, loaded_shard_id)
552+
if needs_scalar_to_array is not None:
553+
param_data, loaded_weight = adjust_scalar_to_fused_array(
554+
param_data, loaded_weight, 0)
581555

582556
assert param_data.shape == loaded_weight.shape
583557
param_data.copy_(loaded_weight)
@@ -667,15 +641,9 @@ def weight_loader(self,
667641
shard_index = ["q", "k", "v"].index(loaded_shard_id)
668642
param_data = param_data.narrow(0, shard_index * shard_size,
669643
shard_size)
670-
# If a param_shard_splitter is defined by the LinearMethod, use it.
671-
elif param_shard_splitter is not None:
672-
logical_widths = getattr(param, "logical_widths", None)
673-
param_data, loaded_weight = param_shard_splitter(
674-
param_data, loaded_weight, loaded_shard_id, logical_widths)
675-
676-
# Special case for Fp8 scales.
677-
elif fp8_scales_shard_indexer is not None:
678-
param_data, loaded_weight = fp8_scales_shard_indexer(
644+
# Special case for per-tensor scales in fused case.
645+
elif needs_scalar_to_array:
646+
param_data, loaded_weight = adjust_scalar_to_fused_array(
679647
param_data, loaded_weight, loaded_shard_id)
680648
else:
681649
ignore_warning = getattr(param, "ignore_warning", False)

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

+3
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
186186
def __init__(self, quantization_config: CompressedTensorsConfig):
187187
self.quantization_config = quantization_config
188188

189+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
190+
return layer.scheme.process_weights_after_loading(layer)
191+
189192
def create_weights(self, layer: torch.nn.Module,
190193
input_size_per_partition: int,
191194
output_partition_sizes: List[int], input_size: int,

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py

+8
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,11 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
3131
3232
"""
3333
raise NotImplementedError
34+
35+
@abstractmethod
36+
def process_weights_after_loading(self, layer: torch.nn.Module):
37+
"""
38+
Called after weight loading is complete for any cleanup that
39+
needs to occur.
40+
"""
41+
raise NotImplementedError

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py

+3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
1818
in a linear transformation.
1919
"""
2020

21+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
22+
pass
23+
2124
def create_weights(self, layer: torch.nn.Module,
2225
output_partition_sizes: List[int],
2326
input_size_per_partition: int,

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py

+3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ def __init__(self,
2929
raise ValueError(
3030
"group_size must be given when using strategy group")
3131

32+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
33+
pass
34+
3235
def create_weights(self, layer: torch.nn.Module, input_size: int,
3336
output_partition_sizes: List[int],
3437
input_size_per_partition: int,

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py

+42-49
Original file line numberDiff line numberDiff line change
@@ -15,70 +15,63 @@ class CompressedTensorsW8A8(CompressedTensorsScheme):
1515
def __init__(self, strategy: str):
1616
self.strategy = strategy
1717

18-
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
19-
if isinstance(shard_id, int):
20-
return shard_id
21-
22-
assert isinstance(shard_id, str)
23-
qkv_idxs = {"q": 0, "k": 1, "v": 2}
24-
assert shard_id in qkv_idxs
25-
return qkv_idxs[shard_id]
26-
27-
def scales_shard_splitter(
28-
self, param: torch.Tensor, loaded_weight: torch.Tensor,
29-
shard_id: Union[str, int],
30-
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
31-
shard_id = self._shard_id_as_int(shard_id)
32-
offset = sum(logical_widths[:shard_id])
33-
size = logical_widths[shard_id]
34-
# update loaded weight with copies for broadcast.
35-
loaded_weight = loaded_weight.repeat(size)
36-
return param[offset:offset + size], loaded_weight
18+
# Cutlass kernels support only per-tensor and per-channel cases.
19+
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
20+
# scales being passed to the kernel), we convert to the per-channel case.
21+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
22+
if (self.strategy == QuantizationStrategy.TENSOR
23+
and len(self.logical_widths) > 1):
24+
25+
# Load the N per-tensor scales into the channelwise buffer.
26+
weight_scale_channel = torch.empty(
27+
(sum(self.logical_widths), 1),
28+
dtype=torch.float32,
29+
device=layer.weight_scale.device)
30+
start = 0
31+
for idx, logical_width in enumerate(self.logical_widths):
32+
end = start + logical_width
33+
weight_scale_channel[start:end, :] = layer.weight_scale[idx]
34+
start = end
35+
36+
layer.weight_scale = Parameter(weight_scale_channel,
37+
requires_grad=False)
3738

3839
def create_weights(self, layer: torch.nn.Module,
3940
output_partition_sizes: List[int],
4041
input_size_per_partition: int,
4142
params_dtype: torch.dtype, weight_loader: Callable,
4243
**kwargs):
44+
self.logical_widths = output_partition_sizes
4345

44-
is_tensor_partitioned = len(output_partition_sizes) != 1
45-
weight_scale_dim = sum(output_partition_sizes) if (
46-
is_tensor_partitioned
47-
or self.strategy == QuantizationStrategy.CHANNEL) else 1
48-
49-
shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, )
46+
# WEIGHT SCALE
47+
shape: Union[Tuple[int], Tuple[int, int]]
5048
if self.strategy == QuantizationStrategy.CHANNEL:
51-
shape = (weight_scale_dim, 1)
49+
shape = (sum(self.logical_widths), 1)
50+
else:
51+
shape = (len(self.logical_widths), )
5252

5353
weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
5454
requires_grad=False)
55-
5655
layer.register_parameter("weight_scale", weight_scale)
57-
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
56+
if self.strategy == QuantizationStrategy.CHANNEL:
57+
set_weight_attrs(weight_scale, {
58+
"weight_loader": weight_loader,
59+
"output_dim": 0,
60+
})
61+
else:
62+
set_weight_attrs(weight_scale, {
63+
"weight_loader": weight_loader,
64+
"needs_scalar_to_array": True,
65+
})
5866

67+
# WEIGHT
5968
weight = Parameter(torch.empty(sum(output_partition_sizes),
6069
input_size_per_partition,
6170
dtype=torch.int8),
6271
requires_grad=False)
63-
6472
layer.register_parameter("weight", weight)
65-
set_weight_attrs(
66-
weight, {
67-
"input_dim": 1,
68-
"output_dim": 0,
69-
"weight_loader": weight_loader,
70-
"logical_widths": output_partition_sizes
71-
})
72-
73-
# Don't need a shard_splitter for channel-wise quantization
74-
# Use the default loading method
75-
if self.strategy == QuantizationStrategy.CHANNEL:
76-
set_weight_attrs(weight_scale, {
77-
"output_dim": 0,
78-
})
79-
else:
80-
set_weight_attrs(
81-
weight_scale, {
82-
"logical_widths": output_partition_sizes,
83-
"shard_splitter": self.scales_shard_splitter,
84-
})
73+
set_weight_attrs(weight, {
74+
"input_dim": 1,
75+
"output_dim": 0,
76+
"weight_loader": weight_loader,
77+
})

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py

+3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ def __init__(self,
2929
raise ValueError(
3030
"group_size must be given when using strategy group")
3131

32+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
33+
pass
34+
3235
def create_weights(self, layer: torch.nn.Module, input_size: int,
3336
output_partition_sizes: List[int],
3437
input_size_per_partition: int,

0 commit comments

Comments
 (0)