Skip to content

Commit

Permalink
Merge branch 'main' into compressed-lifecycle
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins committed May 3, 2024
2 parents df94b5e + 05c1487 commit bfc1136
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 12 deletions.
91 changes: 84 additions & 7 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
# limitations under the License.

from functools import wraps
from math import ceil
from typing import Optional

import torch
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.quant_args import (
QuantizationArgs,
QuantizationStrategy,
)
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module
Expand All @@ -38,9 +42,7 @@ def quantize(
q_min = torch.tensor(-bit_range / 2, device=x.device)

quantized_value = torch.clamp(
torch.round(
x / scale + zero_point,
),
torch.round(x / scale + zero_point),
q_min,
q_max,
)
Expand All @@ -67,8 +69,84 @@ def fake_quantize(
zero_point: torch.Tensor,
args: QuantizationArgs,
) -> torch.Tensor:
x_quant = quantize(x, scale, zero_point, args)
return dequantize(x_quant, scale, zero_point)
"""
Fake quantize the input tensor x depending on the group_size.
if group_size is greater than 0, then q/dq by groups. The groups
must be divisible by the column size
if group_size is -1, then channel wise q/dq. THe input scale and
zero_points are reshaped to support vectorization (Assumes 1 is
the channel dimension)
:param x: Input tensor
:param scale: scale tensor
:param zero_point: zero point tensor
:param args: quantization args that contain group_size info
:return: fake quantized tensor
"""
group_size = args.group_size

# group
if args.strategy == QuantizationStrategy.GROUP:

DQ = torch.zeros_like(x)

# TODO: vectorize the for loop
# TODO: fix genetric assumption about the tensor size for computing group

# TODO: make validation step for inputs

while scale.ndim < 2:
# pad scale and zero point dims for slicing
scale = scale.unsqueeze(1)
zero_point = zero_point.unsqueeze(1)

columns = x.shape[1]
if columns >= group_size:
if columns % group_size != 0:
raise ValueError(
"tesnor column shape must be divisble "
f"by the given group_size {group_size}"
)
for i in range(ceil(columns / group_size)):
# scale.shape should be [nchan, ndim]
# sc.shape should be [nchan, 1] after unsqueeze

sc = scale[:, i].unsqueeze(1)
zp = zero_point[:, i].unsqueeze(1)

idx = i * group_size
Q = quantize(x[:, idx : (idx + group_size)], sc, zp, args)
DQ[:, idx : (idx + group_size)] = dequantize(Q, sc, zp)

# channel-wise
elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1
# before: scale shape = [channel_size]
# after: scale shape = [1, channel_size]
scale = scale.unsqueeze(0)
zero_point = zero_point.unsqueeze(0)

Q = quantize(x, scale, zero_point, args)
DQ = dequantize(Q, scale, zero_point)

# per-token
elif args.strategy == QuantizationStrategy.TOKEN:
# before: scale shape = [num_tokens]
# after: scale shape = [num_tokens, 1]
# x.shape = 1, num_tokens, 1]
# scale gets broadcasted as expected withput having [1, num_tokens, 1] shape

scale = scale.unsqueeze(1)
zero_point = zero_point.unsqueeze(1)

Q = quantize(x, scale, zero_point, args)
DQ = dequantize(Q, scale, zero_point)

else:
Q = quantize(x, scale, zero_point, args)
DQ = dequantize(Q, scale, zero_point)

return DQ


def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
Expand Down Expand Up @@ -145,5 +223,4 @@ def maybe_calibrate_or_quantize(
device = next(module.parameters()).device
scale.data = updated_scale.to(device)
zero_point.data = updated_zero_point.to(device)

return fake_quantize(value, scale, zero_point, args)
67 changes: 64 additions & 3 deletions src/compressed_tensors/quantization/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

from typing import Optional, Tuple

from compressed_tensors.quantization.quant_args import QuantizationArgs
import torch
from compressed_tensors.quantization.quant_args import (
QuantizationArgs,
QuantizationStrategy,
)
from compressed_tensors.registry.registry import RegistryMixin
from torch import FloatTensor, IntTensor, Tensor
from torch.nn import Module
Expand Down Expand Up @@ -52,6 +56,12 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
"""
raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")

def post_calculate_qparams(self) -> None:
"""
Run any logic specific to its observers after running calculate_qparams
"""
...

def get_qparams(
self, observed: Optional[Tensor] = None
) -> Tuple[FloatTensor, IntTensor]:
Expand All @@ -64,6 +74,57 @@ def get_qparams(
:return: tuple of scale and zero point based on last observed value
"""
if observed is not None:
# re-calcualte scale and zero point, update the stored value
self._scale, self._zero_point = self.calculate_qparams(observed)
group_size = self.quantization_args.group_size

if self.quantization_args.strategy == QuantizationStrategy.TENSOR:

# re-calculate scale and zero point, update the stored value
self._scale, self._zero_point = self.calculate_qparams(observed)

elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
columns = observed.shape[1]
scales, zero_points = [], []
for i in range(0, columns, self.quantization_args.group_size):
scale, zero_point = self.get_qparams_along_dim(
observed[:, i : (i + group_size)],
0,
)
scales.append(scale)
zero_points.append(zero_point)

self._scale = torch.stack(scales, dim=1)
self._zero_point = torch.stack(zero_points, dim=1)

elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
# assume observed is transposed, because its the output, hence use dim 0
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)

elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:

# use dim 1, assume the obsersed.shape = [batch, token, hidden]
# should be batch, token

self._scale, self._zero_point = self.get_qparams_along_dim(
observed, dim=1
)

return self._scale, self._zero_point

def get_qparams_along_dim(self, observed, dim: int):
# TODO: add documentation that specifies the shape must
# be padded with 1-dims so the scales are along the right channel
# TODO: generalize the logic for reduce_dims
scales, zero_points = [], []

# TODO: make a more generic way to get the channel
num_dims = observed.shape[dim]

for dim_idx in range(num_dims):
scale, zero_point = self.calculate_qparams(
observed.select(dim=dim, index=dim_idx)
)

scales.append(scale)
zero_points.append(zero_point)
# breakpoint()
return torch.stack(scales), torch.stack(zero_points)
33 changes: 31 additions & 2 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from enum import Enum
from typing import Any, Dict, Optional

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator


__all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"]
Expand All @@ -39,6 +39,7 @@ class QuantizationStrategy(str, Enum):
CHANNEL = "channel"
GROUP = "group"
BLOCK = "block"
TOKEN = "token"


class QuantizationArgs(BaseModel):
Expand All @@ -63,8 +64,8 @@ class QuantizationArgs(BaseModel):
num_bits: int = 8
type: QuantizationType = QuantizationType.INT
symmetric: bool = True
strategy: QuantizationStrategy = QuantizationStrategy.TENSOR
group_size: Optional[int] = None
strategy: Optional[QuantizationStrategy] = None
block_structure: Optional[str] = None
dynamic: bool = False
observer: str = Field(
Expand Down Expand Up @@ -94,3 +95,31 @@ def get_observer(self):
self.observer = "memoryless"

return Observer.load_from_registry(self.observer, quantization_args=self)

@validator("strategy", pre=True, always=True)
def validate_strategy(cls, value, values):
group_size = values.get("group_size")

# use group_size to determinine strategy if not given explicity
if group_size is not None and value is None:
if group_size > 0:
return QuantizationStrategy.GROUP

elif group_size == -1:
return QuantizationStrategy.CHANNEL

else:
raise ValueError(
f"group_size={group_size} with strategy {value} is invald. "
"group_size > 0 for strategy='group' and "
"group_size = -1 for 'channel'"
)

if value == QuantizationStrategy.GROUP:
if group_size is None:
raise ValueError(f"strategy {value} requires group_size to be set.")

if value is None:
return QuantizationStrategy.TENSOR

return value
1 change: 1 addition & 0 deletions src/compressed_tensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def calculate_compression_ratio(model: Module) -> float:
compressed_bits = uncompressed_bits
if is_module_quantized(submodule):
compressed_bits = submodule.quantization_scheme.weights.num_bits

num_weights = parameter.numel()
total_compressed += compressed_bits * num_weights
total_uncompressed += uncompressed_bits * num_weights
Expand Down

0 comments on commit bfc1136

Please sign in to comment.