Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 17, 2024
1 parent 895ce80 commit b8c4877
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
class GroupwiseIntWeightInferenceHandler(IntWeightInferencetHandler):
handled_layer = GroupwiseWeightQuantProxyFromInjector

def __init__(self):
super().__init__()
self.skip_create_quant_tensor = False

def prepare_for_export(self, module):
super().prepare_for_export(module)
if module.is_quant_enabled:
Expand All @@ -151,7 +155,9 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
else:
x = self.input_view(x)
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)
if is_dynamo_compiling():

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
out = self.flattened_view(out)
return out, scale, zero_point, self.bit_width

Expand Down Expand Up @@ -259,6 +265,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
class GroupwiseFloatWeightInferenceHandler(FloatWeightInferencetHandler):
handled_layer = GroupwiseWeightFloatQuantProxyFromInjector

def __init__(self):
super().__init__()
self.skip_create_quant_tensor = False

def prepare_for_export(self, module: nn.Module):
super().prepare_for_export(module)
if module.is_quant_enabled:
Expand All @@ -283,6 +293,9 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
else:
x = self.input_view(x)
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)
if is_dynamo_compiling():

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
out = self.flattened_view(out)

return out, scale, zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values

0 comments on commit b8c4877

Please sign in to comment.