Skip to content

Commit

Permalink
Feat: functionalize QuantLayers + QuantTensor (#878)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Apr 18, 2024
1 parent 6d0a8d7 commit 29ed408
Show file tree
Hide file tree
Showing 16 changed files with 791 additions and 627 deletions.
1 change: 0 additions & 1 deletion src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from brevitas.nn import QuantHardTanh
from brevitas.nn import QuantLinear
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
from brevitas.nn.utils import compute_channel_view_shape
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from brevitas.common import ExportMixin
from brevitas.inject import ExtendedInjector
from brevitas.inject import Injector
from brevitas.nn.utils import compute_channel_view_shape
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.torch_utils import compute_channel_view_shape

from .utils import filter_kwargs

Expand Down Expand Up @@ -86,7 +86,7 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
def pack_output(self, quant_output: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
self._set_global_is_quant_layer(False)
if self.return_quant_tensor:
assert isinstance(quant_output, QuantTensor)
assert isinstance(quant_output, QuantTensor), 'QuantLayer is not correctly configured, check if warnings were raised'
return quant_output
else:
return _unpack_quant_tensor(quant_output)
Expand Down
50 changes: 9 additions & 41 deletions src/brevitas/nn/quant_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,38 +47,19 @@ def channelwise_separable(self) -> bool:
def requires_export_handler(self):
return True

@property
def _avg_scaling(self):
if isinstance(self.kernel_size, tuple):
return self.kernel_size[0] * self.kernel_size[1]
else:
return self.kernel_size * self.kernel_size

def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)

if self.export_mode:
return self.export_handler(_unpack_quant_tensor(x))

if isinstance(x, QuantTensor):
x = x.set(value=super(TruncAvgPool2d, self).forward(x.value))
if self.is_trunc_quant_enabled:
# remove avg scaling
rescaled_value = x.value * self._avg_scaling
x = x.set(value=rescaled_value)
x = x.set(bit_width=self.max_acc_bit_width(x.bit_width))
x = self.trunc_quant(x)
if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled:
y = AvgPool2d.forward(self, x)
y = self.trunc_quant(y)
else:
assert not self.is_trunc_quant_enabled
x = super(TruncAvgPool2d, self).forward(x)
y = AvgPool2d.forward(self, _unpack_quant_tensor(x))

return self.pack_output(x)

def max_acc_bit_width(self, input_bit_width):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_uint_output = max_uint_input * self._avg_scaling
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width
return self.pack_output(y)


class TruncAdaptiveAvgPool2d(TruncMixin, QuantLayerMixin, AdaptiveAvgPool2d):
Expand Down Expand Up @@ -130,23 +111,10 @@ def forward(self, input: Union[Tensor, QuantTensor]):
self._set_global_is_quant_layer(False)
return out

if isinstance(x, QuantTensor):
y = x.set(value=super(TruncAdaptiveAvgPool2d, self).forward(x.value))
k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:])
if self.is_trunc_quant_enabled:
reduce_size = reduce(mul, k_size, 1)
rescaled_value = y.value * reduce_size # remove avg scaling
y = y.set(value=rescaled_value)
y = y.set(bit_width=self.max_acc_bit_width(y.bit_width, reduce_size))
y = self.trunc_quant(y)
if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled:
y = AdaptiveAvgPool2d.forward(self, x)
y = self.trunc_quant(y)
else:
assert not self.is_trunc_quant_enabled
y = super(TruncAdaptiveAvgPool2d, self).forward(x)
y = AdaptiveAvgPool2d.forward(self, _unpack_quant_tensor(x))

return self.pack_output(y)

def max_acc_bit_width(self, input_bit_width, reduce_size):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_uint_output = max_uint_input * reduce_size
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width
26 changes: 0 additions & 26 deletions src/brevitas/nn/quant_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,6 @@ def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Option
else:
return self._conv_forward(x, quant_weight, quant_bias)

def max_acc_bit_width(self, input_bit_width, weight_bit_width):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
group_size = self.in_channels // self.groups
max_uint_output = max_uint_input * max_kernel_val * self.kernel_size[0] * group_size
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width


class QuantConv2d(QuantWBIOL, Conv2d):

Expand Down Expand Up @@ -205,15 +197,6 @@ def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Option
else:
return self._conv_forward(x, quant_weight, quant_bias)

