From afc9a28aa0e2ccad56b0262660b4d01b86758f26 Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Sun, 15 Dec 2024 18:51:58 -0800 Subject: [PATCH] chore: add AphroditeParameter support for FP8 quant (#902) --- aphrodite/modeling/layers/linear.py | 15 +++++++++++ aphrodite/modeling/parameter.py | 9 ++++++- aphrodite/quantization/fp8.py | 41 ++++++++++++++++++----------- 3 files changed, 49 insertions(+), 16 deletions(-) diff --git a/aphrodite/modeling/layers/linear.py b/aphrodite/modeling/layers/linear.py index 26cafecc9..9fd19d7ad 100644 --- a/aphrodite/modeling/layers/linear.py +++ b/aphrodite/modeling/layers/linear.py @@ -26,6 +26,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", "GPTQMarlinLinearMethod", "AWQMarlinLinearMethod", "AWQLinearMethod", "HQQMarlinMethod", + "Fp8LinearMethod", ] @@ -359,6 +360,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param_data.copy_(loaded_weight) def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + assert loaded_weight.numel() == 1 + loaded_weight = loaded_weight.reshape(1) + param.load_column_parallel_weight(loaded_weight=loaded_weight) def forward(self, input_): @@ -1081,8 +1088,16 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) + def weight_loader_v2(self, param: BaseAphroditeParameter, loaded_weight: torch.Tensor): + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + assert loaded_weight.numel() == 1 + loaded_weight = loaded_weight.reshape(1) + param.load_row_parallel_weight(loaded_weight=loaded_weight) def forward(self, input_): diff --git a/aphrodite/modeling/parameter.py b/aphrodite/modeling/parameter.py index 271958b79..77492736c 100644 --- a/aphrodite/modeling/parameter.py +++ b/aphrodite/modeling/parameter.py @@ -208,10 +208,17 @@ def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: if isinstance(shard_id, int): return shard_id + # if not int, assume shard_id for qkv + # map to int and return assert isinstance(shard_id, str) assert shard_id in self.qkv_idxs return self.qkv_idxs[shard_id] + # For row parallel layers, no sharding needed + # load weight into parameter as is + def load_row_parallel_weight(self, *args, **kwargs): + super().load_row_parallel_weight(*args, **kwargs) + def load_merged_column_weight(self, *args, **kwargs): self._load_into_shard_id(*args, **kwargs) @@ -219,7 +226,7 @@ def load_qkv_weight(self, *args, **kwargs): self._load_into_shard_id(*args, **kwargs) def load_column_parallel_weight(self, *args, **kwargs): - self._load_into_shard_id(*args, **kwargs) + super().load_row_parallel_weight(*args, **kwargs) def _load_into_shard_id(self, loaded_weight: torch.Tensor, shard_id: Union[str, int], **kwargs): diff --git a/aphrodite/quantization/fp8.py b/aphrodite/quantization/fp8.py index ddc392f01..bee8ae9ad 100644 --- a/aphrodite/quantization/fp8.py +++ b/aphrodite/quantization/fp8.py @@ -11,6 +11,8 @@ from aphrodite.modeling.layers.fused_moe import FusedMoE, FusedMoEMethodBase from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) +from aphrodite.modeling.parameter import (ModelWeightParameter, + PerTensorScaleParameter) from aphrodite.modeling.utils import set_weight_attrs from aphrodite.platforms import current_platform from aphrodite.quantization.base_config import (QuantizationConfig, @@ -21,8 +23,7 @@ from aphrodite.quantization.utils.quant_utils import is_layer_skipped from aphrodite.quantization.utils.w8a8_utils import ( all_close_1d, apply_fp8_linear, convert_to_channelwise, - create_per_tensor_scale_param, cutlass_fp8_supported, - normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, + cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, requantize_with_max_scale) ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -136,6 +137,7 @@ def create_weights( ): del input_size, output_size output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes @@ -147,34 +149,38 @@ def create_weights( weight_dtype = (torch.float8_e4m3fn if self.quant_config.is_checkpoint_fp8_serialized else params_dtype) - weight = Parameter(torch.empty(output_size_per_partition, - input_size_per_partition, - dtype=weight_dtype), - requires_grad=False) + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=weight_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) layer.register_parameter("weight", weight) - set_weight_attrs(weight, { - **extra_weight_attrs, - "input_dim": 1, - "output_dim": 0, - }) # If checkpoint is serialized fp8, load them. # Otherwise, wait until process_weights_after_loading. if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE - scale = create_per_tensor_scale_param(output_partition_sizes, - **extra_weight_attrs) + scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", scale) # INPUT ACTIVATION SCALE if self.quant_config.activation_scheme == "static": - scale = create_per_tensor_scale_param(output_partition_sizes, - **extra_weight_attrs) + scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", scale) else: layer.register_parameter("input_scale", None) def process_weights_after_loading(self, layer: Module) -> None: + layer.weight = torch.nn.Parameter(layer.weight.data, + requires_grad=False) # If checkpoint not serialized fp8, quantize the weights. if not self.quant_config.is_checkpoint_fp8_serialized: qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, @@ -196,6 +202,11 @@ def process_weights_after_loading(self, layer: Module) -> None: # If checkpoint is fp8, handle that there are N scales for N # shards in a fused module else: + layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, + requires_grad=False) + if self.quant_config.activation_scheme == "static": + layer.input_scale = torch.nn.Parameter(layer.input_scale.data, + requires_grad=False) # If using marlin (w8a16), kernel uses channelwise weights, # so extend the weight scales to be channelwise. if self.use_marlin: