Skip to content

Commit

Permalink
float Groupwsie
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 14, 2024
1 parent dadfa1e commit ed6a20a
Showing 1 changed file with 58 additions and 14 deletions.
72 changes: 58 additions & 14 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector
from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector
from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase
from brevitas.proxy.groupwise_float_parameter_quant import \
GroupwiseWeightFloatQuantProxyFromInjector
from brevitas.proxy.groupwise_float_runtime_quant import GroupwiseActFloatQuantProxyFromInjector
from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
Expand Down Expand Up @@ -92,6 +95,48 @@ def forward(self, x) -> Tuple[torch.Tensor]:
return x, self.scale, self.zero_point, self.bit_width


class GroupwiseIntInferenceHandler(IntInferencetHandler):
handled_layer = GroupwiseActQuantProxyFromInjector

def prepare_for_export(self, module):
super().prepare_for_export(module)
if module.is_quant_enabled:
self.module_forward = module.fused_activation_quant_proxy
self.group_dim = module.group_dim

def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
x, *other = self.module_forward(x)
if is_dynamo_compiling:
start_dim = self.group_dim if self.group_dim != -1 else -2
x = x.flatten(start_dim, start_dim + 1)
output_args = tuple([x] + list(other))
return output_args


class GroupwiseIntWeightInferenceHandler(IntWeightInferencetHandler):
handled_layer = GroupwiseWeightQuantProxyFromInjector

def prepare_for_export(self, module):
super().prepare_for_export(module)
self.input_view = module.input_view_impl
self.flattened_view = module.apply_input_view

def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
x = self.input_view(x)
if self.scale.shape != ():
scale = self.input_view(self.scale)
else:
scale = self.scale
if self.zero_point.shape != ():
zero_point = self.input_view(self.zero_point)
else:
zero_point = self.zero_point
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)
if is_dynamo_compiling:
out = self.flattened_view(out)
return out, scale, zero_point, self.bit_width


class FloatInferencetHandler(InferenceHandler):
handled_layer = (ActFloatQuantProxyFromInjector, BiasQuantProxyFromInjector)

Expand Down Expand Up @@ -121,13 +166,13 @@ def prepare_for_export(self, module):
self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias)
self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value

def quantize(self, x):
def quantize(self, x, scale, zero_point):
# Compute masks
inf_mask = x.isinf()
p_max_val_mask = x > self.max_value
n_max_val_mask = -x > self.max_value
# Quantize
x = x / self.scale
x = x / scale
internal_scale = float_internal_scale(
x, self.mantissa_bit_width, self.fp_internal_scale_min, self.eps)
x = internal_scale * self.float_to_int_impl(x / internal_scale)
Expand All @@ -139,11 +184,11 @@ def quantize(self, x):

return x

def dequantize(self, x):
return (x - self.zero_point) * self.scale
def dequantize(self, x, scale, zero_point):
return (x - zero_point) * scale

def forward(self, x) -> Tuple[torch.Tensor]:
return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values
return self.dequantize(self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values


class FloatWeightInferencetHandler(FloatInferencetHandler):
Expand All @@ -160,19 +205,19 @@ def forward(self, x) -> Tuple[torch.Tensor]:
if self.cached_weight is not None:
x = self.cached_weight
else:
x = self.dequantize(self.quantize(x))
x = self.dequantize(
self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point)
return x, self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values


class GroupwiseIntInferenceHandler(IntInferencetHandler):
handled_layer = GroupwiseActQuantProxyFromInjector
class GroupwiseFloatInferenceHandler(FloatInferencetHandler):
handled_layer = GroupwiseActFloatQuantProxyFromInjector

def prepare_for_export(self, module):
super().prepare_for_export(module)
if module.is_quant_enabled:
self.module_forward = module.fused_activation_quant_proxy
self.flattened_view = module.apply_input_view
self.input_view = module.input_view_impl
self.group_dim = module.group_dim
self.group_dim = module.group_dim

def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
x, *other = self.module_forward(x)
Expand All @@ -183,8 +228,8 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
return output_args


class GroupwiseIntWeightInferenceHandler(IntWeightInferencetHandler):
handled_layer = GroupwiseWeightQuantProxyFromInjector
class GroupwiseFloatWeightInferenceHandler(FloatWeightInferencetHandler):
handled_layer = GroupwiseWeightFloatQuantProxyFromInjector

def prepare_for_export(self, module):
super().prepare_for_export(module)
Expand All @@ -197,7 +242,6 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
scale = self.input_view(self.scale)
else:
scale = self.scale

if self.zero_point.shape != ():
zero_point = self.input_view(self.zero_point)
else:
Expand Down

0 comments on commit ed6a20a

Please sign in to comment.