diff --git a/src/brevitas/export/onnx/standard/qoperator/handler/pool.py b/src/brevitas/export/onnx/standard/qoperator/handler/pool.py deleted file mode 100644 index f82bd62f4..000000000 --- a/src/brevitas/export/onnx/standard/qoperator/handler/pool.py +++ /dev/null @@ -1,42 +0,0 @@ -from abc import ABC -from typing import Union - -import torch -from torch import Tensor -# early import of max_pool to avoid being affected by monkeypatching -from torch.nn.functional import max_pool1d -from torch.nn.functional import max_pool2d - -from brevitas.nn import QuantMaxPool1d -from brevitas.nn import QuantMaxPool2d - -from .base import StdQOpONNXQuantWrapperHandler - - -class StdQOpONNXQuantMaxPoolNd(StdQOpONNXQuantWrapperHandler, ABC): - - @classmethod - def op_symbolic_kwargs(cls, module: Union[QuantMaxPool1d, QuantMaxPool2d]): - return { - 'kernel_size': module.kernel_size, - 'stride': module.stride, - 'padding': module.padding, - 'dilation': module.dilation, - 'ceil_mode': module.ceil_mode, - 'return_indices': module.return_indices} - - -class StdQOpONNXQuantMaxPool1d(StdQOpONNXQuantMaxPoolNd): - handled_layer = QuantMaxPool1d - - def op_symbolic_execution(self, inp: Tensor): - op_symbolic_kwargs = self.symbolic_kwargs['op_symbolic_kwargs'] - return max_pool1d(inp, *op_symbolic_kwargs.values()) - - -class StdQOpONNXQuantMaxPool2d(StdQOpONNXQuantMaxPoolNd): - handled_layer = QuantMaxPool2d - - def op_symbolic_execution(self, inp: Tensor): - op_symbolic_kwargs = self.symbolic_kwargs['op_symbolic_kwargs'] - return max_pool2d(inp, *op_symbolic_kwargs.values()) diff --git a/src/brevitas/export/onnx/standard/qoperator/manager.py b/src/brevitas/export/onnx/standard/qoperator/manager.py index 4c45df04c..12f16cba3 100644 --- a/src/brevitas/export/onnx/standard/qoperator/manager.py +++ b/src/brevitas/export/onnx/standard/qoperator/manager.py @@ -27,8 +27,6 @@ from .handler.parameter import StdQOpONNXQuantConv1dHandler from .handler.parameter import StdQOpONNXQuantConv2dHandler from .handler.parameter import StdQOpONNXQuantLinearHandler -from .handler.pool import StdQOpONNXQuantMaxPool1d -from .handler.pool import StdQOpONNXQuantMaxPool2d class StdQOpONNXManager(StdONNXBaseManager): @@ -43,7 +41,7 @@ class StdQOpONNXManager(StdONNXBaseManager): F.max_pool3d, F.adaptive_max_pool1d, F.adaptive_max_pool2d, - F.adaptive_max_pool3d,] + F.adaptive_max_pool3d] handlers = [ StdQOpONNXQuantConv1dHandler, @@ -53,9 +51,7 @@ class StdQOpONNXManager(StdONNXBaseManager): StdQOpONNXQuantHardTanhHandler, StdQOpONNXQuantIdentityHandler, StdQOpONNXQuantTanhHandler, - StdQOpONNXQuantSigmoidHandler, - StdQOpONNXQuantMaxPool1d, - StdQOpONNXQuantMaxPool2d] + StdQOpONNXQuantSigmoidHandler] onnx_passes = [ # remove unused graph inputs & initializers diff --git a/src/brevitas/export/torch/qoperator/handler/pool.py b/src/brevitas/export/torch/qoperator/handler/pool.py deleted file mode 100644 index a6e3a4550..000000000 --- a/src/brevitas/export/torch/qoperator/handler/pool.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -from abc import ABC -from typing import Union - -from brevitas.nn import QuantMaxPool1d -from brevitas.nn import QuantMaxPool2d - -from . import qF -from .base import PytorchQuantLayerHandler - - -class PytorchQuantMaxPoolNd(PytorchQuantLayerHandler, ABC): - - @classmethod - def validate(cls, module): - # nothing to do here, pytorch's quant max pool is standard max pool - pass - - @classmethod - def explicit_output_dtype(cls) -> bool: - return False - - @classmethod - def prepare_qf_kwargs(cls, module: Union[QuantMaxPool1d, QuantMaxPool2d]): - return { - 'kernel_size': module.kernel_size, - 'stride': module.stride, - 'padding': module.padding, - 'dilation': module.dilation, - 'ceil_mode': module.ceil_mode, - 'return_indices': module.return_indices} - - def prepare_for_export(self, module): - self.qf_impl, self.qf_kwargs = self.prepare_qf(module) - self.return_quant_tensor = module.return_quant_tensor - - def forward(self, inp): - out = self.qf_impl(inp, **self.qf_kwargs) - # We are being tolerant here to non quantized tensors - if out.is_quantized and not self.return_quant_tensor: - out = out.dequantize() - return out - - -class PytorchQuantMaxPool1d(PytorchQuantMaxPoolNd): - handled_layer = QuantMaxPool1d - - @classmethod - def prepare_qf(cls, module: QuantMaxPool1d): - return qF.max_pool1d, cls.prepare_qf_kwargs(module) - - -class PytorchQuantMaxPool2d(PytorchQuantMaxPoolNd): - handled_layer = QuantMaxPool2d - - @classmethod - def prepare_qf(cls, module: QuantMaxPool1d): - return qF.max_pool2d, cls.prepare_qf_kwargs(module) diff --git a/src/brevitas/export/torch/qoperator/manager.py b/src/brevitas/export/torch/qoperator/manager.py index aea97957c..f244bbcf7 100644 --- a/src/brevitas/export/torch/qoperator/manager.py +++ b/src/brevitas/export/torch/qoperator/manager.py @@ -20,16 +20,12 @@ from .handler.parameter import PytorchQuantConv1dHandler from .handler.parameter import PytorchQuantConv2dHandler from .handler.parameter import PytorchQuantLinearHandler -from .handler.pool import PytorchQuantMaxPool1d -from .handler.pool import PytorchQuantMaxPool2d class TorchQOpManager(BaseManager): target_name = 'torch' handlers = [ - PytorchQuantMaxPool1d, - PytorchQuantMaxPool2d, PytorchQuantHardTanhHandler, PytorchQuantIdentityHandler, PytorchQuantReLUHandler, diff --git a/src/brevitas/nn/__init__.py b/src/brevitas/nn/__init__.py index 7138da2bd..5375ab609 100644 --- a/src/brevitas/nn/__init__.py +++ b/src/brevitas/nn/__init__.py @@ -22,8 +22,6 @@ from .quant_eltwise import QuantEltwiseAdd from .quant_embedding import QuantEmbedding from .quant_linear import QuantLinear -from .quant_max_pool import QuantMaxPool1d -from .quant_max_pool import QuantMaxPool2d from .quant_mha import QuantMultiheadAttention from .quant_rnn import QuantLSTM from .quant_rnn import QuantRNN diff --git a/src/brevitas/nn/quant_max_pool.py b/src/brevitas/nn/quant_max_pool.py deleted file mode 100644 index 99b1e900c..000000000 --- a/src/brevitas/nn/quant_max_pool.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -from typing import Union - -from torch import Tensor -from torch.nn import MaxPool1d -from torch.nn import MaxPool2d - -from brevitas.quant_tensor import QuantTensor - -from .mixin.base import QuantLayerMixin - - -class QuantMaxPool1d(QuantLayerMixin, MaxPool1d): - - def __init__( - self, - kernel_size, - stride=None, - padding=0, - dilation=1, - return_indices=False, - ceil_mode=False, - return_quant_tensor: bool = True): - MaxPool1d.__init__( - self, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - return_indices=return_indices, - ceil_mode=ceil_mode) - QuantLayerMixin.__init__(self, return_quant_tensor=return_quant_tensor) - - @property - def channelwise_separable(self) -> bool: - return True - - @property - def requires_export_handler(self): - return False - - def forward(self, input: Union[Tensor, QuantTensor]): - x = self.unpack_input(input) - if self.export_mode: - return self.export_handler(x.value) - x = x.set(value=super().forward(x.value)) - return self.pack_output(x) - - -class QuantMaxPool2d(QuantLayerMixin, MaxPool2d): - - def __init__( - self, - kernel_size, - stride=None, - padding=0, - dilation=1, - return_indices=False, - ceil_mode=False, - return_quant_tensor: bool = True): - MaxPool2d.__init__( - self, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - return_indices=return_indices, - ceil_mode=ceil_mode) - QuantLayerMixin.__init__(self, return_quant_tensor=return_quant_tensor) - - @property - def channelwise_separable(self) -> bool: - return True - - @property - def requires_export_handler(self): - return False - - def forward(self, input: Union[Tensor, QuantTensor]): - x = self.unpack_input(input) - if self.export_mode: - out = self.export_handler(x.value) - self._set_global_is_quant_layer(False) - return out - x = x.set(value=super().forward(x.value)) - return self.pack_output(x) diff --git a/tests/brevitas/export/test_qonnx_export.py b/tests/brevitas/export/test_qonnx_export.py index 4a40f046d..70329db66 100644 --- a/tests/brevitas/export/test_qonnx_export.py +++ b/tests/brevitas/export/test_qonnx_export.py @@ -9,7 +9,6 @@ from brevitas.nn import QuantConv2d from brevitas.nn import QuantIdentity from brevitas.nn import QuantLinear -from brevitas.nn import QuantMaxPool2d from brevitas.nn import QuantReLU from brevitas.nn import TruncAvgPool2d from brevitas.quant.scaled_int import Int4WeightPerTensorFloatDecoupled diff --git a/tests/brevitas/export/test_torch_qop.py b/tests/brevitas/export/test_torch_qop.py index 7b0844912..e01bd93cb 100644 --- a/tests/brevitas/export/test_torch_qop.py +++ b/tests/brevitas/export/test_torch_qop.py @@ -7,7 +7,6 @@ from brevitas.nn import QuantConv2d from brevitas.nn import QuantIdentity from brevitas.nn import QuantLinear -from brevitas.nn import QuantMaxPool2d from brevitas.nn import QuantReLU from brevitas.quant.scaled_int import Int8WeightPerTensorFloat from brevitas.quant.scaled_int import Int16Bias @@ -203,65 +202,3 @@ def forward(self, x): pytorch_out = pytorch_qf_model(inp) atol = model.act2.quant_output_scale().item() * TOLERANCE assert pytorch_out.isclose(brevitas_out, rtol=0.0, atol=atol).all() - - -@jit_disabled_for_export() -def test_quant_max_pool2d_export(): - IN_SIZE = (1, 1, IN_CH, IN_CH) - KERNEL_SIZE = 3 - - class Model(torch.nn.Module): - - def __init__(self): - super().__init__() - self.act = QuantIdentity( - bit_width=8, act_quant=ShiftedUint8ActPerTensorFloat, return_quant_tensor=True) - self.pool = QuantMaxPool2d( - kernel_size=KERNEL_SIZE, stride=KERNEL_SIZE, return_quant_tensor=False) - - def forward(self, x): - return self.pool(self.act(x)) - - inp = torch.randn(IN_SIZE) - model = Model() - model(inp) # collect scale factors - model.eval() - inp = torch.randn(IN_SIZE) * RANDN_STD + RANDN_MEAN # New input with bigger range - brevitas_out = model(inp) - pytorch_qf_model = export_torch_qop(model, input_t=inp) - pytorch_out = pytorch_qf_model(inp) - atol = model.act.quant_output_scale().item() * TOLERANCE - assert pytorch_out.isclose(brevitas_out, rtol=0.0, atol=atol).all() - - -@requires_pt_ge('9999', 'Darwin') -@jit_disabled_for_export() -def test_func_quant_max_pool2d_export(): - IN_SIZE = (1, 1, IN_CH, IN_CH) - KERNEL_SIZE = 2 - - class Model(torch.nn.Module): - - def __init__(self): - super().__init__() - self.act1 = QuantIdentity( - bit_width=8, act_quant=ShiftedUint8ActPerTensorFloat, return_quant_tensor=True) - self.act2 = QuantIdentity( - bit_width=8, act_quant=ShiftedUint8ActPerTensorFloat, return_quant_tensor=False) - - def forward(self, x): - x = self.act1(x) - x = torch.nn.functional.max_pool2d(x, KERNEL_SIZE) - x = self.act2(x) - return x - - inp = torch.randn(IN_SIZE) - model = Model() - model(inp) # collect scale factors - model.eval() - inp = torch.randn(IN_SIZE) * RANDN_STD + RANDN_MEAN # New input with bigger range - brevitas_out = model(inp) - pytorch_qf_model = export_torch_qop(model, input_t=inp) - pytorch_out = pytorch_qf_model(inp) - atol = model.act2.quant_output_scale().item() * TOLERANCE - assert pytorch_out.isclose(brevitas_out, rtol=0.0, atol=atol).all()