Skip to content

Commit

Permalink
Expand inference_mode compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 15, 2024
1 parent dd80cf4 commit b08f120
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
25 changes: 21 additions & 4 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector
from brevitas.quant.experimental.mx_quant_ocp import GroupwiseActQuantProxyFromInjector
from brevitas.utils.torch_utils import float_internal_scale

Expand Down Expand Up @@ -95,6 +96,17 @@ def forward(self, x) -> Tuple[torch.Tensor]:
return x, self.scale, self.zero_point, self.bit_width


class DynamicIntInferenceHandler(IntInferencetHandler):
handled_layer = DynamicActQuantProxyFromInjector

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.module_forward = module.fused_activation_quant_proxy

def forward(self, x, ununsed_scale=None):
return self.module_forward(x)


class GroupwiseIntInferenceHandler(IntInferencetHandler):
handled_layer = GroupwiseActQuantProxyFromInjector

Expand All @@ -119,9 +131,10 @@ def prepare_for_export(self, module):
super().prepare_for_export(module)
self.input_view = module.input_view_impl
self.flattened_view = module.apply_input_view
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight.quant_tensor.value_

def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
x = self.input_view(x)
if self.scale.shape != ():
scale = self.input_view(self.scale)
else:
Expand All @@ -130,9 +143,13 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
zero_point = self.input_view(self.zero_point)
else:
zero_point = self.zero_point
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)
if is_dynamo_compiling():
out = self.flattened_view(out)
if self.cached_weight is not None:
out = self.cached_weight
else:
x = self.input_view(x)
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)
if is_dynamo_compiling():
out = self.flattened_view(out)
return out, scale, zero_point, self.bit_width


Expand Down
4 changes: 3 additions & 1 deletion src/brevitas/export/inference/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch.nn import Module
import torch.nn as nn

from brevitas.export.inference.handler import DynamicIntInferenceHandler
from brevitas.export.inference.handler import FloatInferencetHandler
from brevitas.export.inference.handler import FloatWeightInferencetHandler
from brevitas.export.inference.handler import GroupwiseFloatInferenceHandler
Expand Down Expand Up @@ -69,14 +70,14 @@ def __exit__(self, type, value, traceback):
# Disable all caching
# deactivate export mode
# restore return quant tensor
InferenceManager.set_export_mode(self.model, enabled=False)
self.model.apply(
lambda m: _override_bias_caching_mode(m, enabled=False, metadata_only=False))
self.model.apply(
lambda m: _override_act_caching_mode(m, enabled=False, metadata_only=False))
if self.cache_quant_weight:
self.model.apply(
lambda m: _override_weight_caching_mode(m, enabled=False, metadata_only=False))
InferenceManager.set_export_mode(self.model, enabled=False)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)

def hook(self, module, inp, out):
Expand All @@ -95,6 +96,7 @@ def hook(self, module, inp, out):
class InferenceManager(BaseManager):
handlers = [
IntInferencetHandler,
DynamicIntInferenceHandler,
FloatInferencetHandler,
IntWeightInferencetHandler,
FloatWeightInferencetHandler,
Expand Down

0 comments on commit b08f120

Please sign in to comment.