Skip to content

Commit

Permalink
quants: update compressed tensors lifecycle to remove prefix from `…
Browse files Browse the repository at this point in the history
…create_weights` (#924)
  • Loading branch information
AlpinDale authored Dec 18, 2024
1 parent 0c6d90d commit 5cb2e99
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 74 deletions.
8 changes: 3 additions & 5 deletions aphrodite/modeling/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def __init__(self,
self.input_size,
self.output_size,
self.params_dtype,
prefix=prefix)
weight_loader=self.weight_loader)

if bias:
self.bias = Parameter(
Expand Down Expand Up @@ -318,8 +318,7 @@ def __init__(self,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
prefix=prefix)
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
Expand Down Expand Up @@ -1035,8 +1034,7 @@ def __init__(self,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
prefix=prefix)
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
Expand Down
32 changes: 14 additions & 18 deletions aphrodite/quantization/compressed_tensors/compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
import torch
from pydantic import BaseModel

from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from aphrodite.platforms import current_platform
from aphrodite.quantization.base_config import ( # noqa: E501
QuantizationConfig, QuantizeMethodBase)
from aphrodite.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
CompressedTensorsScheme, CompressedTensorsUnquantized,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16)
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from aphrodite.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy,
QuantizationType, find_matched_target, is_activation_quantization_format,
Expand Down Expand Up @@ -52,16 +52,21 @@ def get_min_capability(cls) -> int:
def get_name(self) -> str:
return "compressed_tensors"

# TODO: do layer skipping though here
# rather than though create_weights to match other methods
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from aphrodite.attention.layer import (
Attention) # Avoid circular import

# Check if the layer is skipped for quantization.
# TODO: support module names
if should_ignore_layer(prefix, ignore=self.ignore):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
layer.scheme = scheme
return CompressedTensorsLinearMethod(self)
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
Expand Down Expand Up @@ -283,15 +288,11 @@ def get_scheme(
to select the CompressedTensorsScheme used for infernece.
"""

# Check if the layer is skipped for quantization.
# TODO: support module names
if should_ignore_layer(layer_name, ignore=self.ignore):
return CompressedTensorsUnquantized()

# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# TODO: add compressed-tensors as dep
# so we do not have to re-write these functions
# need to make accelerate optional in ct to do this
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
Expand Down Expand Up @@ -329,10 +330,7 @@ def create_weights(self, layer: torch.nn.Module,
details
"""
weight_loader = extra_weight_attrs.get("weight_loader")
layer_name = extra_weight_attrs.get("prefix")

scheme = self.quantization_config.get_scheme(layer, layer_name)
scheme.create_weights(
layer.scheme.create_weights(
layer=layer,
input_size=input_size,
input_size_per_partition=input_size_per_partition,
Expand All @@ -341,8 +339,6 @@ def create_weights(self, layer: torch.nn.Module,
params_dtype=params_dtype,
weight_loader=weight_loader)

layer.scheme = scheme

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
Expand Down
2 changes: 0 additions & 2 deletions aphrodite/quantization/compressed_tensors/schemes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_unquantized import CompressedTensorsUnquantized
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24)
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
Expand All @@ -10,7 +9,6 @@

__all__ = [
"CompressedTensorsScheme",
"CompressedTensorsUnquantized",
"CompressedTensorsWNA16",
"CompressedTensorsW8A16Fp8",
"CompressedTensorsW4A16Sparse24",
Expand Down

This file was deleted.

0 comments on commit 5cb2e99

Please sign in to comment.