diff --git a/docs/index.rst b/docs/index.rst index f07ba086..53b9c159 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -63,6 +63,9 @@ Install in editable mode in a venv: pip install -e .[testing, docs, notebooks] +Test suite +++++++++++ + Run entire test suite, parallelized across CPU cores: :: diff --git a/docs/license.rst b/docs/license.rst index e647e180..a5103f77 100644 --- a/docs/license.rst +++ b/docs/license.rst @@ -1,7 +1,7 @@ .. _license: -======= +======== License -======= +======== .. include:: ../LICENSE diff --git a/notebooks/4_quant_lstm_helper/function.py b/notebooks/4_quant_lstm_helper/function.py index 6ba2e9dd..935bf78a 100644 --- a/notebooks/4_quant_lstm_helper/function.py +++ b/notebooks/4_quant_lstm_helper/function.py @@ -2,26 +2,24 @@ # SPDX-License-Identifier: BSD-3-Clause import torch -from torch.autograd import Function - from brevitas.export.onnx import onnx_export_opset +from torch.autograd import Function AXIS_OPSET = 13 DOMAIN_STRING = "onnx.brevitas" class DequantizeLinearFn(Function): - @staticmethod def symbolic(g, x, input_scale, input_zero_point, input_axis): opset_version = onnx_export_opset() if input_axis is not None and opset_version < AXIS_OPSET: - raise RuntimeError('ONNX Opset 13 is required for per-channel quantization') + raise RuntimeError("ONNX Opset 13 is required for per-channel quantization") elif input_axis is not None and opset_version >= AXIS_OPSET: - ret = g.op('DequantizeLinear', x, input_scale, input_zero_point, axis_i=input_axis) + ret = g.op("DequantizeLinear", x, input_scale, input_zero_point, axis_i=input_axis) else: - ret = g.op('DequantizeLinear', x, input_scale, input_zero_point) + ret = g.op("DequantizeLinear", x, input_scale, input_zero_point) return ret @staticmethod @@ -30,10 +28,9 @@ def forward(ctx, int_x, input_scale, input_zero_point, input_axis): class IntClipFn(Function): - @staticmethod def symbolic(g, int_x, min_int_val, max_int_val): - ret = g.op('Clip', int_x, min_int_val, max_int_val) + ret = g.op("Clip", int_x, min_int_val, max_int_val) return ret @staticmethod @@ -42,116 +39,115 @@ def forward(ctx, int_x, min_int_val, max_int_val): class QuantizeLinearFn(Function): - @staticmethod def symbolic(g, x, output_scale, ouput_zero_point, output_dtype, output_axis): opset_version = onnx_export_opset() if output_axis is not None and opset_version < AXIS_OPSET: - raise RuntimeError('ONNX Opset 13 is required for per-channel quantization') + raise RuntimeError("ONNX Opset 13 is required for per-channel quantization") elif output_axis is not None and opset_version >= AXIS_OPSET: - ret = g.op('QuantizeLinear', x, output_scale, ouput_zero_point, axis_i=output_axis) + ret = g.op("QuantizeLinear", x, output_scale, ouput_zero_point, axis_i=output_axis) else: - ret = g.op('QuantizeLinear', x, output_scale, ouput_zero_point) + ret = g.op("QuantizeLinear", x, output_scale, ouput_zero_point) return ret @staticmethod def forward(ctx, x, output_scale, ouput_zero_point, output_dtype, output_axis): return x.type(output_dtype) -class BrevitasQuantLSTMCellFn(Function): - +class BrevitasQuantLSTMCellFn(Function): @staticmethod def symbolic( - g, # args and kwargs passed from _QuantLSTMLayer - quant_input, - quant_hidden_state, - quant_cell_state, - quant_weight_ii, - quant_weight_if, - quant_weight_ic, - quant_weight_io, - quant_weight_hi, - quant_weight_hf, - quant_weight_hc, - quant_weight_ho, - quant_bias_input, - quant_bias_forget, - quant_bias_cell, - quant_bias_output, # Symbolic kwargs passed from BrevitasQuantLSTMLayerHandler - batch_first, - reverse_input, - cifg, # Output quant - output_scale, - output_zero_point, - output_bit_width, - output_narrow_range, - output_signed, - output_rounding_mode, # Cell state quant - cell_state_scale, - cell_state_zero_point, - cell_state_bit_width, - cell_state_narrow_range, - cell_state_signed, - cell_state_rounding_mode, # Input gate accumulator quant - input_acc_scale, - input_acc_zero_point, - input_acc_bit_width, - input_acc_narrow_range, - input_acc_signed, - input_acc_rounding_mode, # Forget gate accumulator quant - forget_acc_scale, - forget_acc_zero_point, - forget_acc_bit_width, - forget_acc_narrow_range, - forget_acc_signed, - forget_acc_rounding_mode, # Cell gate accumulator quant - cell_acc_scale, - cell_acc_zero_point, - cell_acc_bit_width, - cell_acc_narrow_range, - cell_acc_signed, - cell_acc_rounding_mode, # Output gate accumulator quant - output_acc_scale, - output_acc_zero_point, - output_acc_bit_width, - output_acc_narrow_range, - output_acc_signed, - output_acc_rounding_mode, # Input gate sigmoid quant - input_sigmoid_scale, - input_sigmoid_zero_point, - input_sigmoid_bit_width, - input_sigmoid_narrow_range, - input_sigmoid_signed, - input_sigmoid_rounding_mode, # Forget gate sigmoid quant - forget_sigmoid_scale, - forget_sigmoid_zero_point, - forget_sigmoid_bit_width, - forget_sigmoid_narrow_range, - forget_sigmoid_signed, - forget_sigmoid_rounding_mode, # Cell gate tanh quant - cell_tanh_scale, - cell_tanh_zero_point, - cell_tanh_bit_width, - cell_tanh_narrow_range, - cell_tanh_signed, - cell_tanh_rounding_mode, # Output gate sigmoid quant - output_sigmoid_scale, - output_sigmoid_zero_point, - output_sigmoid_bit_width, - output_sigmoid_narrow_range, - output_sigmoid_signed, - output_sigmoid_rounding_mode, # Hidden state tanh quant - hidden_state_tanh_scale, - hidden_state_tanh_zero_point, - hidden_state_tanh_bit_width, - hidden_state_tanh_narrow_range, - hidden_state_tanh_signed, - hidden_state_tanh_rounding_mode): + g, # args and kwargs passed from _QuantLSTMLayer + quant_input, + quant_hidden_state, + quant_cell_state, + quant_weight_ii, + quant_weight_if, + quant_weight_ic, + quant_weight_io, + quant_weight_hi, + quant_weight_hf, + quant_weight_hc, + quant_weight_ho, + quant_bias_input, + quant_bias_forget, + quant_bias_cell, + quant_bias_output, # Symbolic kwargs passed from BrevitasQuantLSTMLayerHandler + batch_first, + reverse_input, + cifg, # Output quant + output_scale, + output_zero_point, + output_bit_width, + output_narrow_range, + output_signed, + output_rounding_mode, # Cell state quant + cell_state_scale, + cell_state_zero_point, + cell_state_bit_width, + cell_state_narrow_range, + cell_state_signed, + cell_state_rounding_mode, # Input gate accumulator quant + input_acc_scale, + input_acc_zero_point, + input_acc_bit_width, + input_acc_narrow_range, + input_acc_signed, + input_acc_rounding_mode, # Forget gate accumulator quant + forget_acc_scale, + forget_acc_zero_point, + forget_acc_bit_width, + forget_acc_narrow_range, + forget_acc_signed, + forget_acc_rounding_mode, # Cell gate accumulator quant + cell_acc_scale, + cell_acc_zero_point, + cell_acc_bit_width, + cell_acc_narrow_range, + cell_acc_signed, + cell_acc_rounding_mode, # Output gate accumulator quant + output_acc_scale, + output_acc_zero_point, + output_acc_bit_width, + output_acc_narrow_range, + output_acc_signed, + output_acc_rounding_mode, # Input gate sigmoid quant + input_sigmoid_scale, + input_sigmoid_zero_point, + input_sigmoid_bit_width, + input_sigmoid_narrow_range, + input_sigmoid_signed, + input_sigmoid_rounding_mode, # Forget gate sigmoid quant + forget_sigmoid_scale, + forget_sigmoid_zero_point, + forget_sigmoid_bit_width, + forget_sigmoid_narrow_range, + forget_sigmoid_signed, + forget_sigmoid_rounding_mode, # Cell gate tanh quant + cell_tanh_scale, + cell_tanh_zero_point, + cell_tanh_bit_width, + cell_tanh_narrow_range, + cell_tanh_signed, + cell_tanh_rounding_mode, # Output gate sigmoid quant + output_sigmoid_scale, + output_sigmoid_zero_point, + output_sigmoid_bit_width, + output_sigmoid_narrow_range, + output_sigmoid_signed, + output_sigmoid_rounding_mode, # Hidden state tanh quant + hidden_state_tanh_scale, + hidden_state_tanh_zero_point, + hidden_state_tanh_bit_width, + hidden_state_tanh_narrow_range, + hidden_state_tanh_signed, + hidden_state_tanh_rounding_mode, + ): return g.op( - f'{DOMAIN_STRING}::QuantLSTMCell', # Tensors - ## Input values + f"{DOMAIN_STRING}::QuantLSTMCell", # Tensors + # Input values quant_input, quant_hidden_state, quant_cell_state, @@ -166,37 +162,37 @@ def symbolic( quant_bias_input, quant_bias_forget, quant_bias_cell, - quant_bias_output, ## Output quant + quant_bias_output, # Output quant output_scale, output_zero_point, - output_bit_width, ## Cell state quant + output_bit_width, # Cell state quant cell_state_scale, cell_state_zero_point, - cell_state_bit_width, ## Input gate accumulator quant + cell_state_bit_width, # Input gate accumulator quant input_acc_scale, input_acc_zero_point, - input_acc_bit_width, ## Forget gate accumulator quant + input_acc_bit_width, # Forget gate accumulator quant forget_acc_scale, forget_acc_zero_point, - forget_acc_bit_width, ## Cell gate accumulator quant + forget_acc_bit_width, # Cell gate accumulator quant cell_acc_scale, cell_acc_zero_point, - cell_acc_bit_width, ## Output gate accumulator quant + cell_acc_bit_width, # Output gate accumulator quant output_acc_scale, output_acc_zero_point, - output_acc_bit_width, ## Input gate sigmoid quant + output_acc_bit_width, # Input gate sigmoid quant input_sigmoid_scale, input_sigmoid_zero_point, - input_sigmoid_bit_width, ## Forget gate sigmoid quant + input_sigmoid_bit_width, # Forget gate sigmoid quant forget_sigmoid_scale, forget_sigmoid_zero_point, - forget_sigmoid_bit_width, ## Cell gate tanh quant + forget_sigmoid_bit_width, # Cell gate tanh quant cell_tanh_scale, cell_tanh_zero_point, - cell_tanh_bit_width, ## Output gate sigmoid quant + cell_tanh_bit_width, # Output gate sigmoid quant output_sigmoid_scale, output_sigmoid_zero_point, - output_sigmoid_bit_width, ## Hidden state tanh quant + output_sigmoid_bit_width, # Hidden state tanh quant hidden_state_tanh_scale, hidden_state_tanh_zero_point, hidden_state_tanh_bit_width, @@ -238,103 +234,102 @@ def symbolic( hidden_state_tanh_signed_i=hidden_state_tanh_signed, hidden_state_tanh_rounding_mode_s=hidden_state_tanh_rounding_mode, # PyTorch requires to specify the number of outputs manually - outputs=3) - + outputs=3, + ) @staticmethod def forward( - ctx, # args and kwargs passed from _QuantLSTMLayer - quant_input, - quant_hidden_state, - quant_cell_state, - quant_weight_ii, - quant_weight_if, - quant_weight_ic, - quant_weight_io, - quant_weight_hi, - quant_weight_hf, - quant_weight_hc, - quant_weight_ho, - quant_bias_input, - quant_bias_forget, - quant_bias_cell, - quant_bias_output, # Symbolic kwargs passed from BrevitasQuantLSTMLayerHandler - batch_first, - reverse_input, - cifg, # Output quant - output_scale, - output_zero_point, - output_bit_width, - output_narrow_range, - output_signed, - output_rounding_mode, # Cell state quant - cell_state_scale, - cell_state_zero_point, - cell_state_bit_width, - cell_state_narrow_range, - cell_state_signed, - cell_state_rounding_mode, # Input gate accumulator quant - input_acc_scale, - input_acc_zero_point, - input_acc_bit_width, - input_acc_narrow_range, - input_acc_signed, - input_acc_rounding_mode, # Forget gate accumulator quant - forget_acc_scale, - forget_acc_zero_point, - forget_acc_bit_width, - forget_acc_narrow_range, - forget_acc_signed, - forget_acc_rounding_mode, # Cell gate accumulator quant - cell_acc_scale, - cell_acc_zero_point, - cell_acc_bit_width, - cell_acc_narrow_range, - cell_acc_signed, - cell_acc_rounding_mode, # Output gate accumulator quant - output_acc_scale, - output_acc_zero_point, - output_acc_bit_width, - output_acc_narrow_range, - output_acc_signed, - output_acc_rounding_mode, # Input gate sigmoid quant - input_sigmoid_scale, - input_sigmoid_zero_point, - input_sigmoid_bit_width, - input_sigmoid_narrow_range, - input_sigmoid_signed, - input_sigmoid_rounding_mode, # Forget gate sigmoid quant - forget_sigmoid_scale, - forget_sigmoid_zero_point, - forget_sigmoid_bit_width, - forget_sigmoid_narrow_range, - forget_sigmoid_signed, - forget_sigmoid_rounding_mode, # Cell gate tanh quant - cell_tanh_scale, - cell_tanh_zero_point, - cell_tanh_bit_width, - cell_tanh_narrow_range, - cell_tanh_signed, - cell_tanh_rounding_mode, # Output gate sigmoid quant - output_sigmoid_scale, - output_sigmoid_zero_point, - output_sigmoid_bit_width, - output_sigmoid_narrow_range, - output_sigmoid_signed, - output_sigmoid_rounding_mode, # Hidden state tanh quant - hidden_state_tanh_scale, - hidden_state_tanh_zero_point, - hidden_state_tanh_bit_width, - hidden_state_tanh_narrow_range, - hidden_state_tanh_signed, - hidden_state_tanh_rounding_mode): + ctx, # args and kwargs passed from _QuantLSTMLayer + quant_input, + quant_hidden_state, + quant_cell_state, + quant_weight_ii, + quant_weight_if, + quant_weight_ic, + quant_weight_io, + quant_weight_hi, + quant_weight_hf, + quant_weight_hc, + quant_weight_ho, + quant_bias_input, + quant_bias_forget, + quant_bias_cell, + quant_bias_output, # Symbolic kwargs passed from BrevitasQuantLSTMLayerHandler + batch_first, + reverse_input, + cifg, # Output quant + output_scale, + output_zero_point, + output_bit_width, + output_narrow_range, + output_signed, + output_rounding_mode, # Cell state quant + cell_state_scale, + cell_state_zero_point, + cell_state_bit_width, + cell_state_narrow_range, + cell_state_signed, + cell_state_rounding_mode, # Input gate accumulator quant + input_acc_scale, + input_acc_zero_point, + input_acc_bit_width, + input_acc_narrow_range, + input_acc_signed, + input_acc_rounding_mode, # Forget gate accumulator quant + forget_acc_scale, + forget_acc_zero_point, + forget_acc_bit_width, + forget_acc_narrow_range, + forget_acc_signed, + forget_acc_rounding_mode, # Cell gate accumulator quant + cell_acc_scale, + cell_acc_zero_point, + cell_acc_bit_width, + cell_acc_narrow_range, + cell_acc_signed, + cell_acc_rounding_mode, # Output gate accumulator quant + output_acc_scale, + output_acc_zero_point, + output_acc_bit_width, + output_acc_narrow_range, + output_acc_signed, + output_acc_rounding_mode, # Input gate sigmoid quant + input_sigmoid_scale, + input_sigmoid_zero_point, + input_sigmoid_bit_width, + input_sigmoid_narrow_range, + input_sigmoid_signed, + input_sigmoid_rounding_mode, # Forget gate sigmoid quant + forget_sigmoid_scale, + forget_sigmoid_zero_point, + forget_sigmoid_bit_width, + forget_sigmoid_narrow_range, + forget_sigmoid_signed, + forget_sigmoid_rounding_mode, # Cell gate tanh quant + cell_tanh_scale, + cell_tanh_zero_point, + cell_tanh_bit_width, + cell_tanh_narrow_range, + cell_tanh_signed, + cell_tanh_rounding_mode, # Output gate sigmoid quant + output_sigmoid_scale, + output_sigmoid_zero_point, + output_sigmoid_bit_width, + output_sigmoid_narrow_range, + output_sigmoid_signed, + output_sigmoid_rounding_mode, # Hidden state tanh quant + hidden_state_tanh_scale, + hidden_state_tanh_zero_point, + hidden_state_tanh_bit_width, + hidden_state_tanh_narrow_range, + hidden_state_tanh_signed, + hidden_state_tanh_rounding_mode, + ): # Tp simplify things, here we are returning the outputs # as if they were already concatenated. Scale/zp/bw are avoided too. # This preserves output shapes but not values. # See _QuantLSTMCell for the actual implementation. quant_outputs = torch.zeros( - quant_input.size(0), - quant_input.size(1), - quant_hidden_state.size(1), - device=quant_hidden_state.device) + quant_input.size(0), quant_input.size(1), quant_hidden_state.size(1), device=quant_hidden_state.device + ) return quant_outputs, quant_hidden_state, quant_cell_state diff --git a/notebooks/4_quant_lstm_helper/handler.py b/notebooks/4_quant_lstm_helper/handler.py index 948eb647..71cbdeb1 100644 --- a/notebooks/4_quant_lstm_helper/handler.py +++ b/notebooks/4_quant_lstm_helper/handler.py @@ -1,32 +1,23 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import torch from abc import ABC -from copy import copy +from brevitas.export.common.handler.qcdq import ( + DQMixin, + QCDQActQuantProxyHandlerMixin, + QCDQBiasQuantProxyHandlerMixin, + QCDQDecoupledWeightQuantProxyHandlerMixin, + QCDQMixin, + QCDQTruncQuantProxyHandlerMixin, + QCDQWeightQuantProxyHandlerMixin, +) +from brevitas.export.onnx.handler import ONNXBaseHandler, QuantLSTMLayerHandler -import torch -from torch import Tensor - -from brevitas.export.common.handler.base import QuantAxisMixin -from brevitas.export.common.handler.qcdq import DQMixin -from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQMixin -from brevitas.export.common.handler.qcdq import QCDQTruncQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQWeightQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import ZeroPointHandlerMixin -from brevitas.export.onnx.handler import ONNXBaseHandler -from brevitas.export.onnx.handler import QuantLSTMLayerHandler - -from ..function import DequantizeLinearFn -from ..function import IntClipFn -from ..function import QuantizeLinearFn -from ..function import BrevitasQuantLSTMCellFn +from ..function import BrevitasQuantLSTMCellFn, DequantizeLinearFn, IntClipFn, QuantizeLinearFn class StdDQONNXMixin(DQMixin, ABC): - def dequantize_fn(self, x, scale, zero_point, axis): return DequantizeLinearFn.apply(x, scale, zero_point, axis) @@ -40,7 +31,6 @@ def itemize_quantize_scalar_params(self): class StdQCDQONNXMixin(QCDQMixin, StdDQONNXMixin, ABC): - @property def clip_over_integers(self): return True @@ -59,8 +49,8 @@ def int32_dtype(cls): def validate(self, module): self.validate_8b_bit_width(module.bit_width(), le_then=True) - assert module.bit_width() > 1., 'Binary quant not supported' - assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported' + assert module.bit_width() > 1.0, "Binary quant not supported" + assert module.rounding_mode.upper() == "ROUND", "Only round to nearest even supported" def quantize_fn(self, x, scale, zero_point, dtype, axis): return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis) @@ -69,55 +59,47 @@ def clip_fn(self, x, min_val, max_val): return IntClipFn.apply(x, min_val, max_val) -class StdQCDQONNXWeightQuantProxyHandler(StdQCDQONNXMixin, - QCDQWeightQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdQCDQONNXWeightQuantProxyHandler(StdQCDQONNXMixin, QCDQWeightQuantProxyHandlerMixin, ONNXBaseHandler): pass -class StdQCDQONNXDecoupledWeightQuantProxyHandler(StdQCDQONNXMixin, - QCDQDecoupledWeightQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdQCDQONNXDecoupledWeightQuantProxyHandler( + StdQCDQONNXMixin, QCDQDecoupledWeightQuantProxyHandlerMixin, ONNXBaseHandler +): pass -class StdQCDQONNXActQuantProxyHandler(StdQCDQONNXMixin, - QCDQActQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdQCDQONNXActQuantProxyHandler(StdQCDQONNXMixin, QCDQActQuantProxyHandlerMixin, ONNXBaseHandler): pass -class StdQCDQONNXBiasQuantProxyHandler(StdDQONNXMixin, - QCDQBiasQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdQCDQONNXBiasQuantProxyHandler(StdDQONNXMixin, QCDQBiasQuantProxyHandlerMixin, ONNXBaseHandler): pass -class StdQCDQONNXTruncQuantProxyHandler(StdQCDQONNXMixin, - QCDQTruncQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdQCDQONNXTruncQuantProxyHandler(StdQCDQONNXMixin, QCDQTruncQuantProxyHandlerMixin, ONNXBaseHandler): pass class StdQCDQONNXQuantLSTMLayerHandler(QuantLSTMLayerHandler): - def quantized_cell_symbolic_execution( - self, - quant_input, - quant_hidden_state, - quant_cell_state, - quant_weight_ii, - quant_weight_if, - quant_weight_ic, - quant_weight_io, - quant_weight_hi, - quant_weight_hf, - quant_weight_hc, - quant_weight_ho, - quant_bias_input, - quant_bias_forget, - quant_bias_cell, - quant_bias_output): + self, + quant_input, + quant_hidden_state, + quant_cell_state, + quant_weight_ii, + quant_weight_if, + quant_weight_ic, + quant_weight_io, + quant_weight_hi, + quant_weight_hf, + quant_weight_hc, + quant_weight_ho, + quant_bias_input, + quant_bias_forget, + quant_bias_cell, + quant_bias_output, + ): return BrevitasQuantLSTMCellFn.apply( quant_input, quant_hidden_state, @@ -134,7 +116,8 @@ def quantized_cell_symbolic_execution( quant_bias_forget, quant_bias_cell, quant_bias_output, - *self.symbolic_kwargs.values()) + *self.symbolic_kwargs.values() + ) # raise RuntimeError( # "Quantized LSTM cell is not supported for ONNX QCDQ " # "(weights only quantization is). Use export_qonnx.") diff --git a/src/qonnx/transformation/lower_convs_to_matmul.py b/src/qonnx/transformation/lower_convs_to_matmul.py index bf95d537..49700cd7 100644 --- a/src/qonnx/transformation/lower_convs_to_matmul.py +++ b/src/qonnx/transformation/lower_convs_to_matmul.py @@ -42,167 +42,184 @@ class LowerConvsToMatMul(Transformation): def apply(self, model): model = model.transform(ExtractBiasFromConv()) graph = model.graph - node_ind = 0 graph_modified = False - for n in graph.node: - node_ind += 1 - if n.op_type == "Conv": - if len(n.input) == 3: - warnings.warn("Found Conv node with bias, skipping") - continue - cnv_input = n.input[0] - cnv_output = n.output[0] - idt = model.get_tensor_datatype(cnv_input) - odt = model.get_tensor_datatype(cnv_output) - # extract conv parameters - k = get_by_name(n.attribute, "kernel_shape").ints - k_h = k[0] - k_w = k[1] - stride_h = get_by_name(n.attribute, "strides").ints[0] - stride_w = get_by_name(n.attribute, "strides").ints[1] - group = get_by_name(n.attribute, "group").i - weight_name = n.input[1] - W_conv = model.get_initializer(weight_name) - ifm_ch = model.get_tensor_shape(n.input[0])[1] # assume NCHW - ofm_ch = model.get_tensor_shape(n.output[0])[1] # assume NCHW - ifm_dim_h = model.get_tensor_shape(n.input[0])[2] # assume NCHW - ifm_dim_w = model.get_tensor_shape(n.input[0])[3] - ofm_dim_h = model.get_tensor_shape(n.output[0])[2] # assume NCHW - ofm_dim_w = model.get_tensor_shape(n.output[0])[3] - dilation_attr = get_by_name(n.attribute, "dilations") - if dilation_attr is not None: - dilation = dilation_attr.ints - else: - dilation = [1, 1] # default value - # handle both auto_pad and explicit padding - auto_pad = get_by_name(n.attribute, "auto_pad") - if auto_pad is not None: - # find equivalent specified padding - auto_pad = auto_pad.s.decode("utf-8") - if auto_pad == "NOTSET": - # use specified padding - pad = get_by_name(n.attribute, "pads").ints - else: - pad = auto_pad_to_explicit_padding( - auto_pad, - ifm_dim_h, - ifm_dim_w, - k_h, - k_w, - stride_h, - stride_w, - len(model.get_tensor_shape(n.input[0])) - 2, - ) - else: - # use specified padding - pad = get_by_name(n.attribute, "pads").ints - - # If len(pad) == 2, assume no padding for other dimension - if len(pad) == 2: # only one dimension should be padded - assert ifm_dim_h == 1 or ifm_dim_w == 1, "Padding is assumed to be 1D, image is 2D" - - # if depthwise conv create sparse matrix and variable "dw" - # to store as attribute in Im2Col that indicates that the created + for node_ind, node in enumerate(graph.node, start=1): + if node.op_type != "Conv": + continue + + if len(node.input) == 3: + warnings.warn("Found Conv node with bias, skipping") + continue + + # extract parameters of node + ( + cnv_input, + cnv_output, + cnv_input_datatype, + cnv_output_datatype, + k_h, + k_w, + stride_h, + stride_w, + group, + weight_name, + W_conv, + ifm_ch, + ofm_ch, + ifm_dim_h, + ifm_dim_w, + ofm_dim_h, + ofm_dim_w, + dilation, + pad, + ) = self.extract_conv_params(model, node) + + # if depthwise conv create sparse matrix and variable "dw" + # to store as attribute in Im2Col that indicates that the created + # Im2Col node belongs to a depthwise convolution + dw = False + if group == ifm_ch and ofm_ch == ifm_ch: + W_sparse = np.zeros((ofm_ch, ifm_ch, k_h, k_w)) # (OFM, IFM, k_H, k_W) + for ch in range(ifm_ch): + W_sparse[ch][ch] = W_conv[ch][0] # W_conv = [OFM, IFM, k_H, k_W] + W_conv = W_sparse.astype(np.float32) + # we need to store information of the + # sparsity of the weight matrix. For this + # we use the sparsity annotation of the + # weight tensor + sparsity = {"dw": {"kernel_shape": [k_h, k_w]}} + model.set_tensor_sparsity(weight_name, sparsity) + # additionally create variable "dw" to store + # as attribute in Im2Col that indicates that the created # Im2Col node belongs to a depthwise convolution - dw = False - if group == ifm_ch and ofm_ch == ifm_ch: - W_sparse = np.zeros((ofm_ch, ifm_ch, k_h, k_w)) # (OFM, IFM, k_H, k_W) - for ch in range(ifm_ch): - W_sparse[ch][ch] = W_conv[ch][0] # W_conv = [OFM, IFM, k_H, k_W] - W_conv = W_sparse.astype(np.float32) - # we need to store information of the - # sparsity of the weight matrix. For this - # we use the sparsity annotation of the - # weight tensor - sparsity = {"dw": {"kernel_shape": [k_h, k_w]}} - model.set_tensor_sparsity(weight_name, sparsity) - # additionally create variable "dw" to store - # as attribute in Im2Col that indicates that the created - # Im2Col node belongs to a depthwise convolution - dw = True - - # reuse conv weights for new matmul weights - # conv weights are [OFM][IFM][k][k] - # first convert to [OFM][k][k][IFM] (to remain compatible with - # finn-hlslib and how it does im2col/sliding window) - W_matmul = W_conv.transpose(0, 2, 3, 1) # W_conv = [OFM, IFM, k_H, k_W] - # reshape into [OFM][k*k*IFM] matrix - W_matmul = W_matmul.reshape(ofm_ch, ifm_ch * k_h * k_w) - # transpose to get ONNX-compatible [k*k*IFM][OFM] matrix - W_matmul = W_matmul.T - model.set_initializer(weight_name, W_matmul) - - # create new intermediate values - inp_trans_out = helper.make_tensor_value_info( - model.make_new_valueinfo_name(), - TensorProto.FLOAT, - (1, ifm_dim_h, ifm_dim_w, ifm_ch), # NHWC + dw = True + + # reuse conv weights for new matmul weights + # conv weights are [OFM][IFM][k][k] + # first convert to [OFM][k_h][k_w][IFM] (to remain compatible with + # finn-hlslib and how it does im2col/sliding window) + W_matmul = W_conv.transpose(0, 2, 3, 1) # W_conv = [OFM, IFM, k_H, k_W] + # reshape into [OFM][k_h*k_w*IFM] matrix + W_matmul = W_matmul.reshape(ofm_ch, ifm_ch * k_h * k_w) + # transpose to get ONNX-compatible [k_h*k_w*IFM][OFM] matrix + W_matmul = W_matmul.T + model.set_initializer(weight_name, W_matmul) + + # create new intermediate values + inp_trans_out = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), + TensorProto.FLOAT, + (1, ifm_dim_h, ifm_dim_w, ifm_ch), # NHWC + ) + graph.value_info.append(inp_trans_out) + inp_trans_out = inp_trans_out.name + model.set_tensor_datatype(inp_trans_out, cnv_input_datatype) + + # k_h=k_w==1: pointwise convolution, thus no im2col needed + need_im2col = any(p != 0 for p in pad) or k_h != 1 or k_w != 1 or stride_h != 1 or stride_w != 1 + + # create new intermediate values + matmul_out = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), TensorProto.FLOAT, (1, ofm_dim_h, ofm_dim_w, ofm_ch) + ) + graph.value_info.append(matmul_out) + matmul_out = matmul_out.name + model.set_tensor_datatype(matmul_out, cnv_output_datatype) + + # create new nodes + # NCHW -> NHWC + inp_trans_node = helper.make_node("Transpose", [cnv_input], [inp_trans_out], perm=[0, 2, 3, 1]) + nodes_to_insert = [inp_trans_node] + + if need_im2col: + im2col_out = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), TensorProto.FLOAT, (1, ofm_dim_h, ofm_dim_w, ifm_ch * k_h * k_w) ) - graph.value_info.append(inp_trans_out) - inp_trans_out = inp_trans_out.name - model.set_tensor_datatype(inp_trans_out, idt) - - need_im2col = True - if all(p == 0 for p in pad): - padding = 0 - - # k_h=k_w==1: pointwise convolution, thus no im2col needed - if k_h == 1 and k_w == 1 and padding == 0 and stride_h == 1 and stride_w == 1: - need_im2col = False - - if need_im2col: - im2col_out = helper.make_tensor_value_info( - model.make_new_valueinfo_name(), - TensorProto.FLOAT, - (1, ofm_dim_h, ofm_dim_w, ifm_ch * k_h * k_w), - ) - graph.value_info.append(im2col_out) - im2col_out = im2col_out.name - model.set_tensor_datatype(im2col_out, idt) - - matmul_out = helper.make_tensor_value_info( - model.make_new_valueinfo_name(), - TensorProto.FLOAT, - (1, ofm_dim_h, ofm_dim_w, ofm_ch), + graph.value_info.append(im2col_out) + im2col_out = im2col_out.name + model.set_tensor_datatype(im2col_out, cnv_input_datatype) + im2col_node = helper.make_node( + "Im2Col", + [inp_trans_out], + [im2col_out], + domain="qonnx.custom_op.general", + stride=[stride_h, stride_w], + kernel_size=[k_h, k_w], + pad_amount=pad, + input_shape="(1,{},{},{})".format(ifm_dim_h, ifm_dim_w, ifm_ch), + depthwise=dw, + dilations=dilation, ) - graph.value_info.append(matmul_out) - matmul_out = matmul_out.name - model.set_tensor_datatype(matmul_out, odt) - - # create new nodes - # NCHW -> NHWC - inp_trans_node = helper.make_node("Transpose", [cnv_input], [inp_trans_out], perm=[0, 2, 3, 1]) - # lower input tensor - matmul_input = inp_trans_out - if need_im2col: - matmul_input = im2col_out - im2col_node = helper.make_node( - "Im2Col", - [inp_trans_out], - [im2col_out], - domain="qonnx.custom_op.general", - stride=[stride_h, stride_w], - kernel_size=[k_h, k_w], - pad_amount=pad, - input_shape="(1,{},{},{})".format(ifm_dim_h, ifm_dim_w, ifm_ch), - depthwise=dw, - dilations=dilation, - ) - - # do matmul - matmul_node = helper.make_node("MatMul", [matmul_input, weight_name], [matmul_out]) - # NHWC -> NCHW - out_trans_node = helper.make_node("Transpose", [matmul_out], [cnv_output], perm=[0, 3, 1, 2]) - # insert nodes where the conv is to preserve topological ordering - graph.node.insert(node_ind, inp_trans_node) - if need_im2col: - graph.node.insert(node_ind + 1, im2col_node) - graph.node.insert(node_ind + 2, matmul_node) - graph.node.insert(node_ind + 3, out_trans_node) - else: - graph.node.insert(node_ind + 1, matmul_node) - graph.node.insert(node_ind + 2, out_trans_node) - # remove old nodes - graph.node.remove(n) + nodes_to_insert.append(im2col_node) + + matmul_input = im2col_out if need_im2col else inp_trans_out + # do matmul + matmul_node = helper.make_node("MatMul", [matmul_input, weight_name], [matmul_out]) + # NHWC -> NCHW + out_trans_node = helper.make_node("Transpose", [matmul_out], [cnv_output], perm=[0, 3, 1, 2]) + + nodes_to_insert.extend([matmul_node, out_trans_node]) + + # insert nodes where the conv is to preserve topological ordering + for i, insert_node in enumerate(nodes_to_insert): + graph.node.insert(node_ind + i, insert_node) + graph.node.remove(node) return (model, graph_modified) + + def extract_conv_params(self, model, node): + cnv_input = node.input[0] + cnv_output = node.output[0] + cnv_input_datatype = model.get_tensor_datatype(cnv_input) + cnv_output_datatype = model.get_tensor_datatype(cnv_output) + k_h = get_by_name(node.attribute, "kernel_shape").ints[0] + k_w = get_by_name(node.attribute, "kernel_shape").ints[1] + stride_h = get_by_name(node.attribute, "strides").ints[0] + stride_w = get_by_name(node.attribute, "strides").ints[1] + group = get_by_name(node.attribute, "group").i + weight_name = node.input[1] + W_conv = model.get_initializer(weight_name) + ifm_ch = model.get_tensor_shape(cnv_input)[1] # assume NCHW + ofm_ch = model.get_tensor_shape(cnv_output)[1] # assume NCHW + ifm_dim_h = model.get_tensor_shape(cnv_input)[2] # assume NCHW + ifm_dim_w = model.get_tensor_shape(cnv_input)[3] # assume NCHW + ofm_dim_h = model.get_tensor_shape(cnv_output)[2] # assume NCHW + ofm_dim_w = model.get_tensor_shape(cnv_output)[3] # assume NCHW + dilation_attr = get_by_name(node.attribute, "dilations") + dilation = dilation_attr.ints if dilation_attr is not None else [1, 1] # default value + auto_pad = get_by_name(node.attribute, "auto_pad") + if auto_pad is not None: + auto_pad = auto_pad.s.decode("utf-8") + if auto_pad == "NOTSET": + pad = get_by_name(node.attribute, "pads").ints + else: + pad = auto_pad_to_explicit_padding( + auto_pad, ifm_dim_h, ifm_dim_w, k_h, k_w, stride_h, stride_w, len(model.get_tensor_shape(cnv_input)) - 2 + ) + else: + pad = get_by_name(node.attribute, "pads").ints + + if len(pad) == 2: # only one dimension should be padded + assert ifm_dim_h == 1 or ifm_dim_w == 1, "Padding is assumed to be 1D, image is 2D" + + return ( + cnv_input, + cnv_output, + cnv_input_datatype, + cnv_output_datatype, + k_h, + k_w, + stride_h, + stride_w, + group, + weight_name, + W_conv, + ifm_ch, + ofm_ch, + ifm_dim_h, + ifm_dim_w, + ofm_dim_h, + ofm_dim_w, + dilation, + pad, + ) diff --git a/tests/transformation/test_conv_lowering.py b/tests/transformation/test_conv_lowering.py index 78da6213..788d6993 100644 --- a/tests/transformation/test_conv_lowering.py +++ b/tests/transformation/test_conv_lowering.py @@ -65,7 +65,7 @@ def test_conv_lowering_convmnist(): model = model.transform(InferShapes()) output_dict_p = oxe.execute_onnx(model, input_dict) produced = output_dict_p[output_name] - assert np.isclose(produced, expected).all() + assert np.isclose(produced, expected, rtol=1.0e-4).all() def run_conv_lowering_test(idt, k_h, k_w, ifm_dim_h, ifm_dim_w, ifm_ch, stride, padding, dilations, dw, bias):