Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quant tensor not empty #819

Merged
merged 32 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
551 changes: 258 additions & 293 deletions notebooks/01_quant_tensor_quant_conv2d_overview.ipynb

Large diffs are not rendered by default.

169 changes: 48 additions & 121 deletions notebooks/02_quant_activation_overview.ipynb

Large diffs are not rendered by default.

489 changes: 214 additions & 275 deletions notebooks/03_anatomy_of_a_quantizer.ipynb

Large diffs are not rendered by default.

946 changes: 830 additions & 116 deletions notebooks/Brevitas_TVMCon2021.ipynb

Large diffs are not rendered by default.

89 changes: 55 additions & 34 deletions notebooks/ONNX_export_tutorial.ipynb

Large diffs are not rendered by default.

639 changes: 302 additions & 337 deletions notebooks/quantized_recurrent.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/brevitas/core/quant/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ class BinaryQuant(brevitas.jit.ScriptModule):
Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
"""

def __init__(self, scaling_impl: Module, quant_delay_steps: int = 0):
def __init__(self, scaling_impl: Module, signed: bool = True, quant_delay_steps: int = 0):
super(BinaryQuant, self).__init__()
assert signed, "Unsigned binary quant not supported"
self.scaling_impl = scaling_impl
self.bit_width = BitWidthConst(1)
self.zero_point = StatelessBuffer(torch.tensor(0.0))
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from brevitas import config
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_int
from brevitas.quant_tensor import _unpack_quant_tensor
# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
from brevitas.utils.torch_utils import kthvalue

Expand Down Expand Up @@ -478,8 +479,7 @@ def evaluate_loss(self, x, candidate):
# Set to local_loss_mode before calling the proxy
self.set_local_loss_mode(True)
quant_value = self.proxy_forward(x)
if isinstance(quant_value, tuple):
quant_value = quant_value[0]
quant_value = _unpack_quant_tensor(quant_value)
loss = self.mse_loss_fn(x, quant_value)
self.set_local_loss_mode(False)
return loss
Expand Down
5 changes: 2 additions & 3 deletions src/brevitas/export/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,9 @@ def _cache_fn_dispatcher(cls, fn, input, *args, **kwargs):
if isinstance(input, QuantTensor):
inp_cache = None
out_cache = None
if input.is_not_none:
inp_cache = _CachedIO(input, metadata_only=True)
inp_cache = _CachedIO(input, metadata_only=True)
output = fn(input, *args, **kwargs)
if isinstance(output, QuantTensor) and output.is_not_none:
if isinstance(output, QuantTensor):
out_cache = _CachedIO(output, metadata_only=True)
cached_io = (inp_cache, out_cache)
if fn in cls._cached_io_handler_map:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def input_quant_symbolic_kwargs(cls, module):

@classmethod
def input_dequant_symbolic_kwargs(cls, module):
if module._cached_inp.scale is not None:
if module._cached_inp is not None:
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
return cls.dequant_symbolic_kwargs_from_cached_io(module._cached_inp)
else:
return None
Expand Down
20 changes: 19 additions & 1 deletion src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,21 @@
BN_LAYERS = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)


def disable_return_quant_tensor(model):
previous_state = {}
for module in model.modules():
if hasattr(module, 'return_quant_tensor'):
previous_state[module] = module.return_quant_tensor
module.return_quant_tensor = False
return previous_state


def restore_return_quant_tensor(model, previous_state):
for module in model.modules():
if hasattr(module, 'return_quant_tensor'):
module.return_quant_tensor = previous_state[module]


def extend_collect_stats_steps(module):
if hasattr(module, 'collect_stats_steps'):
# We extend the collect steps in PTQ to match potentially long calibrations
Expand Down Expand Up @@ -75,18 +90,21 @@ def __init__(self, model, enabled=True):
self.previous_training_state = model.training
self.disable_quant_inference = DisableEnableQuantization(call_act_quantizer_impl=True)
self.enabled = enabled
self.return_quant_tensor_state = dict()

def __enter__(self):
if self.enabled:
self.model.apply(extend_collect_stats_steps)
self.model.apply(set_collect_stats_to_average)
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
self.disable_quant_inference.apply(
self.model, is_training=True, quantization_enabled=False)

def __exit__(self, type, value, traceback):
self.model.apply(finalize_collect_stats)
self.disable_quant_inference.apply(
self.model, is_training=self.previous_training_state, quantization_enabled=True)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)


class load_quant_model:
Expand Down Expand Up @@ -168,7 +186,7 @@ def disable_act_quant_hook(self, module, inp, output):
if isinstance(module.tracked_module_list[0], QuantHardTanh):
inp = F.hardtanh(
inp, min_val=module.quant_injector.min_val, max_val=module.quant_injector.max_val)
return QuantTensor(value=inp, training=module.training)
return inp

def disable_act_quantization(self, model, is_training):
# If self.call_act_quantizer_impl is set to True, the quantization will be performed but the output
Expand Down
12 changes: 5 additions & 7 deletions src/brevitas/nn/hadamard_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,13 @@ def forward(self, inp):
out = inp.value / norm
out = nn.functional.linear(out, self.proj[:self.out_channels, :self.in_channels])
out = -self.scale * out
if inp.scale is not None:
if isinstance(inp, QuantTensor):
output_scale = inp.scale * self.scale / norm
if inp.bit_width is not None:
output_bit_width = self.max_output_bit_width(inp.bit_width)
if (self.return_quant_tensor and inp.zero_point is not None and
(inp.zero_point != 0.0).any()):
raise RuntimeError("Computing zero point of output accumulator not supported yet.")
else:
output_zp = inp.zero_point
if (self.return_quant_tensor and inp.zero_point != 0.0).any():
raise RuntimeError("Computing zero point of output accumulator not supported yet.")
else:
output_zp = inp.zero_point
out = QuantTensor(
value=out,
scale=output_scale,
Expand Down
44 changes: 22 additions & 22 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from brevitas.inject import ExtendedInjector
from brevitas.inject import Injector
from brevitas.nn.utils import compute_channel_view_shape
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor

from .utils import filter_kwargs
Expand Down Expand Up @@ -154,7 +155,7 @@ def quant_output_bit_width(self):
else:
return None

def unpack_input(self, inp: Union[Tensor, QuantTensor]):
def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
self._set_global_is_quant_layer(True)
# Hack to recognize a QuantTensor that has decayed to a tuple
# when used as input to tracing (e.g. during ONNX export)
Expand All @@ -166,25 +167,23 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]):
if not self.training and not self._export_mode and self.cache_inference_quant_inp:
cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only)
self._cached_inp = cached_inp
else:
inp = QuantTensor(inp, training=self.training)
if not self.training and self.cache_inference_quant_inp:
cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only)
self._cached_inp = cached_inp
# Remove any naming metadata to avoid dowmstream errors
# Avoid inplace operations on the input in case of forward hooks
if not torch._C._get_tracing_state():
inp = inp.set(value=inp.value.rename(None))
if isinstance(inp, QuantTensor):
inp = inp.set(value=inp.value.rename(None))
else:
inp = inp.rename(None)
return inp

def pack_output(self, quant_output: QuantTensor):
if not self.training and self.cache_inference_quant_out:
def pack_output(self, quant_output: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
if not self.training and self.cache_inference_quant_out and isinstance(quant_output,
QuantTensor):
self._cached_out = _CachedIO(quant_output.detach(), self.cache_quant_io_metadata_only)
self._set_global_is_quant_layer(False)
if self.return_quant_tensor:
assert isinstance(quant_output, QuantTensor)
return quant_output
else:
return quant_output.value
return _unpack_quant_tensor(quant_output)


class QuantRecurrentLayerMixin(ExportMixin):
Expand Down Expand Up @@ -246,9 +245,9 @@ def gate_params_fwd(gate, quant_input):
acc_bit_width = None
quant_weight_ih = gate.input_weight()
quant_weight_hh = gate.hidden_weight()
if quant_input.bit_width is not None:
if isinstance(quant_input, QuantTensor):
acc_bit_width = None # TODO
if quant_input.scale is not None and quant_weight_ih.scale is not None:
if isinstance(quant_input, QuantTensor) and isinstance(quant_weight_ih, QuantTensor):
acc_scale_shape = compute_channel_view_shape(quant_input.value, channel_dim=1)
acc_scale = quant_weight_ih.scale.view(acc_scale_shape)
acc_scale = acc_scale * quant_input.scale.view(acc_scale_shape)
Expand All @@ -267,24 +266,23 @@ def maybe_quantize_input(self, inp):
quant_input = inp
if not self.quantize_output_only:
quant_input = self.io_quant(quant_input)
elif not isinstance(inp, QuantTensor):
quant_input = QuantTensor(quant_input)
return quant_input

def maybe_quantize_state(self, inp, state, quant):
if state is None:
batch_size = inp.size(0) if self.cell.batch_first else inp.size(1)
quant_state = torch.zeros(
int(batch_size), self.hidden_size, dtype=inp.dtype, device=inp.device)
quant_state = QuantTensor(quant_state)
else:
quant_state = quant(state)
return quant_state

def pack_quant_outputs(self, quant_outputs):
# In export mode, quant_outputs has the shape of the output concatenated value
# Even though we check that return_quant_tensor can be enabled only with io_quant != None,
# inner layers in a deep network overrides it, so we check again.
if self.export_mode:
if self.return_quant_tensor:
if self.return_quant_tensor and self.io_quant.is_quant_enabled:
return QuantTensor(
quant_outputs,
self.io_quant.scale(),
Expand All @@ -295,7 +293,7 @@ def pack_quant_outputs(self, quant_outputs):
else:
return quant_outputs
seq_dim = 1 if self.cell.batch_first else 0
if self.return_quant_tensor:
if self.return_quant_tensor and self.io_quant.is_quant_enabled:
outputs = [
QuantTensor(
torch.unsqueeze(quant_output[0], dim=seq_dim),
Expand All @@ -312,8 +310,10 @@ def pack_quant_outputs(self, quant_outputs):
return torch.cat(outputs, dim=seq_dim)

def pack_quant_state(self, quant_state, quant):
# Even though we check that return_quant_tensor can be enabled only with quant != None,
# inner layers in a deep network overrides it, so we check again.
if self.export_mode:
if self.return_quant_tensor:
if self.return_quant_tensor and quant.is_quant_enabled:
quant_state = QuantTensor(
torch.unsqueeze(quant_state, dim=0),
quant.scale(),
Expand All @@ -324,14 +324,14 @@ def pack_quant_state(self, quant_state, quant):
else:
quant_state = torch.unsqueeze(quant_state, dim=0)
else:
if self.return_quant_tensor:
if self.return_quant_tensor and quant.is_quant_enabled:
quant_state = QuantTensor(
torch.unsqueeze(quant_state[0], dim=0),
quant_state[1],
quant_state[2],
quant_state[3],
quant.is_signed,
self.training)
training=self.training)
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
else:
quant_state = torch.unsqueeze(quant_state[0], dim=0)
return quant_state
Expand Down
6 changes: 5 additions & 1 deletion src/brevitas/nn/mixin/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,11 @@ def quant_bias_zero_point(self):
if self.bias is None:
return None
if not self.bias_quant.requires_input_scale and not self.bias_quant.requires_input_bit_width:
return self.bias_quant(self.bias).zero_point
bias_quant = self.bias_quant(self.bias)
if isinstance(bias_quant, QuantTensor):
return bias_quant.zero_point
else:
return None
else:
if self._cached_bias is None:
raise RuntimeError(
Expand Down
52 changes: 33 additions & 19 deletions src/brevitas/nn/quant_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from brevitas.function.ops import max_int
from brevitas.function.ops_ste import ceil_ste
from brevitas.inject.defaults import RoundTo8bit
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor

from .mixin.acc import AccQuantType
Expand Down Expand Up @@ -55,16 +56,22 @@ def _avg_scaling(self):

def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)

if self.export_mode:
return self.export_handler(x.value)
x = x.set(value=super(TruncAvgPool2d, self).forward(x.value))
if self.is_trunc_quant_enabled:
assert x.is_not_none # check input quant tensor is filled with values
# remove avg scaling
rescaled_value = x.value * self._avg_scaling
x = x.set(value=rescaled_value)
x = x.set(bit_width=self.max_acc_bit_width(x.bit_width))
x = self.trunc_quant(x)
return self.export_handler(_unpack_quant_tensor(x))

if isinstance(x, QuantTensor):
x = x.set(value=super(TruncAvgPool2d, self).forward(x.value))
if self.is_trunc_quant_enabled:
# remove avg scaling
rescaled_value = x.value * self._avg_scaling
x = x.set(value=rescaled_value)
x = x.set(bit_width=self.max_acc_bit_width(x.bit_width))
x = self.trunc_quant(x)
else:
assert not self.is_trunc_quant_enabled
x = super(TruncAvgPool2d, self).forward(x)

return self.pack_output(x)

def max_acc_bit_width(self, input_bit_width):
Expand Down Expand Up @@ -127,23 +134,30 @@ def compute_kernel_size_stride(self, input_shape, output_shape):

def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)

# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(x.value)
out = self.export_handler(_unpack_quant_tensor(x))
self._set_global_is_quant_layer(False)
return out
y = x.set(value=super(TruncAdaptiveAvgPool2d, self).forward(x.value))
k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:])

if self.cache_kernel_size_stride:
self._cached_kernel_size = k_size
self._cached_kernel_stride = stride
if self.is_trunc_quant_enabled:
assert y.is_not_none # check input quant tensor is filled with values
reduce_size = reduce(mul, k_size, 1)
rescaled_value = y.value * reduce_size # remove avg scaling
y = y.set(value=rescaled_value)
y = y.set(bit_width=self.max_acc_bit_width(y.bit_width, reduce_size))
y = self.trunc_quant(y)

if isinstance(x, QuantTensor):
y = x.set(value=super(TruncAdaptiveAvgPool2d, self).forward(x.value))
k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:])
if self.is_trunc_quant_enabled:
reduce_size = reduce(mul, k_size, 1)
rescaled_value = y.value * reduce_size # remove avg scaling
y = y.set(value=rescaled_value)
y = y.set(bit_width=self.max_acc_bit_width(y.bit_width, reduce_size))
y = self.trunc_quant(y)
else:
assert not self.is_trunc_quant_enabled
y = super(TruncAdaptiveAvgPool2d, self).forward(x)

return self.pack_output(y)

def max_acc_bit_width(self, input_bit_width, reduce_size):
Expand Down
Loading
Loading