Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 19, 2024
1 parent ae52c79 commit b8b08d1
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 78 deletions.
2 changes: 1 addition & 1 deletion src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
out.detach(),
metadata_only=self.cache_inference_quant_weight_metadata_only)
else: # quantization disabled
out = self.apply_input_view(x)
out = x
return out


Expand Down
3 changes: 1 addition & 2 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
elif not self.is_quant_enabled:
# A tuple helps later with control flows
# The second None value is used later
# If quant is not enabled, we still apply input_view in the case of groupwise + padding
y = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y))
y = self.fused_activation_quant_proxy.activation_impl(y)
y = (y, None)
else:
y = self.fused_activation_quant_proxy(y)
Expand Down
33 changes: 3 additions & 30 deletions src/brevitas/quant_tensor/groupwise_float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,36 +91,9 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
return func(*args, **kwargs)

def expand(self):
final_shape = self.dequant_shape
curr_shape = self.value_.shape
start_dim = self.group_dim if self.group_dim != -1 else -2
new_value = self.value_.flatten(start_dim, start_dim + 1)
if self.scale_.shape != ():
new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_scale = self.scale_
if self.zero_point_.shape != ():
new_zp = self.zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_zp = self.zero_point_

# If we padded during quantization, we unpad here:
# First, we compute how much we padded along the group_dim shape
# Then, we unbind the tensor along the group_dim shape, and drop the padded columns
# Finally, we stack the remaining tensors
unpadding_shape = final_shape[self.group_dim]
residual = new_value.shape[self.group_dim] - unpadding_shape

if residual > 0:
new_value = torch.stack(
torch.unbind(new_value, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim)
new_scale = torch.stack(
torch.unbind(new_scale, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim)
if self.zero_point_.shape != ():
new_zp = torch.stack(
torch.unbind(new_zp, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim)

return new_value, new_scale, new_zp
from brevitas.utils.quant_utils import groupwise_dequant
return groupwise_dequant(
self.value_, self.scale_, self.zero_point_, self.group_dim, self.dequant_shape)

@staticmethod
def from_expanded(value, group_size, group_dim, compress=False):
Expand Down
33 changes: 3 additions & 30 deletions src/brevitas/quant_tensor/groupwise_int_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,36 +77,9 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
return func(*args, **kwargs)

def expand(self):
final_shape = self.dequant_shape
curr_shape = self.value_.shape
start_dim = self.group_dim if self.group_dim != -1 else -2
new_value = self.value_.flatten(start_dim, start_dim + 1)
if self.scale_.shape != ():
new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_scale = self.scale_
if self.zero_point_.shape != ():
new_zp = self.zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_zp = self.zero_point_

# If we padded during quantization, we unpad here:
# First, we compute how much we padded along the group_dim shape
# Then, we unbind the tensor along the group_dim shape, and drop the padded columns
# Finally, we stack the remaining tensors
unpadding_shape = final_shape[self.group_dim]
residual = new_value.shape[self.group_dim] - unpadding_shape

if residual > 0:
new_value = torch.stack(
torch.unbind(new_value, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim)
new_scale = torch.stack(
torch.unbind(new_scale, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim)
if self.zero_point_.shape != ():
new_zp = torch.stack(
torch.unbind(new_zp, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim)

return new_value, new_scale, new_zp
from brevitas.utils.quant_utils import groupwise_dequant
return groupwise_dequant(
self.value_, self.scale_, self.zero_point_, self.group_dim, self.dequant_shape)

@staticmethod
def from_expanded(value, group_size, group_dim, compress=False):
Expand Down
35 changes: 35 additions & 0 deletions src/brevitas/utils/quant_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import torch

from brevitas.core.bit_width import BitWidthParameter
from brevitas.core.function_wrapper import *
from brevitas.core.quant import RescalingIntQuant
Expand Down Expand Up @@ -221,3 +223,36 @@ def float_to_int_impl_to_enum(module):
return FloatToIntImplType.STOCHASTIC_ROUND
else:
return None


def groupwise_dequant(value_, scale_, zero_point_, group_dim, dequant_shape):
final_shape = dequant_shape
curr_shape = value_.shape
start_dim = group_dim if group_dim != -1 else -2
new_value = value_.flatten(start_dim, start_dim + 1)
if scale_.shape != ():
new_scale = scale_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_scale = scale_
if zero_point_.shape != ():
new_zp = zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_zp = zero_point_

# If we padded during quantization, we unpad here:
# First, we compute how much we padded along the group_dim shape
# Then, we unbind the tensor along the group_dim shape, and drop the padded columns
# Finally, we stack the remaining tensors
unpadding_shape = final_shape[group_dim]
residual = new_value.shape[group_dim] - unpadding_shape

if residual > 0:
new_value = torch.stack(
torch.unbind(new_value, dim=group_dim)[:unpadding_shape], dim=group_dim)
new_scale = torch.stack(
torch.unbind(new_scale, dim=group_dim)[:unpadding_shape], dim=group_dim)
if zero_point_.shape != ():
new_zp = torch.stack(
torch.unbind(new_zp, dim=group_dim)[:unpadding_shape], dim=group_dim)

return new_value, new_scale, new_zp
20 changes: 5 additions & 15 deletions tests/brevitas/graph/test_gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,8 @@ def test_toymodels(toy_quant_model, act_order, use_quant_activations, apply_gpxq
dataset = TensorDataset(inp, inp)
calib_loader = DataLoader(dataset, batch_size=16, num_workers=0, pin_memory=True, shuffle=True)

if ((name == 'gptq' or name == 'gpfq2') and torch_version < version.parse('1.10')):
# Usage of linalg_cholesky() is not compatible with torch 1.9.1 and below
with pytest.raises(AssertionError):
apply_gpxq(
calib_loader=calib_loader,
model=model,
act_order=act_order,
use_quant_activations=use_quant_activations)

else:
apply_gpxq(
calib_loader=calib_loader,
model=model,
act_order=act_order,
use_quant_activations=use_quant_activations)
apply_gpxq(
calib_loader=calib_loader,
model=model,
act_order=act_order,
use_quant_activations=use_quant_activations)

0 comments on commit b8b08d1

Please sign in to comment.