diff --git a/.isort.cfg b/.isort.cfg index 5378b88fad..efb7a4a352 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -9,3 +9,6 @@ sections=FUTURE,STDLIB,TEST,THIRDPARTY,FIRSTPARTY,LOCALFOLDER default_section=THIRDPARTY multi_line_output=3 profile=black +ignore_comments=true +ignore_whitespace=true +honor_noqa=true diff --git a/custom_hls/flatten.hpp b/custom_hls/flatten.hpp new file mode 100644 index 0000000000..ccb15d5af6 --- /dev/null +++ b/custom_hls/flatten.hpp @@ -0,0 +1,47 @@ +#ifndef FLATTEN_HPP +#define FLATTEN_HPP + +// HLS arbitrary precision types +#include + +// Flattens an array of N elements of Type into a single bitvector +template + ap_uint flatten(const Type buffer[N]) { +// Inline this small piece of bit merging logic +#pragma HLS INLINE + // Fill a flat word of N times the bit-width of the element type + ap_uint flat; + // Merge all N chunks of the tile into the flat bitvector + for(unsigned j = 0; j < N; ++j) { +// Do the merging of all chunks in parallel +#pragma HLS UNROLL + // Insert the chunk into the right place of the + // bitvector + flat((j + 1) * Type::width - 1, j * Type::width) = buffer[j]; + } + // Return the buffer flattened into a single bitvector + return flat; + } + +// Flattens an array of N elements of float into a single bitvector +template + ap_uint flatten(const float buffer[N]) { +// Inline this small piece of bit merging logic +#pragma HLS INLINE + // Fill a flat word of N times the bit-width of the element type + ap_uint flat; + // Merge all N chunks of the tile into the flat bitvector + for(unsigned j = 0; j < N; ++j) { +// Do the merging of all chunks in parallel +#pragma HLS UNROLL + // Insert the chunk into the right place of the + // bitvector + flat((j + 1) * 32 - 1, j * 32) = + // Note: Reinterpret the float as a 32-bit unsigned bit-vector + *reinterpret_cast*>(&buffer[j]); + } + // Return the buffer flattened into a single bitvector + return flat; + } + +#endif // FLATTEN_HPP diff --git a/src/finn/custom_op/fpgadataflow/__init__.py b/src/finn/custom_op/fpgadataflow/__init__.py index aed2ab7fe1..4f2f69445e 100644 --- a/src/finn/custom_op/fpgadataflow/__init__.py +++ b/src/finn/custom_op/fpgadataflow/__init__.py @@ -27,6 +27,33 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# The base class of all generic custom operations before specializing to either +# HLS or RTL backend +from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp + +# Dictionary of HWCustomOp implementations +custom_op = dict() + + +# Registers a class into the custom_op dictionary +# Note: This must be defined first, before importing any custom op +# implementation to avoid "importing partially initialized module" issues. +def register_custom_op(cls): + # The class must actually implement HWCustomOp + assert issubclass(cls, HWCustomOp), f"{cls} must subclass {HWCustomOp}" + # Insert the class into the custom_op dictionary by its name + custom_op[cls.__name__] = cls # noqa: Some weird type annotation issue? + # Pass through the class unmodified + return cls + + +# flake8: noqa +# Disable linting from here, as all import will be flagged E402 and maybe F401 + + +# Import the submodule containing specializations of ElementwiseBinaryOperation +# Note: This will automatically register all decorated classes into this domain +import finn.custom_op.fpgadataflow.elementwise_binary from finn.custom_op.fpgadataflow.addstreams import AddStreams from finn.custom_op.fpgadataflow.channelwise_op import ChannelwiseOp from finn.custom_op.fpgadataflow.concat import StreamingConcat @@ -55,8 +82,6 @@ from finn.custom_op.fpgadataflow.upsampler import UpsampleNearestNeighbour from finn.custom_op.fpgadataflow.vectorvectoractivation import VVAU -custom_op = dict() - # make sure new HLSCustomOp subclasses are imported here so that they get # registered and plug in correctly into the infrastructure custom_op["MVAU"] = MVAU diff --git a/src/finn/custom_op/fpgadataflow/elementwise_binary.py b/src/finn/custom_op/fpgadataflow/elementwise_binary.py new file mode 100644 index 0000000000..ad204f416e --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/elementwise_binary.py @@ -0,0 +1,809 @@ +# fmt: off +# Disable formatter. This is deliberately formatted to stay within 80 characters +# per line. Black, however, formats some lines going beyond this. + +# Numpy math and arrays +import numpy as np + +# Operating system stuff, e.g. paths +import os + +# Python warning subsystem +import warnings + +# Helper for creating ONNX nodes +from onnx import helper as oh + +# QONNX/FINN datatypes +from qonnx.core.datatype import DataType + +# QONNX wrapper to ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper + +# Utility for registering HWCustomOp implementations into the module scope +from finn.custom_op.fpgadataflow import register_custom_op + +# Derive custom operators form the FINN base custom op +from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp + +# Converts inputs/outputs to/from RTL simulation format +from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy + + +# Generic implementation for elementwise binary operations +class ElementwiseBinaryOperation(HWCustomOp): + # Specifies the elementwise operation to be implemented + # Format: (Identifier, Python, C++, RTL) + _operation: tuple[str, np.ufunc, str, str] | None = None + + # Numpy operation available as property + @property + def npy_op(self) -> np.ufunc: + return self._operation[1] + + # C++ operation template available as property + @property + def cpp_op(self) -> str: + return self._operation[2] + + # RTL operation template available as property + @property + def rtl_op(self) -> str: + return self._operation[3] + + # Initializes the operator given an onnx graph node + def __init__(self, onnx_node, **kwargs): + # Just forward all arguments to the init method of the CustomOp base + super().__init__(onnx_node, **kwargs) + + # Defines attributes which must be present on this node + def get_nodeattr_types(self): + # Start from parent operator class attributes + attrs = HWCustomOp.get_nodeattr_types(self) + # Update attributes dictionary for new custom operator + attrs.update({ + # Data type of the left-hand-side input elements + "lhs_dtype": ("s", True, ""), + # Data type of the right-hand-side input elements + "rhs_dtype": ("s", True, ""), + # Data type of the output elements + "out_dtype": ("s", True, ""), + # Shape of the left-hand-side input + "lhs_shape": ("ints", True, [1]), + # Shape of the right-hand-side input + "rhs_shape": ("ints", True, [1]), + # Shape of the output, mus correspond to multi-directional + # broadcasting of the left- and right-hand-side + "out_shape": ("ints", True, [1]), + # Style specifies how the left-hand-side input is provided + # Note: Might be inferred from the context + "lhs_style": ("s", False, "input", {"input", "const"}), + # Style specifies how the right-hand-side input is provided + # Note: Might be inferred from the context + "rhs_style": ("s", False, "input", {"input", "const"}), + # Number of elements in the last dimensions processed in parallel + "PE": ("i", False, 1), + # Possible execution modes for simulating this node + # Note: Override to support python mode + "exec_mode": ( + "s", False, "python", {"", "rtlsim", "cppsim", "python"} + ), + # FPGA resource type for memories/internal buffers of the operator + "ram_style": ( + "s", False, "auto", {"auto", "block", "distributed", "ultra"} + ), + # Input and output FIFO depths for multi-I/O nodes + # Note: Need to override here as there might be two inputs + "inFIFODepths": ("ints", False, [2, 2]), + "outFIFODepths": ("ints", False, [2]), + }) + # Return updated attribute dictionary + return attrs + + # Datatype attribute as property for convenience + @property + def lhs_dtype(self): + # Note: Converts from string to QONNX data type + return DataType[self.get_nodeattr("lhs_dtype")] + + # Datatype attribute as property for convenience + @property + def rhs_dtype(self): + # Note: Converts from string to QONNX data type + return DataType[self.get_nodeattr("rhs_dtype")] + + # Datatype attribute as property for convenience + @property + def out_dtype(self): + # Note: Converts from string to QONNX data type + return DataType[self.get_nodeattr("out_dtype")] + + # Shape attribute as property for convenience + @property + def lhs_shape(self): + return self.get_nodeattr("lhs_shape") + + # Shape attribute as property for convenience + @property + def rhs_shape(self): + return self.get_nodeattr("rhs_shape") + + # Shape attribute as property for convenience + @property + def out_shape(self): + return self.get_nodeattr("out_shape") + + # Style attribute as property for convenience + @property + def lhs_style(self): + return self.get_nodeattr("lhs_style") + + # Style attribute as property for convenience + @property + def rhs_style(self): + return self.get_nodeattr("rhs_style") + + # Number of parallel processed elements as property for convenience + @property + def pe(self): + return self.get_nodeattr("PE") + + # Checks whether the last axis is broadcast + @property + def broadcast_last_axis(self): + return (self.lhs_shape[-1] == 1) != (self.rhs_shape[-1] == 1) + + # Makes an operation compatible with the output shape for shape inference + # Note: Propagates shape forward, i.e., never asks for the shape of the + # output, even if it seems easier. + def make_shape_compatible_op(self, model: ModelWrapper): # noqa + # Get the node wrapped by this custom op + node = self.onnx_node + # There must be exactly two inputs to the binary operation + assert len(node.input) == 2, \ + f"Binary operation {node.name} requires exactly two inputs" + # Validate input shapes match what is stored as attributes + assert model.get_tensor_shape(node.input[0]) == self.lhs_shape, \ + f"Input shape mismatch: {node.name} {node.input[0]}" + assert model.get_tensor_shape(node.input[1]) == self.rhs_shape, \ + f"Input shape mismatch: {node.name} {node.input[1]}" + # Validate broadcasting of inputs to the output shape + assert (list(np.broadcast_shapes(self.lhs_shape, self.rhs_shape)) + == self.out_shape), f"Shape broadcast mismatch: {node.name}" + # Simulate behavior via the standard ONNX add operation + return oh.make_node("Add", node.input, node.output) + + # Infers the datatype of the node output + def infer_node_datatype(self, model: ModelWrapper): # noqa + # Get the node wrapped by this custom op # noqa Duplicate + node = self.onnx_node + # Test for changing left-hand-side input datatype + if model.get_tensor_datatype(node.input[0]) != self.lhs_dtype: + # Get the new datatype + new_dtype = model.get_tensor_datatype(node.input[0]) + # Issue a warning message + warnings.warn( + f"{node.name}: lhs_dtype changing from" + f" {self.lhs_dtype} to {new_dtype}" + ) + # Set the new datatype attribute + self.set_nodeattr("lhs_dtype", new_dtype.name) + # Test for changing right-hand-side input datatype + if model.get_tensor_datatype(node.input[1]) != self.rhs_dtype: + # Get the new datatype + new_dtype = model.get_tensor_datatype(node.input[1]) + # Issue a warning message + warnings.warn( + f"{node.name}: rhs_dtype changing from" + f" {self.rhs_dtype} to {new_dtype}" + ) + # Set the new datatype attribute + self.set_nodeattr("rhs_dtype", new_dtype.name) + # Force the output data type stored as a node attribute + model.set_tensor_datatype(node.output[0], self.out_dtype) + + # Executes elementwise operation in python + def _execute_node_python(self, context, graph): # noqa: graph unused + # Get the node wrapped by this custom op + node = self.onnx_node + # Get the inputs out of the execution context + lhs = context[node.input[0]] + rhs = context[node.input[1]] + # Note: Need to make sure these have the right type for the Numpy API + # Note: Always simulate integer inputs in int64, numpy casting is + # weird.... + lhs = lhs.astype(np.int64) if self.lhs_dtype.is_integer() else lhs + rhs = rhs.astype(np.int64) if self.rhs_dtype.is_integer() else rhs + # Apply elementwise operation with broadcasting in numpy and insert + # result into the execution context + out = self.npy_op(lhs, rhs) + # Make sure the output has the right type, e.g. turn all booleans into + # integers (actually floats as the container type) + # Note: This is relevant for logical ops, ==, <=, >=, etc. + # Note: Somehow QONNX does not like boolean tensors + context[node.output[0]] = out.astype(np.float32) + + # Executes elementwise operation in C++ simulation + def _execute_node_cppsim(self, context, graph): # noqa: graph unused + # C++ Simulation needs to be implemented in HLS backend specialization + raise NotImplementedError( + f"exec_mode cppsim of {self.__class__.__name__} is not implemented!" + ) + + # Executes elementwise operation in RTL simulation + def _execute_node_rtlsim(self, context, graph): # noqa: graph unused + # Get the node wrapped by this custom op # noqa Duplicate + node = self.onnx_node + # Input data is stored in numpy files in the code generation dictionary + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + # Get the inputs out of the execution context + lhs = context[node.input[0]] # noqa: Duplicate code prepare simulation + rhs = context[node.input[1]] # noqa: Duplicate code prepare simulation + # Validate the shape of the inputs + assert list(lhs.shape) == self.get_normal_input_shape(ind=0), \ + f"Input shape mismatch for {node.input[0]}" + assert list(rhs.shape) == self.get_normal_input_shape(ind=1), \ + f"Input shape mismatch for {node.input[1]} {rhs.shape=}" + # Reshape the inputs into folded form + lhs = lhs.reshape(self.get_folded_input_shape(ind=0)) + rhs = rhs.reshape(self.get_folded_input_shape(ind=1)) + # Path to store the intermediate inputs in numpy format + lhs_filename = os.path.join(code_gen_dir, "lhs.npy") + rhs_filename = os.path.join(code_gen_dir, "rhs.npy") + # Save the folded inputs to file to be used by simulation + np.save(lhs_filename, lhs) + np.save(rhs_filename, rhs) + # Start collecting inputs/outputs to the RTL simulation in a dictionary + # Note: Prepare one output empty output list + io_dict = { + "inputs": {}, + "outputs": {"out": []} + } + # Type and width of the input tensors + lhs_dtype = self.get_input_datatype(ind=0) + lhs_width = self.get_instream_width(ind=0) + rhs_dtype = self.get_input_datatype(ind=1) + rhs_width = self.get_instream_width(ind=1) + + # If the left-hand-side is provided as runtime input it needs to be + # inserted into the RTL simulation inputs + if self.lhs_style == "input": + # Convert inputs to RTL simulation format + io_dict["inputs"]["lhs"] = npy_to_rtlsim_input( + lhs_filename, lhs_dtype, lhs_width + ) + + # If the right-hand-side is provided as runtime input it needs to be + # inserted into the RTL simulation inputs + if self.rhs_style == "input": + # Convert inputs to RTL simulation format + io_dict["inputs"]["rhs"] = npy_to_rtlsim_input( + rhs_filename, rhs_dtype, rhs_width + ) + + # Setup PyVerilator simulation of the node + sim = self.get_rtlsim() # noqa: Duplicate code prepare simulation + # Reset the RTL simulation + super().reset_rtlsim(sim) + super().toggle_clk(sim) + # Run the RTL Simulation + self.rtlsim_multi_io(sim, io_dict) + + # Collect the output from RTL simulation + out = io_dict["outputs"]["out"] + # Type and sizes of the output tensor + dtype = self.get_output_datatype(ind=0) # noqa: Duplicate readout code + width = self.get_outstream_width(ind=0) + shape = self.get_folded_output_shape(ind=0) + # Path to store the intermediate numpy file + filename = os.path.join(code_gen_dir, "out.npy") + # Convert from RTL simulation format to numpy format + rtlsim_output_to_npy( + out, filename, dtype, shape, width, dtype.bitwidth() + ) + # Load the generated output numpy file + out = np.load(filename) + # Reshape the folded output and insert into the execution context + context[node.output[0]] = out.reshape( + self.get_normal_output_shape(ind=0) + ) + + # Executes elementwise op in simulation (either python c++ or rtl sim) + def execute_node(self, context, graph): + # Get the configured execution mode + mode = self.get_nodeattr("exec_mode") + # Lookup table mapping execution modes to implementing methods + exec_fns = { + "python": self._execute_node_python, + "cppsim": self._execute_node_cppsim, + "rtlsim": self._execute_node_rtlsim, + } + # Select and execute the function by mode string + exec_fns[mode](context, graph) + + # Verifies the node attributes, inputs and outputs + def verify_node(self): + # TODO: Implement + return [] + + # Note: End of QONNX CustomOp region, below is FINN HWCustomOp stuff + + # Gets the datatype of input at index ind + def get_input_datatype(self, ind=0): + # Get input data type by index, order inputs from left to right + return [self.lhs_dtype, self.rhs_dtype][ind] + + # Gets the datatype of the output at index ind + def get_output_datatype(self, ind=0): + # There is only one output, the type is set as an attribute + return self.out_dtype + + # Gets the shape of the input at index ind without folding + def get_normal_input_shape(self, ind=0): + # Input shapes are stored as a node attributes + return [self.lhs_shape, self.rhs_shape][ind] + + # Gets the shape of the output at index ind without folding + def get_normal_output_shape(self, ind=0): + # The output shape is stored as a node attribute + return self.out_shape + + # Gets the shape of the input at index ind with folding + def get_folded_input_shape(self, ind=0): + # Get the normal shape before applying folding + *num_inputs, num_elems = self.get_normal_input_shape(ind=ind) + # Folding only applies if the folded axis is not broadcast + if not self.broadcast_last_axis or num_elems != 1: + # Valid folding requires the PE to divide the number of elements + assert num_elems % self.pe == 0, "PE must divide last axis" + # Folding along the last dimension + return *num_inputs, num_elems // self.pe, self.pe + # For broadcast axes return the non-folded shape with dummy axis + # inserted + return *num_inputs, 1, num_elems + + # Gets the shape of the output at index ind with folding + def get_folded_output_shape(self, ind=0): + # Get the normal shape before applying folding + *num_inputs, num_elems = self.get_normal_output_shape(ind=ind) + # Valid folding requires the PE to divide the number of elements + assert num_elems % self.pe == 0, "PE must divide last axis" + # Folding along the last dimension + return *num_inputs, num_elems // self.pe, self.pe + + # Widths of the input data stream of the input at index ind + def get_instream_width(self, ind=0): + # Get the number of bits used to represent the input + i_bits = self.get_input_datatype(ind).bitwidth() + # Parallelism is the number of elements in the last dimension of the + # folded input + *_, elems = self.get_folded_input_shape(ind) + # Width of a stream receiving input elements in parallel + return elems * i_bits + + # Widths of the output data stream of the output at index ind + def get_outstream_width(self, ind=0): + # Get the number of bits used to represent the output + o_bits = self.get_output_datatype(ind).bitwidth() + # Parallelism is the number of elements in the last dimension of the + # folded output + *_, elems = self.get_folded_output_shape(ind) + # Width of a stream producing output elements in parallel + return elems * o_bits + + # Gets the number of expected output values, i.e. how many times read() + # could/should be called on any output stream of this operator + def get_number_output_values(self): + # Elements over all but the last dimension of the output folded along + # the embedding dimension. + return np.prod(self.get_folded_output_shape()[:-1]) + + # Minimizes the width of the accumulator data type, 'accumulator width' here + # due to convention, it is actually the output data type + def minimize_accumulator_width(self, model: ModelWrapper): + # If any of the inputs is not an integer, the bit-width cannot be + # minimized + if not all([self.lhs_dtype.is_integer(), self.rhs_dtype.is_integer()]): + # Check the annotated tensor data type corresponds to the stored + # attribute + assert (model.get_tensor_datatype(self.onnx_node.output[0]) + == self.out_dtype), \ + f"Output type mismatch for {self.onnx_node.name}" + # Exit here, returning the not-minimized data type + return self.out_dtype + # Call the output type derivation specialized by the concrete operator + # implementation + out_dtype = self._derive_out_dtype(model) + # Set the new output data type as attribute + self.set_nodeattr("out_dtype", out_dtype.name) + # Annotate the output tensor with the new data type + model.set_tensor_datatype(self.onnx_node.output[0], out_dtype) + # Return the minimized output data type + # Note: Probably not required by MinimizeAccumulatorWidth transformation + return out_dtype + + # Derives the optimal width of the output data type + def _derive_out_dtype(self, model: ModelWrapper): + # Depends on the actual operation performed and must be specialized by + # the concrete implementations + raise NotImplementedError( + f"_derive_out_dtype of {self.__class__.__name__}" + f" is not implemented!" + ) + + # Minimizes the width of the weight data type, 'weight' here due to + # convention, it actually applies to any constant initializer input + def minimize_weight_bit_width(self, model: ModelWrapper): + # Check for an initializer providing the left hand side input + lhs = model.get_initializer(self.onnx_node.input[0]) + # If the left hand side input is provided as initializer, minimize the + # bits used for storing this + if lhs is not None: + # Remember the "style" of receiving the input for further code + # generation + self.set_nodeattr("lhs_style", "const") + # Minimum and maximum "weight" on the left hand side, determining + # the range of values which needs to be represented + _min = lhs.min() + _max = lhs.max() + # Determine whether signed or unsigned type is required for + # representing the weights and select the largest "signed magnitude" + _mag = _max if _min > 0 else \ + _min if (abs(_min) > _max) else (-_max - 1) + # Smallest data type large enough to represent this range of values + dtype = DataType.get_smallest_possible(_mag) + # Update the corresponding data type attribute of the node + self.set_nodeattr("lhs_dtype", dtype.name) + # Annotate the tensor with the new data type + model.set_tensor_datatype(self.onnx_node.input[0], dtype) + + # Check for an initializer providing the right hand side input + rhs = model.get_initializer(self.onnx_node.input[1]) + # If the right hand side input is provided as initializer, minimize the + # bits used for storing this + if rhs is not None: + # Remember the "style" of receiving the input for further code + # generation + self.set_nodeattr("rhs_style", "const") + # Minimum and maximum "weight" on the right hand side, determining + # the range of values which needs to be represented + _min = rhs.min() + _max = rhs.max() + assert _min != 0 + assert _max != 0 + # Determine whether signed or unsigned type is required for + # representing the weights and select the largest "signed magnitude" + _mag = _max if _min > 0 else \ + _min if (abs(_min) > _max) else (-_max - 1) + # Smallest data type large enough to represent this range of values + dtype = DataType.get_smallest_possible(_mag) + # Update the corresponding data type attribute of the node + self.set_nodeattr("rhs_dtype", dtype.name) + # Annotate the tensor with the new data type + model.set_tensor_datatype(self.onnx_node.input[1], dtype) + + # TODO: MVAU returns the data type here, which does not make sense for + # potentially two data types changing and apparently, the + # MinimizeWeightBitWidth transformations does not even use the returned + # value. + + # Derives the expected cycles for the elementwise binary operation given the + # folding configuration + def get_exp_cycles(self): + # Number of iterations required to process the whole folded input stream + # Note: This is all but the PE (last, parallelized) dimension + return np.prod(self.get_folded_output_shape()[:-1]) + + +# Derive a specialization to implement elementwise addition of two inputs +@register_custom_op +class ElementwiseAdd(ElementwiseBinaryOperation): + # Specialize to implement the addition operation of left hand side and right + # hand side input + _operation = "Add", np.add, "({0} + {1})", None + + # Derives the output data type according to UG1399 + def _derive_out_dtype(self, model: ModelWrapper): + # Get the width of the data types of the inputs and the larger of the + # two widths + lhs_width = self.lhs_dtype.bitwidth() + rhs_width = self.rhs_dtype.bitwidth() + max_width = max(lhs_width, rhs_width) + # Check whether the addition operation is a signed addition + signed = any([self.lhs_dtype.signed(), self.rhs_dtype.signed()]) + # By default, the output is one bit more than the widest of the inputs + out_width = max_width + 1 + # If the addition is signed, the output might be wider depending on + # which of the inputs is signed + if signed: + # Find the wider and narrower of the two inputs by assuming left to + # right order first + wider, narrower = self.lhs_dtype, self.rhs_dtype + # Swap if the order is not correct + if narrower.bitwidth() > wider.bitwidth(): + wider, narrower = narrower, wider + # If and only if the wider is unsigned and the narrower is signed, + # add two bits to the output width + if not wider.signed() and narrower.signed(): + # Out has two bits more than the widest input + out_width = max_width + 2 + # The new output type is a signed integer of the calculated + # bit-width + return DataType[f"INT{out_width}"] + # By default, if both inputs are unsigned, the output is unsigned as + # well + return DataType[f"UINT{out_width}"] + + +# Derive a specialization to implement elementwise subtraction of two inputs +@register_custom_op +class ElementwiseSub(ElementwiseBinaryOperation): + # Specialize to implement the subtraction operation of left hand side and + # right hand side input + _operation = "Sub", np.subtract, "({0} - {1})", None + + # Derives the output data type according to UG1399 + def _derive_out_dtype(self, model: ModelWrapper): + # Get the width of the data types of the inputs and the larger of the + # two widths + lhs_width = self.lhs_dtype.bitwidth() + rhs_width = self.rhs_dtype.bitwidth() + max_width = max(lhs_width, rhs_width) + # Check whether the addition operation is a signed addition + signed = any([self.lhs_dtype.signed(), self.rhs_dtype.signed()]) + # By default, the output is one bit more than the widest of the inputs + out_width = max_width + 1 + # If the operation is signed, the output might be wider depending on + # which of the inputs is signed + if signed: + # Find the wider and narrower of the two inputs by assuming left to + # right order first + wider, narrower = self.lhs_dtype, self.rhs_dtype + # Swap if the order is not correct + if narrower.bitwidth() > wider.bitwidth(): + wider, narrower = narrower, wider + # If and only if the wider is unsigned and the narrower is signed, + # add two bits to the output width + if not wider.signed() and narrower.signed(): + # Out has two bits more than the widest input + out_width = max_width + 2 + # For subtraction, the output data type is always signed + return DataType[f"INT{out_width}"] + + +# Derive a specialization to implement elementwise multiplication of two inputs +@register_custom_op +class ElementwiseMul(ElementwiseBinaryOperation): + # Specialize to implement the multiplication operation of left hand side and + # right hand side input + _operation = "Mul", np.multiply, "({0} * {1})", None + + # Derives the output data type according to UG1399 + def _derive_out_dtype(self, model: ModelWrapper): + # Get the width of the data types of the inputs + lhs_width = self.lhs_dtype.bitwidth() + rhs_width = self.rhs_dtype.bitwidth() + # Check whether the addition operation is a signed addition + signed = any([self.lhs_dtype.signed(), self.rhs_dtype.signed()]) + # The width of the product is the sum of the widths of the operands. + out_width = lhs_width + rhs_width + # The product is treated as a signed type if either of the operands is + # of a signed type. + return DataType[f"INT{out_width}" if signed else f"UINT{out_width}"] + + +# Derive a specialization to implement elementwise division of two inputs +@register_custom_op +class ElementwiseDiv(ElementwiseBinaryOperation): + # TODO: Not tested due to divide by zero from randomly generated inputs... + # Specialize to implement the division operation of left hand side and + # right hand side input + _operation = "Div", np.divide, "({0} / {1})", None + + # Derives the output data type according to UG1399 + def _derive_out_dtype(self, model: ModelWrapper): + # Get the width of the data types of the inputs + lhs_width = self.lhs_dtype.bitwidth() + # Check whether the addition operation is a signed addition + signed = any([self.lhs_dtype.signed(), self.rhs_dtype.signed()]) + # The width of the quotient is the width of the dividend if the divisor + # is an unsigned type. Otherwise, it is the width of the dividend plus + # one. + out_width = lhs_width if not self.rhs_dtype.signed() else lhs_width + 1 + # The quotient is treated as a signed type if either of the operands is + # of a signed type. + return DataType[f"INT{out_width}" if signed else f"UINT{out_width}"] + + +# TODO: ElementwiseMod - Requires extra attribute selecting the function + + +# Derive a specialization to implement elementwise logical and of two inputs +@register_custom_op +class ElementwiseAnd(ElementwiseBinaryOperation): + # Specialize to implement the logical and operation of left hand side and + # right hand side input + _operation = "And", np.logical_and, "({0} && {1})", None + + # Derives the output data type + def _derive_out_dtype(self, model: ModelWrapper): + # Treat the boolean output of a logical operation as unsigned integer of + # width 1, i.e., a single bit True/False + return DataType["BINARY"] + + +# Derive a specialization to implement elementwise logical or of two inputs +@register_custom_op +class ElementwiseOr(ElementwiseBinaryOperation): + # Specialize to implement the logical or operation of left hand side and + # right hand side input + _operation = "Or", np.logical_or, "({0} || {1})", None + + # Derives the output data type + def _derive_out_dtype(self, model: ModelWrapper): + # Treat the boolean output of a logical operation as unsigned integer of + # width 1, i.e., a single bit True/False + return DataType["BINARY"] + + +# Derive a specialization to implement elementwise logical xor of two inputs +@register_custom_op +class ElementwiseXor(ElementwiseBinaryOperation): + # Specialize to implement the logical xor operation of left hand side and + # right hand side input + _operation = "Xor", np.logical_xor, "(bool({0}) != bool({1}))", None + + # Derives the output data type + def _derive_out_dtype(self, model: ModelWrapper): + # Treat the boolean output of a logical operation as unsigned integer of + # width 1, i.e., a single bit True/False + return DataType["BINARY"] + + +# Derive a specialization to implement elementwise equality of two inputs +@register_custom_op +class ElementwiseEqual(ElementwiseBinaryOperation): + # Specialize to implement the logical equal operation of left hand side and + # right hand side input + _operation = "Equal", np.equal, "({0} == {1})", None + + # Derives the output data type + def _derive_out_dtype(self, model: ModelWrapper): + # Treat the boolean output of a logical operation as unsigned integer of + # width 1, i.e., a single bit True/False + return DataType["BINARY"] + + +# Derive a specialization to implement elementwise less of two inputs +@register_custom_op +class ElementwiseLess(ElementwiseBinaryOperation): + # Specialize to implement the logical less operation of left hand side and + # right hand side input + _operation = "Less", np.less, "({0} < {1})", None + + # Derives the output data type + def _derive_out_dtype(self, model: ModelWrapper): + # Treat the boolean output of a logical operation as unsigned integer of + # width 1, i.e., a single bit True/False + return DataType["BINARY"] + + +# Derive a specialization to implement elementwise less or equal of two inputs +@register_custom_op +class ElementwiseLessOrEqual(ElementwiseBinaryOperation): + # Specialize to implement the logical less or equal operation of left hand + # side and right hand side input + _operation = "LessOrEqual", np.less_equal, "({0} <= {1})", None + + # Derives the output data type + def _derive_out_dtype(self, model: ModelWrapper): + # Treat the boolean output of a logical operation as unsigned integer of + # width 1, i.e., a single bit True/False + return DataType["BINARY"] + + +# Derive a specialization to implement elementwise greater of two inputs +@register_custom_op +class ElementwiseGreater(ElementwiseBinaryOperation): + # Specialize to implement the logical greater operation of left hand side + # and right hand side input + _operation = "Greater", np.greater, "({0} > {1})", None + + # Derives the output data type + def _derive_out_dtype(self, model: ModelWrapper): + # Treat the boolean output of a logical operation as unsigned integer of + # width 1, i.e., a single bit True/False + return DataType["BINARY"] + + +# Derive a specialization to implement elementwise greater or equal of two +# inputs +@register_custom_op +class ElementwiseGreaterOrEqual(ElementwiseBinaryOperation): + # Specialize to implement the logical greater or equal operation of left + # hand side and right hand side input + _operation = "GreaterOrEqual", np.greater_equal, "({0} >= {1})", None + + # Derives the output data type + def _derive_out_dtype(self, model: ModelWrapper): + # Treat the boolean output of a logical operation as unsigned integer of + # width 1, i.e., a single bit True/False + return DataType["BINARY"] + + +# Derive a specialization to implement elementwise bitwise and of two inputs +@register_custom_op +class ElementwiseBitwiseAnd(ElementwiseBinaryOperation): + # Specialize to implement the bitwise and operation of left hand side and + # right hand side input + _operation = "BitwiseAnd", np.bitwise_and, "({0} & {1})", None + + # Derives the output data type according to UG1399 + def _derive_out_dtype(self, model: ModelWrapper): + # Get the width of the data types of the inputs # noqa: Duplicate + lhs_width = self.lhs_dtype.bitwidth() + rhs_width = self.rhs_dtype.bitwidth() + # Check whether the addition operation is a signed addition + signed = any([self.lhs_dtype.signed(), self.rhs_dtype.signed()]) + # The bitwise logical operators all return a value with a width that is + # the maximum of the widths of the two operands. + out_width = max(lhs_width, rhs_width) + # The product is treated as a signed type if either of the operands is + # of a signed type. + return DataType[f"INT{out_width}" if signed else f"UINT{out_width}"] + + +# Derive a specialization to implement elementwise bitwise or of two inputs +@register_custom_op +class ElementwiseBitwiseOr(ElementwiseBinaryOperation): + # Specialize to implement the bitwise or operation of left hand side and + # right hand side input + _operation = "BitwiseOr", np.bitwise_or, "({0} | {1})", None + + # Derives the output data type according to UG1399 + def _derive_out_dtype(self, model: ModelWrapper): + # Get the width of the data types of the inputs # noqa: Duplicate + lhs_width = self.lhs_dtype.bitwidth() + rhs_width = self.rhs_dtype.bitwidth() + # Check whether the addition operation is a signed addition + signed = any([self.lhs_dtype.signed(), self.rhs_dtype.signed()]) + # The bitwise logical operators all return a value with a width that is + # the maximum of the widths of the two operands. + out_width = max(lhs_width, rhs_width) + # The product is treated as a signed type if either of the operands is + # of a signed type. + return DataType[f"INT{out_width}" if signed else f"UINT{out_width}"] + + +# Derive a specialization to implement elementwise bitwise xor of two inputs +@register_custom_op +class ElementwiseBitwiseXor(ElementwiseBinaryOperation): + # Specialize to implement the bitwise xor operation of left hand side and + # right hand side input + _operation = "BitwiseXor", np.bitwise_xor, "({0} ^ {1})", None + + # Derives the output data type according to UG1399 + def _derive_out_dtype(self, model: ModelWrapper): + # Get the width of the data types of the inputs # noqa: Duplicate + lhs_width = self.lhs_dtype.bitwidth() + rhs_width = self.rhs_dtype.bitwidth() + # Check whether the addition operation is a signed addition + signed = any([self.lhs_dtype.signed(), self.rhs_dtype.signed()]) + # The bitwise logical operators all return a value with a width that is + # the maximum of the widths of the two operands. + out_width = max(lhs_width, rhs_width) + # The product is treated as a signed type if either of the operands is + # of a signed type. + return DataType[f"INT{out_width}" if signed else f"UINT{out_width}"] + +# TODO: ElementwiseBitShift - Requires extra attribute selecting the direction + + +# # Derive a specialization to implement elementwise power of two inputs +# TODO: std::pow does not work for HLS types and hls::pow fails to link for some +# reason +# @register_custom_op +# class ElementwisePow(ElementwiseBinaryOperation): +# # Specialize to implement the power operation of left hand side and +# # right hand side input +# _operation = "Pow", np.power, "(std::pow({0}, {1}))", None diff --git a/src/finn/custom_op/fpgadataflow/hls/__init__.py b/src/finn/custom_op/fpgadataflow/hls/__init__.py index 405c47a08d..3fb958a99e 100644 --- a/src/finn/custom_op/fpgadataflow/hls/__init__.py +++ b/src/finn/custom_op/fpgadataflow/hls/__init__.py @@ -26,6 +26,37 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# The base class of all HWCustomOp specializations to HLS backend implementation +from finn.custom_op.fpgadataflow.hlsbackend import HLSBackend + +# The base class of all generic custom operations before specializing to either +# HLS or RTL backend +from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp + +# Dictionary of HLSBackend implementations +custom_op = dict() + + +# Registers a class into the custom_op dictionary +# Note: This must be defined first, before importing any custom op +# implementation to avoid "importing partially initialized module" issues. +def register_custom_op(cls): + # The class must actually implement HWCustomOp + assert issubclass(cls, HWCustomOp), f"{cls} must subclass {HWCustomOp}" + # The class must also implement the HLSBackend + assert issubclass(cls, HLSBackend), f"{cls} must subclass {HLSBackend}" + # Insert the class into the custom_op dictionary by its name + custom_op[cls.__name__] = cls # noqa: Some weird type annotation issue? + # Pass through the class unmodified + return cls + + +# flake8: noqa +# Disable linting from here, as all import will be flagged E402 and maybe F401 + +# Import the submodule containing specializations of ElementwiseBinaryOperation +# Note: This will automatically register all decorated classes into this domain +import finn.custom_op.fpgadataflow.hls.elementwise_binary_hls from finn.custom_op.fpgadataflow.hls.addstreams_hls import AddStreams_hls from finn.custom_op.fpgadataflow.hls.channelwise_op_hls import ChannelwiseOp_hls from finn.custom_op.fpgadataflow.hls.checksum_hls import CheckSum_hls @@ -53,8 +84,6 @@ from finn.custom_op.fpgadataflow.hls.upsampler_hls import UpsampleNearestNeighbour_hls from finn.custom_op.fpgadataflow.hls.vectorvectoractivation_hls import VVAU_hls -custom_op = dict() - # make sure new HLSCustomOp subclasses are imported here so that they get # registered and plug in correctly into the infrastructure custom_op["AddStreams_hls"] = AddStreams_hls diff --git a/src/finn/custom_op/fpgadataflow/hls/elementwise_binary_hls.py b/src/finn/custom_op/fpgadataflow/hls/elementwise_binary_hls.py new file mode 100644 index 0000000000..cffb964baf --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/hls/elementwise_binary_hls.py @@ -0,0 +1,766 @@ +# fmt: off +# Disable formatter. This is deliberately formatted to stay within 80 characters +# per line. Black, however, formats some lines going beyond this. + +# Numpy math and arrays +import numpy as np + +# Operating system stuff, e.g. paths +import os + +# Cleanup post-processing of generated code +import textwrap + +# QONNX wrapper to ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper + +# Specializations of the generic HW operator +import finn.custom_op.fpgadataflow.elementwise_binary as elementwise_binary + +# Utility for registering HLSBackend HWCustomOp implementations into the module +# scope +from finn.custom_op.fpgadataflow.hls import register_custom_op + +# Base class for specializing HW operators as implemented via HLS +from finn.custom_op.fpgadataflow.hlsbackend import HLSBackend + +# Convert and pack (numpy) data for C++ code generation +from finn.util.data_packing import numpy_to_hls_code + +# The generic HW custom operator version of the operator as a base class +from finn.custom_op.fpgadataflow.elementwise_binary import ( # noqa + ElementwiseBinaryOperation +) + +# Mapping of memory resource attributes to the corresponding C++ HLS +# pragma directives +RAM_STYLES = { + "auto": "AUTO", "block": "BRAM", "distributed": "LUTRAM", "ultra": "URAM" +} + + +# HLS Backend specialization of the binary elementwise operation operator +class ElementwiseBinaryOperation_hls( # noqa: Class name does not follow + # CapWords convention + ElementwiseBinaryOperation, HLSBackend +): + # Node attributes matching the HLS operator + def get_nodeattr_types(self): + # Start from parent operator class attributes + attrs = ElementwiseBinaryOperation.get_nodeattr_types(self) + # Add the HLSBackend default attributes on top + attrs.update(HLSBackend.get_nodeattr_types(self)) + # Add/Specialize implementation specific attributes here... + # Return the updated attributes dictionary + return attrs + + # Executes elementwise operation in C++ simulation + def _execute_node_cppsim(self, context, graph): # noqa: graph unused + # Get the node wrapped by this custom op + node = self.onnx_node + # Input data is stored in numpy files in the code generation dictionary + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + # Get the inputs out of the execution context + lhs = context[node.input[0]] # noqa: Duplicate code prepare simulation + rhs = context[node.input[1]] + # Validate the shape of the inputs + assert list(lhs.shape) == self.get_normal_input_shape(ind=0), \ + f"Input shape mismatch for {node.input[0]}" + assert list(rhs.shape) == self.get_normal_input_shape(ind=1), \ + f"Input shape mismatch for {node.input[1]} {rhs.shape=}" + # Reshape the inputs into folded form + lhs = lhs.reshape(self.get_folded_input_shape(ind=0)) + rhs = rhs.reshape(self.get_folded_input_shape(ind=1)) + # Save the folded inputs to file to be used by simulation + np.save(os.path.join(code_gen_dir, "lhs.npy"), lhs) + np.save(os.path.join(code_gen_dir, "rhs.npy"), rhs) + + # Execute the precompiled model + super().exec_precompiled_singlenode_model() + + # Load the output numpy file generated by the C++ simulation + out = np.load(os.path.join(code_gen_dir, "out.npy")) + # Reshape the folded output and insert into the execution context + context[node.output[0]] = out.reshape( + self.get_normal_output_shape(ind=0) + ) + + # Maximum width of any ap_int used in this operator + def get_ap_int_max_w(self): + # Find the widths of the widest of the two inputs + i_bits_max = max( + self.get_instream_width(ind=0), + self.get_instream_width(ind=1) + ) + # Width of the output, there is just one output + # Note: there is one output per replica + o_bits_max = self.get_outstream_width(ind=0) + # Find the biggest of the inputs/outputs + return max([i_bits_max, o_bits_max]) + + # Note: End of shape and datatype utilities + + # Generates list of C++ includes to be placed at the top of the generated + # code + def global_includes(self): + # Currently nothing to include + self.code_gen_dict["$GLOBALS$"] = ['#include "flatten.hpp"'] + + # Generates C++ parameters file, i.e., constant initializer inputs + def generate_params(self, model: ModelWrapper, path: str): + # The code generation directory is specified as an argument, so this + # will work for both RTL and C++ simulation + code_gen_dir = path + # By default, assume runtime inputs not requiring code to be generated + lhs_code = rhs_code = "" + # Check for an initializer providing the left hand side input + lhs = model.get_initializer(self.onnx_node.input[0]) + # Folded output shape for broadcasting/aligning the input shapes + out_shape = self.get_folded_output_shape(ind=0) + # Type of memory to use for storing constant parameters + ram_style = RAM_STYLES[self.get_nodeattr("ram_style")] + + # Check whether there are already pragmas in the code generation + # dictionary + if "$PRAGMAS$" not in self.code_gen_dict: + # If not, insert an empty list to collect more pragmas + # Note: Do this here as it is easier to add the array partition and + # bind storage pragmas for generated parameter here, where the shape + # is computed. + self.code_gen_dict["$PRAGMAS$"] = [] + + # If the left hand side input is provided as initializer, generate + # initializer parameters code + if lhs is not None: + # Remember the "style" of receiving the input for further code + # generation + self.set_nodeattr("lhs_style", "const") + # Reshape the parameter tensor into folded shape + lhs = lhs.reshape(*self.get_folded_input_shape(ind=0)) + # Need to make sure there are PE many elements which can be accessed + # in parallel + if lhs.shape[-1] != self.pe: # noqa: Duplicate + # Broadcast the parameter tensor "offline" to have PE elements + # TODO: This replicates all parameters and might be inefficient + # in terms of memory utilization. It might be ore efficient to + # replicate the PEs when needed in docompute, probably at the + # cost of some latency for extra reads and registers. + lhs = np.broadcast_to(lhs, lhs.shape[:-1] + (self.pe,)) + # Current, maybe non-aligned input shape + lhs_shape = lhs.shape + # Fill up shape from the left to match the broadcast output shape + lhs_shape = (len(out_shape) - len(lhs_shape)) * (1,) + lhs_shape + # Reshape the input to align with the output shape + lhs = lhs.reshape(*lhs_shape) + # Generate C++ array initialization code + # Note: no packing, but with variable name/type declaration + lhs_code = numpy_to_hls_code( + lhs, self.lhs_dtype, "lhs", False, False + ) + # Add pragma configuring the storage type to use for the parameter + # tensors: This is a constant parameter implemented as dual-port ROM + self.code_gen_dict["$PRAGMAS$"].append( + f"#pragma HLS BIND_STORAGE" + f" variable=lhs type=ROM_2P impl={ram_style}" + ) + # Add pragma to partition the parameter tensor along the last + # dimensions, i.e., the PE dimension for parallel access + self.code_gen_dict["$PRAGMAS$"].append( + f"#pragma HLS ARRAY_PARTITION" + f" variable=lhs complete dim={len(lhs_shape)}" + ) + + # Check for an initializer providing the right hand side input + rhs = model.get_initializer(self.onnx_node.input[1]) + # If the right hand side input is provided as initializer, generate + # initializer parameters code + if rhs is not None: + # Remember the "style" of receiving the input for further code + # generation + self.set_nodeattr("rhs_style", "const") + # Reshape the parameter tensor into folded shape + rhs = rhs.reshape(*self.get_folded_input_shape(ind=1)) + # Need to make sure there are PE many elements which can be accessed + # in parallel + if rhs.shape[-1] != self.pe: # noqa: Duplicate + # Broadcast the parameter tensor "offline" to have PE elements + # TODO: This replicates all parameters and might be inefficient + # in terms of memory utilization. It might be ore efficient to + # replicate the PEs when needed in docompute, probably at the + # cost of some latency for extra reads and registers. + rhs = np.broadcast_to(rhs, rhs.shape[:-1] + (self.pe,)) + # Current, maybe non-aligned input shape + rhs_shape = rhs.shape + # Fill up shape from the left to match the broadcast output shape + rhs_shape = (len(out_shape) - len(rhs_shape)) * (1,) + rhs_shape + # Reshape the input to align with the output shape + rhs = rhs.reshape(*rhs_shape) + # Generate C++ array initialization code + # Note: no packing, but with variable name/type declaration + rhs_code = numpy_to_hls_code( + rhs, self.rhs_dtype, "rhs", False, False + ) + # Add pragma configuring the storage type to use for the parameter + # tensors: This is a constant parameter implemented as dual-port ROM + self.code_gen_dict["$PRAGMAS$"].append( + f"#pragma HLS BIND_STORAGE" + f" variable=rhs type=ROM_2P impl={ram_style}" + ) + # Add pragma to partition the parameter tensor along the last + # dimensions, i.e., the PE dimension for parallel access + self.code_gen_dict["$PRAGMAS$"].append( + f"#pragma HLS ARRAY_PARTITION" + f" variable=rhs complete dim={len(rhs_shape)}" + ) + + # Open a file to store the thresholds parameters as C++ code + with open(f"{code_gen_dir}/params.hpp", "w") as file: + # Write lines of C++ code separated by newlines to the file + file.write("\n".join([ + # Insert left-hand-side and right-hand-side parameter code and + # append a newline at the end of the file (to avoid problems + # when including, required by C standard?) + lhs_code, rhs_code, "\n" + ])) + + # Generates C++ code of type alias, global constant and macro definitions + def defines(self, var): + # Insert constants and type aliases into the dictionary + self.code_gen_dict["$DEFINES$"] = [ + # Input and output element datatypes + f"using LhsType = {self.lhs_dtype.get_hls_datatype_str()};", + f"using RhsType = {self.rhs_dtype.get_hls_datatype_str()};", + f"using OutType = {self.out_dtype.get_hls_datatype_str()};", + # Width of single elements to avoid using ::width attribute which is + # not present for datatype float + f"static constexpr auto LhsWidth = {self.lhs_dtype.bitwidth()};", + f"static constexpr auto RhsWidth = {self.rhs_dtype.bitwidth()};", + f"static constexpr auto OutWidth = {self.out_dtype.bitwidth()};", + # Datatype of elements packed into the input stream + f"using LhsPacked = ap_uint<{self.get_instream_width(ind=0)}>;", + f"using RhsPacked = ap_uint<{self.get_instream_width(ind=1)}>;", + # Datatype of elements packed into the output stream + f"using OutPacked = ap_uint<{self.get_outstream_width(ind=0)}>;", + # Include the activation function type definitions and parameters + # Note: The typedefs in this header require the typedefs above, + # thus adding this to the global includes is not possible. + '#include "params.hpp"', + # Input and output HLS stream datatypes + "using LhsStream = hls::stream;", + "using RhsStream = hls::stream;", + "using OutStream = hls::stream;", + ] + + # Generates C++ code for reading data from .npy (numpy format) for testing + # in C++ simulation + def read_npy_data(self): + # Input data is stored in numpy files in the code generation dictionary + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + # Prepare empty stream reading to append optionals + self.code_gen_dict["$READNPYDATA$"] = [] + # If the left-hand-side is provided as runtime input, read code needs + # to be generated + if self.lhs_style == "input": + # Generate function calls for reading the input files into the input + # streams + self.code_gen_dict["$READNPYDATA$"] += [ + # Generate function call reading from file into the input stream + # Note: Inputs are always represented as numpy floats + 'npy2apintstream(', + f'"{code_gen_dir}/lhs.npy", lhs_{self.hls_sname()}, false', + ');' + ] + # If the right-hand-side is provided as runtime input, read code needs + # to be generated + if self.rhs_style == "input": + # Generate function calls for reading the input files into the input + # streams + self.code_gen_dict["$READNPYDATA$"] += [ + # Generate function call reading from file into the input stream + # Note: Inputs are always represented as numpy floats + 'npy2apintstream(', + f'"{code_gen_dir}/rhs.npy", rhs_{self.hls_sname()}, false', + ');' + ] + + # Generates C++ code for declaring all streams involved in C++ simulation + # for testing + def strm_decl(self): + # Allways add the output stream to the declarations + self.code_gen_dict["$STREAMDECLARATIONS$"] = [ + # Note: Assumes stream type aliases to be set in defines + f"OutStream out_{self.hls_sname()};" + ] + # If the left-hand-side is provided as runtime input, read code needs + # to be generated + if self.lhs_style == "input": + # Generate a stream declaration + self.code_gen_dict["$STREAMDECLARATIONS$"] += [ + # Note: Assumes stream type aliases to be set in defines + f"LhsStream lhs_{self.hls_sname()};" + ] + # If the right-hand-side is provided as runtime input, read code needs + # to be generated + if self.rhs_style == "input": + # Generate a stream declaration + self.code_gen_dict["$STREAMDECLARATIONS$"] += [ + # Note: Assumes stream type aliases to be set in defines + f"RhsStream rhs_{self.hls_sname()};" + ] + + # Generates C++ code for calling the computation part of the operator + def docompute(self): + # Add padding ones to a shape to match the broadcast output shape + def pad_shape(shape): + return (len(out_shape) - len(shape)) * (1,) + shape + + # Get the folded shapes of all tensors involved without PE axis + lhs_shape = self.get_folded_input_shape(ind=0)[:-1] + rhs_shape = self.get_folded_input_shape(ind=1)[:-1] + out_shape = self.get_folded_output_shape(ind=0)[:-1] + # Expanded shape of the inputs, filling with dimensions of size 1 from + # the left to align the shape with the broadcast shape + lhs_shape = pad_shape(lhs_shape) + rhs_shape = pad_shape(rhs_shape) + + # Removes contiguous matching dimensions from a shape + def drop_matching_dims(shape, like): + # Core functionality for this is implemented in itertools + from itertools import dropwhile + + # Compare shapes from left to right removing dimensions as long as + # they match + return *[ + size for size, _ in dropwhile( + lambda x: x[0] == x[1], zip(shape, like) + ) + ], + + # Take away all contiguous dimensions where these align with the output + # shape, as these can be consumed directly without buffering to be + # repeated + lhs_buffer_shape = drop_matching_dims(lhs_shape, out_shape) + rhs_buffer_shape = drop_matching_dims(rhs_shape, out_shape) + # Expand once again, filling with dimensions of size 1 from the left to + # align the shape with the broadcast shape + lhs_buffer_shape = pad_shape(lhs_buffer_shape) + rhs_buffer_shape = pad_shape(rhs_buffer_shape) + + # Code generation of array index strings with broadcasting + def make_index_string(shape): + # Generate index operation [i] for "normal" dimensions but reduce to + # hardcoded [0] for broadcast dimensions to repeat from a single + # buffer slot + return "".join([ + f"[i{d}]" if s != 1 else "[0]" for d, s in enumerate(shape) + ]) + + # Generate the C++ code for indexing the buffers + lhs_index = { + "input": make_index_string(lhs_buffer_shape), + "const": make_index_string(lhs_shape) + }[self.lhs_style] + rhs_index = { + "input": make_index_string(rhs_buffer_shape), + "const": make_index_string(rhs_shape) + }[self.rhs_style] + + # Generate C++ code for declaring an array of the buffer shapes + lhs_buffer_shape = "".join([f'[{size}]' for size in lhs_buffer_shape]) + rhs_buffer_shape = "".join([f'[{size}]' for size in rhs_buffer_shape]) + + # Number of dimensions of the (broadcast) output. All shapes will be + # aligned to this number of dimensions. + # Note: +1 for the PE dimension + ndim = len(out_shape) + 1 + + # For-Loop template for nested loops over arbitrary many levels + def for_loop(level, size): + return f"for(std::size_t i{level} = 0; i{level}<{size}; ++i{level})" + + # Generate code testing for the condition when the next element needs to + # be read from the input stream according to broadcasting semantics + def read_stream_condition(shape): + # Start with the assumption that none of the dimensions is + # broadcast, meaning each individual element needs to be read from + # the stream + condition = "true" + # Search for the dimensions which are broadcast + for dim, size in enumerate(shape): + # If this dimension has a size of 1 in the input but not in the + # output, it is broadcast and contributes to the conjunctive + # reading condition if this index wraps around + if size == 1 and out_shape[dim] != 1: + # Add testing for index wrap-around to the condition + condition += f" && (i{dim} == 0)" + # Return the composed reading condition + return condition + + # Generate code for unpacking elements read from the stream into the PE- + # parallel buffer according to broadcasting semantics + def unpack_buffer(shape): + # Unpacking behavior depends on whether the last, i.e., folded PE + # dimension is broadcast + if shape[-1] == 1 and self.pe != self.out_shape[-1]: + # PE axis is broadcast, i.e., slice yields just one element + # which needs to be replicated + return "buffer(0, 0)" + # PE axis is not broadcast, i.e., slice actually yields parallel + # elements to be unpacked + return "buffer(pe, 0)" + + # Type of memory to use for storing constant parameters + ram_style = RAM_STYLES[self.get_nodeattr("ram_style")] + + # Write the body of the top-level function + self.code_gen_dict["$DOCOMPUTE$"] = [ + # @formatter:off Disable formatter for mixed Python and C++ + # For streamed inputs, generate local buffer of non-broadcast size + # but broadcasts dimensions un-squeezed to size 1. For constant + # inputs, use the generated parameters of the same name. + # For streamed inputs, implement a simple dual-port RAM partitioned + # on the last, i.e., the PE, axis for parallel access. + f""" + LhsType lhs{lhs_buffer_shape}[{self.pe}]; + #pragma HLS ARRAY_PARTITION variable=lhs complete dim={ndim} + #pragma HLS BIND_STORAGE variable=lhs type=RAM_S2P impl={ram_style} + """ if self.lhs_style == "input" else """""", + f""" + RhsType rhs{rhs_buffer_shape}[{self.pe}]; + #pragma HLS ARRAY_PARTITION variable=rhs complete dim={ndim} + #pragma HLS BIND_STORAGE variable=rhs type=RAM_S2P impl={ram_style} + """ if self.rhs_style == "input" else """""", + # Buffer to hold the parallel output elements: Implement a simple + # dual-port RAM for the output buffer, partitioned on the last, + # i.e., the PE, axis for parallel access. + # Note: The PE output should be rather small, force this into + # distributed memory here. + # TODO: Maybe reconsider this later? + f""" + OutType out[{self.pe}]; + #pragma HLS ARRAY_PARTITION variable=out complete dim=1 + #pragma HLS BIND_STORAGE variable=out type=RAM_S2P impl=LUTRAM + """, + # Perfect loop nest over all folded output dimensions + *[for_loop(dim, size) + " {" for dim, size in enumerate(out_shape)], + # Pipeline the loops. This should be possible as there is no code + # between the loop levels, i.e., this is a perfect loop nest. + """ + #pragma HLS pipeline II=1 style=flp + """, + # Read from the left-hand-side input stream if new elements are + # needed according to broadcasting semantics + f""" + if({read_stream_condition(lhs_shape)}) {{ + const auto buffer = Slice{{}}( + lhs_{self.hls_sname()}.read() + ); + for(std::size_t pe = 0; pe < {self.pe}; ++pe) {{ + #pragma HLS unroll + lhs{lhs_index}[pe] = {unpack_buffer(lhs_shape)}; + }} + }} + """ if self.lhs_style == "input" else """""", + # Read from the right-hand-side input stream if new elements are + # needed according to broadcasting semantics + f""" + if({read_stream_condition(rhs_shape)}) {{ + const auto buffer = Slice{{}}( + rhs_{self.hls_sname()}.read() + ); + for(std::size_t pe = 0; pe < {self.pe}; ++pe) {{ + #pragma HLS unroll + rhs{rhs_index}[pe] = {unpack_buffer(rhs_shape)}; + }} + }} + """ if self.rhs_style == "input" else """""", + # Apply PE parallel elementwise operations by filling the operation + # template + f""" + for(std::size_t pe = 0; pe < {self.pe}; ++pe) {{ + #pragma HLS unroll + out[pe] = {self.cpp_op.format( + f"lhs{lhs_index}[pe]", f"rhs{rhs_index}[pe]" + )}; + }} + """, + # Write the PE group into the output stream + f""" + out_{self.hls_sname()}.write(flatten<{self.pe}>(out)); + """, + # Close all for-loop bodies of the generated nest + *["}" for _ in enumerate(out_shape)] + # @formatter:on End of code generation + ] + + # Post-process the generated code to remove unnecessary white space + self.code_gen_dict["$DOCOMPUTE$"] = [ + textwrap.dedent(code) for code in self.code_gen_dict["$DOCOMPUTE$"] + ] + + # Generates C++ code for reading the output stream and converting back to + # numpy format for testing in C** simulation + def dataoutstrm(self): + # Output data will be stored in numpy files in the code generation + # dictionary + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + # Get the expected shape of the folded output array formatted as a C++ + # vector initializer + # Note: Valid formatting relies on correct placement of curly braces + # and line breaks: Open/close all three braces on the same line of code + # to avoid '\n' to be inserted into the string + shape = f"""{{{ + ','.join((str(i) for i in self.get_folded_output_shape(ind=0))) + }}}""" + # Generate function call for reading from the output stream into the + # output file + self.code_gen_dict["$DATAOUTSTREAM$"] = [ + # Generate function call reading from stream into the output file + # Note: Outputs are always represented as numpy floats + 'apintstream2npy(', + f'out_{self.hls_sname()}, {shape}, "{code_gen_dir}/out.npy", false', + ');', + ] + + # Generates C++ code for saving the output of C++ simulation to a file in + # numpy format + def save_as_npy(self): + # Note: This seems to be empty in ALL HLSBackends. Probably it was used + # for something before, which is now integrated into dataoutstrm()? + self.code_gen_dict["$SAVEASCNPY$"] = [] + + # Generates essentially the head of the C++ function from which the IP block + # will be generated during ipgen, i.e. actual synthesis + def blackboxfunction(self): + # Check whether the inputs are provided at runtime to generate stream + # inputs to the toplevel interface + runtime_lhs = self.lhs_style == "input" + runtime_rhs = self.rhs_style == "input" + # Insert function head describing the top level interface of the + # attention operator + self.code_gen_dict["$BLACKBOXFUNCTION$"] = [ + # Note: Assumes stream type aliases to be set in defines + f"void {self.onnx_node.name} (", + f" LhsStream &lhs_{self.hls_sname()}," if runtime_lhs else "", + f" RhsStream &rhs_{self.hls_sname()}," if runtime_rhs else "", + f" OutStream &out_{self.hls_sname()}", + ")", + ] + + # Generates C++ pragmas to be inserted into the main function of the C++ + # simulation and the ipgen-blackboxfunction as well + def pragmas(self): + # Check whether there are already pragmas in the code generation + # dictionary + if "$PRAGMAS$" not in self.code_gen_dict: + # If not, insert an empty list to collect more pragmas + self.code_gen_dict["$PRAGMAS$"] = [] + + # Add HLS interface directives specifying how to create RTL ports for + # the top-level function arguments + self.code_gen_dict["$PRAGMAS$"] += [ + # Connect the output stream with an axi stream interface + f"#pragma HLS INTERFACE axis port=out_{self.hls_sname()}", + ] + + # If the left-hand-side is provided as runtime input interface pragmas + # need to be inserted + if self.lhs_style == "input": + # Connect the lhs input stream with an axi stream interface + self.code_gen_dict["$PRAGMAS$"] += [ + f"#pragma HLS INTERFACE axis port=lhs_{self.hls_sname()}", + ] + + # If the right-hand-side is provided as runtime input interface pragmas + # need to be inserted + if self.rhs_style == "input": + # Connect the rhs input stream with an axi stream interface + self.code_gen_dict["$PRAGMAS$"] += [ + f"#pragma HLS INTERFACE axis port=rhs_{self.hls_sname()}", + ] + + # No block-level I/O protocol for the function return value + self.code_gen_dict["$PRAGMAS$"].append( + "#pragma HLS INTERFACE ap_ctrl_none port=return" + ) + + # Returns the names of input and output interfaces grouped by protocol + def get_verilog_top_module_intf_names(self): + # Start collecting interface names in a dictionary starting with clock + # and reset + intf_names = {"clk": ["ap_clk"], "rst": ["ap_rst_n"]} # noqa + # AXI stream input interfaces + intf_names["s_axis"] = [] + # If the left-hand-side is provided as runtime input interface names + # need to be inserted + if self.lhs_style == "input": + intf_names["s_axis"] += [( + f"lhs_{self.hls_sname()}", self.get_instream_width_padded(ind=0) + )] + # If the right-hand-side is provided as runtime input interface names + # need to be inserted + if self.rhs_style == "input": + intf_names["s_axis"] += [( + f"rhs_{self.hls_sname()}", self.get_instream_width_padded(ind=1) + )] + # AXI stream output interfaces + intf_names["m_axis"] = [ + (f"out_{self.hls_sname()}", self.get_outstream_width_padded(ind=0)) + ] + # No AXI-MM, AXI-Lite or protocol-less interfaces + intf_names["aximm"] = [] + intf_names["axilite"] = [] + intf_names["ap_none"] = [] + # Return the interface name dictionary + return intf_names + + +# Derive a specialization to implement elementwise addition of two inputs +@register_custom_op # noqa: PyCharm sees all these specializations as duplicate +class ElementwiseAdd_hls( # noqa: Class name does not follow + ElementwiseBinaryOperation_hls, elementwise_binary.ElementwiseAdd +): + pass + + +# Derive a specialization to implement elementwise subtraction of two inputs +@register_custom_op +class ElementwiseSub_hls( # noqa: Class name does not follow + # CapWords convention + ElementwiseBinaryOperation_hls, elementwise_binary.ElementwiseSub +): + pass + + +# Derive a specialization to implement elementwise multiplication of two inputs +@register_custom_op +class ElementwiseMul_hls( # noqa: Class name does not follow + # CapWords convention + ElementwiseBinaryOperation_hls, elementwise_binary.ElementwiseMul +): + pass + + +# Derive a specialization to implement elementwise division of two inputs +@register_custom_op +class ElementwiseDiv_hls( # noqa: Class name does not follow + # CapWords convention + ElementwiseBinaryOperation_hls, elementwise_binary.ElementwiseDiv +): + pass + + +# TODO: ElementwiseMod_hls - Requires extra attribute selecting the function + +# Derive a specialization to implement elementwise logical and of two inputs +@register_custom_op +class ElementwiseAnd_hls( # noqa: Class name does not follow + # CapWords convention + ElementwiseBinaryOperation_hls, elementwise_binary.ElementwiseAnd +): + pass + + +# Derive a specialization to implement elementwise logical or of two inputs +@register_custom_op +class ElementwiseOr_hls( # noqa: Class name does not follow + # CapWords convention + ElementwiseBinaryOperation_hls, elementwise_binary.ElementwiseOr +): + pass + + +# Derive a specialization to implement elementwise logical xor of two inputs +@register_custom_op +class ElementwiseXor_hls( # noqa: Class name does not follow + # CapWords convention + ElementwiseBinaryOperation_hls, elementwise_binary.ElementwiseXor +): + pass + + +# Derive a specialization to implement elementwise equal of two inputs +@register_custom_op # noqa: PyCharm sees all these specializations as duplicate +class ElementwiseEqual_hls( # noqa: Class name does not follow + # CapWords convention + ElementwiseBinaryOperation_hls, elementwise_binary.ElementwiseEqual +): + pass + + +# Derive a specialization to implement elementwise less of two inputs +@register_custom_op +class ElementwiseLess_hls( # noqa: Class name does not follow + # CapWords convention + ElementwiseBinaryOperation_hls, elementwise_binary.ElementwiseLess +): + pass + + +# Derive a specialization to implement elementwise less or equal of two inputs +@register_custom_op +class ElementwiseLessOrEqual_hls( # noqa: Class name does not follow + # CapWords convention + ElementwiseBinaryOperation_hls, elementwise_binary.ElementwiseLessOrEqual +): + pass + + +# Derive a specialization to implement elementwise greater of two inputs +@register_custom_op +class ElementwiseGreater_hls( # noqa: Class name does not follow + # CapWords convention + ElementwiseBinaryOperation_hls, elementwise_binary.ElementwiseGreater +): + pass + + +# Derive a specialization to implement elementwise greater or equal of two +# inputs +@register_custom_op +class ElementwiseGreaterOrEqual_hls( # noqa: Class name does not follow + # CapWords convention + ElementwiseBinaryOperation_hls, elementwise_binary.ElementwiseGreaterOrEqual +): + pass + + +# Derive a specialization to implement elementwise bitwise and of two inputs +@register_custom_op +class ElementwiseBitwiseAnd_hls( # noqa: Class name does not follow + # CapWords convention + ElementwiseBinaryOperation_hls, elementwise_binary.ElementwiseBitwiseAnd +): + pass + + +# Derive a specialization to implement elementwise bitwise or of two inputs +@register_custom_op +class ElementwiseBitwiseOr_hls( # noqa: Class name does not follow + # CapWords convention + ElementwiseBinaryOperation_hls, elementwise_binary.ElementwiseBitwiseOr +): + pass + + +# Derive a specialization to implement elementwise bitwise xor of two inputs +@register_custom_op +class ElementwiseBitwiseXor_hls( # noqa: Class name does not follow + # CapWords convention + ElementwiseBinaryOperation_hls, elementwise_binary.ElementwiseBitwiseXor +): + pass + +# TODO: ElementwiseBitShift_hls - Requires extra attribute selecting the +# direction + + +# # Derive a specialization to implement elementwise power of two inputs +# TODO: std::pow does not work for HLS types and hls::pow fails to link for some +# reason +# @register_custom_op +# class ElementwisePow_hls( # noqa: Class name does not follow +# # CapWords convention +# ElementwiseBinaryOperation_hls, elementwise_binary.ElementwisePow +# ): +# pass diff --git a/src/finn/custom_op/fpgadataflow/templates.py b/src/finn/custom_op/fpgadataflow/templates.py index 3d89a0ab23..e5787bfd2a 100644 --- a/src/finn/custom_op/fpgadataflow/templates.py +++ b/src/finn/custom_op/fpgadataflow/templates.py @@ -29,6 +29,7 @@ # template for single node execution docompute_template = """ +#define HLS_CONSTEXPR_ENABLE #define AP_INT_MAX_W $AP_INT_MAX_W$ #include "cnpy.h" #include "npy2apintstream.hpp" @@ -62,6 +63,7 @@ # cpp file ipgen_template = """ +#define HLS_CONSTEXPR_ENABLE #define AP_INT_MAX_W $AP_INT_MAX_W$ #include "bnn-library.h" diff --git a/src/finn/custom_op/fpgadataflow/thresholding.py b/src/finn/custom_op/fpgadataflow/thresholding.py index dde813a293..363c1572cf 100644 --- a/src/finn/custom_op/fpgadataflow/thresholding.py +++ b/src/finn/custom_op/fpgadataflow/thresholding.py @@ -259,7 +259,7 @@ def execute_node(self, context, graph): else: # signed offset y += act.min() - context[node.output[0]] = y + context[node.output[0]] = y.astype(np.float32) def calc_tmem(self): """Calculates and returns TMEM.""" diff --git a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py index 897d714bf8..7aa28999de 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py @@ -30,8 +30,11 @@ import numpy as np import qonnx.core.data_layout as DataLayout import warnings -from onnx import TensorProto, helper +from onnx import NodeProto, TensorProto, helper from qonnx.core.datatype import DataType + +# QONNX wrapper to ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.registry import getCustomOp from qonnx.transformation.base import Transformation from qonnx.transformation.general import SortGraph @@ -40,6 +43,12 @@ from qonnx.util.basic import get_by_name from qonnx.util.onnx import nchw_to_nhwc +# Module containing specializations of elementwise binary operations +import finn.custom_op.fpgadataflow.elementwise_binary as elementwise_binary + +# Base class for all FINN custom ops, here just used for type-hinting +from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp + class InferConvInpGen(Transformation): """Convert Im2Col layers to ConvolutionInputGenerator layers.""" @@ -1691,3 +1700,130 @@ def apply(self, model): model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) return (model, graph_modified) + + +# Lifts scalar to rank-1 tensor +def lift_to_rank1(name: str, model: ModelWrapper): + # Scalars have a shape of lengths zero + if len(model.get_tensor_shape(name)) == 0: + # Lift shape to rank-1 tensor with single element + model.set_tensor_shape(name, [1]) + # Check whether this tensor has an initializer + if (tensor := model.get_initializer(name)) is not None: + # Set new initializer tensor of shape [1] + model.set_initializer(name, tensor.reshape(1)) + + +# Converts supported elementwise binary operations to their FINN custom +# operation +class InferElementwiseBinaryOperation(Transformation): + # Filter function to filter out the last elementwise Mul operation, + # typically corresponding to output de-quantization, which should happen + # off-chip + @staticmethod + def reject_output_dequant(model: ModelWrapper, node: NodeProto): + # The operator must be a Mul and have no successor nodes + if node.op_type == "Mul" and not model.find_direct_successors(node): + # If the output is a floating-point tensors, reject this + if model.get_tensor_datatype(node.output[0]) == "FLOAT32": + # Filter False rejects this node + return False + # Filter True accepts this node + return True + + # Filter function to filter out any operation involving any floating-point + # tensor + @staticmethod + def reject_floats(model: ModelWrapper, node: NodeProto): + # Check for any input being floating-point + if any(model.get_tensor_datatype(x) == "FLOAT32" for x in node.input): + # Filter False rejects this node + return False + # Check for any output being floating-point + if any(model.get_tensor_datatype(x) == "FLOAT32" for x in node.output): + # Filter False rejects this node + return False + # Filter True accepts this node + return True + + # Initializes the transformation method with an optional filter function + def __init__(self, _filter=None): + # Initialize the base class Transformation object + super().__init__() + # Register the filter function as attribute + self._filter = _filter if _filter is not None else lambda *_: True + + # Applies the transform to a whole model graph + def apply(self, model: ModelWrapper): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Skip transforming nodes rejected by the filter + if not self._filter(model, node): + continue + # If a custom operation with corresponding name is implemented in + # the module, this operator is supported for conversion + if f"Elementwise{node.op_type}" in dir(elementwise_binary): + # Transplant this operator into our FINN domain + node.domain = "finn.custom_op.fpgadataflow" + # Adapt the op-type prefixing it with Elementwise + # TODO: Consider dropping the prefix? + node.op_type = f"Elementwise{node.op_type}" + # Now we can get the CustomOp wrapper instance providing easier + # attribute access + inst: HWCustomOp = getCustomOp(node) + # Set the backend attribute to mark this an operation supported + # to be implemented on an FPGA by FINN + inst.set_nodeattr("backend", "fpgadataflow") + # Need to "lift" potential scalar inputs to rank-1 tensors + lift_to_rank1(node.input[0], model) + lift_to_rank1(node.input[1], model) + + # fmt: off + # Disable formatter. This is deliberately formatted to stay + # within 80 characters per line. Black, however, formats some + # lines going beyond this. + + # Insert data type attributes from "context" into the CustomOp + # node + # TODO: Find a way to handle this via data type inference? + inst.set_nodeattr( + "lhs_dtype", str(model.get_tensor_datatype(node.input[0])) + ) + inst.set_nodeattr( + "rhs_dtype", str(model.get_tensor_datatype(node.input[1])) + ) + inst.set_nodeattr( + "out_dtype", str(model.get_tensor_datatype(node.output[0])) + ) + # Insert shape attributes from "context" into the CustomOp node + # TODO: Find a way to handle this via shape inference? + inst.set_nodeattr( + "lhs_shape", model.get_tensor_shape(node.input[0]) + ) + inst.set_nodeattr( + "rhs_shape", model.get_tensor_shape(node.input[1]) + ) + inst.set_nodeattr( + "out_shape", model.get_tensor_shape(node.output[0]) + ) + + # fmt: on + + # Consider the graph to be modified, triggering exhaustive + # re-application of this transformation + graph_modified = True + # Exiting here triggers type and shape inference and cleanup + # after each transformed node. This helps QONNX to behave + # better / more consistent in certain cases... + break + # Re-do shape and data type annotations after potential changes to the + # model graph + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + # Return the transformed model and indicate whether the graph actually + # has been transformed + return model, graph_modified diff --git a/src/finn/transformation/fpgadataflow/set_folding.py b/src/finn/transformation/fpgadataflow/set_folding.py index eaee499e6a..0ae425975c 100644 --- a/src/finn/transformation/fpgadataflow/set_folding.py +++ b/src/finn/transformation/fpgadataflow/set_folding.py @@ -27,12 +27,17 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# Inspect information on Python objects like modules +import inspect import numpy as np import warnings from qonnx.custom_op.registry import getCustomOp from qonnx.transformation.base import Transformation from qonnx.transformation.general import GiveUniqueNodeNames +# Import the elementwise binary operation module to extract names of all +# specializations (which require PE parallelism to be configured) +import finn.custom_op.fpgadataflow.hls.elementwise_binary_hls as elementwise_binary_hls from finn.analysis.fpgadataflow.dataflow_performance import dataflow_performance from finn.transformation.fpgadataflow.annotate_cycles import AnnotateCycles from finn.util.fpgadataflow import is_hls_node, is_rtl_node @@ -44,6 +49,15 @@ def divisors(num): yield x +# Find the op-type names for all HLS specializations of elementwise binary +# operations +ELEMENTWISE_BINARY_OPS = [ + op_type + for op_type, cls in inspect.getmembers(elementwise_binary_hls, inspect.isclass) + if issubclass(cls, elementwise_binary_hls.ElementwiseBinaryOperation_hls) +] + + class SetFolding(Transformation): """Attempt to set parallelism attributes in all nodes to meet a specific target expressed as cycles per frame target_cycles_per_frame. For each @@ -106,6 +120,7 @@ def apply(self, model): "GlobalAccPool_hls", "Thresholding_hls", "Thresholding_rtl", + *ELEMENTWISE_BINARY_OPS, ] # these ops use SIMD parallelism, up to a max value of NumChannels # ConvolutionInputGenerator* has a special case when depthwise=1 @@ -151,7 +166,16 @@ def apply(self, model): # increase PE until target met or reached max_pe self.optimize_attribute_val(node_inst, max_pe, "PE") elif op_type in pe_ops: - max_pe = node_inst.get_nodeattr("NumChannels") + # Note: Keep original behavior for all custom-ops defining the + # NumChannels attribute as it is + try: + max_pe = node_inst.get_nodeattr("NumChannels") + # Note: Some of the recent additions do not define the + # NumChannels attribute + except AttributeError: + # We can extract the channels from the normal, i.e., not + # folded, shape of the input in these cases + max_pe = node_inst.get_normal_input_shape()[-1] self.optimize_attribute_val(node_inst, max_pe, "PE") elif op_type == "LabelSelect_hls": max_pe = node_inst.get_nodeattr("Labels") diff --git a/src/finn/transformation/qonnx/fold_quant_weights.py b/src/finn/transformation/qonnx/fold_quant_weights.py index 0f6cbacb82..59ebe4eea3 100644 --- a/src/finn/transformation/qonnx/fold_quant_weights.py +++ b/src/finn/transformation/qonnx/fold_quant_weights.py @@ -149,7 +149,8 @@ def apply(self, model): mul_tensor = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - mul_shape, + mul_shape, # Note: This shape is known exactly as + # it is an initializer with known shape ) graph.value_info.append(mul_tensor) model.set_initializer(mul_tensor.name, scale) @@ -168,7 +169,9 @@ def apply(self, model): act_mul_tensor = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - output_shape, + None, # Note: Explicitly delete the shape + # annotation to be redone by the next shape + # inference ) graph.value_info.append(act_mul_tensor) successor.output[0] = act_mul_tensor.name @@ -186,19 +189,37 @@ def apply(self, model): div_tensor = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - mul_shape, + None, # Note: Explicitly delete the shape + # annotation to be redone by the next shape + # inference ) graph.value_info.append(div_tensor) model.set_initializer(div_tensor.name, scale) - succ_input_name = successor.input[0] + # Detect which input of the add-like successor is + # fed by the quantizer node to select the other + # branch to insert the scale factor + if successor.input[0] == node_out: + succ_input_name = successor.input[1] + else: + succ_input_name = successor.input[0] + act_mul_tensor = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - output_shape, + None, # Note: Explicitly delete the shape + # annotation to be redone by the next shape + # inference ) graph.value_info.append(act_mul_tensor) - successor.input[0] = act_mul_tensor.name + + # Detect which input of the add-like successor is + # fed by the quantizer node to select the other + # branch to insert the scale factor + if successor.input[0] == node_out: + successor.input[1] = act_mul_tensor.name + else: + successor.input[0] = act_mul_tensor.name div_node = helper.make_node( "Div", @@ -210,6 +231,8 @@ def apply(self, model): # remove old node graph.node.remove(n) graph_modified = True + # Note: Running shape inference is necessary as shape + # annotations have been deleted above model = model.transform(InferShapes()) return (model, graph_modified) return (model, graph_modified) diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py index 323e391df4..451ba52c29 100644 --- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py +++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py @@ -25,8 +25,8 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - import numpy as np +import warnings from abc import ABC, abstractmethod from onnx import TensorProto, helper from qonnx.core.modelwrapper import ModelWrapper @@ -70,7 +70,7 @@ def _check_compatibility(self): @abstractmethod def _calculate_act_bias(self): """Calculate the activation bias, - which is introduced as an Add node behind the MultiTrheshold node. + which is introduced as an Add node behind the MultiThreshold node. """ raise NotImplementedError() @@ -82,7 +82,7 @@ def _calculate_thresholds(self): @abstractmethod def _calculate_act_scale(self): """Calculate the activation scale, - which is indroduced as a Mul node behind the Add node + which is introduced as a Mul node behind the Add node for the activation bias. """ raise NotImplementedError() @@ -157,7 +157,7 @@ def replace_quant_node(self): # Set scale and bias # If these values are scalar then they can be set as attributes # of the MultiThreshold node, if not they get inserted as adder and mul nodes - # behind the MultiTrheshold nodes. + # behind the MultiThreshold nodes. bias_scalar = adder_bias.shape == (1,) or len(adder_bias.shape) == 0 scale_scalar = mul_scale.shape == (1,) or len(mul_scale.shape) == 0 if scale_scalar and bias_scalar and self._q_node.op_type == "BipolarQuant": @@ -355,7 +355,7 @@ def _calculate_thresholds(self): act_node = self._model.find_direct_predecessors(self._q_node) act_node = act_node[0] if act_node.op_type == "Relu": - # Calculate thersholds, see: https://github.com/Xilinx/brevitas/blob/ + # Calculate thresholds, see: https://github.com/Xilinx/brevitas/blob/ # a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/export/ # onnx/finn/handler/act.py#L21 num_distinct_values = 2**bit_width @@ -395,8 +395,27 @@ def _calculate_thresholds(self): else: thresholds[c][t] = step / selu_scale + # First try to consider the tensor layout of the output for determining + # the number of output channels + layout = self._model.get_tensor_layout(self._q_node.output[0]) + # If there is a layout annotation, use this to determine the index of + # the channel dimension + if layout is not None and "C" in layout: + # Lookup the index in list + cdim = layout.index("C") + # If no layout has been annotated or there is no channel dimension, fall + # back to the previous default assumption + else: + # Assume the channels to be in axis 1 + cdim = 1 + # Issue a warning to the user, so they are aware of this + warnings.warn( + f"No layout annotations for {self._q_node.output[0]}:" + f" Assuming channel dimension at index {cdim}" + ) + # ToDo: The index 1 needs to be changed to -1 for the channels last format - num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[1] + num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[cdim] final_shape = (num_output_channels, num_thresholds) if thresholds.shape != final_shape: thresholds = np.broadcast_to(thresholds, final_shape) @@ -417,12 +436,12 @@ def _remove_activation_node(self, multi_threshold_node): act_node = self._model.find_direct_predecessors(self._q_node) if act_node is None: raise RuntimeError( - "For handling of Relu activations a predecesor to " "the Quant node must exist." + "For handling of Relu activations a predecessor to " "the Quant node must exist." ) act_node = act_node[0] if act_node.op_type not in self.valid_predecessor_op_types(): raise RuntimeError( - "The predecesor of the Quant node must be Relu or Selu for handling " + "The predecessor of the Quant node must be Relu or Selu for handling " "of activations." ) @@ -509,7 +528,7 @@ def _calculate_thresholds(self): else: raise RuntimeError("Got an unexpected quantizer node type") - # Calculate thersholds, see: https://github.com/Xilinx/brevitas/ + # Calculate thresholds, see: https://github.com/Xilinx/brevitas/ # blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/ # export/onnx/finn/handler/act.py#L76 if bit_width == 1.0: @@ -537,8 +556,28 @@ def _calculate_thresholds(self): for t in range(num_thresholds): thresholds[c][t] = min_threshold[c] + step[c] * t + # First try to consider the tensor layout of the output for + # determining the number of output channels + layout = self._model.get_tensor_layout(self._q_node.output[0]) + # If there is a layout annotation, use this to determine the index + # of the channel dimension + if layout is not None and "C" in layout: + # Lookup the index in list + cdim = layout.index("C") + # If no layout has been annotated or there is no channel dimension, + # fall back to the previous default assumption + else: + # Assume the channels to be in axis 1 + cdim = 1 + # Issue a warning to the user, so they are aware of this + warnings.warn( + f"No layout annotations for {self._q_node.output[0]}:" + f" Assuming channel dimension at index {cdim}" + ) + # ToDo: The index 1 needs to be changed to -1 for the channels last format - num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[1] + num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[cdim] + final_shape = (num_output_channels, num_thresholds) if thresholds.shape != final_shape: thresholds = np.broadcast_to(thresholds, final_shape) diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index e3e2468bba..4c280d8f28 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -31,12 +31,18 @@ import warnings from onnx import helper as oh from qonnx.core.datatype import DataType + +# QONNX wrapper of ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.registry import getCustomOp from qonnx.transformation.base import Transformation from qonnx.transformation.infer_datatypes import InferDataTypes from qonnx.transformation.infer_shapes import InferShapes from qonnx.util.basic import get_by_name +# Protobuf onnx graph node type +from onnx import NodeProto # noqa + class AbsorbSignBiasIntoMultiThreshold(Transformation): """Absorb scalar bias originating from signed int export back into @@ -100,6 +106,19 @@ def apply(self, model): return (model, graph_modified) +# Groups inputs by categories, i.e., groups dynamic inputs first, followed by +# initializers. Keeps order of inputs in each category. +def group_inputs_by_category(node: NodeProto, model: ModelWrapper): # noqa + # First select all dynamic inputs, which are those without initializer + # tensor + dynamics = [i for i in node.input if model.get_initializer(i) is None] + # Select all input which are initializers, which, by exclusion, are all + # those not among the dynamic inputs + initializers = [i for i in node.input if i not in dynamics] + # Return lists of dynamic anc initializer inputs + return dynamics, initializers + + class AbsorbAddIntoMultiThreshold(Transformation): """Absorb preceding Add ops into MultiThreshold by updating the threshold values. Only scalar/1D add vectors can be absorbed.""" @@ -113,28 +132,55 @@ def apply(self, model): if n.op_type == "Add" and not model.is_fork_node(n) and not model.is_join_node(n): consumer = model.find_consumer(n.output[0]) if consumer is not None and consumer.op_type == "MultiThreshold": - add_weight_name = n.input[1] - threshold_name = consumer.input[1] - A = model.get_initializer(add_weight_name) - T = model.get_initializer(threshold_name) - assert A is not None, "Initializer for add weights is not set." + # As Add is not a join node, there must be one initializer + # and one dynamic input. We do not know their order, but + # can group them accordingly to extract the tensor names + (start,), (add_weight,) = group_inputs_by_category(n, model) + threshold = consumer.input[1] + A = model.get_initializer(add_weight) + T = model.get_initializer(threshold) + # Test for the thresholds actually being initializers + # Note: No need to validate the add_weights anymore, this + # is already handled by the grouping and is_join_node test. assert T is not None, "Initializer for thresholds is not set." - start_name = n.input[0] # we can only absorb 0d or 1d adds is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape) actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape))) is_1d = actual_ndims == 1 + + def can_broadcast_shapes(lhs, rhs): + # Broadcasting might raise an exception + try: + # Try broadcasting the shapes + if len(np.broadcast_shapes(lhs, rhs)) == 2: + # These tensors can be broadcast, preserving the + # left-hand-side shape + return True + # These tensors cannot be broadcast + return False + # Failing to broadcast the tensors raises ValueError + except ValueError: + # These tensors cannot be broadcast + return False + if is_scalar or is_1d: - Tnew = T - A.reshape(-1, 1) - # Tnew = T - A.reshape(-1, T.shape[1]) - # compute new thresholds and set initializer - model.set_initializer(threshold_name, Tnew) - # wire add input directly to MultiThreshold - consumer.input[0] = start_name - # remove the add node - graph.node.remove(n) - graph_modified = True - return (model, graph_modified) + # Reshape addition parameters to have the elements/PE + # dimension first, aligned with the thresholds. + A = A.reshape(-1, 1) # noqa: Not lowercase + # Check that we can actually broadcast the addition + # weights to the thresholds tensors, i.e., it is adding + # along the right axis + if can_broadcast_shapes(T.shape, A.shape): + Tnew = T - A # noqa: Not lowercase + # Tnew = T - A.reshape(-1, T.shape[1]) + # compute new thresholds and set initializer + model.set_initializer(threshold, Tnew) + # wire add input directly to MultiThreshold + consumer.input[0] = start + # remove the add node + graph.node.remove(n) + graph_modified = True + return model, graph_modified class AbsorbMulIntoMultiThreshold(Transformation): @@ -186,7 +232,7 @@ def apply(self, model): graph_modified = False for n in graph.node: node_ind += 1 - if n.op_type == "Mul": + if n.op_type == "Mul" and not model.is_join_node(n): mul_weight_name = n.input[1] A = model.get_initializer(mul_weight_name) assert A is not None, "Initializer for mul weights is not set." diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 8ac2d7dad6..4cfc4cfff7 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -91,13 +91,27 @@ def apply(self, model): graph.node.insert(node_ind + 1, new_add) # replace add value model.set_initializer(add_weight_name, BA) + # Delete the datatype annotation of the parameter tensor + # TODO: Maybe we should derive the new type properly... + model.set_tensor_datatype(add_weight_name, None) + # Delete the shape annotation of the connecting tensors + # to be re-done later. This prevents shapes from propagating + # backwards. + # Note: Do not delete annotation for the input tensor, as + # this prevents future shape inference. + model.set_tensor_shape(middle_name, None) + model.set_tensor_shape(end_name, None) # remove old nodes graph.node.remove(n) graph.node.remove(consumer) graph_modified = True - + # Note: Running shape inference is necessary as shape + # annotations have been deleted above model = model.transform(InferShapes()) - return (model, graph_modified) + # Note. Running datatype inference is necessary as datatype + # annotations have been deleted above + model = model.transform(InferDataTypes()) + return model, graph_modified class MoveScalarMulPastMatMul(Transformation): @@ -580,6 +594,11 @@ def apply(self, model): if prod0.op_type == "Mul" and prod1.op_type == "Mul": if np.array_equal(init0, init1): self.move_node(graph, n, prod0, prod1, node_ind) + # Delete shape annotations of connecting tensors to be + # re-done later. This prevents wrong shape propagation, + # for example in cases where the Add broadcasts shapes. + model.set_tensor_shape(n.output[0], None) + model.set_tensor_shape(prod0.output[0], None) node_ind -= 1 graph_modified = True elif prod0.op_type == "Add" and prod1.op_type == "Add": @@ -587,12 +606,20 @@ def apply(self, model): # update initializer of prod0, which we'll move model.set_initializer(prod0.input[1], init) self.move_node(graph, n, prod0, prod1, node_ind) + # Delete shape annotations of connecting tensors to be + # re-done later. This prevents wrong shape propagation, + # for example in cases where the Add broadcasts shapes. + model.set_tensor_shape(n.output[0], None) + model.set_tensor_shape(prod0.output[0], None) node_ind -= 1 graph_modified = True else: continue + # Note: Running shape inference is necessary as shape annotations have + # been deleted above model = model.transform(InferShapes()) - return (model, graph_modified) + model = model.transform(InferDataTypes()) + return model, graph_modified class MoveScalarLinearPastInvariants(Transformation): diff --git a/src/finn/transformation/streamline/round_thresholds.py b/src/finn/transformation/streamline/round_thresholds.py index 5ba5ee0ff5..2666242730 100644 --- a/src/finn/transformation/streamline/round_thresholds.py +++ b/src/finn/transformation/streamline/round_thresholds.py @@ -26,43 +26,90 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# Need numpy for modifying the onnx graph tensors, which are numpy style arrays import numpy as np + +# QONNX wrapper of ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper + +# QONNX graph transformation base class from qonnx.transformation.base import Transformation +# Transformation running qonnx datatype inference +from qonnx.transformation.infer_datatypes import InferDataTypes + +# Rounds and clips thresholds to integer values if the node inputs are integer, +# respecting range, representability and data type (promotion) of the container +# data type class RoundAndClipThresholds(Transformation): """For MultiThreshold nodes operating on integer inputs, round up thresholds values to the nearest integer. Additionally, if the input - is unsigned, sets negative thresholds to zero.""" + is unsigned, sets negative thresholds to zero. Type-casts thresholds (back) + to the float32 container type (this is separate from the quantization + annotation). Runs InferDataTypes() afterward to propagate any changes to the + quantization data types.""" - def apply(self, model): + # Applies the transform to a whole model graph + def apply(self, model: ModelWrapper): # noqa + # Get the model graph out of the model wrapper object graph = model.graph + # Keep track of whether the graph has been modified graph_modified = False - for n in graph.node: - if n.op_type == "MultiThreshold": - idtype = model.get_tensor_datatype(n.input[0]) - T = model.get_initializer(n.input[1]) - Tnew = np.ceil(T) - if idtype.is_integer() and (T != Tnew).any(): - # round up the thresholds to nearest integer - model.set_initializer(n.input[1], Tnew) - # use same datatype as inputs for thresholds - model.set_tensor_datatype(n.input[1], idtype) - graph_modified = True - if idtype.is_integer() and not idtype.signed() and (Tnew < 0).any(): - # clip any negative thresholds if input is unsigned - Tnew = np.clip(Tnew, 0, None) - model.set_initializer(n.input[1], Tnew) - # use same datatype as inputs for thresholds - model.set_tensor_datatype(n.input[1], idtype) - graph_modified = True - if idtype.is_integer() and ( - (Tnew < (idtype.min() - 1)).any() or (Tnew > (idtype.max() + 1)).any() - ): - # clip any large thresholds to input range + 1 - Tnew = np.clip(Tnew, idtype.min() - 1, idtype.max() + 1) - model.set_initializer(n.input[1], Tnew) - # use same datatype as inputs for thresholds - model.set_tensor_datatype(n.input[1], idtype) + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Applies to initializer tensors of MultiThreshold operations + if node.op_type == "MultiThreshold": + # Try to get the thresholds initializer tensor + thresholds = model.get_initializer(node.input[1]) + # There might be no constant thresholds stored as initializer + # tensor inside the model + if thresholds is None: + # Nothing we can do, skip to the next node + continue + # Get the data type of the inputs to this operation + dtype = model.get_tensor_datatype(node.input[0]) + # This transformation only applies to thresholding operations + # operating on integer inputs + if not dtype.is_integer(): + # Nothing we can do, skip to the next node + continue + # Round thresholds up to nearest integer and clip thresholds + # outside the input range + # Note: This might promote the thresholds to float64 and + # introduce extra inaccuracies due to large integers not being + # exactly representable in floating-point representation. + # See for example: np.ceil(np.float32(16777217)) == 16777216 + # fmt: off + new_thresholds = np.clip( + np.ceil(thresholds), dtype.min(), dtype.max() + ) + # fmt: on + # Convert back to the preferred float32 container type + # Note: np.clip might have promoted the thresholds to float64 + # TODO: Maybe consider an int64 container type for thresholds + # rounded to integer? Need to check all other transformations + # and code generation through the whole FINN and QONNX stack + # first, as these probably assume a float32 container type. + new_thresholds = new_thresholds.astype(np.float32) + # Insert the rounded and clipped thresholds back into the model + model.set_initializer(node.input[1], new_thresholds) + # The rounded and clipped thresholds now fit into the input data + # type + model.set_tensor_datatype(node.input[1], dtype) + # Test whether the new thresholds actually differ from the old + # ones + if np.any(new_thresholds != thresholds): + # Track the graph has been modified to inform the transform + # container to exhaustively repeat this transformation until + # no changes are possible graph_modified = True - return (model, graph_modified) + # Immediately exit here to propagate the data type changes + # before considering the next node + break + # Some data types might have changed, do one pass of data type inference + # to propagate these changes through the graph + model = model.transform(InferDataTypes()) + # Return the transformed model and indicate whether the graph actually + # has been transformed to exhaustively apply this transformation again. + return model, graph_modified diff --git a/tests/fpgadataflow/test_elementwise_binary.py b/tests/fpgadataflow/test_elementwise_binary.py new file mode 100644 index 0000000000..0222be62a4 --- /dev/null +++ b/tests/fpgadataflow/test_elementwise_binary.py @@ -0,0 +1,835 @@ +# fmt: off +# Disable formatter. This is deliberately formatted to stay within 80 characters +# per line. Black, however, formats some lines going beyond this. + +# Testing framework +import pytest + +# Numpy math and arrays +import numpy as np + +# Create temporary files automatically deleted after integration test +import tempfile + +# PyTorch required for integration test +import torch + +# Export brevitas models to QONNX representation in integration test +from brevitas.export import export_qonnx + +# Test the quantized elementwise addition operation from brevitas in integration +# test: this one should be representative enough for the operator pattern +from brevitas.nn import QuantEltwiseAdd + +# ONNX graph and tensor utility +from onnx import TensorProto +from onnx import helper as oh + +# QONNX/FINN datatypes +from qonnx.core.datatype import DataType + +# QONNX wrapper to ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper + +# Execute onnx model graphs +from qonnx.core.onnx_exec import execute_onnx + +# Registry of all QONNX CustomOps +from qonnx.custom_op.registry import getCustomOp + +# Cleanup transformations required after QONNX model import +from qonnx.transformation.general import ( + ApplyConfig, + GiveReadableTensorNames, + GiveUniqueNodeNames, + GiveUniqueParameterTensors, + RemoveUnusedTensors, +) + +# Adds data layout annotations to the model graph to correctly convert +# quantizers to multi-thresholds +from qonnx.transformation.infer_data_layouts import InferDataLayouts + +# QONNX graph transformations for inferring datatypes and shapes +from qonnx.transformation.infer_datatypes import InferDataTypes +from qonnx.transformation.infer_shapes import InferShapes + +# Utility for wrapping onnx graphs and generating tensor of FINN datatypes +from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model + +# FINN graph transformations for preparing simulation (cppsim or rtlsim) +from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim + +# Mapping to hardware operators of the two operations relevant for the +# integration test +# Note: The integration test serves as the test-case for +# InferElementwiseBinaryOperation +from finn.transformation.fpgadataflow.convert_to_hw_layers import ( + InferElementwiseBinaryOperation, + InferThresholdingLayer, +) +# Synthesizes HLS code generated from an operator to IP block +from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP +# Bit-width optimization transformations +from finn.transformation.fpgadataflow.minimize_accumulator_width import ( + MinimizeAccumulatorWidth, +) +from finn.transformation.fpgadataflow.minimize_weight_bit_width import ( + MinimizeWeightBitWidth, +) +# Transformations preparing the operators for C++ and RTL simulation +from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim +from finn.transformation.fpgadataflow.prepare_ip import PrepareIP +from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim +from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode +from finn.transformation.fpgadataflow.specialize_layers import SpecializeLayers + +# Converts between QONNX and FINN dialect of ONNX representation +from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN + +# Standard set of streamlining transformations delivered with FINN +from finn.transformation.streamline import Streamline + +# Specific streamlining transformations which needs to be applied manually in +# integration test +from finn.transformation.streamline.absorb import ( + AbsorbMulIntoMultiThreshold, + AbsorbSignBiasIntoMultiThreshold, +) +from finn.transformation.streamline.reorder import MoveLinearPastEltwiseAdd + +# Checks whether a node is a fpgadataflow backend node handled by FINN +from finn.util.fpgadataflow import is_fpgadataflow_node + + +# Specializes all nodes to be implemented as HLS backend +def specialize_hls(model: ModelWrapper): + # Mark all nodes to be specialized as HLS backend implementations + for node in model.graph.node: # noqa: Duplicate test setup code + # Skip non-fpgadataflow backend operators as these do not have the + # preferred_impl_style attribute + if is_fpgadataflow_node(node): + # Get the CustomOp instance of the node to get access to the node + # attributes + inst = getCustomOp(node) + # Note: only HLS-based layers execute C++ Simulation + inst.set_nodeattr("preferred_impl_style", "hls") + # Turn all HWCustomOp layers into HLS specializations + return model.transform(SpecializeLayers("xczu7ev-ffvc1156-2-e")) + + +# Mapping of ElementwiseBinaryOperation specializations to numpy reference +# implementation functions +NUMPY_REFERENCES = { + "ElementwiseAdd": np.add, + "ElementwiseSub": np.subtract, + "ElementwiseMul": np.multiply, + # TODO: "ElementwiseDiv": np.divide, Cannot guarantee non-zero test input + # TODO: "ElementwiseMod": np.mode / np.fmod + "ElementwiseAnd": np.logical_and, + "ElementwiseOr": np.logical_or, + "ElementwiseXor": np.logical_xor, + "ElementwiseEqual": np.equal, + "ElementwiseLess": np.less, + "ElementwiseLessOrEqual": np.less_equal, + "ElementwiseGreater": np.greater, + "ElementwiseGreaterOrEqual": np.greater_equal, + "ElementwiseBitwiseAnd": np.bitwise_and, + "ElementwiseBitwiseOr": np.bitwise_or, + "ElementwiseBitwiseXor": np.bitwise_xor, + # TODO: "ElementwiseBitShift": np.left_shift / np.right_shift + # TODO: "ElementwisePow": np.power +} + +# Names of bitwise operations which somtimes require special treatment +BITWISE = [ + "ElementwiseBitwiseAnd", "ElementwiseBitwiseOr", "ElementwiseBitwiseXor" +] + + +# Creates a model executing a binary elementwise operation +def mock_elementwise_binary_operation( + op_type, lhs_dtype, rhs_dtype, out_dtype, lhs_shape, rhs_shape, pe +): + # Automatically derive the output shape by broadcasting the inputs + out_shape = np.broadcast_shapes(lhs_shape, rhs_shape) + # Create a node representing the binary elementwise operation + node = oh.make_node( + # Operator type from the name of the fpgadataflow hlscustomop + op_type=op_type, + # Specify the domain, i.e., the package to look for the custom operator + # implementation + domain="finn.custom_op.fpgadataflow", + # Execution backend: Required attribute inherited from HLSCustomOp + backend="fpgadataflow", + # Just one input + inputs=["lhs", "rhs"], + # Enumerate the outputs + outputs=["out"], + # Data type of the left-hand-side input elements + lhs_dtype=lhs_dtype, + # Data type of the right-hand-side input elements + rhs_dtype=rhs_dtype, + # Data type of the output elements + out_dtype=out_dtype, + # Shape of the left-hand-side input + lhs_shape=lhs_shape, + # Shape of the right-hand-side input + rhs_shape=rhs_shape, + # Shape of the output, mus correspond to multi-directional + # broadcasting of the left- and right-hand-side + out_shape=out_shape, + # Number of elements to process in parallel + PE=pe, + ) + # Construct the input tensor value infos + lhs = oh.make_tensor_value_info("lhs", TensorProto.FLOAT, lhs_shape) + rhs = oh.make_tensor_value_info("rhs", TensorProto.FLOAT, rhs_shape) + # Construct output tensor value infos + out = oh.make_tensor_value_info("out", TensorProto.FLOAT, out_shape) + # Create a graph connecting the node to the inputs and outputs + graph = oh.make_graph( + [node], inputs=[lhs, rhs], outputs=[out], name="elementwise-binary" + ) + # Wrap the ONNX graph in QONNX model wrapper + model = ModelWrapper( + qonnx_make_model(graph, producer_name="elementwise-binary") + ) + + # Add datatype annotation to the value info of input tensors + model.set_tensor_datatype("lhs", DataType[lhs_dtype]) + model.set_tensor_datatype("rhs", DataType[rhs_dtype]) + model.set_tensor_datatype("out", DataType[out_dtype]) + + # Return the wrapped onnx model + return model + + +# Operator type to be tested +@pytest.mark.parametrize("op_type", [ # noqa: Duplicate test setup + # Test all Numpy references specified above + *NUMPY_REFERENCES.keys() +]) +# Data type of the left-hand-side input elements +@pytest.mark.parametrize("lhs_dtype", ["INT8"]) +# Data type of the right-hand-side input elements +@pytest.mark.parametrize("rhs_dtype", ["INT8"]) +# Data type of the output elements +@pytest.mark.parametrize("out_dtype", ["INT32"]) +# Shape of the left-hand-side input +@pytest.mark.parametrize("lhs_shape", [ + [3, 1, 7, 1], [1] +]) +# Shape of the right-hand-side input +@pytest.mark.parametrize("rhs_shape", [ + [3, 32, 1, 16], +]) +# Which inputs to set as initializers +@pytest.mark.parametrize("initializers", [ + [], ["lhs"], ["rhs"], ["lhs", "rhs"] +]) +# Number of elements to process in parallel +@pytest.mark.parametrize("pe", [1, 2, 4]) +def test_elementwise_binary_operation_python( + op_type, lhs_dtype, rhs_dtype, out_dtype, lhs_shape, rhs_shape, pe, + initializers +): + # Make dummy model for testing + model = mock_elementwise_binary_operation( # noqa: Duplicate test setup + op_type, lhs_dtype, rhs_dtype, out_dtype, lhs_shape, rhs_shape, pe + ) + # Prepare the execution context + context = { + "lhs": gen_finn_dt_tensor(DataType[lhs_dtype], lhs_shape), + "rhs": gen_finn_dt_tensor(DataType[rhs_dtype], rhs_shape) + } + + # Turn selected inputs into initializers + for name in initializers: + model.set_initializer(name, context[name]) + + # Get the numpy reference implementation for this operation + numpy_reference = NUMPY_REFERENCES[op_type] + + # Test running shape and data type inference on the model graph + model = model.transform(InferDataTypes()) + model = model.transform(InferShapes()) + + # Try to minimize the bit-widths of all data types involved + model = model.transform(MinimizeWeightBitWidth()) + model = model.transform(MinimizeAccumulatorWidth()) + + # Set model execution mode to python simulation + model = model.transform(SetExecMode("python")) + model = model.transform(GiveUniqueNodeNames()) + + # Compute ground-truth output in software + o_expected = numpy_reference( + # Note: Need to make sure these have the right type for the Numpy API + # Note: Assume all test cases fit into int64 without loss of precision + context["lhs"].astype(np.int64), + context["rhs"].astype(np.int64) + ) + # Execute the onnx model to collect the result + o_produced = execute_onnx(model, context)["out"] + + # Compare the expected to the produced for exact equality + assert np.all(o_produced == o_expected) + + +# Operator type to be tested +@pytest.mark.parametrize("op_type", [ # noqa: Duplicate test setup + # Test all Numpy references specified above, except for the bitwise + # operations, for which floating-point doe not make sense + *sorted((NUMPY_REFERENCES.keys() - BITWISE)), +]) +# Data type of the left-hand-side input elements +@pytest.mark.parametrize("lhs_dtype", ["FLOAT32"]) +# Data type of the right-hand-side input elements +@pytest.mark.parametrize("rhs_dtype", ["FLOAT32"]) +# Data type of the output elements +@pytest.mark.parametrize("out_dtype", ["FLOAT32"]) +# Shape of the left-hand-side input +@pytest.mark.parametrize("lhs_shape", [ + [3, 1, 7, 1], [1] +]) +# Shape of the right-hand-side input +@pytest.mark.parametrize("rhs_shape", [ + [3, 32, 1, 16], +]) +# Which inputs to set as initializers +@pytest.mark.parametrize("initializers", [ + [], ["lhs"], ["rhs"], ["lhs", "rhs"] +]) +# Number of elements to process in parallel +@pytest.mark.parametrize("pe", [1, 2, 4]) +def test_elementwise_binary_operation_float_python( + op_type, lhs_dtype, rhs_dtype, out_dtype, lhs_shape, rhs_shape, pe, + initializers +): + # Make dummy model for testing + model = mock_elementwise_binary_operation( # noqa: Duplicate test setup + op_type, lhs_dtype, rhs_dtype, out_dtype, lhs_shape, rhs_shape, pe + ) + # Prepare the execution context + context = { + "lhs": gen_finn_dt_tensor(DataType[lhs_dtype], lhs_shape), + "rhs": gen_finn_dt_tensor(DataType[rhs_dtype], rhs_shape) + } + + # Turn selected inputs into initializers + for name in initializers: + model.set_initializer(name, context[name]) + + # Get the numpy reference implementation for this operation + numpy_reference = NUMPY_REFERENCES[op_type] + + # Test running shape and data type inference on the model graph + model = model.transform(InferDataTypes()) + model = model.transform(InferShapes()) + + # Try to minimize the bit-widths of all data types involved + model = model.transform(MinimizeWeightBitWidth()) + model = model.transform(MinimizeAccumulatorWidth()) + + # Set model execution mode to python simulation + model = model.transform(SetExecMode("python")) + model = model.transform(GiveUniqueNodeNames()) + + # Compute ground-truth output in software + o_expected = numpy_reference(context["lhs"], context["rhs"]) + # Execute the onnx model to collect the result + o_produced = execute_onnx(model, context)["out"] + + # Compare the expected to the produced for exact equality + assert np.all(o_produced == o_expected) + + +# Operator type to be tested +@pytest.mark.parametrize("op_type", [ # noqa: Duplicate test setup + # Test all Numpy references specified above + *NUMPY_REFERENCES.keys(), +]) +# Data type of the left-hand-side input elements +@pytest.mark.parametrize("lhs_dtype", ["INT8"]) +# Data type of the right-hand-side input elements +@pytest.mark.parametrize("rhs_dtype", ["INT8"]) +# Data type of the output elements +@pytest.mark.parametrize("out_dtype", ["INT32"]) +# Shape of the left-hand-side input +@pytest.mark.parametrize("lhs_shape", [ + [3, 1, 7, 1], [1] +]) +# Shape of the right-hand-side input +@pytest.mark.parametrize("rhs_shape", [ + [3, 32, 1, 16], +]) +# Which inputs to set as initializers +@pytest.mark.parametrize("initializers", [ + [], ["lhs"], ["rhs"], ["lhs", "rhs"] +]) +# Number of elements to process in parallel +@pytest.mark.parametrize("pe", [1, 2, 4]) +# This is a slow running fpgadataflow type of test which requires vivado +@pytest.mark.fpgadataflow +@pytest.mark.slow +def test_elementwise_binary_operation_cppsim( + op_type, lhs_dtype, rhs_dtype, out_dtype, lhs_shape, rhs_shape, pe, + initializers +): + # Make dummy model for testing + model = mock_elementwise_binary_operation( # noqa: Duplicate test setup + op_type, lhs_dtype, rhs_dtype, out_dtype, lhs_shape, rhs_shape, pe + ) + # Prepare the execution context + context = { + "lhs": gen_finn_dt_tensor(DataType[lhs_dtype], lhs_shape), + "rhs": gen_finn_dt_tensor(DataType[rhs_dtype], rhs_shape) + } + + # Turn selected inputs into initializers + for name in initializers: + model.set_initializer(name, context[name]) + + # Get the numpy reference implementation for this operation + numpy_reference = NUMPY_REFERENCES[op_type] + + # Test running shape and data type inference on the model graph + model = model.transform(InferDataTypes()) + model = model.transform(InferShapes()) + # Specializes all nodes to be implemented as HLS backend + model = specialize_hls(model) + + # Try to minimize the bit-widths of all data types involved + model = model.transform(MinimizeWeightBitWidth()) + model = model.transform(MinimizeAccumulatorWidth()) + + # Set model execution mode to C++ simulation + model = model.transform(SetExecMode("cppsim")) + # Generates the C++ source and compiles the C++ simulation + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareCppSim()) + model = model.transform(CompileCppSim()) + + # Compute ground-truth output in software + o_expected = numpy_reference( + # Note: Need to make sure these have the right type for the Numpy API + # Note: Assume all test cases fit into int64 without loss of precision + context["lhs"].astype(np.int64), + context["rhs"].astype(np.int64) + ) + # Execute the onnx model to collect the result + o_produced = execute_onnx(model, context)["out"] + + # Compare the expected to the produced for exact equality + assert np.all(o_produced == o_expected) + + +# Operator type to be tested +@pytest.mark.parametrize("op_type", [ # noqa: Duplicate test setup + # Test all Numpy references specified above, except for the bitwise + # operations, for which floating-point doe not make sense + *sorted((NUMPY_REFERENCES.keys() - BITWISE)), +]) +# Data type of the left-hand-side input elements +@pytest.mark.parametrize("lhs_dtype", ["FLOAT32"]) +# Data type of the right-hand-side input elements +@pytest.mark.parametrize("rhs_dtype", ["FLOAT32"]) +# Data type of the output elements +@pytest.mark.parametrize("out_dtype", ["FLOAT32"]) +# Shape of the left-hand-side input +@pytest.mark.parametrize("lhs_shape", [ + [3, 1, 7, 1], [1] +]) +# Shape of the right-hand-side input +@pytest.mark.parametrize("rhs_shape", [ + [3, 32, 1, 16], +]) +# Which inputs to set as initializers +@pytest.mark.parametrize("initializers", [ + [], ["lhs"], ["rhs"], ["lhs", "rhs"] +]) +# Number of elements to process in parallel +@pytest.mark.parametrize("pe", [1, 2, 4]) +# This is a slow running fpgadataflow type of test which requires vivado +@pytest.mark.fpgadataflow +@pytest.mark.slow +def test_elementwise_binary_operation_float_cppsim( + op_type, lhs_dtype, rhs_dtype, out_dtype, lhs_shape, rhs_shape, pe, + initializers +): + # Make dummy model for testing + model = mock_elementwise_binary_operation( # noqa: Duplicate test setup + op_type, lhs_dtype, rhs_dtype, out_dtype, lhs_shape, rhs_shape, pe + ) + # Prepare the execution context + context = { + "lhs": gen_finn_dt_tensor(DataType[lhs_dtype], lhs_shape), + "rhs": gen_finn_dt_tensor(DataType[rhs_dtype], rhs_shape) + } + + # Turn selected inputs into initializers + for name in initializers: + model.set_initializer(name, context[name]) + + # Get the numpy reference implementation for this operation + numpy_reference = NUMPY_REFERENCES[op_type] + + # Test running shape and data type inference on the model graph + model = model.transform(InferDataTypes()) + model = model.transform(InferShapes()) + # Specializes all nodes to be implemented as HLS backend + model = specialize_hls(model) + + # Try to minimize the bit-widths of all data types involved + model = model.transform(MinimizeWeightBitWidth()) + model = model.transform(MinimizeAccumulatorWidth()) + + # Set model execution mode to C++ simulation + model = model.transform(SetExecMode("cppsim")) + # Generates the C++ source and compiles the C++ simulation + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareCppSim()) + model = model.transform(CompileCppSim()) + + # Compute ground-truth output in software + o_expected = numpy_reference(context["lhs"], context["rhs"]) + # Execute the onnx model to collect the result + o_produced = execute_onnx(model, context)["out"] + + # Compare the expected to the produced for exact equality + assert np.all(o_produced == o_expected) + + +# Operator type to be tested +@pytest.mark.parametrize("op_type", [ # noqa: Duplicate test setup + # Test all Numpy references specified above + *NUMPY_REFERENCES.keys() +]) +# Data type of the left-hand-side input elements +@pytest.mark.parametrize("lhs_dtype", ["INT8"]) +# Data type of the right-hand-side input elements +@pytest.mark.parametrize("rhs_dtype", ["INT8"]) +# Data type of the output elements +@pytest.mark.parametrize("out_dtype", ["INT32"]) +# Shape of the left-hand-side input +@pytest.mark.parametrize("lhs_shape", [ + [3, 1, 7, 1], [1] +]) +# Shape of the right-hand-side input +@pytest.mark.parametrize("rhs_shape", [ + [3, 32, 1, 16], +]) +# Which inputs to set as initializers +@pytest.mark.parametrize("initializers", [ + [], ["lhs"], ["rhs"], ["lhs", "rhs"] +]) +# Number of elements to process in parallel +@pytest.mark.parametrize("pe", [1, 2, 4]) +# This is a slow running fpgadataflow type of test which requires vivado +@pytest.mark.fpgadataflow +@pytest.mark.slow +def test_elementwise_binary_operation_rtlsim( + op_type, lhs_dtype, rhs_dtype, out_dtype, lhs_shape, rhs_shape, pe, + initializers +): + # Make dummy model for testing + model = mock_elementwise_binary_operation( # noqa: Duplicate test setup + op_type, lhs_dtype, rhs_dtype, out_dtype, lhs_shape, rhs_shape, pe + ) + # Prepare the execution context + context = { + "lhs": gen_finn_dt_tensor(DataType[lhs_dtype], lhs_shape), + "rhs": gen_finn_dt_tensor(DataType[rhs_dtype], rhs_shape) + } + + # Turn selected inputs into initializers + for name in initializers: + model.set_initializer(name, context[name]) + + # Get the numpy reference implementation for this operation + numpy_reference = NUMPY_REFERENCES[op_type] + + # Test running shape and data type inference on the model graph + model = model.transform(InferDataTypes()) + model = model.transform(InferShapes()) + # Specializes all nodes to be implemented as HLS backend + model = specialize_hls(model) + + # Try to minimize the bit-widths of all data types involved + model = model.transform(MinimizeWeightBitWidth()) + model = model.transform(MinimizeAccumulatorWidth()) + + # Set model execution mode to RTL simulation + model = model.transform(SetExecMode("rtlsim")) + # Generates the C++ source and compiles the RTL simulation + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP("xczu7ev-ffvc1156-2-e", 10)) # noqa + model = model.transform(HLSSynthIP()) + model = model.transform(PrepareRTLSim()) + + # Compute ground-truth output in software + o_expected = numpy_reference( + # Note: Need to make sure these have the right type for the Numpy API + # Note: Assume all test cases fit into int64 without loss of precision + context["lhs"].astype(np.int64), + context["rhs"].astype(np.int64) + ) + # Execute the onnx model to collect the result + o_produced = execute_onnx(model, context)["out"] + + # Compare the expected to the produced for exact equality + assert np.all(o_produced == o_expected) + + +# TODO: No floating-point support in RTL simulation +# # Operator type to be tested +# @pytest.mark.parametrize("op_type", [ # noqa: Duplicate test setup +# # Test all Numpy references specified above, except for the bitwise +# # operations, for which floating-point doe not make sense +# *sorted((NUMPY_REFERENCES.keys() - BITWISE)), +# ]) +# # Data type of the left-hand-side input elements +# @pytest.mark.parametrize("lhs_dtype", ["FLOAT32"]) +# # Data type of the right-hand-side input elements +# @pytest.mark.parametrize("rhs_dtype", ["FLOAT32"]) +# # Data type of the output elements +# @pytest.mark.parametrize("out_dtype", ["FLOAT32"]) +# # Shape of the left-hand-side input +# @pytest.mark.parametrize("lhs_shape", [ +# [3, 1, 7, 1], [1] +# ]) +# # Shape of the right-hand-side input +# @pytest.mark.parametrize("rhs_shape", [ +# [3, 32, 1, 16], +# ]) +# # Which inputs to set as initializers +# @pytest.mark.parametrize("initializers", [ +# [], ["lhs"], ["rhs"], ["lhs", "rhs"] +# ]) +# # Number of elements to process in parallel +# @pytest.mark.parametrize("pe", [1, 2, 4]) +# # This is a slow running fpgadataflow type of test which requires vivado +# @pytest.mark.fpgadataflow +# @pytest.mark.slow +# def test_elementwise_binary_operation_float_rtlsim( +# op_type, lhs_dtype, rhs_dtype, out_dtype, lhs_shape, rhs_shape, pe, +# initializers +# ): +# # Make dummy model for testing +# model = mock_elementwise_binary_operation( # noqa: Duplicate test setup +# op_type, lhs_dtype, rhs_dtype, out_dtype, lhs_shape, rhs_shape, pe +# ) +# # Prepare the execution context +# context = { +# "lhs": gen_finn_dt_tensor(DataType[lhs_dtype], lhs_shape), +# "rhs": gen_finn_dt_tensor(DataType[rhs_dtype], rhs_shape) +# } +# +# # Turn selected inputs into initializers +# for name in initializers: +# model.set_initializer(name, context[name]) +# +# # Get the numpy reference implementation for this operation +# numpy_reference = NUMPY_REFERENCES[op_type] +# +# # Test running shape and data type inference on the model graph +# model = model.transform(InferDataTypes()) +# model = model.transform(InferShapes()) +# # Specializes all nodes to be implemented as HLS backend +# model = specialize_hls(model) +# +# # Try to minimize the bit-widths of all data types involved +# model = model.transform(MinimizeWeightBitWidth()) +# model = model.transform(MinimizeAccumulatorWidth()) +# +# # Set model execution mode to RTL simulation +# model = model.transform(SetExecMode("rtlsim")) +# # Generates the C++ source and compiles the RTL simulation +# model = model.transform(GiveUniqueNodeNames()) +# model = model.transform(PrepareIP("xczu7ev-ffvc1156-2-e", 10)) # noqa +# model = model.transform(HLSSynthIP()) +# model = model.transform(PrepareRTLSim()) +# +# # Compute ground-truth output in software +# o_expected = numpy_reference(context["lhs"], context["rhs"]) +# # Execute the onnx model to collect the result +# o_produced = execute_onnx(model, context)["out"] +# +# # Compare the expected to the produced for exact equality +# assert np.all(o_produced == o_expected) + + +# Test-case setting up a complete dummy model containing various elementwise +# binary operations in PyTorch, converting to QONNX and verifying in Python, C++ +# and RTL simulation +# Shape of the left-hand-side input +# Note: Stripped down test of broadcasting semantics due to rather poor support +# for arbitrary data layouts inf QONNX and FINN: Only 2d and 4d layouts (with +# certain assumptions/restrictions) are really supported. +# Note: Cannot test scalar shapes (or effectively scalar shapes like [1,1]), due +# to streamlining integrating those into MultiThresholds (removing the operator +# to be tested), leading to consecutive quantizers. Consecutive quantizers +# should be avoided as this sometimes can cause range and precision errors. +@pytest.mark.parametrize("lhs_shape", [[32, 1]]) +# Shape of the right-hand-side input +@pytest.mark.parametrize("rhs_shape", [[32, 16]]) +# Which inputs to set as initializers +@pytest.mark.parametrize("initializers", [[], ["lhs"], ["rhs"]]) +# Number of elements to process in parallel +@pytest.mark.parametrize("pe", [1, 2, 4]) +# This is a slow running fpgadataflow type of test which requires vivado +@pytest.mark.fpgadataflow +@pytest.mark.slow +def test_elementwise_binary_operation_integration_elementwise_add( + lhs_shape, rhs_shape, initializers, pe +): + # PyTorch model wrapping the component(s) to be tested + class Dummy(torch.nn.Module): + # Sets up the test model and initializes parameters + def __init__(self): + # Initialize the PyTorch Module superclass + super().__init__() + # Elementwise addition component to be tested + self.add = QuantEltwiseAdd() + # Left- and right-hand-side input tensors in case these are set to + # be initializers + self.lhs = torch.randn(*lhs_shape) + self.rhs = torch.randn(*rhs_shape) + + # Model forward pass taking multiple inputs as arguments + def forward(self, *xs): + # Depending on the test configuration, extract inputs to the add + # operation from model inputs of from model parameters + _lhs = self.lhs if "lhs" in initializers else xs[0] + _rhs = self.rhs if "rhs" in initializers else xs[1] + # Quantized elementwise addition of the two inputs + return self.add(_lhs, _rhs) + + # Create the test instance of the dummy model + model = Dummy() + # Create dummy test inputs + lhs = torch.randn(*lhs_shape) + rhs = torch.randn(*rhs_shape) + # Do a forward pass with model in training mode to calibrate the quantizers + _ = model(lhs, rhs) + # Switch model to evaluation mode to keep parameters fixed for export + model = model.eval() + # Do not accumulate gradients while generating test output + with torch.no_grad(): + # Model forward pass generating the expected output for verification + out_expected = model(lhs, rhs).numpy().astype(np.float32) + # Generate a temporary directory for running this test + with tempfile.TemporaryDirectory() as tmp: + # Export the model to ONNX format to be consumed by FINN + export_qonnx(model, (lhs, rhs), tmp + "/model.onnx") + # Wrap the model with QONNX wrapper for transformations + model = ModelWrapper(tmp + "/model.onnx") + # Cleanup transformations preparing the model to be consumed by FINN + model = model.transform(InferDataTypes()) + model = model.transform(InferShapes()) + model = model.transform(InferDataLayouts()) + model = model.transform(ConvertQONNXtoFINN()) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(GiveUniqueParameterTensors()) + model = model.transform(GiveReadableTensorNames()) + model = model.transform(RemoveUnusedTensors()) + # Need to absorb scalar multiplication into the thresholding layer + # first, to prevent large rounding error due to moving these in front of + # add operations later. + model = model.transform(AbsorbMulIntoMultiThreshold()) + # Need to absorb the sign bias of the quantizer back into the + # corresponding thresholds first instead of moving them past the next + # operator to avoid sign and range issues. + model = model.transform(AbsorbSignBiasIntoMultiThreshold()) + # There might be identical Mul in front of the joining Add node + model = model.transform(MoveLinearPastEltwiseAdd()) + model = model.transform(AbsorbMulIntoMultiThreshold()) + # Do a single round of standard streamlining of the model graph + model = model.transform(Streamline()) + # Convert layers to hardware custom operations + model = model.transform(InferThresholdingLayer()) + model = model.transform(InferElementwiseBinaryOperation( + # We want to keep the output de-quantization off-chip + _filter=InferElementwiseBinaryOperation.reject_floats + )) + + # Apply folding config to set the PE parallelism for hardware layers + model = model.transform(ApplyConfig({ + "Defaults": {"PE": [pe, ["ElementwiseAdd", "Thresholding"]]} + })) + + # Try to minimize the bit-widths of all data types involved + model = model.transform(MinimizeWeightBitWidth()) + model = model.transform(MinimizeAccumulatorWidth()) + + # Prepare the execution context with dummy data from above and input + # node names extracted from transformed modelo graph + context = {} + + # Convert verification inputs to numpy format used by ONNX execution + lhs = lhs.numpy().astype(np.float32) + rhs = rhs.numpy().astype(np.float32) + + # If the left-hand-side is not an initializer, it must be an input + # inserted into the execution context + if "lhs" not in initializers: + # Left-hand-side is always the first input + context[model.graph.input[0].name] = lhs + + # If the right-hand-side is not an initializer, it must be an input + # inserted into the execution context + if "rhs" not in initializers: + # Index of the right-hand-side input depends on whether there is a + # left-hand-side input + rhs_index = int("lhs" not in initializers) + context[model.graph.input[rhs_index].name] = rhs + + # Set model execution mode to python simulation + model = model.transform(SetExecMode("python")) + model = model.transform(GiveUniqueNodeNames()) + # Execute the onnx model to collect the result + out_produced = execute_onnx(model, context)[model.graph.output[0].name] + # Compare the expected to the produced + # Note: Only test for close up to some tolerance as the modelo has + # streamlined, which may involve rounding + assert np.allclose(out_produced, out_expected, atol=1e-3), \ + "Python simulation verification failed" + + # Apply folding config to implement Thresholding layers in RTL mode + # Note: Must be done in RTL for now to avoid test failing due to + # PE-parallel stream being too wide for Vitis HLS. + model = model.transform(ApplyConfig({ + "Defaults": {"preferred_impl_style": ["rtl", ["Thresholding"]]} + })) + # # Specializes all nodes to their backend implementation + model = model.transform(SpecializeLayers("xczu7ev-ffvc1156-2-e")) + + # Set model execution mode to C++ simulation + model = model.transform(SetExecMode("cppsim")) + model = model.transform(GiveUniqueNodeNames()) + # Generates the C++ source and compiles the C++ simulation + model = model.transform(PrepareCppSim()) + model = model.transform(CompileCppSim()) + # Execute the onnx model to collect the result + out_produced = execute_onnx(model, context)[model.graph.output[0].name] + # Compare the expected to the produced + # Note: Only test for close up to some tolerance as the modelo has + # streamlined, which may involve rounding + assert np.allclose(out_produced, out_expected, atol=1e-3), \ + "C++ simulation verification failed" + + # Set model execution mode to RTL simulation + model = model.transform(SetExecMode("rtlsim")) + model = model.transform(GiveUniqueNodeNames()) + # Generates the C++ source and compiles the RTL simulation + model = model.transform(PrepareIP("xczu7ev-ffvc1156-2-e", 10)) # noqa + model = model.transform(HLSSynthIP()) + model = model.transform(PrepareRTLSim()) + # Execute the onnx model to collect the result + out_produced = execute_onnx(model, context)[model.graph.output[0].name] + # Compare the expected to the produced + # Note: Only test for close up to some tolerance as the modelo has + # streamlined, which may involve rounding + assert np.allclose(out_produced, out_expected, atol=1e-3), \ + "RTL simulation verification failed" diff --git a/tests/transformation/streamline/test_round_thresholds.py b/tests/transformation/streamline/test_round_thresholds.py index 85c60b37d5..63375598a0 100644 --- a/tests/transformation/streamline/test_round_thresholds.py +++ b/tests/transformation/streamline/test_round_thresholds.py @@ -26,45 +26,242 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# fmt: off +# Disable formatter. This is deliberately formatted to stay within 80 characters +# per line. Black, however, formats some lines going beyond this. + +# Testing framework import pytest +# Use numpy for python execution / computing the ground truth expected values import numpy as np + +# Utility types and function for creating onnx nodes and graphs from onnx import TensorProto, helper + +# QONNX data types like INT25 from qonnx.core.datatype import DataType + +# QONNX wrapper of ONNX model graphs from qonnx.core.modelwrapper import ModelWrapper -from qonnx.util.basic import qonnx_make_model +# Generate random tensors of QONNX/FINN data types for testing +from qonnx.util.basic import gen_finn_dt_tensor + +# Execution of onnx graphs within FINN import finn.core.onnx_exec as oxe + +# The transformation to be tested from finn.transformation.streamline import RoundAndClipThresholds -@pytest.mark.streamline -def test_round_thresholds(): - v = helper.make_tensor_value_info("v", TensorProto.FLOAT, [1, 4]) - thresholds = helper.make_tensor_value_info("thresholds", TensorProto.FLOAT, [4, 1]) - out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 4]) - node_def = helper.make_node( - "MultiThreshold", ["v", "thresholds"], ["out"], domain="qonnx.custom_op.general" +# Tests the RoundAndClipThresholds transformation under various input, output +# data type combinations with purely integer inputs. Without proper rounding, +# this tests only the clipping, range and type-casting behavior of the +# transformation. +@pytest.mark.parametrize("i_dtype", [ + # Explanation for selecting these test configurations: + # 1. Below 24-bit thresholds we will not observe any interesting rounding + # behavior, as all integers < 2^24 can be exactly represented in 32-bit + # floating-point. Thus, we test thresholds at 25-bit signed integers and + # generate test inputs slightly above and below this. + # 2. We want to test out-of-range clipping of thresholds, in particular + # clipping of the negative portion of signed thresholds. Thus, we only + # generate signed thresholds, but test with signed and unsigned + # inputs of smaller, larger and equal range. + # 3. Testing proper floating-point thresholds requires a separate test-case + "INT23", "UINT23", "INT24", "UINT24", "INT25", "UINT25", "INT26", "UINT26" +]) +@pytest.mark.parametrize("o_dtype", [ + # Explanation for selecting these test configurations: + # 1. Outputs of MultiThreshold are typically much smaller bit-width than the + # inputs and thresholds. + # 2. However, with randomly samples thresholds from a rather large range due + # to the selected input bit-widths (see above), we risk not adequately + # covering the input range if we sample too few thresholds. The number of + # thresholds sampled depends on the bit-width of the output, thus we use + # rather high bit-width for testing. + # 3. For a "real" model, the quantization procedure *should* take care of + # adequately covering the true input range. + "INT8", "UINT8" +]) +@pytest.mark.parametrize("n_elems", [ + # Explanation for selecting these test configurations: + # 1. Small edge cases and quickly running through tests: 1, 2, 3, 4 + # 2. Large test case 256, hopefully amplifying any rarely occurring errors + 1, 2, 3, 4, 256 +]) +def test_round_and_clip_thresholds_ints(i_dtype, o_dtype, n_elems): + # Convert string representation of data type to onnx DataType + i_dtype = DataType[i_dtype] + t_dtype = DataType["INT25"] # Note: Matches configuration above + o_dtype = DataType[o_dtype] # noqa: Duplicate model setup code + # Create a dummy MultiThreshold operation to be tested + node = helper.make_node( + # Op-Type of the node + "MultiThreshold", + # MultiThreshold is implemented under the qonnx domain + domain="qonnx.custom_op.general", + # List the names of the input tensors + inputs=["inp", "thresholds"], + # List the names of the output tensors + outputs=["out"], + # The CustomOp needs to know the data type of the output to be produced + out_dtype=str(o_dtype) + ) + # Number of threshold values required to produce outputs of type o_dtype + n_thresholds = o_dtype.get_num_possible_values() - 1 + # Create tensor value infos for all input/output tensors involved + inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, n_elems]) + out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, n_elems]) + # Create a tensor value info for the thresholds parameter tensor + # Note: Number of thresholds is determined by the output data type + thresholds = helper.make_tensor_value_info( + "thresholds", TensorProto.FLOAT, [n_elems, n_thresholds] + ) + # Combine node and tensor value infos into an onnx graph + graph = helper.make_graph([node], "thresholds", [inp, thresholds], [out]) + # Wrap the model graph in a ModelWrapper container + model = ModelWrapper(helper.make_model(graph)) + # Sample random tensors of the configured input data type + inp = gen_finn_dt_tensor(i_dtype, [1, n_elems]) + # Generate sorted thresholds for each of the input channels + thresholds = np.sort(gen_finn_dt_tensor(t_dtype, [n_elems, n_thresholds])) + # Set data type annotations for the input and thresholds tensor + model.set_tensor_datatype("inp", i_dtype) # noqa: Duplicate model execution + model.set_tensor_datatype("thresholds", t_dtype) + model.set_tensor_datatype("out", o_dtype) + # Set the thresholds as initializer input to the model + model.set_initializer("thresholds", thresholds) + # Execute the model before running the RoundAndClipThresholds transformation + out_expected = oxe.execute_onnx(model, {"inp": inp})["out"] + # Before rounding the threshold data type must be as annotated + assert model.get_tensor_datatype("thresholds") == t_dtype + # Run the transformation to be tested + model = model.transform(RoundAndClipThresholds()) + # After this transformation, the thresholds and output data type should be + # inferred correctly + assert model.get_tensor_datatype("thresholds") == i_dtype + assert model.get_tensor_datatype("out") == o_dtype + # After this transformation, the container type used to store the thresholds + # values must be float32. No other type-cast or type promotion may happen. + assert model.get_initializer("thresholds").dtype == np.float32 + # After rounding, all thresholds must be integers represented as float32 + assert all( + x.is_integer() for x in model.get_initializer("thresholds").flatten() + ) + # Execute the model after running the RoundAndClipThresholds transformation + out_produced = oxe.execute_onnx(model, {"inp": inp})["out"] + # Compare the results before and after: This is the pure integer test-case + # and no actual rounding should happen, thus the rounded operation should + # produce outputs exactly equal. + assert np.all(out_produced == out_expected) + + +# Tests the RoundAndClipThresholds transformation under various input, output +# data type combinations with purely integer inputs. This test case tests actual +# rounding of floating-point thresholds. +@pytest.mark.parametrize("i_dtype", [ + # Explanation for selecting these test configurations: + # 1. Below 24-bit thresholds we will not observe any interesting rounding + # behavior, as all integers < 2^24 can be exactly represented in 32-bit + # floating-point. Thus, we test thresholds at 25-bit signed integers and + # generate test inputs slightly above and below this. + # 2. We want to test out-of-range clipping of thresholds, in particular + # clipping of the negative portion of signed thresholds. Thus, we only + # generate signed thresholds, but test with signed and unsigned + # inputs of smaller, larger and equal range. + # 3. Testing proper floating-point thresholds requires a separate test-case + "INT23", "UINT23", "INT24", "UINT24", "INT25", "UINT25", "INT26", "UINT26" +]) +@pytest.mark.parametrize("o_dtype", [ + # Explanation for selecting these test configurations: + # 1. Outputs of MultiThreshold are typically much smaller bit-width than the + # inputs and thresholds. + # 2. However, with randomly samples thresholds from a rather large range due + # to the selected input bit-widths (see above), we risk not adequately + # covering the input range if we sample too few thresholds. The number of + # thresholds sampled depends on the bit-width of the output, thus we use + # rather high bit-width for testing. + # 3. For a "real" model, the quantization procedure *should* take care of + # adequately covering the true input range. + "INT8", "UINT8" +]) +@pytest.mark.parametrize("n_elems", [ + # Explanation for selecting these test configurations: + # 1. Small edge cases and quickly running through tests: 1, 2, 3, 4 + # 2. Large test case 256, hopefully amplifying any rarely occurring errors + 1, 2, 3, 4, 256 +]) +def test_round_and_clip_thresholds_floats(i_dtype, o_dtype, n_elems): + # Convert string representation of data type to onnx DataType + i_dtype = DataType[i_dtype] + t_dtype = DataType["FLOAT32"] + o_dtype = DataType[o_dtype] # noqa: Duplicate model setup code + # Create a dummy MultiThreshold operation to be tested + node = helper.make_node( + # Op-Type of the node + "MultiThreshold", + # MultiThreshold is implemented under the qonnx domain + domain="qonnx.custom_op.general", + # List the names of the input tensors + inputs=["inp", "thresholds"], + # List the names of the output tensors + outputs=["out"], + # The CustomOp needs to know the data type of the output to be produced + out_dtype=str(o_dtype) + ) + # Number of threshold values required to produce outputs of type o_dtype + n_thresholds = o_dtype.get_num_possible_values() - 1 + # Create tensor value infos for all input/output tensors involved + inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, n_elems]) + out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, n_elems]) + # Create a tensor value info for the thresholds parameter tensor + # Note: Number of thresholds is determined by the output data type + thresholds = helper.make_tensor_value_info( + "thresholds", TensorProto.FLOAT, [n_elems, n_thresholds] + ) + # Combine node and tensor value infos into an onnx graph + graph = helper.make_graph([node], "thresholds", [inp, thresholds], [out]) + # Wrap the model graph in a ModelWrapper container + model = ModelWrapper(helper.make_model(graph)) + # Sample random tensors of the configured input data type + inp = gen_finn_dt_tensor(i_dtype, [1, n_elems]) + # Draw uniformly random prototype thresholds in [0,+1] range + thresholds = np.random.rand(n_elems, n_thresholds) + # Type alias to 25-bit signed integer type used to set the range of the + # thresholds + INT25 = DataType["INT25"] # noqa: Variable name not lowercase + # Map the prototype thresholds into the test integer range and sort + thresholds = np.sort((INT25.max() - INT25.min()) * thresholds + INT25.min()) + # Set data type annotations for the input and thresholds tensor + model.set_tensor_datatype("inp", i_dtype) # noqa: Duplicate model execution + model.set_tensor_datatype("thresholds", t_dtype) + model.set_tensor_datatype("out", o_dtype) + # Set the thresholds as initializer input to the model + model.set_initializer("thresholds", thresholds) + # Execute the model before running the RoundAndClipThresholds transformation + out_expected = oxe.execute_onnx(model, {"inp": inp})["out"] + # Before rounding the threshold data type must be as annotated + assert model.get_tensor_datatype("thresholds") == t_dtype + # Run the transformation to be tested + model = model.transform(RoundAndClipThresholds()) + # After this transformation, the thresholds and output data type should be + # inferred correctly + assert model.get_tensor_datatype("thresholds") == i_dtype + assert model.get_tensor_datatype("out") == o_dtype + # After this transformation, the container type used to store the thresholds + # values must be float32. No other type-cast or type promotion may happen. + assert model.get_initializer("thresholds").dtype == np.float32 + # After rounding, all thresholds must be integers represented as float32 + assert all( + x.is_integer() for x in model.get_initializer("thresholds").flatten() ) - graph_def = helper.make_graph([node_def], "test_model", [v, thresholds], [out]) - model_def = qonnx_make_model(graph_def) - model = ModelWrapper(model_def) - threshold_val = np.asarray([[-1.1], [0.7], [2.3], [5.1]], dtype=np.float32) - model.set_initializer("thresholds", threshold_val) - model.set_tensor_datatype("v", DataType["INT8"]) - inp_dict_f = {"v": np.floor(threshold_val).T} - inp_dict_n = {"v": np.round(threshold_val).T} - inp_dict_c = {"v": np.ceil(threshold_val).T} - orig_f = oxe.execute_onnx(model, inp_dict_f)["out"] - orig_n = oxe.execute_onnx(model, inp_dict_n)["out"] - orig_c = oxe.execute_onnx(model, inp_dict_c)["out"] - assert model.get_tensor_datatype("thresholds") == DataType["FLOAT32"] - new_model = model.transform(RoundAndClipThresholds()) - # rounded up thresholds should have same dtype as input - assert new_model.get_tensor_datatype("thresholds") == DataType["INT8"] - new_f = oxe.execute_onnx(new_model, inp_dict_f)["out"] - new_n = oxe.execute_onnx(new_model, inp_dict_n)["out"] - new_c = oxe.execute_onnx(new_model, inp_dict_c)["out"] - assert np.isclose(orig_f, new_f, atol=1e-3).all() - assert np.isclose(orig_n, new_n, atol=1e-3).all() - assert np.isclose(orig_c, new_c, atol=1e-3).all() + # Execute the model after running the RoundAndClipThresholds transformation + out_produced = oxe.execute_onnx(model, {"inp": inp})["out"] + # Compare the results before and after: This is the floating-point test with + # actual rounding, this the transformed result may only be equal within some + # tolerance. + # Hm, never observed this to be relevant. For all test configurations, exact + # equality seems to hold, probably due to only integer inputs being tested. + assert np.allclose(out_produced, out_expected, atol=1.0e-3)