Skip to content

Commit

Permalink
Feat!: remove quant metadata from quantlayer (#883)
Browse files Browse the repository at this point in the history
Breaking change: The interface to access quant metadata has changed and now everything is directly delegated to the underlying proxies.
  • Loading branch information
Giuseppe5 authored Mar 28, 2024
1 parent 3364a92 commit e1da07b
Show file tree
Hide file tree
Showing 24 changed files with 241 additions and 638 deletions.
50 changes: 21 additions & 29 deletions notebooks/01_quant_tensor_quant_conv2d_overview.ipynb

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions notebooks/03_anatomy_of_a_quantizer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Indeed we can verify that `quant_weight_scale()` is equal to `weight.abs().max()`:"
"Indeed we can verify that `weight_quant.scale()` is equal to `weight.abs().max()`:"
]
},
{
Expand All @@ -792,7 +792,7 @@
}
],
"source": [
"assert_with_message((param_from_max_quant_conv.quant_weight_scale() == param_from_max_quant_conv.weight.abs().max()).item())"
"assert_with_message((param_from_max_quant_conv.weight_quant.scale() == param_from_max_quant_conv.weight.abs().max()).item())"
]
},
{
Expand Down Expand Up @@ -1024,7 +1024,7 @@
}
],
"source": [
"assert_with_message((quant_conv1.quant_weight_scale() == quant_conv2.quant_weight_scale()).item())"
"assert_with_message((quant_conv1.weight_quant.scale() == quant_conv2.weight_quant.scale()).item())"
]
},
{
Expand Down Expand Up @@ -1059,9 +1059,9 @@
" return module.weight.abs().mean()\n",
" \n",
"quant_conv1 = QuantConv2d(3, 2, (3, 3), weight_quant=SharedParamFromMeanWeightQuantizer)\n",
"old_quant_conv1_scale = quant_conv1.quant_weight_scale()\n",
"old_quant_conv1_scale = quant_conv1.weight_quant.scale()\n",
"quant_conv2 = QuantConv2d(3, 2, (3, 3), weight_quant=quant_conv1.weight_quant)\n",
"new_quant_conv1_scale = quant_conv1.quant_weight_scale()\n",
"new_quant_conv1_scale = quant_conv1.weight_quant.scale()\n",
"\n",
"assert_with_message(not (old_quant_conv1_scale == new_quant_conv1_scale).item())"
]
Expand All @@ -1080,7 +1080,7 @@
}
],
"source": [
"assert_with_message((new_quant_conv1_scale == quant_conv2.quant_weight_scale()).item())"
"assert_with_message((new_quant_conv1_scale == quant_conv2.weight_quant.scale()).item())"
]
},
{
Expand Down Expand Up @@ -1134,7 +1134,7 @@
"quant_conv_w_init = QuantConv2d(3, 2, (3, 3), weight_quant=ParamFromMaxWeightQuantizer)\n",
"torch.nn.init.uniform_(quant_conv_w_init.weight)\n",
"\n",
"assert_with_message(not (quant_conv_w_init.weight.abs().max() == quant_conv_w_init.quant_weight_scale()).item())"
"assert_with_message(not (quant_conv_w_init.weight.abs().max() == quant_conv_w_init.weight_quant.scale()).item())"
]
},
{
Expand All @@ -1160,7 +1160,7 @@
"source": [
"quant_conv_w_init.weight_quant.init_tensor_quant()\n",
"\n",
"assert_with_message((quant_conv_w_init.weight.abs().max() == quant_conv_w_init.quant_weight_scale()).item())"
"assert_with_message((quant_conv_w_init.weight.abs().max() == quant_conv_w_init.weight_quant.scale()).item())"
]
},
{
Expand Down
14 changes: 7 additions & 7 deletions notebooks/Brevitas_TVMCon2021.ipynb

Large diffs are not rendered by default.

21 changes: 0 additions & 21 deletions src/brevitas/export/common/handler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,24 +127,3 @@ def zero_point_with_dtype(cls, signed, bit_width, zero_point):
return zero_point.type(torch.int8)
else:
return zero_point.type(torch.int32)

@classmethod
def quant_input_zero_point(cls, module):
signed = module.is_quant_input_signed
zero_point = module.quant_input_zero_point()
bit_width = module.quant_input_bit_width()
return cls.zero_point_with_dtype(signed, bit_width, zero_point)

@classmethod
def quant_weight_zero_point(cls, module):
signed = module.is_quant_weight_signed
zero_point = module.quant_weight_zero_point()
bit_width = module.quant_weight_bit_width()
return cls.zero_point_with_dtype(signed, bit_width, zero_point)

@classmethod
def quant_output_zero_point(cls, module):
signed = module.is_quant_output_signed
zero_point = module.quant_output_zero_point()
bit_width = module.quant_output_bit_width()
return cls.zero_point_with_dtype(signed, bit_width, zero_point)
39 changes: 12 additions & 27 deletions src/brevitas/export/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from torch.nn import Module

from brevitas import config
from brevitas.nn.mixin.base import _CachedIO
from brevitas.nn.mixin.base import QuantLayerMixin
from brevitas.nn.mixin.base import QuantRecurrentLayerMixin
from brevitas.proxy.quant_proxy import QuantProxyProtocol
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.jit_utils import clear_class_registry
from brevitas.utils.python_utils import patch
from brevitas.utils.quant_utils import _CachedIO


class _JitTraceExportWrapper(nn.Module):
Expand Down Expand Up @@ -64,18 +64,11 @@ def _override_bias_caching_mode(m: Module, enabled: bool):
m.cache_inference_quant_bias = enabled


def _override_inp_caching_mode(m: Module, enabled: bool):
if hasattr(m, 'cache_inference_quant_inp'):
if not hasattr(m, "cache_inference_quant_inp_backup"):
m.cache_inference_quant_inp_backup = m.cache_inference_quant_inp
m.cache_inference_quant_inp = enabled


def _override_out_caching_mode(m: Module, enabled: bool):
if hasattr(m, 'cache_inference_quant_out'):
if not hasattr(m, "cache_inference_quant_out_backup"):
m.cache_inference_quant_out_backup = m.cache_inference_quant_out
m.cache_inference_quant_out = enabled
def _override_act_caching_mode(m: Module, enabled: bool):
if hasattr(m, 'cache_inference_quant_act'):
if not hasattr(m, "cache_inference_quant_act_backup"):
m.cache_inference_quant_act_backup = m.cache_inference_quant_act
m.cache_inference_quant_act = enabled


def _restore_quant_metadata_caching_mode(m: Module):
Expand All @@ -90,16 +83,10 @@ def _restore_bias_caching_mode(m: Module):
del m.cache_inference_quant_bias_backup


def _restore_inp_caching_mode(m: Module):
if hasattr(m, "cache_inference_quant_inp_backup"):
m.cache_inference_quant_inp = m.cache_inference_quant_inp_backup
del m.cache_inference_quant_inp_backup


def _restore_out_caching_mode(m: Module):
if hasattr(m, "cache_inference_quant_out_backup"):
m.cache_inference_quant_out = m.cache_inference_quant_out_backup
del m.cache_inference_quant_out_backup
def _restore_act_caching_mode(m: Module):
if hasattr(m, "cache_inference_quant_act_backup"):
m.cache_inference_quant_act = m.cache_inference_quant_act_backup
del m.cache_inference_quant_act_backup


def _set_recurrent_layer_export_mode(model: Module, enabled: bool):
Expand Down Expand Up @@ -202,14 +189,12 @@ def _cache_inp_out(cls, module, *args, **kwargs):
# force enable caching
module.apply(lambda m: _override_quant_metadata_caching_mode(m, enabled=True))
module.apply(lambda m: _override_bias_caching_mode(m, enabled=True))
module.apply(lambda m: _override_inp_caching_mode(m, enabled=True))
module.apply(lambda m: _override_out_caching_mode(m, enabled=True))
module.apply(lambda m: _override_act_caching_mode(m, enabled=True))
_ = module.forward(*args, **kwargs)
# Restore previous caching properties
module.apply(lambda m: _restore_quant_metadata_caching_mode(m))
module.apply(lambda m: _restore_bias_caching_mode(m))
module.apply(lambda m: _restore_inp_caching_mode(m))
module.apply(lambda m: _restore_out_caching_mode(m))
module.apply(lambda m: _restore_act_caching_mode(m))

@classmethod
def jit_inference_trace(
Expand Down
8 changes: 4 additions & 4 deletions src/brevitas/export/onnx/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from brevitas import torch_version
from brevitas.quant_tensor import QuantTensor

from ..manager import _override_inp_caching_mode
from ..manager import _restore_inp_caching_mode
from ..manager import _override_act_caching_mode
from ..manager import _restore_act_caching_mode
from ..manager import BaseManager
from ..manager import ExportContext

Expand Down Expand Up @@ -120,7 +120,7 @@ def export_onnx(
# enable export mode, this triggers collecting export values into handlers
cls.set_export_mode(module, enabled=True)
# temporarily disable input caching to avoid collectives empty debug values
module.apply(lambda m: _override_inp_caching_mode(m, enabled=False))
module.apply(lambda m: _override_act_caching_mode(m, enabled=False))
# perform export pass
if export_path is not None:
export_target = export_path
Expand All @@ -130,7 +130,7 @@ def export_onnx(
torch.onnx.export(module, args, export_target, **onnx_export_kwargs)

# restore the model to previous properties
module.apply(lambda m: _restore_inp_caching_mode(m))
module.apply(lambda m: _restore_act_caching_mode(m))
cls.set_export_mode(module, enabled=False)
module.train(training_state)

Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def update_batch(self, module, input, current_layer):
raise StopFwdException

def single_layer_update(self):
assert not self.layer.weight_quant_requires_quant_input, "Error: GPFQ does not support weight quantizers that require quantized inputs."
assert not self.layer.weight_quant.requires_quant_input, "Error: GPFQ does not support weight quantizers that require quantized inputs."
weight = self.layer.weight.data
dev = weight.device
dtype = weight.dtype
Expand Down Expand Up @@ -360,7 +360,7 @@ def single_layer_update(self):
input_is_signed = self.quant_input.signed
T = get_upper_bound_on_l1_norm(
torch.tensor(self.accumulator_bit_width), input_bit_width, input_is_signed)
s = self.layer.quant_weight_scale()
s = self.layer.weight_quant.scale()
if s.ndim > 1:
s = s.view(self.groups, -1) # [Groups, OC/Groups]

Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def update_batch(self, module, input, current_layer):
raise StopFwdException

def single_layer_update(self, percdamp=.01):
assert not self.layer.weight_quant_requires_quant_input, "Error: GPTQ does not support weight quantizers that require quantized inputs."
assert not self.layer.weight_quant.requires_quant_input, "Error: GPTQ does not support weight quantizers that require quantized inputs."
if hasattr(self.layer, 'allocate_params'):
self.layer.allocate_params(self.layer)
weight = self.layer.weight.data
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,10 @@ def align_input_quant(
# If it is a QuantIdentity already, simply modify tensor_quant or the scaling implementations
# based on whether we need to align the sign or not
if isinstance(module, qnn.QuantIdentity):
if align_sign or module.is_quant_act_signed == shared_quant_identity.is_quant_act_signed:
if align_sign or module.input_quant.is_signed == shared_quant_identity.input_quant.is_signed:
return shared_quant_identity
else:
assert not module.is_quant_act_signed and shared_quant_identity.is_quant_act_signed
assert not module.input_quant.is_signed and shared_quant_identity.input_quant.is_signed
quant_module_class, quant_module_kwargs = quant_identity_map['unsigned']
return (
quant_module_class,
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/graph/quantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def are_inputs_unsigned(model, node, is_unsigned_list, quant_act_map, unsigned_a
elif isinstance(inp_module, tuple(SIGN_PRESERVING_MODULES)):
are_inputs_unsigned(
model, inp_node, is_unsigned_list, quant_act_map, unsigned_act_tuple)
elif hasattr(inp_module, 'is_quant_act_signed'):
is_unsigned_list.append(not inp_module.is_quant_act_signed)
elif hasattr(inp_module, 'input_quant'):
is_unsigned_list.append(not inp_module.input_quant.is_signed)
else:
is_unsigned_list.append(False)
elif inp_node.op == 'call_function':
Expand Down
77 changes: 0 additions & 77 deletions src/brevitas/nn/mixin/act.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
from abc import ABCMeta
from abc import abstractmethod
from typing import Optional, Type, Union
from warnings import warn

from torch.nn import Module

from brevitas.inject import ExtendedInjector
from brevitas.inject import Injector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
from brevitas.proxy.runtime_quant import ActQuantProxyProtocol
from brevitas.quant import NoneActQuant

Expand All @@ -34,31 +32,6 @@ def __init__(self, input_quant: Optional[ActQuantType], **kwargs):
input_passthrough_act=True,
**kwargs)

@property
def is_input_quant_enabled(self):
return self.input_quant.is_quant_enabled

@property
def is_quant_input_narrow_range(self): # TODO make abstract once narrow range can be cached
return self.input_quant.is_narrow_range

@property
@abstractmethod
def is_quant_input_signed(self):
pass

@abstractmethod
def quant_input_scale(self):
pass

@abstractmethod
def quant_input_zero_point(self):
pass

@abstractmethod
def quant_input_bit_width(self):
pass


class QuantOutputMixin(QuantProxyMixin):
__metaclass__ = ABCMeta
Expand All @@ -75,31 +48,6 @@ def __init__(self, output_quant: Optional[ActQuantType], **kwargs):
output_passthrough_act=True,
**kwargs)

@property
def is_output_quant_enabled(self):
return self.output_quant.is_quant_enabled

@property
def is_quant_output_narrow_range(self): # TODO make abstract once narrow range can be cached
return self.output_quant.is_narrow_range

@property
@abstractmethod
def is_quant_output_signed(self):
pass

@abstractmethod
def quant_output_scale(self):
pass

@abstractmethod
def quant_output_zero_point(self):
pass

@abstractmethod
def quant_output_bit_width(self):
pass


class QuantNonLinearActMixin(QuantProxyMixin):
__metaclass__ = ABCMeta
Expand All @@ -124,28 +72,3 @@ def __init__(
none_quant_injector=NoneActQuant,
**prefixed_kwargs,
**kwargs)

@property
def is_act_quant_enabled(self):
return self.act_quant.is_quant_enabled

@property
def is_quant_act_narrow_range(self): # TODO make abstract once narrow range can be cached
return self.act_quant.is_narrow_range

@property
@abstractmethod
def is_quant_act_signed(self):
pass

@abstractmethod
def quant_act_scale(self):
pass

@abstractmethod
def quant_act_zero_point(self):
pass

@abstractmethod
def quant_act_bit_width(self):
pass
Loading

0 comments on commit e1da07b

Please sign in to comment.