From f2189c7d5555f17ba85ba7550855d97b8d5a2730 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Fri, 19 Apr 2024 20:16:44 +0000 Subject: [PATCH 01/21] group size --- .../quantization/observers/base.py | 4 +- .../quantization/observers/min_max.py | 52 +++++++++++++------ 2 files changed, 40 insertions(+), 16 deletions(-) diff --git a/src/compressed_tensors/quantization/observers/base.py b/src/compressed_tensors/quantization/observers/base.py index 96fe1049..efe500a9 100644 --- a/src/compressed_tensors/quantization/observers/base.py +++ b/src/compressed_tensors/quantization/observers/base.py @@ -65,5 +65,7 @@ def get_qparams( """ if observed is not None: # re-calcualte scale and zero point, update the stored value - self._scale, self._zero_point = self.calculate_qparams(observed) + self._scale, self._zero_point = self.calculate_qparams( + observed, group_size=self.quantization_args.group_size + ) return self._scale, self._zero_point diff --git a/src/compressed_tensors/quantization/observers/min_max.py b/src/compressed_tensors/quantization/observers/min_max.py index 3496bb77..8083b298 100644 --- a/src/compressed_tensors/quantization/observers/min_max.py +++ b/src/compressed_tensors/quantization/observers/min_max.py @@ -38,26 +38,48 @@ def __init__(self, quantization_args: QuantizationArgs): self.max_val = -float("inf") self.counter = 0 - def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: + def calculate_qparams( + self, observed: Tensor, group_size: int = 0 + ) -> Tuple[FloatTensor, IntTensor]: """ + :param observed: observed tensor to calculate quantization parameters for :return: tuple of scale and zero point derived from the observed tensor """ - min_val = torch.tensor([observed.min()]) - max_val = torch.tensor([observed.max()]) + # quantize by groups + if group_size > 0: + columns = observed.shape[1] + scales, zero_points = [], [] + for i in range(0, columns, self.quantization_args.group_size): + scale, zero_point = self.calculate_qparams( + observed[:, i : (i + group_size)], 0 + ) + scales.append(scale) + zero_points.append(zero_point) + + return torch.cat(scales), torch.cat(zero_points) + + # channel-wise quantization + if group_size < 0: + ... + + if group_size == 0: + + min_val = torch.tensor([observed.min()]) + max_val = torch.tensor([observed.max()]) - # update global min and max - if self.counter > 0: - self.min_val = torch.min(min_val, self.min_val) - self.max_val = torch.max(max_val, self.max_val) - else: - self.min_val = min_val - self.max_val = max_val + # update global min and max + if self.counter > 0: + self.min_val = torch.min(min_val, self.min_val) + self.max_val = torch.max(max_val, self.max_val) + else: + self.min_val = min_val + self.max_val = max_val - # ensure that the zeros are in the range - min_val = torch.min(self.min_val, torch.zeros_like(self.min_val)) - max_val = torch.max(self.max_val, torch.zeros_like(self.max_val)) + # ensure that the zeros are in the range + min_val = torch.min(self.min_val, torch.zeros_like(self.min_val)) + max_val = torch.max(self.max_val, torch.zeros_like(self.max_val)) - self.counter += 1 - return calculate_qparams(min_val, max_val, self.quantization_args) + self.counter += 1 + return calculate_qparams(min_val, max_val, self.quantization_args) From 81954b6903fffda533962d9c4d8b5dc75a14fbd3 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Fri, 19 Apr 2024 20:29:35 +0000 Subject: [PATCH 02/21] add logic in base observer --- .../quantization/observers/base.py | 34 +++++++++++-- .../quantization/observers/min_max.py | 50 +++++++------------ 2 files changed, 47 insertions(+), 37 deletions(-) diff --git a/src/compressed_tensors/quantization/observers/base.py b/src/compressed_tensors/quantization/observers/base.py index efe500a9..b089ae3f 100644 --- a/src/compressed_tensors/quantization/observers/base.py +++ b/src/compressed_tensors/quantization/observers/base.py @@ -14,6 +14,7 @@ from typing import Optional, Tuple +import torch from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.registry.registry import RegistryMixin from torch import FloatTensor, IntTensor, Tensor @@ -64,8 +65,33 @@ def get_qparams( :return: tuple of scale and zero point based on last observed value """ if observed is not None: - # re-calcualte scale and zero point, update the stored value - self._scale, self._zero_point = self.calculate_qparams( - observed, group_size=self.quantization_args.group_size - ) + group_size = self.quantization_args.group_size + + if group_size > 0: # quantize by groups + columns = observed.shape[1] + scales, zero_points = [], [] + for i in range(0, columns, self.quantization_args.group_size): + scale, zero_point = self.calculate_qparams( + observed[:, i : (i + group_size)] + ) + scales.append(scale) + zero_points.append(zero_point) + + if hasattr(self, "inc"): + self.inc() + + self._scale = torch.cat(scales) + self._zero_point = torch.cat(zero_points) + + elif group_size < 0: # channel-wise quantization + # TODO: Import channel wise logic here + + if hasattr(self, "inc"): + self.inc() + + else: + # re-calcualte scale and zero point, update the stored value + self._scale, self._zero_point = self.calculate_qparams(observed) + if hasattr(self, "inc"): + self.inc() return self._scale, self._zero_point diff --git a/src/compressed_tensors/quantization/observers/min_max.py b/src/compressed_tensors/quantization/observers/min_max.py index 8083b298..27768c76 100644 --- a/src/compressed_tensors/quantization/observers/min_max.py +++ b/src/compressed_tensors/quantization/observers/min_max.py @@ -39,7 +39,8 @@ def __init__(self, quantization_args: QuantizationArgs): self.counter = 0 def calculate_qparams( - self, observed: Tensor, group_size: int = 0 + self, + observed: Tensor, ) -> Tuple[FloatTensor, IntTensor]: """ @@ -47,39 +48,22 @@ def calculate_qparams( :return: tuple of scale and zero point derived from the observed tensor """ - # quantize by groups - if group_size > 0: - columns = observed.shape[1] - scales, zero_points = [], [] - for i in range(0, columns, self.quantization_args.group_size): - scale, zero_point = self.calculate_qparams( - observed[:, i : (i + group_size)], 0 - ) - scales.append(scale) - zero_points.append(zero_point) + min_val = torch.tensor([observed.min()]) + max_val = torch.tensor([observed.max()]) - return torch.cat(scales), torch.cat(zero_points) + # update global min and max + if self.counter > 0: + self.min_val = torch.min(min_val, self.min_val) + self.max_val = torch.max(max_val, self.max_val) + else: + self.min_val = min_val + self.max_val = max_val - # channel-wise quantization - if group_size < 0: - ... + # ensure that the zeros are in the range + min_val = torch.min(self.min_val, torch.zeros_like(self.min_val)) + max_val = torch.max(self.max_val, torch.zeros_like(self.max_val)) - if group_size == 0: + return calculate_qparams(min_val, max_val, self.quantization_args) - min_val = torch.tensor([observed.min()]) - max_val = torch.tensor([observed.max()]) - - # update global min and max - if self.counter > 0: - self.min_val = torch.min(min_val, self.min_val) - self.max_val = torch.max(max_val, self.max_val) - else: - self.min_val = min_val - self.max_val = max_val - - # ensure that the zeros are in the range - min_val = torch.min(self.min_val, torch.zeros_like(self.min_val)) - max_val = torch.max(self.max_val, torch.zeros_like(self.max_val)) - - self.counter += 1 - return calculate_qparams(min_val, max_val, self.quantization_args) + def inc(self): + self.counter += 1 From 803f49565b9c2e7723dc42e9cd2137ebced2dc16 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Tue, 23 Apr 2024 17:44:39 +0000 Subject: [PATCH 03/21] group size full lifecycle run --- .../quantization/lifecycle/forward.py | 21 +++++++++++++------ .../quantization/observers/base.py | 13 ++++++------ .../quantization/observers/helpers.py | 2 +- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 48b93e02..ccd232fa 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -33,9 +33,7 @@ def quantize( q_max: torch.Tensor, ) -> torch.Tensor: return torch.clamp( - torch.round( - x / scale + zero_point, - ), + torch.round(x / scale + zero_point), q_min, q_max, ) @@ -60,9 +58,20 @@ def fake_quantize( bit_range = 2**args.num_bits max_q = torch.tensor(bit_range / 2 - 1, device=x.device) min_q = torch.tensor(-bit_range / 2, device=x.device) - Q = torch.zeros_like(x) - Q = quantize(x, scale, zero_point, min_q, max_q) - return dequantize(Q, scale, zero_point) + # Q = torch.zeros_like(x) + DQ = torch.zeros_like(x) + num_groups = len(scale) + group_size = int(x.shape[1] / num_groups) + for i in range(num_groups): + sc = scale[i] + zp = zero_point[i] + + idx = i * group_size + Q = quantize(x[:, idx : (idx + group_size)], sc, zp, min_q, max_q) + DQ[:, idx : (idx + group_size)] = dequantize(Q, sc, zp) + breakpoint() + # Q = quantize(x, scale, zero_point, min_q, max_q) + return DQ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme): diff --git a/src/compressed_tensors/quantization/observers/base.py b/src/compressed_tensors/quantization/observers/base.py index b089ae3f..db821511 100644 --- a/src/compressed_tensors/quantization/observers/base.py +++ b/src/compressed_tensors/quantization/observers/base.py @@ -66,8 +66,14 @@ def get_qparams( """ if observed is not None: group_size = self.quantization_args.group_size + if group_size is None: - if group_size > 0: # quantize by groups + # re-calcualte scale and zero point, update the stored value + self._scale, self._zero_point = self.calculate_qparams(observed) + if hasattr(self, "inc"): + self.inc() + + elif group_size > 0: # quantize by groups columns = observed.shape[1] scales, zero_points = [], [] for i in range(0, columns, self.quantization_args.group_size): @@ -89,9 +95,4 @@ def get_qparams( if hasattr(self, "inc"): self.inc() - else: - # re-calcualte scale and zero point, update the stored value - self._scale, self._zero_point = self.calculate_qparams(observed) - if hasattr(self, "inc"): - self.inc() return self._scale, self._zero_point diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index d0fca813..bf3b7e7f 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -38,7 +38,7 @@ def calculate_qparams( if quantization_args.symmetric: symmetric_range = 2 * max(min_vals.abs(), max_vals.abs()) scales = symmetric_range / bit_range - zero_points = torch.tensor(0).to(torch.int8) + zero_points = torch.tensor([0]).to(torch.int8) else: # non-symmetric observed_range = max_vals - min_vals From cda1c48d9d8274f85f338b07eeed2ec31878c314 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Wed, 24 Apr 2024 15:53:32 +0000 Subject: [PATCH 04/21] before vectorize the for loop --- src/compressed_tensors/quantization/lifecycle/forward.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index ccd232fa..67990dde 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -58,10 +58,12 @@ def fake_quantize( bit_range = 2**args.num_bits max_q = torch.tensor(bit_range / 2 - 1, device=x.device) min_q = torch.tensor(-bit_range / 2, device=x.device) - # Q = torch.zeros_like(x) + DQ = torch.zeros_like(x) num_groups = len(scale) group_size = int(x.shape[1] / num_groups) + + # TODO: vectorize the for loop for i in range(num_groups): sc = scale[i] zp = zero_point[i] @@ -69,8 +71,7 @@ def fake_quantize( idx = i * group_size Q = quantize(x[:, idx : (idx + group_size)], sc, zp, min_q, max_q) DQ[:, idx : (idx + group_size)] = dequantize(Q, sc, zp) - breakpoint() - # Q = quantize(x, scale, zero_point, min_q, max_q) + return DQ From 3cc730d9b4ccdd1fc32d53a07e320c5f2303a810 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Wed, 24 Apr 2024 16:21:39 +0000 Subject: [PATCH 05/21] comments, todo add channelwise --- .../quantization/lifecycle/forward.py | 34 ++++++++++++------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 67990dde..7fb92c33 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -59,18 +59,28 @@ def fake_quantize( max_q = torch.tensor(bit_range / 2 - 1, device=x.device) min_q = torch.tensor(-bit_range / 2, device=x.device) - DQ = torch.zeros_like(x) - num_groups = len(scale) - group_size = int(x.shape[1] / num_groups) - - # TODO: vectorize the for loop - for i in range(num_groups): - sc = scale[i] - zp = zero_point[i] - - idx = i * group_size - Q = quantize(x[:, idx : (idx + group_size)], sc, zp, min_q, max_q) - DQ[:, idx : (idx + group_size)] = dequantize(Q, sc, zp) + columns = x.shape[1] + group_size = args.group_size + + if group_size is None or group_size == 0: + Q = quantize(x, scale, zero_point, min_q, max_q) + DQ = dequantize(Q, scale, zero_point) + + elif group_size > 0: + DQ = torch.zeros_like(x) + + # TODO: vectorize the for loop + # TODO: fix genetric assumption about the tensor size for computing group + for i in range(int(columns / group_size)): + sc = scale[i] + zp = zero_point[i] + + idx = i * group_size + Q = quantize(x[:, idx : (idx + group_size)], sc, zp, min_q, max_q) + DQ[:, idx : (idx + group_size)] = dequantize(Q, sc, zp) + + else: # group_size < 0 + ... return DQ From bd67232604353d0ba571f8d43c595d718200a8ec Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Wed, 24 Apr 2024 19:00:52 +0000 Subject: [PATCH 06/21] chan wise impl --- .../quantization/lifecycle/forward.py | 8 ++++++-- src/compressed_tensors/quantization/observers/base.py | 11 ++++++++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 7fb92c33..081ac4e7 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -66,6 +66,7 @@ def fake_quantize( Q = quantize(x, scale, zero_point, min_q, max_q) DQ = dequantize(Q, scale, zero_point) + # group elif group_size > 0: DQ = torch.zeros_like(x) @@ -79,8 +80,11 @@ def fake_quantize( Q = quantize(x[:, idx : (idx + group_size)], sc, zp, min_q, max_q) DQ[:, idx : (idx + group_size)] = dequantize(Q, sc, zp) - else: # group_size < 0 - ... + # channel-wise + else: # group_size == -1 + DQ = torch.zeros_like(x) + for i in range(len(x)): + DQ[i, :] = quantize(x[i, :], scale[i], zero_point[i], min_q, max_q) return DQ diff --git a/src/compressed_tensors/quantization/observers/base.py b/src/compressed_tensors/quantization/observers/base.py index db821511..ce04c8eb 100644 --- a/src/compressed_tensors/quantization/observers/base.py +++ b/src/compressed_tensors/quantization/observers/base.py @@ -90,9 +90,18 @@ def get_qparams( self._zero_point = torch.cat(zero_points) elif group_size < 0: # channel-wise quantization - # TODO: Import channel wise logic here + + # TODO: generalize the logic for reduce_dims + scales, zero_points = [], [] + for observed_c in observed: + scale, zero_point = self.calculate_qparams(observed_c) + scales.append(scale) + zero_points.append(zero_point) if hasattr(self, "inc"): self.inc() + self._scale = torch.cat(scales) + self._zero_point = torch.cat(zero_points) + return self._scale, self._zero_point From 5bf66ad2b2eb456faa867bddba4308f8a72e4430 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Wed, 24 Apr 2024 19:08:45 +0000 Subject: [PATCH 07/21] comments --- .../quantization/lifecycle/forward.py | 6 ++++-- .../quantization/observers/base.py | 16 ++++++++++------ .../quantization/observers/min_max.py | 2 +- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 081ac4e7..6badc14f 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -13,6 +13,7 @@ # limitations under the License. from functools import wraps +from math import ceil import torch from compressed_tensors.quantization.quant_args import QuantizationArgs @@ -59,7 +60,6 @@ def fake_quantize( max_q = torch.tensor(bit_range / 2 - 1, device=x.device) min_q = torch.tensor(-bit_range / 2, device=x.device) - columns = x.shape[1] group_size = args.group_size if group_size is None or group_size == 0: @@ -68,11 +68,13 @@ def fake_quantize( # group elif group_size > 0: + DQ = torch.zeros_like(x) # TODO: vectorize the for loop # TODO: fix genetric assumption about the tensor size for computing group - for i in range(int(columns / group_size)): + columns = x.shape[1] + for i in range(ceil(columns / group_size)): sc = scale[i] zp = zero_point[i] diff --git a/src/compressed_tensors/quantization/observers/base.py b/src/compressed_tensors/quantization/observers/base.py index ce04c8eb..fc9230fa 100644 --- a/src/compressed_tensors/quantization/observers/base.py +++ b/src/compressed_tensors/quantization/observers/base.py @@ -53,6 +53,12 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: """ raise NotImplementedError(f"{self.__class__} must implement calculate_qparams") + def post_calculate_qparams(self) -> None: + """ + Run any logic specific to its observers after running calculate_qparams + """ + ... + def get_qparams( self, observed: Optional[Tensor] = None ) -> Tuple[FloatTensor, IntTensor]: @@ -70,8 +76,8 @@ def get_qparams( # re-calcualte scale and zero point, update the stored value self._scale, self._zero_point = self.calculate_qparams(observed) - if hasattr(self, "inc"): - self.inc() + + self.post_calculate_qparams() elif group_size > 0: # quantize by groups columns = observed.shape[1] @@ -83,8 +89,7 @@ def get_qparams( scales.append(scale) zero_points.append(zero_point) - if hasattr(self, "inc"): - self.inc() + self.post_calculate_qparams() self._scale = torch.cat(scales) self._zero_point = torch.cat(zero_points) @@ -98,8 +103,7 @@ def get_qparams( scales.append(scale) zero_points.append(zero_point) - if hasattr(self, "inc"): - self.inc() + self.post_calculate_qparams() self._scale = torch.cat(scales) self._zero_point = torch.cat(zero_points) diff --git a/src/compressed_tensors/quantization/observers/min_max.py b/src/compressed_tensors/quantization/observers/min_max.py index 27768c76..b42919e6 100644 --- a/src/compressed_tensors/quantization/observers/min_max.py +++ b/src/compressed_tensors/quantization/observers/min_max.py @@ -65,5 +65,5 @@ def calculate_qparams( return calculate_qparams(min_val, max_val, self.quantization_args) - def inc(self): + def post_calculate_qparams(self): self.counter += 1 From 666adeac9dd43a3d1b046ebe4d61eae3585a45c4 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Thu, 25 Apr 2024 15:17:48 +0000 Subject: [PATCH 08/21] fix channel wise --- .../quantization/lifecycle/forward.py | 28 ++++++++++++-- .../quantization/observers/base.py | 38 ++++++++++++------- 2 files changed, 49 insertions(+), 17 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 6badc14f..e31735a2 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -56,6 +56,21 @@ def fake_quantize( zero_point: torch.Tensor, args: QuantizationArgs, ) -> torch.Tensor: + """ + Fake quantize the input tensor x depending on the group_size. + if group_size is greater than 0, then q/dq by groups. The groups + must be divisible by the column size + if group_size is -1, then channel wise q/dq. THe input scale and + zero_points are reshaped to support vectorization (Assumes 1 is + the channel dimension) + + :param x: Input tensor + :param scale: scale tensor + :param zero_point: zero point tensor + :param args: quantization args that contain group_size info + :return: fake quantized tensor + + """ bit_range = 2**args.num_bits max_q = torch.tensor(bit_range / 2 - 1, device=x.device) min_q = torch.tensor(-bit_range / 2, device=x.device) @@ -74,6 +89,9 @@ def fake_quantize( # TODO: vectorize the for loop # TODO: fix genetric assumption about the tensor size for computing group columns = x.shape[1] + + # TODO: make validation step for inputs + assert columns % group_size == 0 for i in range(ceil(columns / group_size)): sc = scale[i] zp = zero_point[i] @@ -84,9 +102,13 @@ def fake_quantize( # channel-wise else: # group_size == -1 - DQ = torch.zeros_like(x) - for i in range(len(x)): - DQ[i, :] = quantize(x[i, :], scale[i], zero_point[i], min_q, max_q) + # before: scale shape = [channel_size] + # after: scale shape = [1, channel_size] + scale = scale.unsqueeze(0) + zero_point = zero_point.unsqueeze(0) + + Q = quantize(x, scale, zero_point, min_q, max_q) + DQ = dequantize(Q, scale, zero_point) return DQ diff --git a/src/compressed_tensors/quantization/observers/base.py b/src/compressed_tensors/quantization/observers/base.py index fc9230fa..1da9579f 100644 --- a/src/compressed_tensors/quantization/observers/base.py +++ b/src/compressed_tensors/quantization/observers/base.py @@ -77,8 +77,6 @@ def get_qparams( # re-calcualte scale and zero point, update the stored value self._scale, self._zero_point = self.calculate_qparams(observed) - self.post_calculate_qparams() - elif group_size > 0: # quantize by groups columns = observed.shape[1] scales, zero_points = [], [] @@ -89,23 +87,35 @@ def get_qparams( scales.append(scale) zero_points.append(zero_point) - self.post_calculate_qparams() - self._scale = torch.cat(scales) self._zero_point = torch.cat(zero_points) elif group_size < 0: # channel-wise quantization - # TODO: generalize the logic for reduce_dims - scales, zero_points = [], [] - for observed_c in observed: - scale, zero_point = self.calculate_qparams(observed_c) - scales.append(scale) - zero_points.append(zero_point) + # TODO: make a genertic way to get the channel + channel = 1 + self._scale, self._zero_point = self.get_qparams_per_channel( + observed, channel + ) + + self.post_calculate_qparams() + return self._scale, self._zero_point - self.post_calculate_qparams() + def get_qparams_per_channel(self, observed, channel: int): + # TODO: add documentation that specifies the shape must + # be padded with 1-dims so the scales are along the right channel + # TODO: generalize the logic for reduce_dims + scales, zero_points = [], [] - self._scale = torch.cat(scales) - self._zero_point = torch.cat(zero_points) + # TODO: make a more generic way to get the channel + num_channels = observed.shape[channel] - return self._scale, self._zero_point + for channel_idx in range(num_channels): + scale, zero_point = self.calculate_qparams( + observed.select(dim=channel, index=channel_idx) + ) + + scales.append(scale) + zero_points.append(zero_point) + + return torch.cat(scales), torch.cat(zero_points) From 407ab026c84868a52a641b05134af14ff2f63ca1 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Thu, 25 Apr 2024 18:48:35 +0000 Subject: [PATCH 09/21] comments, validators --- .../quantization/observers/base.py | 14 ++++++--- .../quantization/quant_args.py | 31 ++++++++++++++++++- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/src/compressed_tensors/quantization/observers/base.py b/src/compressed_tensors/quantization/observers/base.py index 1da9579f..4438c554 100644 --- a/src/compressed_tensors/quantization/observers/base.py +++ b/src/compressed_tensors/quantization/observers/base.py @@ -15,7 +15,10 @@ from typing import Optional, Tuple import torch -from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_args import ( + QuantizationArgs, + QuantizationStrategy, +) from compressed_tensors.registry.registry import RegistryMixin from torch import FloatTensor, IntTensor, Tensor from torch.nn import Module @@ -72,12 +75,13 @@ def get_qparams( """ if observed is not None: group_size = self.quantization_args.group_size - if group_size is None: + # if group_size is None: + if self.quantization_args.strategy == QuantizationStrategy.TENSOR: # re-calcualte scale and zero point, update the stored value self._scale, self._zero_point = self.calculate_qparams(observed) - elif group_size > 0: # quantize by groups + elif self.quantization_args.strategy == QuantizationStrategy.GROUP: columns = observed.shape[1] scales, zero_points = [], [] for i in range(0, columns, self.quantization_args.group_size): @@ -90,7 +94,9 @@ def get_qparams( self._scale = torch.cat(scales) self._zero_point = torch.cat(zero_points) - elif group_size < 0: # channel-wise quantization + elif ( + self.quantization_args.strategy == QuantizationStrategy.CHANNEL + ): # channel-wise quantization # TODO: make a genertic way to get the channel channel = 1 diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 64b5005f..b8c5c293 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -15,7 +15,7 @@ from enum import Enum from typing import Any, Dict, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator __all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"] @@ -83,3 +83,32 @@ def get_observer(self): from compressed_tensors.quantization.observers.base import Observer return Observer.load_from_registry(self.observer, quantization_args=self) + + @validator("strategy", pre=True) + def validate_strategy(cls, value, values): + group_size = values.get("group_size") + if group_size is not None: + if group_size > 0: + if value != QuantizationStrategy.GROUP: + raise ValueError( + f"group_size={group_size} with strategy {value} is invald. " + "Please set strategy to 'group'" + ) + return QuantizationStrategy.GROUP + + elif group_size == -1: + if value != QuantizationStrategy.CHANNEL: + raise ValueError( + f"group_size={group_size} with strategy {value} is invald. " + "Please set strategy to 'channel'" + ) + return QuantizationStrategy.CHANNEL + + else: + raise ValueError( + f"group_size={group_size} with strategy {value} is invald. " + "group_size > 0 for strategy='group' and " + "group_size = -1 for 'channel'" + ) + + return value From 309ebe27af2f30f803b2773043a83b24a2bd755f Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Thu, 25 Apr 2024 18:49:48 +0000 Subject: [PATCH 10/21] fix typo --- src/compressed_tensors/quantization/observers/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/observers/base.py b/src/compressed_tensors/quantization/observers/base.py index 4438c554..69145076 100644 --- a/src/compressed_tensors/quantization/observers/base.py +++ b/src/compressed_tensors/quantization/observers/base.py @@ -78,7 +78,7 @@ def get_qparams( # if group_size is None: if self.quantization_args.strategy == QuantizationStrategy.TENSOR: - # re-calcualte scale and zero point, update the stored value + # re-calculate scale and zero point, update the stored value self._scale, self._zero_point = self.calculate_qparams(observed) elif self.quantization_args.strategy == QuantizationStrategy.GROUP: From d3f0803c6aca5c7d548d17222237c051f6470c99 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Mon, 29 Apr 2024 15:36:02 +0000 Subject: [PATCH 11/21] tensor return error fix --- src/compressed_tensors/quantization/observers/base.py | 2 +- src/compressed_tensors/quantization/observers/min_max.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/compressed_tensors/quantization/observers/base.py b/src/compressed_tensors/quantization/observers/base.py index 69145076..6e21a309 100644 --- a/src/compressed_tensors/quantization/observers/base.py +++ b/src/compressed_tensors/quantization/observers/base.py @@ -124,4 +124,4 @@ def get_qparams_per_channel(self, observed, channel: int): scales.append(scale) zero_points.append(zero_point) - return torch.cat(scales), torch.cat(zero_points) + return torch.stack(scales), torch.stack(zero_points) diff --git a/src/compressed_tensors/quantization/observers/min_max.py b/src/compressed_tensors/quantization/observers/min_max.py index f33b7419..de8735ed 100644 --- a/src/compressed_tensors/quantization/observers/min_max.py +++ b/src/compressed_tensors/quantization/observers/min_max.py @@ -40,7 +40,6 @@ def __init__( self.max_val = -float("inf") self.averaging_constant = averaging_constant - def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: """ Updates the observed min and max using a moving average smoothed by the From 182195f68510618af74469c56c824924ea8cc20f Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Mon, 29 Apr 2024 16:25:22 +0000 Subject: [PATCH 12/21] fix sparseml-side of code and add per channel --- .../quantization/observers/base.py | 19 +++++++++---------- .../quantization/quant_args.py | 1 + 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/compressed_tensors/quantization/observers/base.py b/src/compressed_tensors/quantization/observers/base.py index 6e21a309..60e3299f 100644 --- a/src/compressed_tensors/quantization/observers/base.py +++ b/src/compressed_tensors/quantization/observers/base.py @@ -97,28 +97,27 @@ def get_qparams( elif ( self.quantization_args.strategy == QuantizationStrategy.CHANNEL ): # channel-wise quantization - - # TODO: make a genertic way to get the channel - channel = 1 - self._scale, self._zero_point = self.get_qparams_per_channel( - observed, channel + self._scale, self._zero_point = self.get_qparams_along_dim(observed, 1) + elif self.quantization_args.strategy == QuantizationStrategy.TOKEN: + dims = observed.ndim + self._scale, self._zero_point = self.get_qparams_along_dim( + observed, dim=dims - 1 ) - self.post_calculate_qparams() return self._scale, self._zero_point - def get_qparams_per_channel(self, observed, channel: int): + def get_qparams_along_dim(self, observed, dim: int): # TODO: add documentation that specifies the shape must # be padded with 1-dims so the scales are along the right channel # TODO: generalize the logic for reduce_dims scales, zero_points = [], [] # TODO: make a more generic way to get the channel - num_channels = observed.shape[channel] + num_dims = observed.shape[dim] - for channel_idx in range(num_channels): + for dim_idx in range(num_dims): scale, zero_point = self.calculate_qparams( - observed.select(dim=channel, index=channel_idx) + observed.select(dim=dim, index=dim_idx) ) scales.append(scale) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 0917c380..112739d3 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -39,6 +39,7 @@ class QuantizationStrategy(str, Enum): CHANNEL = "channel" GROUP = "group" BLOCK = "block" + TOKEN = "token" class QuantizationArgs(BaseModel): From f35e4c9284c5529a17683fd2728a64b981b517e8 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Mon, 29 Apr 2024 17:33:18 +0000 Subject: [PATCH 13/21] pyndatic defaults --- .../quantization/lifecycle/forward.py | 27 ++++++++++++++----- .../quantization/observers/base.py | 2 +- .../quantization/quant_args.py | 22 +++++++-------- 3 files changed, 31 insertions(+), 20 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 2ecccb97..81c4b9d9 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -16,7 +16,10 @@ from math import ceil import torch -from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_args import ( + QuantizationArgs, + QuantizationStrategy, +) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme from torch.nn import Module @@ -77,12 +80,8 @@ def fake_quantize( group_size = args.group_size - if group_size is None or group_size == 0: - Q = quantize(x, scale, zero_point, min_q, max_q) - DQ = dequantize(Q, scale, zero_point) - # group - elif group_size > 0: + if args.strategy == QuantizationStrategy.GROUP: DQ = torch.zeros_like(x) @@ -101,7 +100,7 @@ def fake_quantize( DQ[:, idx : (idx + group_size)] = dequantize(Q, sc, zp) # channel-wise - else: # group_size == -1 + elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1 # before: scale shape = [channel_size] # after: scale shape = [1, channel_size] scale = scale.unsqueeze(0) @@ -110,6 +109,20 @@ def fake_quantize( Q = quantize(x, scale, zero_point, min_q, max_q) DQ = dequantize(Q, scale, zero_point) + # per-token + elif args.strategy == QuantizationStrategy.TOKEN: + # before: scale shape = [channel_size] + # after: scale shape = [channel_size, 1] + scale = scale.unsqueeze(0)[::-1] + zero_point = zero_point.unsqueeze(0)[::-1] + + Q = quantize(x, scale, zero_point, min_q, max_q) + DQ = dequantize(Q, scale, zero_point) + + else: + Q = quantize(x, scale, zero_point, min_q, max_q) + DQ = dequantize(Q, scale, zero_point) + return DQ diff --git a/src/compressed_tensors/quantization/observers/base.py b/src/compressed_tensors/quantization/observers/base.py index 60e3299f..a15033ba 100644 --- a/src/compressed_tensors/quantization/observers/base.py +++ b/src/compressed_tensors/quantization/observers/base.py @@ -101,7 +101,7 @@ def get_qparams( elif self.quantization_args.strategy == QuantizationStrategy.TOKEN: dims = observed.ndim self._scale, self._zero_point = self.get_qparams_along_dim( - observed, dim=dims - 1 + observed, dim=dims - 2 ) return self._scale, self._zero_point diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 112739d3..ead4bba1 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -64,7 +64,7 @@ class QuantizationArgs(BaseModel): num_bits: int = 8 type: QuantizationType = QuantizationType.INT symmetric: bool = True - strategy: QuantizationStrategy = QuantizationStrategy.TENSOR + strategy: Optional[QuantizationStrategy] = None group_size: Optional[int] = None block_structure: Optional[str] = None dynamic: bool = False @@ -99,21 +99,13 @@ def get_observer(self): @validator("strategy", pre=True) def validate_strategy(cls, value, values): group_size = values.get("group_size") - if group_size is not None: + + # use group_size to determinine strategy if not given explicity + if group_size is not None and value is None: if group_size > 0: - if value != QuantizationStrategy.GROUP: - raise ValueError( - f"group_size={group_size} with strategy {value} is invald. " - "Please set strategy to 'group'" - ) return QuantizationStrategy.GROUP elif group_size == -1: - if value != QuantizationStrategy.CHANNEL: - raise ValueError( - f"group_size={group_size} with strategy {value} is invald. " - "Please set strategy to 'channel'" - ) return QuantizationStrategy.CHANNEL else: @@ -122,5 +114,11 @@ def validate_strategy(cls, value, values): "group_size > 0 for strategy='group' and " "group_size = -1 for 'channel'" ) + if value == QuantizationStrategy.GROUP: + if group_size is None: + raise ValueError(f"strategy {value} is need group_size to be set.") + + if value is None: + return QuantizationStrategy.TENSOR return value From f26d7f8404726b48f27c259023fbd62b40327e00 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Mon, 29 Apr 2024 17:49:10 +0000 Subject: [PATCH 14/21] token wise quant --- src/compressed_tensors/quantization/lifecycle/forward.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 81c4b9d9..d88a7a82 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -103,6 +103,9 @@ def fake_quantize( elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1 # before: scale shape = [channel_size] # after: scale shape = [1, channel_size] + + breakpoint() + scale = scale.unsqueeze(0) zero_point = zero_point.unsqueeze(0) @@ -113,8 +116,8 @@ def fake_quantize( elif args.strategy == QuantizationStrategy.TOKEN: # before: scale shape = [channel_size] # after: scale shape = [channel_size, 1] - scale = scale.unsqueeze(0)[::-1] - zero_point = zero_point.unsqueeze(0)[::-1] + scale = scale.unsqueeze(1) + zero_point = zero_point.unsqueeze(1) Q = quantize(x, scale, zero_point, min_q, max_q) DQ = dequantize(Q, scale, zero_point) From 98a0f8b86b9109aa8f5b6a11a48bab826ce12ef0 Mon Sep 17 00:00:00 2001 From: George Date: Mon, 29 Apr 2024 15:04:35 -0400 Subject: [PATCH 15/21] Update src/compressed_tensors/quantization/quant_args.py Co-authored-by: Benjamin Fineran --- src/compressed_tensors/quantization/quant_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index ead4bba1..7fe8fcb7 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -116,7 +116,7 @@ def validate_strategy(cls, value, values): ) if value == QuantizationStrategy.GROUP: if group_size is None: - raise ValueError(f"strategy {value} is need group_size to be set.") + raise ValueError(f"strategy {value} requires group_size to be set.") if value is None: return QuantizationStrategy.TENSOR From 176713a62bec5a634dcf9e232379914feb460e0f Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Mon, 29 Apr 2024 19:11:49 +0000 Subject: [PATCH 16/21] comments' --- .../quantization/lifecycle/forward.py | 2 -- src/compressed_tensors/quantization/observers/base.py | 10 ++++------ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index d88a7a82..8f70d737 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -104,8 +104,6 @@ def fake_quantize( # before: scale shape = [channel_size] # after: scale shape = [1, channel_size] - breakpoint() - scale = scale.unsqueeze(0) zero_point = zero_point.unsqueeze(0) diff --git a/src/compressed_tensors/quantization/observers/base.py b/src/compressed_tensors/quantization/observers/base.py index a15033ba..5bb8703a 100644 --- a/src/compressed_tensors/quantization/observers/base.py +++ b/src/compressed_tensors/quantization/observers/base.py @@ -75,7 +75,7 @@ def get_qparams( """ if observed is not None: group_size = self.quantization_args.group_size - # if group_size is None: + if self.quantization_args.strategy == QuantizationStrategy.TENSOR: # re-calculate scale and zero point, update the stored value @@ -94,14 +94,12 @@ def get_qparams( self._scale = torch.cat(scales) self._zero_point = torch.cat(zero_points) - elif ( - self.quantization_args.strategy == QuantizationStrategy.CHANNEL - ): # channel-wise quantization + elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL: self._scale, self._zero_point = self.get_qparams_along_dim(observed, 1) + elif self.quantization_args.strategy == QuantizationStrategy.TOKEN: - dims = observed.ndim self._scale, self._zero_point = self.get_qparams_along_dim( - observed, dim=dims - 2 + observed, dim=0 ) return self._scale, self._zero_point From 50671469f03e6c72024064cc99943b72b5e47a2a Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Tue, 30 Apr 2024 18:56:52 +0000 Subject: [PATCH 17/21] update dim --- .../quantization/lifecycle/forward.py | 6 +++++- .../quantization/observers/base.py | 16 ++++++++++++---- .../quantization/quant_args.py | 2 ++ .../quantization/utils/helpers.py | 4 +++- 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 8f70d737..17dd8730 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -103,7 +103,7 @@ def fake_quantize( elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1 # before: scale shape = [channel_size] # after: scale shape = [1, channel_size] - + breakpoint() scale = scale.unsqueeze(0) zero_point = zero_point.unsqueeze(0) @@ -114,6 +114,7 @@ def fake_quantize( elif args.strategy == QuantizationStrategy.TOKEN: # before: scale shape = [channel_size] # after: scale shape = [channel_size, 1] + scale = scale.unsqueeze(1) zero_point = zero_point.unsqueeze(1) @@ -145,6 +146,8 @@ def wrapped_forward(self, *args, **kwargs): if scheme.weights is not None: # calibrate and (fake) quantize weights when applicable unquantized_weight = self.weight.data.clone() + print(11111111) + breakpoint() self.weight.data = _maybe_calibrate_or_quantize( module, self.weight, "weight", scheme.weights ) @@ -194,6 +197,7 @@ def _maybe_calibrate_or_quantize( if module.quantization_status == QuantizationStatus.CALIBRATION: # calibration mode - get new quant params from observer observer = getattr(module, f"{base_name}_observer") + breakpoint() updated_scale, updated_zero_point = observer(value) # update scale and zero point diff --git a/src/compressed_tensors/quantization/observers/base.py b/src/compressed_tensors/quantization/observers/base.py index 5bb8703a..3edcbe83 100644 --- a/src/compressed_tensors/quantization/observers/base.py +++ b/src/compressed_tensors/quantization/observers/base.py @@ -88,18 +88,26 @@ def get_qparams( scale, zero_point = self.calculate_qparams( observed[:, i : (i + group_size)] ) + # 2048 x 16 scales.append(scale) zero_points.append(zero_point) + print(i, scales) - self._scale = torch.cat(scales) - self._zero_point = torch.cat(zero_points) + breakpoint() + self._scale = torch.stack(scales) + self._zero_point = torch.stack(zero_points) elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL: - self._scale, self._zero_point = self.get_qparams_along_dim(observed, 1) + # assume observed is transposed, because its the output, hence use dim 0 + self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0) elif self.quantization_args.strategy == QuantizationStrategy.TOKEN: + + # use dim 1, assume the obsersed.shape = [batch, token, hidden] + # should be batch, token + self._scale, self._zero_point = self.get_qparams_along_dim( - observed, dim=0 + observed, dim=1 ) return self._scale, self._zero_point diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 7fe8fcb7..140cfdb8 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -114,6 +114,8 @@ def validate_strategy(cls, value, values): "group_size > 0 for strategy='group' and " "group_size = -1 for 'channel'" ) + # breakpoint() + group_size = 128 if value == QuantizationStrategy.GROUP: if group_size is None: raise ValueError(f"strategy {value} requires group_size to be set.") diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 3c00cdbe..6f4925da 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -107,7 +107,9 @@ def calculate_compression_ratio(model: Module) -> float: uncompressed_bits = torch.iinfo(parameter.dtype).bits compressed_bits = uncompressed_bits if is_module_quantized(submodule): - compressed_bits = submodule.quantization_scheme.weights.num_bits + # compressed_bits = submodule.quantization_scheme.weights.num_bits + compressed_bits = 4 + num_weights = parameter.numel() total_compressed += compressed_bits * num_weights total_uncompressed += uncompressed_bits * num_weights From 0fd1c8d4b7639dcb7a7b5553cd04c1262a132793 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Wed, 1 May 2024 21:21:25 +0000 Subject: [PATCH 18/21] shape consistency --- .../quantization/lifecycle/forward.py | 23 +++++++++++-------- .../quantization/observers/base.py | 14 +++++------ .../quantization/utils/helpers.py | 3 +-- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 17dd8730..5046f507 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -36,6 +36,7 @@ def quantize( q_min: torch.Tensor, q_max: torch.Tensor, ) -> torch.Tensor: + return torch.clamp( torch.round(x / scale + zero_point), q_min, @@ -87,13 +88,22 @@ def fake_quantize( # TODO: vectorize the for loop # TODO: fix genetric assumption about the tensor size for computing group - columns = x.shape[1] # TODO: make validation step for inputs - assert columns % group_size == 0 + + while scale.ndim < 2: + scale = scale.unsqueeze(1) + zero_point = zero_point.unsqueeze(1) + + columns = x.shape[1] + if columns >= group_size: + assert columns % group_size == 0 for i in range(ceil(columns / group_size)): - sc = scale[i] - zp = zero_point[i] + + # scale.shape should be [nchan, ndim] + # sc.shape should be [nchan, 1] after unsqueeze + sc = scale[:, i].unsqueeze(1) + zp = zero_point[:, i].unsqueeze(1) idx = i * group_size Q = quantize(x[:, idx : (idx + group_size)], sc, zp, min_q, max_q) @@ -103,7 +113,6 @@ def fake_quantize( elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1 # before: scale shape = [channel_size] # after: scale shape = [1, channel_size] - breakpoint() scale = scale.unsqueeze(0) zero_point = zero_point.unsqueeze(0) @@ -146,8 +155,6 @@ def wrapped_forward(self, *args, **kwargs): if scheme.weights is not None: # calibrate and (fake) quantize weights when applicable unquantized_weight = self.weight.data.clone() - print(11111111) - breakpoint() self.weight.data = _maybe_calibrate_or_quantize( module, self.weight, "weight", scheme.weights ) @@ -197,12 +204,10 @@ def _maybe_calibrate_or_quantize( if module.quantization_status == QuantizationStatus.CALIBRATION: # calibration mode - get new quant params from observer observer = getattr(module, f"{base_name}_observer") - breakpoint() updated_scale, updated_zero_point = observer(value) # update scale and zero point device = next(module.parameters()).device scale.data = updated_scale.to(device) zero_point.data = updated_zero_point.to(device) - return fake_quantize(value, scale, zero_point, args) diff --git a/src/compressed_tensors/quantization/observers/base.py b/src/compressed_tensors/quantization/observers/base.py index 3edcbe83..87d7c0e2 100644 --- a/src/compressed_tensors/quantization/observers/base.py +++ b/src/compressed_tensors/quantization/observers/base.py @@ -85,17 +85,15 @@ def get_qparams( columns = observed.shape[1] scales, zero_points = [], [] for i in range(0, columns, self.quantization_args.group_size): - scale, zero_point = self.calculate_qparams( - observed[:, i : (i + group_size)] + scale, zero_point = self.get_qparams_along_dim( + observed[:, i : (i + group_size)], + 0, ) - # 2048 x 16 scales.append(scale) zero_points.append(zero_point) - print(i, scales) - breakpoint() - self._scale = torch.stack(scales) - self._zero_point = torch.stack(zero_points) + self._scale = torch.stack(scales, dim=1) + self._zero_point = torch.stack(zero_points, dim=1) elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL: # assume observed is transposed, because its the output, hence use dim 0 @@ -128,5 +126,5 @@ def get_qparams_along_dim(self, observed, dim: int): scales.append(scale) zero_points.append(zero_point) - + # breakpoint() return torch.stack(scales), torch.stack(zero_points) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 6f4925da..8676ef15 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -107,8 +107,7 @@ def calculate_compression_ratio(model: Module) -> float: uncompressed_bits = torch.iinfo(parameter.dtype).bits compressed_bits = uncompressed_bits if is_module_quantized(submodule): - # compressed_bits = submodule.quantization_scheme.weights.num_bits - compressed_bits = 4 + compressed_bits = submodule.quantization_scheme.weights.num_bits num_weights = parameter.numel() total_compressed += compressed_bits * num_weights From e62de872212e68c696482ca60793decd9b6236ba Mon Sep 17 00:00:00 2001 From: George Date: Thu, 2 May 2024 12:31:58 -0400 Subject: [PATCH 19/21] Update src/compressed_tensors/quantization/lifecycle/forward.py Co-authored-by: Benjamin Fineran --- src/compressed_tensors/quantization/lifecycle/forward.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 5046f507..db2ca074 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -92,6 +92,7 @@ def fake_quantize( # TODO: make validation step for inputs while scale.ndim < 2: + # pad scale and zero point dims for slicing scale = scale.unsqueeze(1) zero_point = zero_point.unsqueeze(1) From e929df2fbb1a6fe1738d09ce79d70ebdcb86a73d Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Thu, 2 May 2024 17:24:15 +0000 Subject: [PATCH 20/21] comments --- .../quantization/lifecycle/forward.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index db2ca074..2c393c70 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -98,11 +98,15 @@ def fake_quantize( columns = x.shape[1] if columns >= group_size: - assert columns % group_size == 0 + if columns % group_size != 0: + raise ValueError( + "tesnor column shape must be divisble " + f"by the given group_size {group_size}" + ) for i in range(ceil(columns / group_size)): - # scale.shape should be [nchan, ndim] # sc.shape should be [nchan, 1] after unsqueeze + sc = scale[:, i].unsqueeze(1) zp = zero_point[:, i].unsqueeze(1) @@ -122,8 +126,10 @@ def fake_quantize( # per-token elif args.strategy == QuantizationStrategy.TOKEN: - # before: scale shape = [channel_size] - # after: scale shape = [channel_size, 1] + # before: scale shape = [num_tokens] + # after: scale shape = [num_tokens, 1] + # x.shape = 1, num_tokens, 1] + # scale gets broadcasted as expected withput having [1, num_tokens, 1] shape scale = scale.unsqueeze(1) zero_point = zero_point.unsqueeze(1) From 1229c5a344310d69e0173316cc5199bf4e3df294 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Fri, 3 May 2024 15:36:37 +0000 Subject: [PATCH 21/21] pass test_quant_args --- src/compressed_tensors/quantization/quant_args.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 140cfdb8..f8c82d8a 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -64,8 +64,8 @@ class QuantizationArgs(BaseModel): num_bits: int = 8 type: QuantizationType = QuantizationType.INT symmetric: bool = True - strategy: Optional[QuantizationStrategy] = None group_size: Optional[int] = None + strategy: Optional[QuantizationStrategy] = None block_structure: Optional[str] = None dynamic: bool = False observer: str = Field( @@ -96,7 +96,7 @@ def get_observer(self): return Observer.load_from_registry(self.observer, quantization_args=self) - @validator("strategy", pre=True) + @validator("strategy", pre=True, always=True) def validate_strategy(cls, value, values): group_size = values.get("group_size") @@ -114,8 +114,7 @@ def validate_strategy(cls, value, values): "group_size > 0 for strategy='group' and " "group_size = -1 for 'channel'" ) - # breakpoint() - group_size = 128 + if value == QuantizationStrategy.GROUP: if group_size is None: raise ValueError(f"strategy {value} requires group_size to be set.")