def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
group_size = self.in_channels // self.groups
kernel_size = self.kernel_size[0] * self.kernel_size[1]
max_uint_output = max_uint_input * max_kernel_val * kernel_size * group_size
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width


class QuantConv3d(QuantWBIOL, Conv3d):

Expand Down Expand Up @@ -314,12 +297,3 @@ def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Option
return self.conv3d_same_zeros_pad_stride(x, quant_weight, quant_bias)
else:
return self._conv_forward(x, quant_weight, quant_bias)

def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
group_size = self.in_channels // self.groups
kernel_size = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
max_uint_output = max_uint_input * max_kernel_val * kernel_size * group_size
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width
57 changes: 24 additions & 33 deletions src/brevitas/nn/quant_convtranspose.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,14 @@ def compute_output_padding(self, inp, output_size):
def conv_transpose1d_zeros_pad(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor], output_padding):
out = conv_transpose1d(
x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
x,
weight,
bias,
stride=self.stride,
padding=self.padding,
output_padding=output_padding,
groups=self.groups,
dilation=self.dilation)
return out

def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]):
Expand All @@ -115,15 +122,6 @@ def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Option
else:
raise NotImplementedError(f"Padding mode {self.padding_mode} not supported.")

def max_acc_bit_width(self, input_bit_width, weight_bit_width):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
group_size = self.out_channels // self.groups
patch_size = max(math.ceil(self.kernel_size[0] / self.stride[0]), 1)
max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width


class QuantConvTranspose2d(QuantWBIOL, ConvTranspose2d):

Expand Down Expand Up @@ -200,7 +198,14 @@ def compute_output_padding(self, inp, output_size):
def conv_transpose2d_zeros_pad(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor], output_padding):
out = conv_transpose2d(
x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
x,
weight,
bias,
stride=self.stride,
padding=self.padding,
output_padding=output_padding,
groups=self.groups,
dilation=self.dilation)
return out

def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]):
Expand All @@ -212,16 +217,6 @@ def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Option
else:
raise NotImplementedError(f"Padding mode {self.padding_mode} not supported.")

def max_acc_bit_width(self, input_bit_width, weight_bit_width):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
group_size = self.out_channels // self.groups
patch_size = max(math.ceil(self.kernel_size[0] / self.stride[0]), 1) * max(
math.ceil(self.kernel_size[1] / self.stride[1]), 1)
max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width


class QuantConvTranspose3d(QuantWBIOL, ConvTranspose3d):

Expand Down Expand Up @@ -298,7 +293,14 @@ def compute_output_padding(self, inp, output_size):
def conv_transpose3d_zeros_pad(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor], output_padding):
out = conv_transpose3d(
x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
x,
weight,
bias,
stride=self.stride,
padding=self.padding,
output_padding=output_padding,
groups=self.groups,
dilation=self.dilation)
return out

def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]):
Expand All @@ -309,14 +311,3 @@ def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Option
return out
else:
raise NotImplementedError(f"Padding mode {self.padding_mode} not supported.")

def max_acc_bit_width(self, input_bit_width, weight_bit_width):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
group_size = self.out_channels // self.groups
patch_size = max(math.ceil(self.kernel_size[0] / self.stride[0]), 1) * max(
math.ceil(self.kernel_size[1] / self.stride[1]), 1) * max(
math.ceil(self.kernel_size[2] / self.stride[2]), 1)
max_uint_output = max_uint_input * max_kernel_val * patch_size * group_size
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width
18 changes: 8 additions & 10 deletions src/brevitas/nn/quant_embedding.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Optional, Type, Union
from typing import Optional

import torch
from torch import Tensor
from torch.nn import Embedding
from torch.nn.functional import embedding

from brevitas.inject.defaults import Int8WeightPerTensorFloat
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor

from .mixin.parameter import QuantWeightMixin
Expand Down Expand Up @@ -62,19 +63,16 @@ def forward(self, inp):
quant_weight = self.quant_weight()
out = embedding(
inp,
quant_weight.value,
quant_weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse)
if self.return_quant_tensor:
scale = quant_weight.scale
zero_point = quant_weight.zero_point
bit_width = quant_weight.bit_width
if any(t.numel() > 1 for t in [scale, zero_point, bit_width]):
raise RuntimeError("Only per-tensor quantization is supported.")
signed = quant_weight.signed
training = quant_weight.training
out = QuantTensor(out, scale, zero_point, bit_width, signed, training)
assert isinstance(out, QuantTensor), "Enable weight quantization to return QuantTensor"
return out
else:
out = _unpack_quant_tensor(out)

