Skip to content

Commit

Permalink
[Misc] Revert compressed-tensors code reuse (#7521)
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs authored Aug 14, 2024
1 parent 951fdd6 commit f55a9ae
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 13 deletions.
1 change: 0 additions & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,3 @@ librosa # Required for audio processing
soundfile # Required for audio processing
gguf == 0.9.1
importlib_metadata
compressed-tensors == 0.5.0
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ peft
requests
ray
sentence-transformers # required for embedding
compressed-tensors==0.5.0 # required for compressed-tensors
compressed-tensors==0.4.0 # required for compressed-tensors
timm # required for internvl test

# TODO: Add this after fully implementing llava(mantis)
Expand Down
3 changes: 2 additions & 1 deletion tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

import pytest
import torch
from compressed_tensors.quantization import QuantizationType

from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationType)


@pytest.mark.parametrize("model_args", [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from typing import Any, Dict, List, Optional

import torch
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from pydantic import BaseModel

from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
Expand All @@ -17,7 +13,8 @@
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format,
CompressionFormat, QuantizationArgs, QuantizationStrategy,
QuantizationType, find_matched_target, is_activation_quantization_format,
should_ignore_layer)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Callable, List, Optional

import torch
from compressed_tensors.quantization import QuantizationStrategy

from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Callable, List, Optional

import torch
from compressed_tensors.quantization import QuantizationStrategy
from torch.nn import Parameter

from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Callable, List, Optional

import torch
from compressed_tensors.quantization import QuantizationStrategy
from torch.nn import Parameter

from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_int8_linear, convert_to_channelwise)
from vllm.model_executor.parameter import (BasevLLMParameter,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,85 @@
import re
from typing import Iterable, Optional
from enum import Enum
from typing import Any, Dict, Iterable, Optional

from compressed_tensors import CompressionFormat
from pydantic import BaseModel, Field
from torch.nn import Module

from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)


class CompressionFormat(Enum):
dense = "dense"
sparse_bitmask = "sparse-bitmask"
naive_quantized = "naive-quantized"
float_quantized = "float-quantized"
int_quantized = "int-quantized"
pack_quantized = "pack-quantized"
marlin_24 = "marlin-24"


class QuantizationType(str, Enum):
"""
Enum storing quantization type options
"""

INT = "int"
FLOAT = "float"


class QuantizationStrategy(str, Enum):
"""
Enum storing quantization strategy options
"""

TENSOR = "tensor"
CHANNEL = "channel"
GROUP = "group"
BLOCK = "block"
TOKEN = "token"


class QuantizationArgs(BaseModel):
"""
User facing arguments used to define a quantization config
for weights or activations
:param num_bits: quantization bit depth
:param type: dtype to quantized to, either int or float
:param symmetric: whether or not quantization scale is symmetric
:param strategy: string determining the scope of scale/zero-point to apply
:param group_size: group length to use for the group strategy
:param block_structure: 2d block structure to use for the block
strategy, must be of the format "2x4", "8x16", etc.
:param dynamic: set True to perform dynamic quantization -
values will not be calibrated during calibration phase,
instead during inference new quantization ranges will be
observed with every sample. Defaults to False for static
quantization. Note that enabling dynamic quantization
will change the default observer to a memoryless one
"""

num_bits: int = 8
type: QuantizationType = QuantizationType.INT
symmetric: bool = True
group_size: Optional[int] = None
strategy: Optional[QuantizationStrategy] = None
block_structure: Optional[str] = None
dynamic: bool = False
observer: str = Field(
default="minmax",
description=("The class to use to compute the quantization param - "
"scale and zero-point'"),
)
observer_kwargs: Dict[str, Any] = Field(
default_factory=dict,
description=
("optional dict of kwargs to be passed directly to torch quantization "
"Observers constructor excluding quantization range or symmetry"),
)


def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [
CompressionFormat.naive_quantized.value,
Expand Down

0 comments on commit f55a9ae

Please sign in to comment.