Skip to content

Commit

Permalink
Last fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 19, 2024
1 parent b8b08d1 commit 2ac05a1
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 8 deletions.
9 changes: 6 additions & 3 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector
from brevitas.quant.experimental.mx_quant_ocp import GroupwiseActQuantProxyFromInjector
from brevitas.utils.quant_utils import groupwise_dequant_expand
from brevitas.utils.torch_utils import float_internal_scale


Expand Down Expand Up @@ -139,8 +140,8 @@ def __init__(self):
def prepare_for_export(self, module):
super().prepare_for_export(module)
if module.is_quant_enabled:
self.group_dim = module.group_dim
self.input_view = module.input_view_impl
self.flattened_view = module.apply_input_view
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight.quant_tensor.value_
else:
Expand All @@ -158,12 +159,13 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
if self.cached_weight is not None:
out = self.cached_weight
else:
inp_shape = x.shape
x = self.input_view(x)
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
out = self.flattened_view(out)
out = groupwise_dequant_expand(out, scale, zero_point, self.group_dim, inp_shape)[0]
return out, scale, zero_point, self.bit_width


Expand Down Expand Up @@ -302,11 +304,12 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
if self.cached_weight is not None:
out = self.cached_weight
else:
inp_shape = x.shape
x = self.input_view(x)
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
out = self.flattened_view(out)
out = groupwise_dequant_expand(out, scale, zero_point, self.group_dim, inp_shape)[0]

return out, scale, zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values
4 changes: 2 additions & 2 deletions src/brevitas/quant_tensor/groupwise_float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
return func(*args, **kwargs)

def expand(self):
from brevitas.utils.quant_utils import groupwise_dequant
return groupwise_dequant(
from brevitas.utils.quant_utils import groupwise_dequant_expand
return groupwise_dequant_expand(
self.value_, self.scale_, self.zero_point_, self.group_dim, self.dequant_shape)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/quant_tensor/groupwise_int_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
return func(*args, **kwargs)

def expand(self):
from brevitas.utils.quant_utils import groupwise_dequant
return groupwise_dequant(
from brevitas.utils.quant_utils import groupwise_dequant_expand
return groupwise_dequant_expand(
self.value_, self.scale_, self.zero_point_, self.group_dim, self.dequant_shape)

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/utils/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def float_to_int_impl_to_enum(module):
return None


def groupwise_dequant(value_, scale_, zero_point_, group_dim, dequant_shape):
def groupwise_dequant_expand(value_, scale_, zero_point_, group_dim, dequant_shape):
final_shape = dequant_shape
curr_shape = value_.shape
start_dim = group_dim if group_dim != -1 else -2
Expand Down

0 comments on commit 2ac05a1

Please sign in to comment.