return out
65 changes: 9 additions & 56 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
from torch import Tensor
from torch.nn import Module

from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.torch_utils import compute_channel_view_shape

from .mixin import *
from .utils import compute_channel_view_shape
from .utils import merge_bn
from .utils import rename_state_dict_by_prefix

Expand Down Expand Up @@ -47,7 +46,7 @@ def forward(self, input: Union[Tensor, QuantTensor]):
quant_input = self.input_quant(input)
# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(_unpack_quant_tensor(quant_input))
out = self.export_handler(quant_input)
self._set_global_is_quant_layer(False)
return out
out = self.act_quant(quant_input)
Expand Down Expand Up @@ -121,7 +120,8 @@ def max_acc_bit_width(self, input_bit_width: Tensor, quant_weight_bit_width: Ten

def quant_output_scale_impl(
self, inp: Tensor, quant_input_scale: Tensor, quant_weight_scale: Tensor):
output_scale_shape = compute_channel_view_shape(inp, channel_dim=1)
channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1
output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim)
output_scale = quant_weight_scale.view(output_scale_shape)
output_scale = output_scale * quant_input_scale.view(output_scale_shape)
return output_scale
Expand All @@ -140,16 +140,12 @@ def merge_bn_in(self, bn):
merge_bn(self, bn, output_channel_dim=self.output_channel_dim)

def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
output_scale = None
output_bit_width = None
output_zero_point = None
output_signed = None

inp = self.unpack_input(inp)

# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(_unpack_quant_tensor(inp))
out = self.export_handler(inp)
self._set_global_is_quant_layer(False)
return out

Expand All @@ -163,58 +159,15 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
self.output_quant.is_quant_enabled) and self.return_quant_tensor:
raise RuntimeError("QuantLayer is not correctly configured")

output_scale = None
if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor):
output_bit_width = self.max_acc_bit_width(quant_input.bit_width, quant_weight.bit_width)
output_scale = self.quant_output_scale_impl(inp, quant_input.scale, quant_weight.scale)
output_signed = quant_input.signed or quant_weight.signed

if self.bias is not None:
quant_bias = self.bias_quant(self.bias, output_scale)

output_tensor = self.inner_forward_impl(
_unpack_quant_tensor(quant_input),
_unpack_quant_tensor(quant_weight),
_unpack_quant_tensor(quant_bias))

if output_scale is not None:
if (isinstance(quant_bias, QuantTensor) and
quant_bias.scale.data_ptr() != output_scale.data_ptr()) or not isinstance(
quant_bias, QuantTensor):
channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1
output_scale_broadcast_shape = compute_channel_view_shape(
inp, channel_dim=channel_dim)
output_zero_point = -_unpack_quant_tensor(quant_bias).view(
output_scale_broadcast_shape) / output_scale

if output_bit_width is not None and isinstance(quant_bias, QuantTensor):
output_bit_width = torch.where(
quant_bias.bit_width > output_bit_width, quant_bias.bit_width, output_bit_width)
output_bit_width = output_bit_width + 1
else:
output_tensor = self.inner_forward_impl(
_unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), None)

if not self.output_quant.is_quant_enabled and self.return_quant_tensor:
if compute_output_quant_tensor:
if (quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any():
raise RuntimeError(
"Computing zero point of output accumulator not supported yet.")
elif output_zero_point is None:
output_zero_point = quant_input.zero_point

elif output_zero_point is None:
output_zero_point = torch.zeros(1).type_as(output_tensor)

if compute_output_quant_tensor:
quant_output = QuantTensor(
output_tensor,
scale=output_scale,
zero_point=output_zero_point,
bit_width=output_bit_width,
signed=output_signed,
training=self.training)
else:
quant_output = output_tensor
quant_bias = None
output_tensor = self.inner_forward_impl(quant_input, quant_weight, quant_bias)

quant_output = self.output_quant(quant_output)
quant_output = self.output_quant(output_tensor)
return self.pack_output(quant_output)
Loading

0 comments on commit 29ed408

Please sign in to comment.