diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 5c4e447d4..805d66bae 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -148,7 +148,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) else: # quantization disabled - out = self.apply_input_view(x) + out = x return out diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 6cd2c03ed..95949add7 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -180,8 +180,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: elif not self.is_quant_enabled: # A tuple helps later with control flows # The second None value is used later - # If quant is not enabled, we still apply input_view in the case of groupwise + padding - y = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y)) + y = self.fused_activation_quant_proxy.activation_impl(y) y = (y, None) else: y = self.fused_activation_quant_proxy(y) diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index f5be0ec67..8166f1b15 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -91,36 +91,9 @@ def __torch_function__(self, func, types, args=(), kwargs=None): return func(*args, **kwargs) def expand(self): - final_shape = self.dequant_shape - curr_shape = self.value_.shape - start_dim = self.group_dim if self.group_dim != -1 else -2 - new_value = self.value_.flatten(start_dim, start_dim + 1) - if self.scale_.shape != (): - new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1) - else: - new_scale = self.scale_ - if self.zero_point_.shape != (): - new_zp = self.zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1) - else: - new_zp = self.zero_point_ - - # If we padded during quantization, we unpad here: - # First, we compute how much we padded along the group_dim shape - # Then, we unbind the tensor along the group_dim shape, and drop the padded columns - # Finally, we stack the remaining tensors - unpadding_shape = final_shape[self.group_dim] - residual = new_value.shape[self.group_dim] - unpadding_shape - - if residual > 0: - new_value = torch.stack( - torch.unbind(new_value, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim) - new_scale = torch.stack( - torch.unbind(new_scale, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim) - if self.zero_point_.shape != (): - new_zp = torch.stack( - torch.unbind(new_zp, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim) - - return new_value, new_scale, new_zp + from brevitas.utils.quant_utils import groupwise_dequant + return groupwise_dequant( + self.value_, self.scale_, self.zero_point_, self.group_dim, self.dequant_shape) @staticmethod def from_expanded(value, group_size, group_dim, compress=False): diff --git a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py index c92cc01fd..7d97ad4cb 100644 --- a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py @@ -77,36 +77,9 @@ def __torch_function__(self, func, types, args=(), kwargs=None): return func(*args, **kwargs) def expand(self): - final_shape = self.dequant_shape - curr_shape = self.value_.shape - start_dim = self.group_dim if self.group_dim != -1 else -2 - new_value = self.value_.flatten(start_dim, start_dim + 1) - if self.scale_.shape != (): - new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1) - else: - new_scale = self.scale_ - if self.zero_point_.shape != (): - new_zp = self.zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1) - else: - new_zp = self.zero_point_ - - # If we padded during quantization, we unpad here: - # First, we compute how much we padded along the group_dim shape - # Then, we unbind the tensor along the group_dim shape, and drop the padded columns - # Finally, we stack the remaining tensors - unpadding_shape = final_shape[self.group_dim] - residual = new_value.shape[self.group_dim] - unpadding_shape - - if residual > 0: - new_value = torch.stack( - torch.unbind(new_value, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim) - new_scale = torch.stack( - torch.unbind(new_scale, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim) - if self.zero_point_.shape != (): - new_zp = torch.stack( - torch.unbind(new_zp, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim) - - return new_value, new_scale, new_zp + from brevitas.utils.quant_utils import groupwise_dequant + return groupwise_dequant( + self.value_, self.scale_, self.zero_point_, self.group_dim, self.dequant_shape) @staticmethod def from_expanded(value, group_size, group_dim, compress=False): diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index 62290b1de..7adea43d9 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import torch + from brevitas.core.bit_width import BitWidthParameter from brevitas.core.function_wrapper import * from brevitas.core.quant import RescalingIntQuant @@ -221,3 +223,36 @@ def float_to_int_impl_to_enum(module): return FloatToIntImplType.STOCHASTIC_ROUND else: return None + + +def groupwise_dequant(value_, scale_, zero_point_, group_dim, dequant_shape): + final_shape = dequant_shape + curr_shape = value_.shape + start_dim = group_dim if group_dim != -1 else -2 + new_value = value_.flatten(start_dim, start_dim + 1) + if scale_.shape != (): + new_scale = scale_.expand(curr_shape).flatten(start_dim, start_dim + 1) + else: + new_scale = scale_ + if zero_point_.shape != (): + new_zp = zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1) + else: + new_zp = zero_point_ + + # If we padded during quantization, we unpad here: + # First, we compute how much we padded along the group_dim shape + # Then, we unbind the tensor along the group_dim shape, and drop the padded columns + # Finally, we stack the remaining tensors + unpadding_shape = final_shape[group_dim] + residual = new_value.shape[group_dim] - unpadding_shape + + if residual > 0: + new_value = torch.stack( + torch.unbind(new_value, dim=group_dim)[:unpadding_shape], dim=group_dim) + new_scale = torch.stack( + torch.unbind(new_scale, dim=group_dim)[:unpadding_shape], dim=group_dim) + if zero_point_.shape != (): + new_zp = torch.stack( + torch.unbind(new_zp, dim=group_dim)[:unpadding_shape], dim=group_dim) + + return new_value, new_scale, new_zp diff --git a/tests/brevitas/graph/test_gpxq.py b/tests/brevitas/graph/test_gpxq.py index 33116332c..aa2ec9f97 100644 --- a/tests/brevitas/graph/test_gpxq.py +++ b/tests/brevitas/graph/test_gpxq.py @@ -89,18 +89,8 @@ def test_toymodels(toy_quant_model, act_order, use_quant_activations, apply_gpxq dataset = TensorDataset(inp, inp) calib_loader = DataLoader(dataset, batch_size=16, num_workers=0, pin_memory=True, shuffle=True) - if ((name == 'gptq' or name == 'gpfq2') and torch_version < version.parse('1.10')): - # Usage of linalg_cholesky() is not compatible with torch 1.9.1 and below - with pytest.raises(AssertionError): - apply_gpxq( - calib_loader=calib_loader, - model=model, - act_order=act_order, - use_quant_activations=use_quant_activations) - - else: - apply_gpxq( - calib_loader=calib_loader, - model=model, - act_order=act_order, - use_quant_activations=use_quant_activations) + apply_gpxq( + calib_loader=calib_loader, + model=model, + act_order=act_order, + use_quant_activations=use_quant_activations)