diff --git a/aphrodite/modeling/layers/linear.py b/aphrodite/modeling/layers/linear.py index bf4ee67a1..281c1a03f 100644 --- a/aphrodite/modeling/layers/linear.py +++ b/aphrodite/modeling/layers/linear.py @@ -212,7 +212,7 @@ def __init__(self, self.input_size, self.output_size, self.params_dtype, - prefix=prefix) + weight_loader=self.weight_loader) if bias: self.bias = Parameter( @@ -318,8 +318,7 @@ def __init__(self, params_dtype=self.params_dtype, weight_loader=( self.weight_loader_v2 if self.quant_method.__class__.__name__ - in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader), - prefix=prefix) + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -1035,8 +1034,7 @@ def __init__(self, params_dtype=self.params_dtype, weight_loader=( self.weight_loader_v2 if self.quant_method.__class__.__name__ - in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader), - prefix=prefix) + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " "results can lead to incorrect results") diff --git a/aphrodite/quantization/compressed_tensors/compressed_tensors.py b/aphrodite/quantization/compressed_tensors/compressed_tensors.py index 6239f062e..757c7712f 100644 --- a/aphrodite/quantization/compressed_tensors/compressed_tensors.py +++ b/aphrodite/quantization/compressed_tensors/compressed_tensors.py @@ -3,16 +3,16 @@ import torch from pydantic import BaseModel -from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase +from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) from aphrodite.platforms import current_platform from aphrodite.quantization.base_config import ( # noqa: E501 QuantizationConfig, QuantizeMethodBase) from aphrodite.quantization.compressed_tensors.schemes import ( W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, - CompressedTensorsScheme, CompressedTensorsUnquantized, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, - CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16) + CompressedTensorsScheme, CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) from aphrodite.quantization.compressed_tensors.utils import ( CompressionFormat, QuantizationArgs, QuantizationStrategy, QuantizationType, find_matched_target, is_activation_quantization_format, @@ -52,8 +52,6 @@ def get_min_capability(cls) -> int: def get_name(self) -> str: return "compressed_tensors" - # TODO: do layer skipping though here - # rather than though create_weights to match other methods def get_quant_method( self, layer: torch.nn.Module, @@ -61,7 +59,14 @@ def get_quant_method( ) -> Optional["QuantizeMethodBase"]: from aphrodite.attention.layer import ( Attention) # Avoid circular import + + # Check if the layer is skipped for quantization. + # TODO: support module names + if should_ignore_layer(prefix, ignore=self.ignore): + return UnquantizedLinearMethod() if isinstance(layer, LinearBase): + scheme = self.get_scheme(layer=layer, layer_name=prefix) + layer.scheme = scheme return CompressedTensorsLinearMethod(self) if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) @@ -283,15 +288,11 @@ def get_scheme( to select the CompressedTensorsScheme used for infernece. """ - # Check if the layer is skipped for quantization. - # TODO: support module names - if should_ignore_layer(layer_name, ignore=self.ignore): - return CompressedTensorsUnquantized() - # Find the "target" in the compressed-tensors config # that our layer conforms to. # TODO: add compressed-tensors as dep # so we do not have to re-write these functions + # need to make accelerate optional in ct to do this matched_target = find_matched_target( layer_name=layer_name, module=layer, @@ -329,10 +330,7 @@ def create_weights(self, layer: torch.nn.Module, details """ weight_loader = extra_weight_attrs.get("weight_loader") - layer_name = extra_weight_attrs.get("prefix") - - scheme = self.quantization_config.get_scheme(layer, layer_name) - scheme.create_weights( + layer.scheme.create_weights( layer=layer, input_size=input_size, input_size_per_partition=input_size_per_partition, @@ -341,8 +339,6 @@ def create_weights(self, layer: torch.nn.Module, params_dtype=params_dtype, weight_loader=weight_loader) - layer.scheme = scheme - def apply(self, layer: torch.nn.Module, x: torch.Tensor, diff --git a/aphrodite/quantization/compressed_tensors/schemes/__init__.py b/aphrodite/quantization/compressed_tensors/schemes/__init__.py index ca9e286ce..5d259ec72 100644 --- a/aphrodite/quantization/compressed_tensors/schemes/__init__.py +++ b/aphrodite/quantization/compressed_tensors/schemes/__init__.py @@ -1,5 +1,4 @@ from .compressed_tensors_scheme import CompressedTensorsScheme -from .compressed_tensors_unquantized import CompressedTensorsUnquantized from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24) from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 @@ -10,7 +9,6 @@ __all__ = [ "CompressedTensorsScheme", - "CompressedTensorsUnquantized", "CompressedTensorsWNA16", "CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24", diff --git a/aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py b/aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py deleted file mode 100644 index 941e1df20..000000000 --- a/aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import Callable, List, Optional - -import torch -import torch.nn.functional as F - -from aphrodite.modeling.parameter import ModelWeightParameter -from aphrodite.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) - -__all__ = ["CompressedTensorsUnquantized"] - - -class CompressedTensorsUnquantized(CompressedTensorsScheme): - """ - Implements the scheme for all layers which are ignored - in the CompressedTensors config. The input and loaded weight are used - in a linear transformation. - """ - - @classmethod - def get_min_capability(cls) -> int: - # volta and up - return 70 - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # required by torch.compile to be torch.nn.Parameter - layer.weight = torch.nn.Parameter(layer.weight.data, - requires_grad=False) - - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: List[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): - - weight = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) - - layer.register_parameter("weight", weight) - - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - - return F.linear(x, layer.weight, bias)