Skip to content

Commit

Permalink
chore: add AphroditeParameter support for FP8 quant (#902)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale authored Dec 16, 2024
1 parent 2a60b8f commit afc9a28
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 16 deletions.
15 changes: 15 additions & 0 deletions aphrodite/modeling/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "GPTQMarlinLinearMethod",
"AWQMarlinLinearMethod", "AWQLinearMethod", "HQQMarlinMethod",
"Fp8LinearMethod",
]


Expand Down Expand Up @@ -359,6 +360,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
param_data.copy_(loaded_weight)

def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)

param.load_column_parallel_weight(loaded_weight=loaded_weight)

def forward(self, input_):
Expand Down Expand Up @@ -1081,8 +1088,16 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)


def weight_loader_v2(self, param: BaseAphroditeParameter,
loaded_weight: torch.Tensor):

# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)

param.load_row_parallel_weight(loaded_weight=loaded_weight)

def forward(self, input_):
Expand Down
9 changes: 8 additions & 1 deletion aphrodite/modeling/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,18 +208,25 @@ def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id

# if not int, assume shard_id for qkv
# map to int and return
assert isinstance(shard_id, str)
assert shard_id in self.qkv_idxs
return self.qkv_idxs[shard_id]

# For row parallel layers, no sharding needed
# load weight into parameter as is
def load_row_parallel_weight(self, *args, **kwargs):
super().load_row_parallel_weight(*args, **kwargs)

def load_merged_column_weight(self, *args, **kwargs):
self._load_into_shard_id(*args, **kwargs)

def load_qkv_weight(self, *args, **kwargs):
self._load_into_shard_id(*args, **kwargs)

def load_column_parallel_weight(self, *args, **kwargs):
self._load_into_shard_id(*args, **kwargs)
super().load_row_parallel_weight(*args, **kwargs)

def _load_into_shard_id(self, loaded_weight: torch.Tensor,
shard_id: Union[str, int], **kwargs):
Expand Down
41 changes: 26 additions & 15 deletions aphrodite/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from aphrodite.modeling.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from aphrodite.modeling.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
from aphrodite.modeling.utils import set_weight_attrs
from aphrodite.platforms import current_platform
from aphrodite.quantization.base_config import (QuantizationConfig,
Expand All @@ -21,8 +23,7 @@
from aphrodite.quantization.utils.quant_utils import is_layer_skipped
from aphrodite.quantization.utils.w8a8_utils import (
all_close_1d, apply_fp8_linear, convert_to_channelwise,
create_per_tensor_scale_param, cutlass_fp8_supported,
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
requantize_with_max_scale)

ACTIVATION_SCHEMES = ["static", "dynamic"]
Expand Down Expand Up @@ -136,6 +137,7 @@ def create_weights(
):
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")

layer.logical_widths = output_partition_sizes

Expand All @@ -147,34 +149,38 @@ def create_weights(
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype)
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=weight_dtype),
requires_grad=False)
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=weight_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
**extra_weight_attrs,
"input_dim": 1,
"output_dim": 0,
})

# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
scale = create_per_tensor_scale_param(output_partition_sizes,
**extra_weight_attrs)
scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", scale)

# INPUT ACTIVATION SCALE
if self.quant_config.activation_scheme == "static":
scale = create_per_tensor_scale_param(output_partition_sizes,
**extra_weight_attrs)
scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)

def process_weights_after_loading(self, layer: Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
Expand All @@ -196,6 +202,11 @@ def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module
else:
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
if self.quant_config.activation_scheme == "static":
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
Expand Down

0 comments on commit afc9a28

Please sign in to comment.