From 64d95ac94f2c3aeb1de55a82e6793d151a52f9de Mon Sep 17 00:00:00 2001 From: shashwat1198 Date: Sun, 22 Oct 2023 10:12:04 +0100 Subject: [PATCH 01/28] QuantLSTM ONNX representation --- notebooks/4_quant_lstm.ipynb | 2934 ++++++++++++++++++++++++++++++++++ 1 file changed, 2934 insertions(+) create mode 100644 notebooks/4_quant_lstm.ipynb diff --git a/notebooks/4_quant_lstm.ipynb b/notebooks/4_quant_lstm.ipynb new file mode 100644 index 00000000..72cac7e9 --- /dev/null +++ b/notebooks/4_quant_lstm.ipynb @@ -0,0 +1,2934 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5ef5f772-f48a-4bb1-bb68-4e8e9236fd2e", + "metadata": {}, + "source": [ + "# QuantLSTM - ONNX (QCDQ) representation" + ] + }, + { + "cell_type": "markdown", + "id": "e5a747f9-fd74-4ebc-8d74-17bf06ff2d48", + "metadata": {}, + "source": [ + "This notebook is divided into `five` parts:\n", + "\n", + "
Part 1 : Introduction to LSTMs.\n", + "
\n", + "
Part 2 : Model creation with brevitas QuantLSTM layer. \n", + "
\n", + "
Part 3 : Build ONNX model representing the LSTM computation used to process a single input with `QCDQ quantization` (weights/inputs/activations) \n", + "
\n", + "
Part 4 : Integration of the QCDQ-LSTM graph with the `SCAN` operator. This operator allows cyclic computations (required for state updates in recurrent neural networks) that are currently not supported in ONNX.\n", + "
\n", + "
Part 5 : Functional verification of the `QCDQ-LSTM` model with brevitas `QuantLSTM` model output." + ] + }, + { + "cell_type": "markdown", + "id": "69ae7154-8cf3-4ee7-88c3-3bec0550008a", + "metadata": {}, + "source": [ + "# Introduction to LSTM's " + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "e7a903ef-1680-4a20-8c61-267884b76c96", + "metadata": {}, + "source": [ + "`LSTM’s (Long Short-Term Memory)` are sequential neural networks that are capable of learning long term dependencies especially in sequence prediction problems. They are deployed in machine translation, speech recognition, image captioning and especially used for time-series analysis applications.\n", + "\n", + "LSTM's have `feedback connections`, unlike conventional feed-forward neural networks (where the compute path goes only in the forward direction). This makes them capable of processing time-series data like vide streams or analyzing network traffic patterns.\n", + "Such feedback connections though also make their hardware implementations compiliacted as they require state updates unlike feed-forward neural networks.\n", + "
\n", + "
\n", + "The LSTM compute requires the following six compute equations:\n", + "$$\n", + " f_t = \\sigma (W_f * x_t + U_f * H_{t-1} + b_f) \n", + "$$\n", + "$$\n", + " i_t = \\sigma (W_i * x_t + U_i * H_{t-1} + b_i)\n", + "$$\n", + "$$\n", + " \\tilde{C_t} = tanh(W_c * x_t + U_c * H_{t-1} + b_c)\n", + "$$\n", + "$$\n", + " o_t = \\sigma (W_o * x_t + U_o * H_{t-1} + b_o)\n", + "$$\n", + "$$\n", + " C_t = f_t \\odot C_{t-1} + i_t \\odot \\tilde{C_t}\n", + "$$\n", + "$$\n", + " H_t = tanh(C_t) \\odot o_t \n", + "$$\n", + "\n", + "The first four equations represent the `gate computations`.\n", + "We compute the `cell state` and the `hidden state` in the last two equations respectively. \n", + "These two states are then fed back into the LSTM cell for the computation of the next input." + ] + }, + { + "cell_type": "markdown", + "id": "70d052c8-e5cd-4eb1-89e5-f8ae956cb853", + "metadata": {}, + "source": [ + "# QuantLSTM model creation" + ] + }, + { + "cell_type": "markdown", + "id": "6a64be7c", + "metadata": {}, + "source": [ + "In the 2nd part of the notebook, we will create a single layer `QuantLSTM` model in brevitas. We will evaluate with a given set of inputs. We then export this model to `QONNX` so that the same parameters (weights/biases/scales) can be extracted and used in the `QCDQ-LSTM` implementation." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "84d66548-365d-46a5-9eaa-bb767085f9aa", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'\n" + ] + } + ], + "source": [ + "# We import the required libraries to execute different functions in the notebook.\n", + "# The first four imports are required to build the QuantLSTM model in brevitas. \n", + "# The model created will then be exported to QONNX and it's parameters used in the QCDQ implementation.\n", + "\n", + "import torch\n", + "from torch import nn\n", + "from brevitas.nn import QuantLSTM\n", + "from brevitas.export import export_onnx_qcdq\n", + "\n", + "#We need the onnx and onnx helper nodes to build the onnx graph for the LSTM compute.\n", + "import onnx\n", + "from onnx import numpy_helper\n", + "from onnx.helper import make_tensor_value_info, make_node, make_graph, make_model, make_tensor\n", + "#onnxruntime will be used to execute our onnx model.\n", + "import onnxruntime as rt \n", + "from qonnx.util.basic import qonnx_make_model\n", + "#numpy allows us to manipulate outputs from the brevitas and the ONNX model\n", + "import numpy as np \n", + "# Netron visualization tool will help us view interactable graphs\n", + "import netron" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "23a7682c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "quant_input_supplied to brevitas = tensor([[-1.0000, -0.5000, -1.0000, 0.5156, -1.0000, 0.9922, -0.8047, -1.0000,\n", + " 0.2188, 0.9922]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[-0.7266, -0.9531, 0.9922, 0.9922, -1.0000, 0.9922, -0.7734, -1.0000,\n", + " -0.0859, 0.6250]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[-0.6719, -1.0000, 0.0547, -0.5234, -0.0000, 0.1250, -1.0000, 0.3047,\n", + " -0.0312, -1.0000]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[-1.0000, -0.1797, 0.3516, -0.1328, -1.0000, -1.0000, 0.8750, -0.2812,\n", + " 0.4844, -0.3203]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[ 0.6719, -0.1484, 0.5078, 0.5312, -0.2969, 0.1719, -1.0000, 0.4688,\n", + " -0.2500, 0.8672]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[ 0.3125, 0.9922, 0.8281, -0.4297, -1.0000, 0.9922, -1.0000, 0.9922,\n", + " -1.0000, 0.2578]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[-0.3125, -1.0000, -0.4688, 0.2656, -1.0000, -1.0000, -1.0000, -0.7266,\n", + " 0.9922, 0.8984]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[-0.5625, 0.8359, -1.0000, 0.1875, -1.0000, -1.0000, 0.1562, 0.3438,\n", + " 0.6172, -1.0000]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[-1.0000, -0.0781, 0.3203, 0.1797, -1.0000, -0.1875, 0.9219, -0.4609,\n", + " -0.3125, 0.2031]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[ 0.8750, -1.0000, 0.6016, -1.0000, -0.7656, -0.1484, 0.9922, 0.6406,\n", + " -1.0000, 0.9922]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[-0.9922, -1.0000, 0.5078, -1.0000, -1.0000, 0.4453, -1.0000, 0.6719,\n", + " -1.0000, -1.0000]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[ 0.0703, -1.0000, -0.6797, -1.0000, -1.0000, -0.8750, -0.6797, 0.3672,\n", + " -0.5938, -0.2031]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[-0.6641, 0.9922, 0.1641, 0.9922, 0.9922, -1.0000, -1.0000, 0.9922,\n", + " 0.3438, 0.4688]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[-0.1875, 0.0000, -0.2812, -1.0000, -1.0000, -0.0391, 0.0781, 0.9922,\n", + " -0.2188, 0.9922]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[ 0.2578, 0.9922, -1.0000, 0.4297, -0.7500, 0.2891, -1.0000, -1.0000,\n", + " 0.6484, 0.3828]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[ 0.3594, -0.0000, -1.0000, 0.4688, -0.2734, -1.0000, -0.2969, 0.9922,\n", + " 0.9922, 0.9062]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[-0.0938, -1.0000, 0.1016, -0.7109, -0.3203, 0.7578, 0.9922, 0.3359,\n", + " 0.1328, 0.4062]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[ 0.4141, -0.6328, -0.7422, 0.9609, -0.9062, -0.4297, 0.7031, 0.9922,\n", + " -1.0000, -0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[ 0.3203, -1.0000, -0.7109, 0.3281, 0.6016, -0.2031, -0.6172, 0.7031,\n", + " -0.5078, -1.0000]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[-1.0000, -0.2500, -0.9766, -1.0000, 0.3984, -0.6484, -1.0000, 0.7188,\n", + " 0.9922, 0.9453]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[-0.5234, 0.9922, -0.3984, 0.1328, -0.0625, -0.8047, -0.1562, -0.1250,\n", + " -0.1172, 0.6328]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[ 0.0547, 0.0156, 0.0703, -0.8750, -1.0000, 0.5156, -0.0938, -0.2969,\n", + " -0.9922, 0.9922]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[ 0.9922, -1.0000, 0.3438, 0.9922, 0.1328, 0.2891, 0.0469, -0.3438,\n", + " -0.9531, -1.0000]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[ 0.2969, -1.0000, 0.1250, -1.0000, -0.5469, -1.0000, 0.5000, 0.7344,\n", + " -1.0000, 0.7109]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[ 0.4219, 0.4922, 0.7266, 0.0078, 0.0469, 0.9844, -0.5391, -0.0781,\n", + " 0.9922, -1.0000]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", + " 0.7969]])\n", + "----------------------------\n", + "[[[ 0.1484375 -0.0078125 0.0390625 0.140625 0.0078125 0.\n", + " 0.109375 -0.09375 0.0390625 -0.0625 0.015625 -0.1171875\n", + " 0.1015625 0.03125 0.1640625 -0.015625 -0.0234375 -0.015625\n", + " -0.046875 0.0078125]]\n", + "\n", + " [[ 0.2109375 -0.0234375 0.03125 0.2109375 0.0234375 -0.015625\n", + " 0.1875 -0.1484375 0.046875 -0.09375 0.0234375 -0.1640625\n", + " 0.1484375 0.0625 0.2578125 -0.015625 -0.03125 -0.0234375\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2421875 -0.0390625 0.015625 0.25 0.03125 -0.0234375\n", + " 0.234375 -0.1796875 0.0546875 -0.109375 0.015625 -0.1875\n", + " 0.1796875 0.09375 0.3125 0. -0.03125 -0.03125\n", + " -0.078125 0.0078125]]\n", + "\n", + " [[ 0.25 -0.0390625 0.015625 0.265625 0.0390625 -0.03125\n", + " 0.265625 -0.1875 0.0546875 -0.125 0.015625 -0.1953125\n", + " 0.1953125 0.1171875 0.3359375 0.015625 -0.03125 -0.0390625\n", + " -0.078125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.015625 0.2734375 0.046875 -0.0390625\n", + " 0.2890625 -0.1953125 0.0546875 -0.125 0.015625 -0.203125\n", + " 0.203125 0.125 0.359375 0.0234375 -0.03125 -0.046875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.2734375 0.046875 -0.0390625\n", + " 0.296875 -0.1953125 0.0546875 -0.1328125 0.015625 -0.203125\n", + " 0.2109375 0.1328125 0.3671875 0.03125 -0.0234375 -0.046875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.015625 0.28125 0.0546875 -0.046875\n", + " 0.3046875 -0.1953125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.2109375 0.140625 0.375 0.0390625 -0.0234375 -0.046875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.0546875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.2109375 0.140625 0.3828125 0.0390625 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.2109375 0.1484375 0.390625 0.0390625 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.0390625 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]]\n" + ] + } + ], + "source": [ + "# In this block of code we will create the QuantLSTM model using the brevitas layer\n", + "torch.manual_seed(0) #Setting the manual seeds to 0 for consistency in outputs.\n", + "\n", + "# Initializing attributes that can be changed accordingly depending on users requirements.\n", + "\n", + "num_inputs = 25 #Defining the number of inputs \n", + "num_features_brevitas = 10 #This attribute defines number of features each input comprises of\n", + "num_hidden_cells_brevitas = 20 #This attribute defines the number of hidden cells in the QuantLSTM layer\n", + "\n", + "# Creating a sequential model\n", + "\n", + "model_lstm = nn.Sequential( \n", + " QuantLSTM(input_size = num_features_brevitas, hidden_size = num_hidden_cells_brevitas, bias_quant=None) \n", + " ) #No other feature described here implies quantization of inputs/parametersers/activations to 8-bits.\n", + "model_lstm.eval() #Setting the model to eval mode to make sure all the parameters and scales are frozen and not updated on runtime.\n", + "export_path = './quant_lstm_quantization_qcdq.onnx' #Setting export path for the model\n", + "export_onnx_qcdq(model_lstm,(torch.randn(num_inputs, 1, num_features_brevitas)), opset_version=14, export_path=export_path) #Exporting the model to QCDQ representation. \n", + "\n", + "# Creating a test input to execute the above created model\n", + "\n", + "in_qcdq_node = np.empty([num_inputs,1,num_features_brevitas],dtype=np.float32).reshape([num_inputs,1,num_features_brevitas])\n", + "in_qcdq_node.fill(0.8) #Using the fill function to fill the numpy array with a value of 0.8\n", + "test_input = torch.from_numpy(in_qcdq_node) #Converting the array to a torch tensor\n", + "brevitas_output = model_lstm(test_input) #Executing the model with the set input\n", + "brevitas_output = brevitas_output[0].detach().numpy()\n", + "print(brevitas_output)" + ] + }, + { + "cell_type": "markdown", + "id": "347ef1f5-36e8-4103-9b13-efa7fe93eb5e", + "metadata": {}, + "source": [ + "`Abbreviations` : Short-forms defined in the next code block can be referenced here for definitions.\n", + "\n", + "* Wi = \"Weight matrix for the input gate\" (Similarily for the other three gates)\n", + "* Ui = \"Recurrence matrix for the input gate\" (Similarily for the other three gates)\n", + "* bi = \"Bias for the input gate\" (Similarily for the other three gates)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "0bfbf5a3-8556-4190-a28f-4fe9859c55a9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.layers.0.0.input_gate_params.bias\n", + "(20,)\n", + "[-0.02587563 -0.18425222 -0.18189065 0.02914573 -0.21827428 0.0595416\n", + " -0.20598626 -0.15559138 -0.04639753 -0.2133838 0.18059207 0.18321364\n", + " -0.11679631 0.04684116 0.11439164 0.07105622 -0.02995344 -0.21090843\n", + " 0.1625932 -0.19612479] , 0\n", + "-------------------------\n", + "0.layers.0.0.input_gate_params.input_weight.weight\n", + "(20, 10)\n", + "[[-4.14119214e-02 1.38706667e-02 -7.36431107e-02 -8.17852393e-02\n", + " -1.93256751e-01 1.23205660e-02 -2.53894478e-02 1.94940954e-01\n", + " -7.36160800e-02 1.72829047e-01]\n", + " [ 1.05855539e-02 -1.00462548e-01 -5.31778559e-02 -2.53751595e-02\n", + " 2.31616711e-03 -3.68398018e-02 6.63604736e-02 1.84143797e-01\n", + " 3.51473056e-02 8.09932351e-02]\n", + " [ 1.38081744e-01 4.81988601e-02 1.03076197e-01 1.17293097e-01\n", + " 2.09298924e-01 -2.04075590e-01 7.65163079e-02 -1.01319486e-02\n", + " -4.01576199e-02 -8.62098187e-02]\n", + " [ 1.34432539e-01 2.04552680e-01 -1.82483241e-01 1.20810278e-01\n", + " 1.54187992e-01 3.90806384e-02 2.63404008e-03 1.72071218e-01\n", + " 6.62961556e-03 -5.57729751e-02]\n", + " [-1.65121444e-02 7.17408881e-02 5.59775345e-02 -1.20642958e-02\n", + " 7.05851838e-02 6.02219440e-02 -1.81134686e-01 5.57176135e-02\n", + " 1.36812523e-01 2.56436393e-02]\n", + " [-2.04101056e-02 1.71289816e-01 -1.95361048e-01 -1.02062307e-01\n", + " -1.01068199e-01 1.93207934e-01 -2.16277346e-01 2.21768115e-02\n", + " -2.16605455e-01 -7.35303294e-03]\n", + " [ 8.33466202e-02 -5.22914641e-02 2.17063010e-01 7.11822009e-04\n", + " -1.14001475e-01 5.76605424e-02 1.16289847e-01 -4.44249017e-04\n", + " 1.91289768e-01 -1.41524345e-01]\n", + " [ 9.54081938e-02 1.26971915e-01 1.11063533e-01 -8.20205314e-05\n", + " 6.38317242e-02 -1.75422058e-01 -1.75476715e-01 -1.38986288e-02\n", + " -2.80253254e-02 1.66033790e-01]\n", + " [ 1.62366882e-01 1.51616067e-01 -1.02419287e-01 -1.75539613e-01\n", + " -2.09742919e-01 8.09257179e-02 -2.01488122e-01 -2.23217383e-01\n", + " -1.13006435e-01 -1.88792080e-01]\n", + " [-8.81207064e-02 -1.40770882e-01 -1.14718042e-01 2.12588429e-01\n", + " -4.21379767e-02 1.85490459e-01 4.96126944e-03 -2.87544206e-02\n", + " -6.54680878e-02 -1.59840211e-01]\n", + " [-1.79656431e-01 1.54830217e-01 -6.89065754e-02 -2.18012080e-01\n", + " 2.05210581e-01 4.14780807e-03 -1.49626598e-01 -1.75766915e-01\n", + " -1.87781662e-01 -1.96070760e-01]\n", + " [ 2.02346548e-01 1.54175445e-01 1.82888191e-02 -1.90574318e-01\n", + " -5.84847443e-02 -2.10055038e-01 7.70593956e-02 -5.93719892e-02\n", + " -4.78506237e-02 -6.97683394e-02]\n", + " [ 1.04838371e-01 1.21036001e-01 4.89832126e-02 -2.80011501e-02\n", + " -2.20977236e-02 -3.90723767e-03 -1.66511953e-01 2.18188778e-01\n", + " -9.64377001e-02 1.30095944e-01]\n", + " [-1.25353500e-01 1.50110642e-03 7.65467212e-02 -2.05311388e-01\n", + " 1.02568395e-01 -1.71158642e-01 3.12034953e-02 -4.43410687e-02\n", + " 1.28176615e-01 2.17323676e-01]\n", + " [ 5.03933132e-02 -6.38488680e-03 -1.10784821e-01 8.33686888e-02\n", + " -1.07626989e-01 9.23645869e-02 -9.69173536e-02 1.51675642e-01\n", + " 1.71514452e-01 1.37112319e-01]\n", + " [ 2.23987759e-03 1.03696242e-01 -2.03757793e-01 1.81339085e-01\n", + " -5.80957830e-02 8.15173239e-02 -3.78652588e-02 -7.50842392e-02\n", + " -1.05006970e-01 1.44231498e-01]\n", + " [-1.21653110e-01 -3.94320451e-02 1.12798467e-01 2.25366149e-02\n", + " -1.88142627e-01 -2.22348958e-01 -1.08711593e-01 2.06236228e-01\n", + " -1.58990204e-01 1.23237595e-01]\n", + " [ 1.60061240e-01 -9.26844329e-02 -9.87462699e-02 -1.60870835e-01\n", + " 3.48785594e-02 -3.12594734e-02 1.08638955e-02 9.69918296e-02\n", + " 9.38790441e-02 -7.05472827e-02]\n", + " [ 1.53575651e-02 5.31169996e-02 4.75974986e-03 4.47460003e-02\n", + " -9.05808210e-02 1.83284596e-01 -2.29354147e-02 -2.86094397e-02\n", + " -2.00689927e-01 -1.62085444e-01]\n", + " [ 6.95567206e-03 -3.45815569e-02 -1.12424992e-01 1.17047116e-01\n", + " -2.00185552e-02 7.86398575e-02 1.88336477e-01 -1.02802545e-01\n", + " -1.10053055e-01 -4.49331515e-02]] , 1\n", + "-------------------------\n", + "0.layers.0.0.input_gate_params.hidden_weight.weight\n", + "(20, 20)\n", + "[[-1.89352538e-02 -1.11839756e-01 -5.36844507e-02 -6.44523604e-03\n", + " 1.00301303e-01 2.06872717e-01 1.65582791e-01 2.36654170e-02\n", + " -1.40909785e-02 5.72774969e-02 -9.12800338e-03 -2.93454379e-02\n", + " 7.68917575e-02 -1.81926534e-01 -1.90163419e-01 9.05744440e-04\n", + " -6.77747875e-02 -1.10600702e-01 -2.08165124e-01 1.49785221e-01]\n", + " [-8.90937075e-03 -1.20138384e-01 -9.10849124e-02 5.87869175e-02\n", + " -1.62167445e-01 1.43613769e-02 -2.75748386e-03 7.61744976e-02\n", + " 8.87038633e-02 -1.46100059e-01 9.65513662e-02 1.68849513e-01\n", + " 1.43956831e-02 1.13917463e-01 -8.46547335e-02 4.44148518e-02\n", + " 6.53375536e-02 -1.03280008e-01 1.38058737e-01 -2.11419612e-01]\n", + " [-8.39947835e-02 -1.31567493e-01 -1.32741287e-01 -1.35494858e-01\n", + " -2.10702628e-01 3.83746810e-02 -4.42331657e-02 -1.88279316e-01\n", + " -9.19632221e-05 -3.72487307e-02 9.22437534e-02 -1.75148100e-01\n", + " -6.29062578e-02 4.60259691e-02 9.47839618e-02 1.69158224e-02\n", + " 6.05970472e-02 2.23524958e-01 -7.74600878e-02 1.52398065e-01]\n", + " [ 1.92612275e-01 -1.97806209e-01 5.40891960e-02 1.26661941e-01\n", + " -3.48797850e-02 1.23408221e-01 7.60573195e-03 1.70228094e-01\n", + " 4.81458148e-03 -1.43158093e-01 1.69815615e-01 6.65016174e-02\n", + " 1.90237820e-01 5.55088967e-02 1.18736811e-01 1.39421389e-01\n", + " 3.76524106e-02 -5.19809462e-02 4.61825170e-02 -1.55909836e-01]\n", + " [ 7.63913197e-03 -7.18704611e-02 1.41373863e-02 -1.77042618e-01\n", + " 1.36628836e-01 -2.06302434e-01 9.57576782e-02 1.47258580e-01\n", + " -2.04934716e-01 2.02031001e-01 -1.66225716e-01 -4.39088680e-02\n", + " 1.15872569e-01 -7.09063411e-02 1.99275032e-01 -9.86447409e-02\n", + " -2.99374424e-02 -1.46168455e-01 -1.03737742e-01 2.18205780e-01]\n", + " [ 1.68166518e-01 1.64642967e-02 1.83855016e-02 -1.89751670e-01\n", + " 1.68811426e-01 -3.35250199e-02 -9.32650268e-02 -1.77951321e-01\n", + " 1.83845311e-01 1.06031545e-01 1.34684831e-01 2.31534615e-02\n", + " -1.51732951e-01 9.15970504e-02 2.57883817e-02 7.50367939e-02\n", + " -5.56799732e-02 -1.05523452e-01 1.83565930e-01 7.49567226e-02]\n", + " [-9.07528847e-02 1.99678559e-02 -4.86066155e-02 -1.91221125e-02\n", + " 1.25389591e-01 -1.77972749e-01 2.02371553e-01 1.50499865e-01\n", + " 1.92136504e-04 -9.14627835e-02 4.55915295e-02 -1.48007214e-01\n", + " 1.45243973e-01 -1.18256845e-01 4.27256078e-02 -2.19991282e-01\n", + " 1.07079633e-01 1.51370272e-01 1.67834863e-01 1.82519276e-02]\n", + " [ 1.32025823e-01 7.62412176e-02 1.49954304e-01 1.26183063e-01\n", + " -1.95639879e-01 2.35728398e-02 -7.62314126e-02 -1.06771380e-01\n", + " 1.56516239e-01 -3.20035741e-02 3.47357877e-02 1.40789405e-01\n", + " 1.50514722e-01 1.19332708e-01 -3.90392952e-02 -1.99321926e-01\n", + " -2.14659125e-01 7.02862144e-02 -2.65357876e-03 -1.41277447e-01]\n", + " [ 9.76564139e-02 2.02965632e-01 1.29328549e-01 -3.15438919e-02\n", + " 3.02148778e-02 -1.42630830e-01 1.05540812e-01 -1.73283800e-01\n", + " 1.54376432e-01 -1.02132224e-01 -8.86853859e-02 -1.87295631e-01\n", + " -5.40727489e-02 -2.16292981e-02 -1.03067294e-01 1.59174219e-01\n", + " 1.28328785e-01 -1.97347268e-01 -2.23675612e-02 7.51795396e-02]\n", + " [ 2.15735227e-01 -5.34672327e-02 1.37278914e-01 -1.25270970e-02\n", + " -8.57628211e-02 1.36838645e-01 -1.99253812e-01 1.87337860e-01\n", + " 2.23344907e-01 -6.10500947e-02 8.83295834e-02 2.22981662e-01\n", + " 6.74140528e-02 8.74451399e-02 8.21070075e-02 -9.14832279e-02\n", + " 5.45820408e-02 -1.19176529e-01 1.90940976e-01 -9.58186984e-02]\n", + " [ 5.11176400e-02 -6.47741258e-02 1.11825228e-01 3.68577940e-03\n", + " 1.22950912e-01 -6.05489872e-02 -1.31215081e-01 8.57292935e-02\n", + " -1.25841707e-01 -1.83588028e-01 8.63927826e-02 -1.34484172e-01\n", + " -8.40481222e-02 -5.58335669e-02 1.58777572e-02 -7.74438009e-02\n", + " -8.04765150e-02 -5.62009923e-02 1.56701818e-01 6.69540018e-02]\n", + " [-1.07652791e-01 -1.54563770e-01 5.18102152e-03 7.16358349e-02\n", + " -4.67919558e-03 1.30897254e-01 1.88077956e-01 6.55371249e-02\n", + " 7.37451240e-02 1.29728526e-01 -7.66031295e-02 3.96637134e-02\n", + " 1.80782616e-01 -1.07077263e-01 1.74031202e-02 -8.74211192e-02\n", + " -1.71936572e-01 1.18438050e-01 1.78673968e-01 -1.20800309e-01]\n", + " [ 8.38049129e-02 6.85676187e-02 8.73105526e-02 1.23087496e-01\n", + " 2.08757341e-01 1.69717655e-01 -1.95658267e-01 -8.76599625e-02\n", + " 1.18758187e-01 -1.27650708e-01 4.39067073e-02 -9.58611295e-02\n", + " 4.44106422e-02 1.09106824e-01 7.02822655e-02 1.62435979e-01\n", + " -2.69077457e-02 1.21389672e-01 7.22895712e-02 -7.04701096e-02]\n", + " [-1.57925934e-01 2.04573229e-01 -6.66687265e-02 1.68426275e-01\n", + " 1.40947536e-01 -9.00426600e-03 -1.84701070e-01 1.80013608e-02\n", + " -1.08096078e-01 5.81858531e-02 -8.88810679e-02 1.72345534e-01\n", + " -2.01746121e-01 -6.01959564e-02 3.52624580e-02 2.13314164e-02\n", + " 1.83701098e-01 -7.06517771e-02 -1.78495154e-01 1.48046315e-01]\n", + " [ 6.24824539e-02 1.47299409e-01 -1.32342920e-01 -1.31334439e-01\n", + " -9.03252959e-02 1.58978552e-02 7.57712200e-02 -1.28496692e-01\n", + " -2.10528076e-02 -3.86467576e-02 2.04027027e-01 -8.06416422e-02\n", + " 2.16690734e-01 -1.37144789e-01 -9.21397135e-02 -1.68184295e-01\n", + " 1.64731190e-01 -1.53769597e-01 9.25582647e-02 -8.21671411e-02]\n", + " [ 2.22826257e-01 3.15412283e-02 -1.94183901e-01 3.84835452e-02\n", + " 2.71859560e-02 -2.16274336e-01 4.48757894e-02 2.13342309e-01\n", + " 6.43487200e-02 -1.18915108e-03 -4.63541821e-02 5.94213046e-02\n", + " -9.96202976e-02 2.20200241e-01 1.93298727e-01 1.04461670e-01\n", + " -8.32887441e-02 -2.09956676e-01 -1.28724366e-01 2.17411697e-01]\n", + " [-2.05243871e-01 -2.13502616e-01 -1.61161683e-02 7.11405650e-02\n", + " -2.22554103e-01 -2.07601383e-01 1.21570053e-03 -7.50053376e-02\n", + " 1.55782372e-01 6.41999543e-02 -1.94095746e-01 -2.01538876e-01\n", + " 1.53562352e-01 -3.96501981e-02 -9.78184044e-02 7.04318583e-02\n", + " -4.39465865e-02 1.06939368e-01 5.67044728e-02 -9.68158469e-02]\n", + " [-1.79218486e-01 1.21047780e-01 -1.34345368e-01 -2.47318167e-02\n", + " 3.05733737e-02 -1.30131751e-01 1.21804118e-01 -1.57282248e-01\n", + " 5.49192652e-02 2.39149425e-02 8.20437744e-02 -2.19451547e-01\n", + " 1.29167549e-02 1.09009661e-01 -1.43156886e-01 5.53317666e-02\n", + " 8.76156322e-04 1.89696804e-01 -4.73480262e-02 1.52765575e-03]\n", + " [-9.72549468e-02 -5.51085509e-02 6.40134960e-02 -2.15656430e-01\n", + " 1.69629768e-01 1.60795882e-01 9.46965069e-02 1.67391464e-01\n", + " -6.96057901e-02 5.09320870e-02 1.13759311e-02 -1.54622883e-01\n", + " -8.59646648e-02 -7.93827102e-02 -5.52875437e-02 -1.98549107e-01\n", + " -1.57260388e-01 -2.12343093e-02 -3.40157561e-02 -2.02978238e-01]\n", + " [ 4.77774814e-02 1.21752672e-01 1.86222807e-01 1.88188314e-01\n", + " -1.56248853e-01 -7.16619864e-02 -1.06078379e-01 4.10118401e-02\n", + " 5.99195063e-02 4.97494638e-02 1.30669191e-01 1.17969945e-01\n", + " -1.20020248e-01 1.53502032e-01 1.50838137e-01 2.95910202e-02\n", + " -1.94543302e-01 -1.37143746e-01 6.23138808e-02 7.73103088e-02]] , 2\n", + "-------------------------\n", + "0.layers.0.0.forget_gate_params.bias\n", + "(20,)\n", + "[ 0.20850217 0.11380532 0.08104482 -0.00762655 0.15247074 -0.08138975\n", + " 0.0910454 -0.10650107 -0.00208706 0.13215044 0.10260209 -0.05017841\n", + " -0.00283135 -0.12413156 0.10357434 0.15046087 0.07697045 -0.21637587\n", + " -0.16006967 0.14969489] , 3\n", + "-------------------------\n", + "0.layers.0.0.forget_gate_params.input_weight.weight\n", + "(20, 10)\n", + "[[-0.03201701 0.13732338 0.16482215 -0.06550063 -0.13119501 -0.2103679\n", + " 0.08553377 0.11468438 -0.0387658 -0.21708311]\n", + " [-0.14402747 -0.01204806 0.10205487 -0.07492673 -0.14435105 -0.15566948\n", + " 0.2000676 0.08097311 -0.1815501 -0.13809344]\n", + " [-0.18981868 0.03235186 -0.09079897 -0.00075695 -0.0353742 -0.1957324\n", + " -0.19982079 -0.17343585 -0.09364887 0.03477862]\n", + " [-0.10515709 -0.00797041 -0.02678433 0.20449734 -0.10193561 0.21008612\n", + " -0.17165995 -0.18656294 0.07271551 -0.13013807]\n", + " [ 0.11469334 -0.12370986 0.17608246 0.21651667 0.01431521 0.04778921\n", + " 0.20847315 0.13255776 -0.19520605 -0.00715788]\n", + " [-0.20184483 0.17081025 -0.04095714 -0.00155866 -0.13738167 -0.12158713\n", + " 0.02901981 0.18449156 -0.1123966 0.02112942]\n", + " [ 0.20241037 0.20039941 -0.04371644 0.20957804 0.08143061 0.20365277\n", + " 0.00663433 -0.1895056 -0.06086665 0.06706649]\n", + " [ 0.1192437 -0.22275887 0.17393245 -0.20059223 0.13101582 0.22062524\n", + " 0.05510434 -0.0422016 0.12311912 -0.06636703]\n", + " [-0.16563286 -0.15869099 0.10513588 0.1707739 0.00905446 -0.2168069\n", + " -0.21971782 -0.05049207 0.12070725 -0.1490105 ]\n", + " [ 0.06027115 -0.12221678 0.18192975 -0.05859193 -0.04659947 -0.19612114\n", + " -0.20028274 0.01511241 0.03615525 0.12080745]\n", + " [-0.19552828 0.03918052 -0.03230212 0.1311668 -0.1016731 0.06661848\n", + " 0.09010674 0.11232612 -0.07669472 0.07195909]\n", + " [-0.04382298 0.06021269 -0.13749652 -0.17768005 -0.18290731 -0.1405653\n", + " -0.09463658 0.03328432 -0.04891114 -0.12729394]\n", + " [ 0.00187842 -0.07061429 0.13783802 -0.18416376 -0.08253521 -0.1436971\n", + " 0.02759105 0.01219904 -0.0128632 0.22186181]\n", + " [-0.08530237 -0.03213883 0.05777045 0.18662488 0.16948868 0.02554451\n", + " -0.08459641 0.07345897 0.14069013 -0.00477207]\n", + " [ 0.12276765 0.18300453 -0.11980148 -0.04943415 -0.20131664 0.05132969\n", + " 0.15936238 -0.04342245 0.03568069 0.07144996]\n", + " [-0.00476937 0.17384104 0.0325843 -0.21979333 -0.18465139 -0.22154187\n", + " 0.00921626 0.12087465 -0.02950055 0.20104776]\n", + " [-0.04022751 0.04571649 0.20163535 0.11316557 -0.00713371 0.2153832\n", + " -0.1335971 0.08328808 0.14121595 -0.13845547]\n", + " [-0.21004361 0.07152335 -0.08483391 -0.1128413 0.04447659 -0.16221067\n", + " 0.2011128 -0.02007227 -0.07161061 0.18693109]\n", + " [ 0.06226142 0.04260208 -0.10691333 0.21311398 -0.06810362 0.18598051\n", + " -0.016437 0.11216957 0.15722302 -0.1664758 ]\n", + " [-0.14903465 -0.22111452 0.16127922 0.19229865 -0.08172148 -0.10951796\n", + " 0.03742959 0.12038527 0.05519409 -0.04660187]] , 4\n", + "-------------------------\n", + "0.layers.0.0.forget_gate_params.hidden_weight.weight\n", + "(20, 20)\n", + "[[-0.14223064 0.19124371 -0.14481081 -0.21607104 -0.08928006 0.04458899\n", + " 0.0831126 0.08646142 -0.12953514 -0.08581803 -0.09943341 -0.10828371\n", + " -0.18833804 0.04577223 -0.06502874 -0.2152229 -0.13056786 -0.13428617\n", + " -0.09645564 -0.13816758]\n", + " [-0.03877772 0.08013236 -0.18096809 -0.01915519 -0.06435173 -0.11432081\n", + " -0.0496515 -0.09477154 0.00718846 -0.16141057 0.04240454 0.20530063\n", + " 0.18528308 -0.10025615 0.06892193 -0.21135406 0.18826427 -0.22283866\n", + " -0.19982089 -0.20071597]\n", + " [-0.20765333 0.03028304 -0.05912894 0.05351972 -0.01383548 -0.00480333\n", + " -0.08078498 -0.13266474 -0.18721604 0.11282834 -0.11529152 -0.04547688\n", + " 0.10860465 -0.05537887 -0.05637903 -0.14906646 -0.19131811 0.10732386\n", + " -0.05044974 0.14060505]\n", + " [ 0.01471702 -0.00028402 -0.20187245 0.0049368 -0.0505344 -0.12759772\n", + " -0.05175107 0.01168989 -0.16848378 0.03718214 0.15558895 0.04417289\n", + " 0.21344449 0.10434435 -0.17634727 -0.08801483 -0.05380939 0.06689031\n", + " -0.00637761 0.17993565]\n", + " [ 0.02597556 -0.14161254 -0.08197778 -0.18603216 -0.061655 0.10993782\n", + " 0.00215927 -0.21323241 -0.19348647 0.08106777 -0.19626026 -0.1783532\n", + " -0.1333177 0.21312374 -0.06358164 -0.09219337 -0.15098219 0.14304285\n", + " -0.03610551 0.04311918]\n", + " [ 0.05341741 0.06306308 0.14312816 0.01160373 0.02312934 -0.01452105\n", + " -0.17375752 -0.05117204 0.21281871 -0.15847513 -0.14112028 -0.22188812\n", + " 0.013559 -0.20914444 -0.11453009 0.20604049 0.09261008 0.11913135\n", + " 0.03828845 -0.19001652]\n", + " [-0.10404866 -0.18102278 -0.13826925 0.076148 -0.06201827 0.2185227\n", + " -0.16299975 -0.19082828 0.2207899 -0.19316407 0.19027402 0.06021235\n", + " -0.20380671 0.1947569 -0.06087566 -0.09220145 -0.17443547 -0.1891369\n", + " 0.04978558 -0.21964009]\n", + " [ 0.09188584 -0.05525529 0.0784739 -0.05474811 0.07732737 -0.00610806\n", + " 0.06572182 -0.09097287 -0.15380703 0.02847747 -0.14272346 -0.13861606\n", + " -0.21501313 -0.07127416 -0.14941145 0.17413448 0.1611419 0.05305404\n", + " 0.18168166 0.10766982]\n", + " [-0.21064265 -0.022373 -0.03629636 -0.13576584 0.06368566 -0.06979065\n", + " -0.10692404 -0.00260666 -0.14866948 0.18506847 0.14149404 0.21166477\n", + " -0.03960523 0.07302888 -0.00899392 -0.18503006 0.10116354 -0.15618756\n", + " -0.08071785 -0.10013654]\n", + " [-0.21814388 0.00802042 0.03663212 -0.01662389 0.1644524 0.01072139\n", + " -0.0407296 -0.12196475 -0.13280123 -0.03179033 -0.1312358 -0.14750735\n", + " -0.02957479 -0.03948133 -0.13649467 0.13065115 0.18963577 -0.15246144\n", + " 0.09794185 -0.10375587]\n", + " [-0.02321799 0.20873794 0.02861272 -0.21320319 0.20555921 -0.00946067\n", + " -0.11196752 -0.11808899 0.19175017 0.00377388 0.12350584 0.14696068\n", + " -0.08678884 0.01897924 -0.14464125 0.18672368 -0.11824197 0.14852415\n", + " 0.05665502 0.1379358 ]\n", + " [-0.1575466 -0.00695391 0.11586404 -0.00892534 -0.0032084 0.10896464\n", + " -0.16712412 -0.04483069 0.10185106 0.10966767 0.20768207 -0.04423303\n", + " 0.05298113 -0.11002054 -0.03752897 -0.11225442 0.16570821 0.0013621\n", + " 0.09096613 0.12299404]\n", + " [ 0.04166875 0.02379598 -0.01636612 -0.1894117 0.03602695 -0.04953878\n", + " -0.18794785 0.20833082 -0.02383836 -0.11159918 -0.21768506 -0.20595226\n", + " 0.08515022 -0.1020775 -0.09659212 -0.12938367 0.18049696 -0.05375253\n", + " 0.14493793 0.17751718]\n", + " [-0.17336273 0.16682073 -0.04269946 0.21416363 0.11421449 -0.21660405\n", + " 0.04154139 0.07860353 -0.08111839 0.16956337 -0.1851744 -0.07095176\n", + " 0.2130592 0.21838497 0.11170101 -0.13348123 -0.19239157 -0.1818077\n", + " -0.05589887 0.12667239]\n", + " [ 0.07079396 -0.02715501 0.20110089 0.17559125 -0.10450983 -0.09683432\n", + " -0.00262346 0.04640241 -0.00160075 0.08632647 0.15427703 -0.04031902\n", + " 0.10981148 0.03041176 0.08583194 0.09205452 -0.05976621 -0.09969731\n", + " 0.09557738 -0.14316456]\n", + " [ 0.1173941 -0.1434708 0.15340208 0.08971985 -0.05478028 0.12781222\n", + " -0.07363954 0.04763815 0.06583516 0.02283663 0.04490386 -0.00443905\n", + " -0.0645991 0.1247524 0.08819748 0.08340425 0.15096036 -0.11699554\n", + " -0.0519524 -0.00637345]\n", + " [ 0.18044722 -0.1780605 -0.12826072 -0.05326315 -0.19100511 -0.17666493\n", + " 0.15317535 0.01043098 -0.17988645 -0.03692174 -0.00735149 -0.07949581\n", + " -0.18703558 0.12169496 -0.02761802 0.21831468 -0.17125311 -0.12275734\n", + " -0.01161703 -0.15571442]\n", + " [ 0.16295849 0.17292082 0.2025731 -0.14115438 0.15909635 0.15525764\n", + " -0.08897205 0.02453648 0.10655329 0.16001071 -0.20884806 0.2226173\n", + " -0.05621968 0.09110746 -0.13887972 -0.17207511 -0.15143432 0.13178375\n", + " -0.11029776 0.12998497]\n", + " [ 0.0675995 0.08894558 -0.04973555 -0.07073203 -0.10462123 -0.12498911\n", + " 0.20617247 -0.01215215 -0.09589054 -0.20804486 0.0097276 -0.22196051\n", + " -0.00263305 0.14118703 -0.12879056 0.12285849 -0.07132839 -0.1719783\n", + " -0.22146888 0.11108326]\n", + " [-0.1710799 0.10918202 0.03201576 0.12152903 -0.16808327 0.19554281\n", + " -0.22271936 -0.16972543 0.13409424 0.00759949 -0.12556304 -0.04690479\n", + " -0.19899549 -0.194607 -0.04797396 0.17057896 0.06677905 0.04216573\n", + " -0.05926214 0.20352075]] , 5\n", + "-------------------------\n", + "0.layers.0.0.cell_gate_params.bias\n", + "(20,)\n", + "[ 0.00214154 0.07550146 0.00355405 0.03489293 0.07456551 0.17159154\n", + " 0.12870987 0.0286169 0.08939798 -0.06724557 0.15284362 0.06277069\n", + " 0.16875166 -0.03491265 -0.18256952 0.04417255 0.09094475 0.18067895\n", + " 0.08666804 0.08261736] , 6\n", + "-------------------------\n", + "0.layers.0.0.cell_gate_params.input_weight.weight\n", + "(20, 10)\n", + "[[ 0.17794745 -0.07684495 0.19742867 0.11464191 0.14933479 0.15947415\n", + " -0.18268393 0.11646748 0.20825341 -0.15708849]\n", + " [-0.01916463 -0.1364658 -0.05399449 0.03332363 0.11960924 -0.06491657\n", + " -0.21173826 0.12073942 0.12545025 -0.04053707]\n", + " [ 0.19142465 0.17237733 -0.04928424 0.00863487 0.03938841 -0.04381773\n", + " -0.05508858 -0.10093604 -0.12716216 0.11167222]\n", + " [-0.06639788 -0.10727276 0.19697405 0.03575112 0.16133724 0.2037714\n", + " -0.03149954 0.03335407 0.20731461 -0.15384933]\n", + " [-0.06704343 0.03181893 -0.01517017 0.05953267 0.11757869 -0.09199598\n", + " 0.01741112 0.20230028 -0.1265286 -0.15163381]\n", + " [-0.17148444 0.13366292 -0.20509928 -0.1087402 0.15102275 -0.13404797\n", + " 0.1818403 -0.10452814 0.03537463 0.02927051]\n", + " [-0.00548471 0.13927223 0.18991414 -0.13961166 0.12540615 0.0597448\n", + " -0.00416681 -0.15634763 0.06633033 0.1623022 ]\n", + " [-0.19193047 -0.20651296 -0.21982425 0.05166686 -0.06424998 -0.06945844\n", + " 0.20821334 -0.05703437 -0.14200093 0.02011372]\n", + " [-0.12272914 -0.06551553 0.11811562 0.05160707 -0.1534436 0.21288224\n", + " 0.15128401 -0.15242937 0.09739923 0.09188432]\n", + " [-0.16044928 -0.1571494 -0.18515183 0.09960561 0.03895786 0.09450045\n", + " -0.09821384 0.1681353 0.02855213 -0.17842196]\n", + " [-0.056282 0.11411482 0.04916727 -0.03420792 -0.15622441 -0.13909872\n", + " 0.19286813 -0.12808998 0.15845725 -0.07484471]\n", + " [ 0.00223508 -0.21774605 -0.07268656 0.18849593 -0.20075409 0.11251042\n", + " -0.188184 0.03261365 -0.20273004 -0.17701481]\n", + " [-0.18051723 -0.07753571 0.03044572 -0.16394225 0.05667006 0.13467607\n", + " 0.18228398 0.19799176 0.14722027 -0.06584404]\n", + " [-0.02060739 0.19784163 0.11123517 -0.05929887 0.16882291 -0.19541554\n", + " 0.1913779 0.12510933 -0.16400692 -0.18237662]\n", + " [ 0.17486629 0.22059093 0.01951262 -0.08737109 0.12732458 0.1008788\n", + " -0.0279066 0.17902343 0.14493623 0.05574536]\n", + " [ 0.11610299 -0.20945168 -0.10473937 0.02451142 0.06080827 -0.03056943\n", + " 0.08443112 0.06811719 -0.20665829 0.07052966]\n", + " [-0.01818041 -0.15387398 0.00754629 -0.05499369 -0.11874414 -0.20375879\n", + " 0.18706112 -0.13579562 0.0300329 0.17913137]\n", + " [-0.02817055 -0.14655502 -0.21633011 0.03715306 -0.11219743 0.01630673\n", + " 0.07142475 -0.06335549 0.1516163 -0.02909804]\n", + " [-0.08923855 -0.14784832 0.06784268 -0.13824603 0.04700406 -0.02822138\n", + " 0.1536749 -0.10962173 -0.11015368 -0.02889775]\n", + " [-0.13657494 0.08524874 -0.08190698 0.09174035 0.12977527 0.13057181\n", + " -0.04105001 0.12203032 -0.11840606 -0.22279048]] , 7\n", + "-------------------------\n", + "0.layers.0.0.cell_gate_params.hidden_weight.weight\n", + "(20, 20)\n", + "[[-2.12806370e-02 -1.62129834e-01 -1.73234463e-01 5.68399914e-02\n", + " 1.91077381e-01 -8.79967287e-02 -1.26489419e-02 -1.62001878e-01\n", + " 3.90813835e-02 6.37496263e-02 -3.43248062e-02 1.70126632e-01\n", + " -1.79964885e-01 -3.00010163e-02 -1.24117516e-01 1.96340203e-01\n", + " 1.89398184e-01 2.19951704e-01 2.05728129e-01 8.85609612e-02]\n", + " [-1.71218976e-01 -1.51676044e-01 5.36037646e-02 -1.99636862e-01\n", + " 1.41561761e-01 9.72114205e-02 5.33513576e-02 -1.95168942e-01\n", + " 1.62662312e-01 -2.36655492e-02 -9.38338637e-02 1.16747312e-01\n", + " 1.88960433e-02 -9.94693190e-02 5.23358434e-02 -1.49113968e-01\n", + " 2.07823291e-01 1.95990741e-01 1.03123404e-01 1.18294187e-01]\n", + " [-2.22277910e-01 -1.24300212e-01 -2.15169474e-01 -1.16545178e-01\n", + " -1.85386583e-01 1.64590582e-01 1.20638609e-01 1.31684974e-01\n", + " -9.92668644e-02 1.70430213e-01 -3.23111340e-02 -5.79339787e-02\n", + " 1.20397158e-01 1.48079455e-01 -1.60713032e-01 2.12880254e-01\n", + " -2.25685220e-02 5.95554635e-02 -2.22653463e-01 2.48931386e-02]\n", + " [-1.10666625e-01 -1.40009314e-01 -9.33616757e-02 -1.04158348e-03\n", + " -6.37013763e-02 -1.43241197e-01 1.60099015e-01 6.65228367e-02\n", + " -2.08098441e-01 4.69054580e-02 5.49288094e-02 8.21655430e-03\n", + " 5.42974621e-02 -1.87213402e-02 9.77927893e-02 -1.57414630e-01\n", + " -9.53418463e-02 1.67505234e-01 -1.38533488e-01 1.09708525e-01]\n", + " [ 2.06897184e-01 -2.04468444e-01 -9.79631692e-02 1.90820277e-01\n", + " -1.35208331e-02 4.41430137e-02 3.18236202e-02 9.21481624e-02\n", + " -9.21330750e-02 2.90291384e-02 1.52316689e-01 -1.88640561e-02\n", + " -2.05149427e-01 7.72908777e-02 -5.70836812e-02 -4.71739881e-02\n", + " 1.16618834e-01 3.91878746e-02 -1.35271400e-01 -1.03187911e-01]\n", + " [-3.39903794e-02 -5.52454554e-02 -4.73374985e-02 -1.52837262e-01\n", + " 1.61986634e-01 1.15967356e-01 4.41279002e-02 5.06293550e-02\n", + " 2.61772387e-02 1.67198420e-01 5.05979806e-02 3.40624861e-02\n", + " -1.22919112e-01 7.45933205e-02 -2.09194586e-01 7.05230013e-02\n", + " -1.93819985e-01 -9.25445408e-02 1.18050657e-01 -1.33182898e-01]\n", + " [ 1.78052112e-01 -1.23547316e-01 2.11798310e-01 6.89183101e-02\n", + " -9.69009325e-02 1.36373073e-01 -1.98024541e-01 -1.41652852e-01\n", + " -1.40091866e-01 2.94355899e-02 2.19678022e-02 -1.92325816e-01\n", + " 2.15771765e-01 -2.13701205e-04 -1.19405292e-01 5.34111727e-03\n", + " -9.59839672e-02 6.16913289e-02 8.09477344e-02 -6.34285584e-02]\n", + " [ 1.30358534e-02 1.33047834e-01 -1.45440847e-01 -4.98616323e-02\n", + " -3.29875015e-02 -1.47941127e-01 1.82121564e-02 8.21812730e-03\n", + " -1.80613607e-01 4.58700024e-02 2.13425189e-01 1.18935056e-01\n", + " -1.21292830e-01 2.04682201e-01 -1.53705969e-01 -1.13691926e-01\n", + " 9.86314118e-02 1.77888468e-01 2.13384852e-01 1.92508563e-01]\n", + " [-1.23128124e-01 5.11671938e-02 -1.40405849e-01 4.93797194e-03\n", + " 1.85259327e-01 1.10102132e-01 -2.06472665e-01 -9.62342396e-02\n", + " -1.88666239e-01 1.05334759e-01 -2.83857696e-02 -1.63461700e-01\n", + " -7.14522004e-02 7.33797774e-02 2.07014289e-02 2.09811881e-01\n", + " -2.96870619e-03 7.03370497e-02 -6.77365363e-02 2.66825557e-02]\n", + " [ 8.01036973e-03 1.92074046e-01 9.36935991e-02 -1.27431735e-01\n", + " -1.98687479e-01 -2.12748200e-01 -8.12046453e-02 2.89045740e-02\n", + " 2.10361689e-01 -2.19703875e-02 8.74281824e-02 1.13642633e-01\n", + " -1.71282887e-01 -1.84971020e-01 8.47281963e-02 1.04225203e-01\n", + " -1.04119189e-01 3.50410007e-02 -2.18935862e-01 2.81849946e-03]\n", + " [ 5.48111200e-02 2.11656699e-03 -3.54930870e-02 9.30717662e-02\n", + " -6.14620335e-02 1.66451484e-01 -1.92599118e-01 -1.27790585e-01\n", + " -1.86674312e-01 -2.02230543e-01 1.65771663e-01 -5.53366169e-02\n", + " -1.75649151e-01 4.63781990e-02 -1.69327542e-01 1.15589779e-02\n", + " 1.06298663e-01 -4.72831465e-02 1.14950888e-01 4.58941013e-02]\n", + " [-1.79431096e-01 4.40098420e-02 1.44146204e-01 -5.18364720e-02\n", + " 2.11329088e-02 2.85264328e-02 1.92284174e-02 5.81263304e-02\n", + " -2.14094386e-01 1.69653893e-01 9.75249708e-02 2.76133306e-02\n", + " 4.06875163e-02 -1.80331707e-01 -6.38444126e-02 -9.72616393e-03\n", + " 5.31534106e-02 -1.22661509e-01 2.37256587e-02 -6.93958476e-02]\n", + " [ 1.62758812e-01 -1.91935405e-01 2.33742520e-02 1.51492402e-01\n", + " -1.73671409e-01 -6.40887721e-03 1.03327051e-01 9.02309865e-02\n", + " 2.62962040e-02 9.03898776e-02 -1.55875593e-01 1.86238810e-01\n", + " 4.98715229e-03 1.44541100e-01 4.94662710e-02 -2.48756800e-02\n", + " 9.57791656e-02 2.12270051e-01 2.20569506e-01 -1.88220173e-01]\n", + " [ 1.35616167e-02 -1.60633817e-01 1.30284145e-01 1.60526067e-01\n", + " -1.57016143e-01 -1.29234986e-02 1.54731110e-01 1.47872686e-01\n", + " -1.68123141e-01 1.50136366e-01 -3.95872369e-02 -1.90171361e-01\n", + " 4.45422679e-02 1.04169942e-01 1.34101674e-01 -1.52035385e-01\n", + " -1.61954522e-01 -1.50239438e-01 1.26720712e-01 -1.95428118e-01]\n", + " [-1.88556593e-03 -6.57092705e-02 9.76277590e-02 4.39127870e-02\n", + " -1.12915963e-01 3.90566476e-02 2.05778107e-01 3.68154384e-02\n", + " -1.10807024e-01 7.48633966e-03 -2.05102757e-01 -1.43465236e-01\n", + " -4.15345095e-02 -1.39340952e-01 1.89353585e-01 4.34043780e-02\n", + " 1.73192978e-01 -5.09172641e-02 -3.10981516e-02 5.64037636e-02]\n", + " [-6.64871484e-02 -7.62214959e-02 -2.19352797e-01 1.68453470e-01\n", + " 2.02370644e-01 -2.21398085e-01 -7.39822015e-02 -1.69133484e-01\n", + " -9.07677040e-02 1.70234248e-01 1.19611956e-01 -1.73501018e-02\n", + " 9.55028459e-02 6.67780936e-02 1.22115597e-01 -1.79690495e-01\n", + " 6.91184700e-02 -2.11776465e-01 -1.47058472e-01 -8.33279863e-02]\n", + " [-2.17858739e-02 -2.11018786e-01 5.56494808e-03 3.57002839e-02\n", + " -8.87419507e-02 7.25275800e-02 1.95392817e-01 -3.81953120e-02\n", + " -1.19088188e-01 -1.98077247e-01 -1.63278311e-01 -1.23674117e-01\n", + " -1.65306747e-01 -8.79110843e-02 1.23181596e-01 6.99715093e-02\n", + " 2.01542184e-01 2.22007304e-01 -8.05223361e-02 -8.75686854e-02]\n", + " [ 3.05994693e-02 -1.78054109e-01 1.21623978e-01 -4.02442813e-02\n", + " -1.87232435e-01 -1.68819025e-01 -1.54080361e-01 6.14588112e-02\n", + " 1.71410367e-01 1.77153081e-01 -6.15712442e-02 -1.29883334e-01\n", + " -9.92444977e-02 -1.52750149e-01 -5.76506779e-02 -2.01948732e-01\n", + " 1.19517274e-01 -2.10457653e-01 -1.39095634e-01 1.50062576e-01]\n", + " [-1.67259946e-01 5.34564890e-02 1.67486787e-01 2.20412284e-01\n", + " 1.13142729e-01 -6.00084551e-02 1.27776846e-01 -7.37963570e-03\n", + " -6.89469650e-02 7.28242099e-04 5.01570366e-02 1.49932787e-01\n", + " 9.38621163e-02 1.06770106e-01 3.34510244e-02 -1.12544857e-02\n", + " 9.38917845e-02 5.37824407e-02 -2.13967159e-01 3.61516774e-02]\n", + " [-9.93019715e-02 -1.18578210e-01 8.64755288e-02 4.57250476e-02\n", + " 3.78663242e-02 -1.06075369e-01 1.03322893e-01 2.09839717e-01\n", + " 2.73554083e-02 9.19082835e-02 -1.96176514e-01 1.32933155e-01\n", + " 7.76783228e-02 1.00741126e-01 9.32467878e-02 -5.88140823e-02\n", + " -1.34220198e-02 2.16287613e-01 1.63621128e-01 -1.60278752e-01]] , 8\n", + "-------------------------\n", + "0.layers.0.0.output_gate_params.bias\n", + "(20,)\n", + "[ 0.17741492 0.22254053 0.02940683 -0.17445402 0.04334408 -0.04515981\n", + " 0.16077036 -0.21483785 0.05722176 -0.00262266 0.01760296 0.15381731\n", + " 0.0040394 -0.18002152 -0.13043821 -0.08953302 0.02384774 0.08628984\n", + " -0.04173774 -0.08825271] , 9\n", + "-------------------------\n", + "0.layers.0.0.output_gate_params.input_weight.weight\n", + "(20, 10)\n", + "[[ 9.81200710e-02 -2.17414662e-01 1.56252235e-01 -2.59936582e-02\n", + " 1.55592158e-01 1.68960407e-01 2.38872208e-02 7.07329437e-02\n", + " -1.26473457e-01 1.60210714e-01]\n", + " [ 1.30875960e-01 -3.51194218e-02 8.71568248e-02 -1.25249382e-02\n", + " 1.74701765e-01 9.20466036e-02 1.63019851e-01 -2.03253865e-01\n", + " 2.17866078e-01 8.33117217e-02]\n", + " [ 1.08713590e-01 4.98261265e-02 1.46862045e-01 2.10508242e-01\n", + " -1.90491565e-02 -1.83473915e-01 2.05329910e-01 -4.71567698e-02\n", + " -1.07840233e-01 1.37649149e-01]\n", + " [ 1.24790154e-01 2.99369618e-02 -1.40363071e-02 -4.27761748e-02\n", + " 2.05027208e-01 1.36240214e-01 1.33165866e-01 1.42589167e-01\n", + " -1.17026694e-01 4.66880240e-02]\n", + " [-1.93439931e-01 1.29910931e-01 -2.21640781e-01 -2.23473564e-01\n", + " -2.21031293e-01 1.37891039e-01 2.32707467e-02 5.08490019e-04\n", + " 3.55657227e-02 -8.46242681e-02]\n", + " [-6.79011941e-02 -1.50619775e-01 -5.46085611e-02 -1.37593433e-01\n", + " 5.88322058e-03 1.75689265e-01 -1.84854001e-01 1.09963417e-01\n", + " -1.66318297e-01 -9.26456451e-02]\n", + " [ 4.37250473e-02 3.84753868e-02 1.83374569e-01 -8.36465479e-05\n", + " -8.51647705e-02 -9.24766734e-02 6.55569835e-03 -1.67666823e-01\n", + " -1.75320774e-01 -9.56731290e-02]\n", + " [ 5.74407633e-03 -1.51010871e-01 -1.27642184e-01 1.59654185e-01\n", + " 2.06639260e-01 -7.00415373e-02 -1.91840678e-01 -8.56086463e-02\n", + " 9.02482048e-02 7.25704432e-02]\n", + " [-6.93180412e-02 -1.96934849e-01 -6.72358871e-02 -4.99973148e-02\n", + " 1.28766835e-01 -1.10879898e-01 1.34200945e-01 3.10183968e-02\n", + " -3.74761075e-02 -1.99273914e-01]\n", + " [ 2.20759660e-01 -3.98728549e-02 1.40693069e-01 -1.15664735e-01\n", + " -2.17755169e-01 -1.78237423e-01 -1.14595190e-01 -7.12116584e-02\n", + " -3.15762796e-02 1.86491266e-01]\n", + " [-2.06223264e-01 1.11605875e-01 1.88149154e-01 1.43918453e-03\n", + " -1.39450610e-01 7.15188682e-03 5.30482270e-02 9.89372358e-02\n", + " -6.79695681e-02 -7.67354444e-02]\n", + " [-1.05491146e-01 -2.16275647e-01 7.85326734e-02 -1.69050053e-01\n", + " -1.07421041e-01 -2.30107992e-03 1.72379389e-01 1.98816836e-01\n", + " -1.62642673e-01 1.93931282e-01]\n", + " [ 2.00302720e-01 1.80637628e-01 1.94676816e-02 1.79588884e-01\n", + " 1.08642928e-01 -1.60451204e-01 -1.17858045e-01 4.20530513e-03\n", + " -1.58465564e-01 -7.36296773e-02]\n", + " [ 1.80281103e-01 1.04106739e-01 1.94734529e-01 1.71422120e-03\n", + " -1.14017285e-01 1.47993699e-01 1.64847951e-02 3.76562215e-02\n", + " -9.47417393e-02 9.18511599e-02]\n", + " [-1.65143967e-01 1.78432971e-01 1.95620790e-01 8.06822702e-02\n", + " 1.74128443e-01 1.35722205e-01 -8.53993148e-02 -1.93941638e-01\n", + " 2.94244476e-02 1.40397370e-01]\n", + " [-2.28753053e-02 1.88145563e-02 1.65735826e-01 9.23255607e-02\n", + " 1.67166159e-01 3.28338295e-02 2.50651501e-02 -5.34861833e-02\n", + " -3.77333388e-02 -1.18839331e-01]\n", + " [ 1.49498299e-01 2.03940362e-01 8.29838291e-02 6.35351241e-03\n", + " -7.38137364e-02 -2.20774114e-01 -4.14042696e-02 -1.58739850e-01\n", + " -1.65080443e-01 -4.42778133e-02]\n", + " [-4.39881422e-02 4.51072417e-02 -1.62074581e-01 1.60696968e-01\n", + " -2.03583151e-01 -1.05898663e-01 -8.48927200e-02 1.37860607e-02\n", + " 9.24347416e-02 -5.89275286e-02]\n", + " [ 3.48980725e-02 -5.29355779e-02 -8.79468024e-02 -3.12774107e-02\n", + " 4.50214110e-02 -2.17200696e-01 -1.55640006e-01 1.74693078e-01\n", + " 1.01111621e-01 -5.97870257e-03]\n", + " [ 7.06157601e-03 3.08655780e-02 5.19711897e-02 -1.52664930e-01\n", + " -6.09524250e-02 -2.05220923e-01 -1.75796479e-01 -4.20728028e-02\n", + " -2.95243543e-02 1.74893185e-01]] , 10\n", + "-------------------------\n", + "0.layers.0.0.output_gate_params.hidden_weight.weight\n", + "(20, 20)\n", + "[[ 0.03851524 -0.03625689 -0.00619491 0.12488268 -0.06773603 -0.0418019\n", + " -0.04485707 -0.18031046 -0.03125188 -0.20671144 -0.12019279 -0.14232881\n", + " 0.16657048 -0.20598304 0.21545227 0.08384079 -0.15111198 0.18525589\n", + " -0.0492739 -0.18939163]\n", + " [-0.03105276 0.11050874 -0.21741039 -0.01675669 0.09098183 -0.08714523\n", + " 0.02036562 -0.0876366 -0.15001732 0.17511557 -0.1587715 -0.00262151\n", + " 0.07447443 -0.12496222 0.10796666 -0.18569624 0.21355589 0.09958527\n", + " -0.03165689 -0.18600492]\n", + " [ 0.00689578 0.0793154 -0.12144296 -0.02816021 -0.22284126 -0.22354037\n", + " -0.02428471 0.187102 -0.01052416 0.07010341 -0.08937916 -0.07301357\n", + " -0.02457852 -0.11304034 0.13682817 0.13944101 -0.17383203 0.06858449\n", + " -0.09237309 -0.12858376]\n", + " [-0.02727968 -0.0693544 -0.12731954 0.03295429 0.12762886 -0.03450404\n", + " -0.01564156 0.01682661 -0.09610138 0.11838 0.2063172 -0.02043679\n", + " 0.01520035 0.18016809 0.18314716 -0.16634111 -0.10355289 -0.21934243\n", + " 0.13695723 0.17452586]\n", + " [-0.08138426 0.07172713 0.05416519 -0.19238184 0.0892937 0.10971964\n", + " 0.00491766 0.02293088 0.05196048 0.16108814 0.19757238 0.03213832\n", + " 0.09531388 -0.05850127 0.13331535 -0.08795608 -0.18431664 0.1049106\n", + " 0.08293276 0.0492176 ]\n", + " [ 0.09513766 0.02660845 0.0761021 0.09111597 -0.12062387 -0.01198089\n", + " 0.03369791 -0.03394864 -0.188005 0.02121117 0.13665509 -0.11958458\n", + " 0.21953909 0.0509951 0.09510146 -0.08634473 -0.18291326 -0.08321758\n", + " 0.00683159 -0.10189173]\n", + " [ 0.19913672 -0.14311586 -0.15060481 -0.0793146 0.20060927 -0.10224532\n", + " 0.20686573 0.10745841 -0.03397548 0.11565119 0.10630453 -0.11381406\n", + " -0.04603498 0.21659105 0.12819836 -0.10921414 -0.0601254 0.12532982\n", + " 0.11351746 0.01772486]\n", + " [-0.14387828 -0.16492477 -0.04719649 0.08221286 -0.02383876 -0.18695372\n", + " -0.05480145 0.22319667 -0.18481532 -0.17354017 0.14056584 0.22249034\n", + " -0.21510145 -0.20223859 -0.06991865 0.22294378 -0.1269095 0.01911828\n", + " 0.18253623 -0.0791588 ]\n", + " [-0.06857247 -0.15009233 0.0085855 0.20870976 0.0914357 0.157171\n", + " -0.01481424 -0.03551737 -0.03994827 0.12753342 -0.02932107 -0.19100396\n", + " -0.07851914 0.08750965 0.21801063 -0.04065894 -0.19468635 -0.16464569\n", + " -0.1759353 0.09013668]\n", + " [ 0.16482699 0.06612828 0.07709847 0.14567545 0.15288451 0.13352284\n", + " 0.12504087 0.06050573 0.11541758 -0.1534312 -0.14473058 0.06013739\n", + " 0.03479816 -0.19657765 -0.16289718 -0.17800786 0.17759389 0.14619377\n", + " -0.11769552 0.033738 ]\n", + " [-0.05143119 0.19438726 -0.20252845 -0.16313015 -0.18616724 0.13013433\n", + " -0.11177826 0.13318242 0.07558636 -0.10929734 -0.06023749 -0.09048979\n", + " 0.09864956 -0.08967353 0.07588523 0.01597441 -0.17857382 -0.1405619\n", + " -0.1550431 0.1171688 ]\n", + " [ 0.0484514 -0.00562237 -0.1331447 -0.22155127 -0.07913139 -0.17113578\n", + " -0.22241357 -0.21326728 -0.14605871 -0.21737726 0.069704 0.08366753\n", + " 0.0901287 -0.22259942 0.13826938 0.04359518 0.11433873 -0.05495736\n", + " 0.10737925 -0.21207204]\n", + " [ 0.0761621 0.17731208 0.09399657 -0.21077465 -0.06277167 -0.02776839\n", + " 0.11715963 -0.08461329 0.03216063 -0.07849736 -0.03552182 -0.00445118\n", + " -0.1283987 -0.15520401 0.1845957 0.18787426 -0.00676964 0.19354711\n", + " 0.17230819 -0.14084579]\n", + " [-0.08885217 -0.15358365 0.07229424 0.00565505 -0.03066478 0.16602065\n", + " -0.08740129 -0.12237797 -0.15895672 -0.11375529 0.21551864 -0.10871551\n", + " -0.06152614 0.10078279 -0.17173737 -0.13572007 0.16457646 -0.08576282\n", + " -0.1160312 -0.02892987]\n", + " [-0.03186222 0.04086494 0.08197901 -0.17241116 0.2032053 -0.21259488\n", + " 0.07573222 -0.06309208 -0.09442816 0.20916638 -0.2154794 0.01527144\n", + " 0.1432838 0.19990316 -0.18904059 0.02694101 0.22123207 -0.21902935\n", + " 0.0546164 -0.14010552]\n", + " [ 0.03629959 -0.20227122 0.11001531 -0.04960475 0.13363701 -0.0033625\n", + " -0.03187283 -0.05428797 -0.2047436 -0.09497944 0.00742607 -0.1729926\n", + " 0.19623755 -0.14542621 -0.08711543 -0.02990268 -0.1811355 -0.00176668\n", + " -0.10767633 -0.1871676 ]\n", + " [ 0.00548474 0.19795649 0.05506302 0.18442854 -0.0021867 -0.07804751\n", + " 0.1802177 -0.11907462 -0.20685978 0.0489392 0.11143997 -0.13366425\n", + " 0.07870162 -0.07933193 -0.02713096 -0.04951058 -0.04782786 -0.18194063\n", + " 0.05480235 -0.05881837]\n", + " [ 0.17097771 0.03732251 -0.18287036 -0.17010981 -0.11653572 0.10708019\n", + " -0.14437075 -0.10229405 0.04059571 -0.15502611 -0.11010965 0.20276332\n", + " -0.11821949 -0.07449946 0.1599237 0.05010674 0.17550889 -0.19699533\n", + " 0.11176885 -0.03420243]\n", + " [-0.14325288 -0.09576999 -0.21628909 0.15468563 -0.04290593 -0.2192564\n", + " 0.19123225 0.14483131 0.09245753 0.21885075 0.20192903 0.20897363\n", + " 0.2002456 0.18172018 0.05853782 -0.01872608 0.00850361 -0.09292599\n", + " 0.10506337 0.00647802]\n", + " [ 0.05275466 -0.14403579 -0.08419433 0.16763861 0.02174832 0.07716487\n", + " -0.1952104 -0.09575427 -0.00569092 -0.0234643 0.14273825 -0.06748112\n", + " 0.18662164 -0.04324729 0.08697162 -0.15742545 0.03795354 -0.21800253\n", + " -0.19185208 -0.14310952]] , 11\n", + "-------------------------\n", + "/0/layers.0.0/output_quant/export_handler/Constant_output_0\n", + "()\n", + "0.0078125 , 12\n", + "-------------------------\n", + "/0/layers.0.0/output_quant/export_handler/Constant_1_output_0\n", + "()\n", + "0 , 13\n", + "-------------------------\n", + "/0/layers.0.0/output_quant/export_handler/Constant_2_output_0\n", + "()\n", + "8.0 , 14\n", + "-------------------------\n", + "/0/layers.0.0/input_weight/weight_quant/export_handler/Constant_output_0\n", + "()\n", + "0.001760039 , 15\n", + "-------------------------\n", + "/0/layers.0.0/input_weight/weight_quant/export_handler/Constant_1_output_0\n", + "()\n", + "-127 , 16\n", + "-------------------------\n", + "/0/layers.0.0/input_weight/weight_quant/export_handler/Constant_2_output_0\n", + "()\n", + "127 , 17\n", + "-------------------------\n", + "/0/layers.0.0/input_weight/weight_quant/export_handler_1/Constant_output_0\n", + "()\n", + "0.0017542557 , 18\n", + "-------------------------\n", + "/0/layers.0.0/input_weight/weight_quant/export_handler_2/Constant_output_0\n", + "()\n", + "0.0017601603 , 19\n", + "-------------------------\n", + "/0/layers.0.0/input_weight/weight_quant/export_handler_3/Constant_output_0\n", + "()\n", + "0.0017546351 , 20\n", + "-------------------------\n", + "onnx.brevitas::QuantLSTMCell_48\n", + "(1, 20)\n", + "[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]] , 21\n", + "-------------------------\n", + "/0/layers.0.0/export_handler/Constant_output_0\n", + "()\n", + "0.003921569 , 22\n", + "-------------------------\n", + "/0/layers.0.0/export_handler/Constant_1_output_0\n", + "()\n", + "0 , 23\n", + "-------------------------\n", + "/0/layers.0.0/Constant_output_0\n", + "(1,)\n", + "[0] , 24\n", + "-------------------------\n", + "/0/layers.0.0/Constant_1_output_0\n", + "(1,)\n", + "[0] , 25\n", + "-------------------------\n" + ] + } + ], + "source": [ + "# In this block of code we store all the parameters (weight matrices, recurrence matrices, biases, scales and zero-points) that we will need to import in the QCDQ implementation.\n", + "# Importing the exported quantized model from brevitas\n", + "brevitas_lstm_export = onnx.load(\"./quant_lstm_quantization_qcdq.onnx\")\n", + "parameters = brevitas_lstm_export.graph.initializer #Extracting all the parameters from the loaded graph\n", + "\n", + "# In this loop we will be printing all the parameters to correctly import the parameters values to the right variables\n", + "for i in range(len(parameters)):\n", + " w = numpy_helper.to_array(parameters[i])\n", + " print (brevitas_lstm_export.graph.initializer[i].name)\n", + " print(w.shape)\n", + " print(w,',',i)\n", + " print(\"-------------------------\")\n", + " \n", + "# Storing the extracted parameters (weights/biases/scales) to the right variables depending on the order in which they are exported. \n", + "# The abbreviation described in the above block will help in understanding what each variable denotes\n", + "\n", + "bi_val = numpy_helper.to_array(parameters[0])\n", + "Wi_val = numpy_helper.to_array(parameters[1])\n", + "Ui_val = numpy_helper.to_array(parameters[2])\n", + "bf_val = numpy_helper.to_array(parameters[3])\n", + "Wf_val = numpy_helper.to_array(parameters[4])\n", + "Uf_val = numpy_helper.to_array(parameters[5])\n", + "bc_val = numpy_helper.to_array(parameters[6])\n", + "Wc_val = numpy_helper.to_array(parameters[7])\n", + "Uc_val = numpy_helper.to_array(parameters[8])\n", + "bo_val = numpy_helper.to_array(parameters[9])\n", + "Wo_val = numpy_helper.to_array(parameters[10])\n", + "Uo_val = numpy_helper.to_array(parameters[11])\n", + "# Scalar values can either be int or float\n", + "inp_scale_val = float(numpy_helper.to_array(parameters[12])) \n", + "w1_scale_val = float(numpy_helper.to_array(parameters[15]))\n", + "w2_scale_val = float(numpy_helper.to_array(parameters[18]))\n", + "w3_scale_val = float(numpy_helper.to_array(parameters[19]))\n", + "w4_scale_val = float(numpy_helper.to_array(parameters[20]))\n", + "eq_scale_val_1 = float(numpy_helper.to_array(parameters[12]))\n", + "eq_scale_val_2 = float(numpy_helper.to_array(parameters[22]))" + ] + }, + { + "cell_type": "markdown", + "id": "10237589-f84e-423a-829e-3e2c2e806ed7", + "metadata": {}, + "source": [ + "# LSTM ONNX model" + ] + }, + { + "cell_type": "markdown", + "id": "367547b8", + "metadata": {}, + "source": [ + "In the 3rd part of the notebook, we will construct the `QCDQ-LSTM` model with standard ONNX operators. After loading all the parameters in the above block we can now start building our ONNX model with QCDQ quantization to represent the LSTM computations described in part-1.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "02fe4d94-af24-4d5e-a809-7d8c49e7fd90", + "metadata": {}, + "outputs": [], + "source": [ + "# Setting parameters : Matching the input output lengths exported from brevitas\n", + "num_features = 10\n", + "num_hidden_cells = 20\n", + "activation_bit_width = 8\n", + "\n", + "# The below two parameters are for the 'Clip' operation. \n", + "# Clip node parameters\n", + "max_clip_val = (2 ** (activation_bit_width -1) - 1)\n", + "min_clip_val = -(2 ** (activation_bit_width -1) - 1)\n", + "\n", + "# Zero-point datatype decides the datatype of the output tensor for the quantization operations hence we defined two. One for signed and other for unsigned.\n", + "# Zero point values for quantization\n", + "zero_point_signed_val = 0\n", + "zero_point_unsigned_val = 0" + ] + }, + { + "cell_type": "markdown", + "id": "15098a9e-4187-4987-82cc-275eba650923", + "metadata": {}, + "source": [ + "`Abbreviations` : These describe different short-forms used in the next two blocks.\n", + "\n", + "* ql = \"QuantizeLinear\"\n", + "* dql = \"DequantizeLinear\"\n", + "* clp = \"Clip\"\n", + "* id = \"Identity\"\n", + "* matmul = \"Matrix Multiplication\"\n", + "* el_mul = \"Elementwise Multiplication\"\n", + "* sig = \"Sigmoid\"" + ] + }, + { + "cell_type": "markdown", + "id": "f2edc0cc", + "metadata": {}, + "source": [ + "We start defining the model by defining the `inputs` and `outputs` defined as value_info tensors in ONNX.\n", + "For LSTMs we need three inputs : `inputs`, `previous hidden state` and `previous cell state`. \n", + "We get three outputs : `hidden_state`, `cell_state` and `concatenated_hidden_states`." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "02761646-4c6d-440f-8e90-4935beebab56", + "metadata": {}, + "outputs": [], + "source": [ + "# Defining the inputs 'value info' tensors for the compute graph.\n", + "hidden_state = make_tensor_value_info(\"h_t-1\",onnx.TensorProto.FLOAT, [num_hidden_cells,1])\n", + "cell_state = make_tensor_value_info(\"c_t-1\", onnx.TensorProto.FLOAT, [num_hidden_cells,1])\n", + "inputs = make_tensor_value_info(\"inp\",onnx.TensorProto.FLOAT, [num_features,1])\n", + "\n", + "#Output value info tensor definitions\n", + "out_hidden_state = make_tensor_value_info(\"h_t\", onnx.TensorProto.FLOAT, [num_hidden_cells,1])\n", + "out_cell_state = make_tensor_value_info(\"c_t\", onnx.TensorProto.FLOAT, [num_hidden_cells,1])\n", + "out_hidden_state_concat = make_tensor_value_info(\"h_t_concat\", onnx.TensorProto.FLOAT, [num_hidden_cells,1])#maybe this will have one more dimension" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c08e5a23-ef2e-4bca-9293-c800350c2c62", + "metadata": {}, + "outputs": [], + "source": [ + "# Once we have defined the inputs and outputs, we will now start defining the operations in the LSTM compute graph.\n", + "# We start by quantizing the input with the standard QDQ operation which is 8-bit quantization. \n", + "# Note: For quantization to lower bit-width's we can use the clip node.\n", + "\n", + "# Input quantization\n", + "ql_input = make_node(\"QuantizeLinear\", inputs=[\"inp\",\"inp_scale\",\"zero_point_signed\"], outputs=[\"ql_input_out\"],name=\"ql_input\")\n", + "id_0_input = make_node(\"Identity\", inputs=[\"ql_input_out\"], outputs=[\"first_input_out\"], name=\"id_0_input\")\n", + "dql_input = make_node(\"DequantizeLinear\", inputs=[\"ql_input_out\", 'inp_scale', \"zero_point_signed\"], outputs=[\"dql_input_out\"],name=\"dql_input\")\n", + "\n", + "# Quantization of the four weight matrices showing QCDQ quantization\n", + "ql_w1 = make_node(\"QuantizeLinear\", inputs=[\"W_f\",\"scale_f\",\"zero_point_signed\"], outputs=[\"ql_wf_out\"], name=\"ql_w1\")\n", + "clp_w1 = make_node(\"Clip\", inputs=[\"ql_wf_out\",\"min\",\"max\"], outputs=[\"clp_wf\"], name=\"clp_w1\")\n", + "dql_w1 = make_node(\"DequantizeLinear\", inputs=[\"clp_wf\",\"scale_f\",\"zero_point_signed\"], outputs=[\"dql_wf_out\"], name=\"dql_w1\")\n", + "\n", + "ql_w2 = make_node(\"QuantizeLinear\", inputs=[\"W_i\",\"scale_i\",\"zero_point_signed\"], outputs=[\"ql_wi_out\"], name=\"ql_w2\")\n", + "clp_w2 = make_node(\"Clip\", inputs=[\"ql_wi_out\",\"min\",\"max\"], outputs=[\"clp_wi\"], name=\"clp_w2\")\n", + "dql_w2 = make_node(\"DequantizeLinear\", inputs=[\"clp_wi\",\"scale_i\",\"zero_point_signed\"], outputs=[\"dql_wi_out\"], name=\"dql_w2\")\n", + "\n", + "ql_w3 = make_node(\"QuantizeLinear\", inputs=[\"W_c\",\"scale_c\",\"zero_point_signed\"], outputs=[\"ql_wc_out\"], name=\"ql_w3\")\n", + "clp_w3 = make_node(\"Clip\", inputs=[\"ql_wc_out\",\"min\",\"max\"], outputs=[\"clp_wc\"], name=\"clp_w3\")\n", + "dql_w3 = make_node(\"DequantizeLinear\", inputs=[\"clp_wc\",\"scale_c\",\"zero_point_signed\"], outputs=[\"dql_wc_out\"], name=\"dql_w3\")\n", + "\n", + "ql_w4 = make_node(\"QuantizeLinear\", inputs=[\"W_o\",\"scale_o\",\"zero_point_signed\"], outputs=[\"ql_wo_out\"], name=\"ql_w4\")\n", + "clp_w4 = make_node(\"Clip\", inputs=[\"ql_wo_out\",\"min\",\"max\"], outputs=[\"clp_wo\"], name=\"clp_w4\")\n", + "dql_w4 = make_node(\"DequantizeLinear\", inputs=[\"clp_wo\",\"scale_o\",\"zero_point_signed\"], outputs=[\"dql_wo_out\"], name=\"dql_w4\")\n", + "\n", + "# Quantizations for the four recurrence weight matrices showing QCDQ quantization\n", + "ql_u1 = make_node(\"QuantizeLinear\", inputs=[\"U_f\",\"scale_f\",\"zero_point_signed\"], outputs=[\"ql_uf_out\"], name=\"ql_u1\")\n", + "clp_u1 = make_node(\"Clip\", inputs=[\"ql_uf_out\",\"min\",\"max\"], outputs=[\"clp_uf\"], name=\"clp_u1\")\n", + "dql_u1 = make_node(\"DequantizeLinear\", inputs=[\"clp_uf\",\"scale_f\",\"zero_point_signed\"], outputs=[\"dql_uf_out\"], name=\"dql_u1\")\n", + "\n", + "ql_u2 = make_node(\"QuantizeLinear\", inputs=[\"U_i\",\"scale_i\",\"zero_point_signed\"], outputs=[\"ql_ui_out\"], name=\"ql_u2\")\n", + "clp_u2 = make_node(\"Clip\", inputs=[\"ql_ui_out\",\"min\",\"max\"], outputs=[\"clp_ui\"], name=\"clp_u2\")\n", + "dql_u2 = make_node(\"DequantizeLinear\", inputs=[\"clp_ui\",\"scale_i\",\"zero_point_signed\"], outputs=[\"dql_ui_out\"], name=\"dql_u2\")\n", + "\n", + "ql_u3 = make_node(\"QuantizeLinear\", inputs=[\"U_c\",\"scale_c\",\"zero_point_signed\"], outputs=[\"ql_uc_out\"], name=\"ql_u3\")\n", + "clp_u3 = make_node(\"Clip\", inputs=[\"ql_uc_out\",\"min\",\"max\"], outputs=[\"clp_uc\"], name=\"clp_u3\")\n", + "dql_u3 = make_node(\"DequantizeLinear\", inputs=[\"clp_uc\",\"scale_c\",\"zero_point_signed\"], outputs=[\"dql_uc_out\"], name=\"dql_u3\")\n", + "\n", + "ql_u4 = make_node(\"QuantizeLinear\", inputs=[\"U_o\",\"scale_o\",\"zero_point_signed\"], outputs=[\"ql_uo_out\"], name=\"ql_u4\")\n", + "clp_u4 = make_node(\"Clip\", inputs=[\"ql_uo_out\",\"min\",\"max\"], outputs=[\"clp_uo\"], name=\"clp_u4\")\n", + "dql_u4 = make_node(\"DequantizeLinear\", inputs=[\"clp_uo\",\"scale_o\",\"zero_point_signed\"], outputs=[\"dql_uo_out\"], name=\"dql_u4\")\n", + "\n", + "# Once we have quantized the weights and inputs we can now start defining the operations for the 6 LSTM equations.\n", + "# The first four gate equations have a very similar compute structure. We define the first four gate computations in this order : Forget, Input, Output, Cell \n", + "\n", + "# 1st Equation : Forget gate\n", + "matmul_1_e1 = make_node(\"MatMul\", inputs=[\"dql_wf_out\",\"dql_input_out\"], outputs=[\"out_m1_e1\"], name=\"matmul_1_e1\")\n", + "matmul_2_e1 = make_node(\"MatMul\", inputs=[\"dql_uf_out\",\"h_t-1\"], outputs=[\"out_m2_e1\"],name=\"matmul_2_e1\")\n", + "add_1_e1 = make_node(\"Add\", inputs=[\"out_m1_e1\",\"out_m2_e1\"], outputs=[\"out_add1_e1\"],name=\"add_1_e1\")\n", + "add_2_e1 = make_node(\"Add\", inputs=[\"out_add1_e1\",\"b_f\"], outputs=[\"f_t_ba\"],name=\"add_2_e1\")\n", + "ql_1_e1 = make_node(\"QuantizeLinear\", inputs=[\"f_t_ba\",\"scale_3\",\"zero_point_signed\"], outputs=[\"f_t_ql1\"],name=\"ql_1_e1\")\n", + "dql_1_e1 = make_node(\"DequantizeLinear\", inputs=[\"f_t_ql1\", \"scale_4\", \"zero_point_signed\"], outputs=[\"f_t_dql1\"], name=\"dql_1_e1\")\n", + "sig_f_e1 = make_node(\"Sigmoid\", inputs=[\"f_t_dql1\"], outputs=[\"f_t\"],name=\"sig_f_e1\")\n", + "ql_2_e1 = make_node(\"QuantizeLinear\", inputs=[\"f_t\",\"scale_4\",\"zero_point_unsigned\"], outputs=[\"f_t_ql2\"],name=\"ql_2_e1\")\n", + "dql_2_e1 = make_node(\"DequantizeLinear\", inputs=[\"f_t_ql2\", \"scale_4\", \"zero_point_unsigned\"], outputs=[\"f_t_dql2\"], name=\"dql_2_e1\")\n", + "\n", + "# 2nd Equation : Input gate\n", + "matmul_1_e2 = make_node(\"MatMul\", inputs=[\"dql_wi_out\",\"dql_input_out\"], outputs=[\"out_m1_e2\"], name=\"matmul_1_e2\")\n", + "matmul_2_e2 = make_node(\"MatMul\", inputs=[\"dql_ui_out\",\"h_t-1\"], outputs=[\"out_m2_e2\"],name=\"matmul_2_e2\")\n", + "add_1_e2 = make_node(\"Add\", inputs=[\"out_m1_e2\",\"out_m2_e2\"], outputs=[\"out_add1_e2\"],name=\"add_1_e2\")\n", + "add_2_e2 = make_node(\"Add\", inputs=[\"out_add1_e2\",\"b_i\"], outputs=[\"i_t_ba\"],name=\"add_2_e2\")\n", + "ql_1_e2 = make_node(\"QuantizeLinear\", inputs=[\"i_t_ba\",\"scale_1\",\"zero_point_signed\"], outputs=[\"i_t_ql1\"],name=\"ql_1_e2\")\n", + "dql_1_e2 = make_node(\"DequantizeLinear\", inputs=[\"i_t_ql1\",\"scale_1\", \"zero_point_signed\"], outputs=[\"i_t_dql1\"], name=\"dql_1_e2\")\n", + "sig_i_e2 = make_node(\"Sigmoid\", inputs=[\"i_t_dql1\"], outputs=[\"i_t\"],name=\"sig_i_e2\")\n", + "ql_2_e2 = make_node(\"QuantizeLinear\", inputs=[\"i_t\",\"scale_2\",\"zero_point_unsigned\"], outputs=[\"i_t_ql2\"],name=\"ql_2_e2\")\n", + "dql_2_e2 = make_node(\"DequantizeLinear\", inputs=[\"i_t_ql2\", \"scale_2\", \"zero_point_unsigned\"], outputs=[\"i_t_dql2\"], name=\"dql_2_e2\")\n", + "\n", + "# 3rd Equation : Output gate\n", + "matmul_1_e3 = make_node(\"MatMul\", inputs=[\"dql_wo_out\",\"dql_input_out\"], outputs=[\"out_m1_e3\"], name=\"matmul_1_e3\")\n", + "matmul_2_e3 = make_node(\"MatMul\", inputs=[\"dql_uo_out\",\"h_t-1\"], outputs=[\"out_m2_e3\"],name=\"matmul_2_e3\")\n", + "add_1_e3 = make_node(\"Add\", inputs=[\"out_m1_e3\",\"out_m2_e3\"], outputs=[\"out_add1_e3\"],name=\"add_1_e3\")\n", + "add_2_e3 = make_node(\"Add\", inputs=[\"out_add1_e3\",\"b_o\"], outputs=[\"o_t_ba\"],name=\"add_2_e3\" )\n", + "ql_1_e3 = make_node(\"QuantizeLinear\", inputs=[\"o_t_ba\",\"scale_7\",\"zero_point_signed\"], outputs=[\"o_t_ql1\"],name=\"ql_1_e3\")\n", + "dql_1_e3 = make_node(\"DequantizeLinear\", inputs=[\"o_t_ql1\",\"scale_7\", \"zero_point_signed\"], outputs=[\"o_t_dql1\"], name=\"dql_1_e3\")\n", + "sig_o_e3 = make_node(\"Sigmoid\", inputs=[\"o_t_dql1\"], outputs=[\"o_t\"],name=\"sig_o_e3\")\n", + "ql_2_e3 = make_node(\"QuantizeLinear\", inputs=[\"o_t\",\"scale_8\",\"zero_point_unsigned\"], outputs=[\"o_t_ql2\"],name=\"ql_2_e3\")\n", + "dql_2_e3 = make_node(\"DequantizeLinear\", inputs=[\"o_t_ql2\", \"scale_8\", \"zero_point_unsigned\"], outputs=[\"o_t_dql2\"], name=\"dql_2_e3\")\n", + "\n", + "# 4th Equation : Cell gate\n", + "matmul_1_e4 = make_node(\"MatMul\", inputs=[\"dql_wc_out\",\"dql_input_out\"], outputs=[\"out_m1_e4\"], name=\"matmul_1_e4\")\n", + "matmul_2_e4 = make_node(\"MatMul\", inputs=[\"dql_uc_out\",\"h_t-1\"], outputs=[\"out_m2_e4\"],name=\"matmul_2_e4\")\n", + "add_1_e4 = make_node(\"Add\", inputs=[\"out_m1_e4\",\"out_m2_e4\"], outputs=[\"out_add1_e4\"],name=\"add_1_e4\")\n", + "add_2_e4 = make_node(\"Add\", inputs=[\"out_add1_e4\",\"b_c\"], outputs=[\"c_t_ba\"],name=\"add_2_e4\")\n", + "ql_1_e4 = make_node(\"QuantizeLinear\", inputs=[\"c_t_ba\",\"scale_5\",\"zero_point_signed\"], outputs=[\"c_t_ql1\"],name=\"ql_1_e4\")\n", + "dql_1_e4 = make_node(\"DequantizeLinear\", inputs=[\"c_t_ql1\",\"scale_5\", \"zero_point_signed\"], outputs=[\"c_t_dql1\"], name=\"dql_1_e4\")\n", + "tanh_c_e4 = make_node(\"Tanh\", inputs=[\"c_t_dql1\"], outputs=[\"c_t_partial\"],name=\"tanh_c_e4\")\n", + "ql_2_e4 = make_node(\"QuantizeLinear\", inputs=[\"c_t_partial\",\"scale_6\",\"zero_point_signed\"], outputs=[\"c_t_ql2\"],name=\"ql_2_e4\")\n", + "dql_2_e4 = make_node(\"DequantizeLinear\", inputs=[\"c_t_ql2\", \"scale_6\", \"zero_point_signed\"], outputs=[\"c_t_dql2\"], name=\"dql_2_e4\")\n", + "\n", + "# Once we have the first four gate computations we can procedd with the computation of the cell_state and the hidden_state in the 5th and the 6th equations.\n", + "# 5th Equation : Cell state compute\n", + "el_mul_1_e5 = make_node(\"Mul\", inputs=[\"f_t_dql2\",\"c_t-1\"], outputs=[\"out_el_mul1_e5\"],name=\"el_mul_1_e5\")\n", + "ql_1_e5 = make_node(\"QuantizeLinear\", inputs=[\"out_el_mul1_e5\",\"scale_9\",\"zero_point_signed\"], outputs=[\"fifth_ql1\"],name=\"ql_1_e5\")\n", + "dql_1_e5 = make_node(\"DequantizeLinear\", inputs=[\"fifth_ql1\",\"scale_9\", \"zero_point_signed\"], outputs=[\"fifth_dql1\"], name=\"dql_1_e5\")\n", + "el_mul_2_e5 = make_node(\"Mul\", inputs=[\"i_t_dql2\",\"c_t_dql2\"], outputs=[\"out_el_mul2_e5\"], name=\"el_mul_2_e5\") \n", + "ql_2_e5 = make_node(\"QuantizeLinear\", inputs=[\"out_el_mul2_e5\",\"scale_9\",\"zero_point_signed\"], outputs=[\"fifth_ql2\"],name=\"ql_2_e5\")\n", + "dql_2_e5 = make_node(\"DequantizeLinear\", inputs=[\"fifth_ql2\",\"scale_9\", \"zero_point_signed\"], outputs=[\"fifth_dql2\"], name=\"dql_2_e5\")\n", + "add_1_e5 = make_node(\"Add\", inputs=[\"fifth_dql1\",\"fifth_dql2\"], outputs=[\"c_t\"], name=\"add_1_e5\") #-----------------> The first output is computed here.\n", + "ql_3_e5 = make_node(\"QuantizeLinear\", inputs=[\"c_t\",\"scale_9\",\"zero_point_signed\"], outputs=[\"h_t_ql\"], name=\"ql_3_e5\")\n", + "dql_3_e5 = make_node(\"DequantizeLinear\", inputs=[\"h_t_ql\",\"scale_9\",\"zero_point_signed\"], outputs=[\"h_t_dql\"], name=\"dql_3_e5\")\n", + "\n", + "# 6th Equation : Hidden state compute\n", + "tanh_node_e6 = make_node(\"Tanh\", inputs=[\"h_t_dql\"], outputs=[\"out_tanh_e6\"], name=\"tanh_node_e6\") \n", + "ql_1_e6 = make_node(\"QuantizeLinear\", inputs=[\"out_tanh_e6\",\"scale_10\",\"zero_point_signed\"], outputs=[\"sixth_ql1\"], name=\"ql_1_e6\")\n", + "dql_1_e6 = make_node(\"DequantizeLinear\", inputs=[\"sixth_ql1\",\"scale_10\",\"zero_point_signed\"], outputs=[\"sixth_dql1\"], name=\"dql_1_e6\")\n", + "el_mul_1_e6 = make_node(\"Mul\", inputs=[\"sixth_dql1\",\"o_t_dql2\"], outputs=[\"h_t_inter\"], name=\"el_mul_1_e6\")#h_t_inter\n", + "ql_2_e6 = make_node(\"QuantizeLinear\", inputs=[\"h_t_inter\",\"scale_11\",\"zero_point_signed\"], outputs=[\"sixth_ql2\"], name=\"ql_2_e6\")\n", + "dql_2_e6 = make_node(\"DequantizeLinear\", inputs=[\"sixth_ql2\",\"scale_11\",\"zero_point_signed\"], outputs=[\"h_t\"], name=\"dql_2_e6\") #-----------------> The second output is computed here.\n", + "id_1_e6 = make_node(\"Identity\", inputs=[\"h_t\"], outputs=[\"h_t_concat\"], name=\"id_1_e6\") #-----------------> The third output is computed here." + ] + }, + { + "cell_type": "markdown", + "id": "3d10867f", + "metadata": {}, + "source": [ + "After defining the above operations we now connect them and create a graph with the help of onnx.helper `make_graph` utility function" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "79839558-8752-4fc8-9b0e-8fed47c91701", + "metadata": {}, + "outputs": [], + "source": [ + "lstm_body = make_graph(\n", + " nodes=[\n", + " ql_input,\n", + " dql_input, \n", + " ql_w1,\n", + " clp_w1, \n", + " dql_w1,\n", + " ql_w2,\n", + " clp_w2, \n", + " dql_w2,\n", + " ql_w3,\n", + " clp_w3, \n", + " dql_w3,\n", + " ql_w4,\n", + " clp_w4, \n", + " dql_w4,\n", + " ql_u1,\n", + " clp_u1, \n", + " dql_u1,\n", + " ql_u2,\n", + " clp_u2,\n", + " dql_u2, \n", + " ql_u3,\n", + " clp_u3,\n", + " dql_u3, \n", + " ql_u4,\n", + " clp_u4,\n", + " dql_u4, \n", + " matmul_1_e1,\n", + " matmul_2_e1, \n", + " add_1_e1, \n", + " add_2_e1,\n", + " ql_1_e1,\n", + " dql_1_e1,\n", + " sig_f_e1,\n", + " ql_2_e1, \n", + " dql_2_e1, \n", + " matmul_1_e2,\n", + " matmul_2_e2, \n", + " add_1_e2, \n", + " add_2_e2,\n", + " ql_1_e2,\n", + " dql_1_e2,\n", + " sig_i_e2,\n", + " ql_2_e2, \n", + " dql_2_e2, \n", + " matmul_1_e3,\n", + " matmul_2_e3, \n", + " add_1_e3, \n", + " add_2_e3,\n", + " ql_1_e3,\n", + " dql_1_e3,\n", + " sig_o_e3,\n", + " ql_2_e3, \n", + " dql_2_e3, \n", + " matmul_1_e4,\n", + " matmul_2_e4, \n", + " add_1_e4, \n", + " add_2_e4,\n", + " ql_1_e4,\n", + " dql_1_e4,\n", + " tanh_c_e4,\n", + " ql_2_e4, \n", + " dql_2_e4, \n", + " el_mul_1_e5,\n", + " ql_1_e5, \n", + " dql_1_e5,\n", + " el_mul_2_e5,\n", + " ql_2_e5,\n", + " dql_2_e5,\n", + " add_1_e5,\n", + " ql_3_e5, \n", + " dql_3_e5,\n", + " tanh_node_e6,\n", + " ql_1_e6, \n", + " dql_1_e6,\n", + " el_mul_1_e6,\n", + " ql_2_e6,\n", + " dql_2_e6, \n", + " id_1_e6\n", + " ],\n", + " name = \"qcdq-lsmt-body\",\n", + " inputs=[hidden_state,cell_state,inputs], #The order in which the inputs are defined here should match the input order when the scan node is defined.\n", + " outputs = [out_hidden_state, out_cell_state, out_hidden_state_concat],\n", + " value_info=[\n", + " make_tensor_value_info(\"ql_input_out\",onnx.TensorProto.INT8, [num_features,1]),\n", + " make_tensor_value_info(\"dql_input_out\",onnx.TensorProto.FLOAT, [num_features,1]),\n", + " make_tensor_value_info(\"out_m1_e1\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_m2_e1\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_add1_e1\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"f_t_ba\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"f_t_ql1\",onnx.TensorProto.INT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"f_t_dql1\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"f_t_ql2\",onnx.TensorProto.UINT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"f_t_dql2\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_m1_e2\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_m2_e2\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_add1_e2\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"i_t_ba\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"i_t_ql1\",onnx.TensorProto.INT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"i_t_dql1\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"i_t_ql2\",onnx.TensorProto.UINT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"i_t_dql2\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_m1_e3\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_m2_e3\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_add1_e3\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"o_t_ba\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"o_t_ql1\",onnx.TensorProto.INT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"o_t_dql1\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"o_t_ql2\",onnx.TensorProto.UINT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"o_t_dql2\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_m1_e4\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_m2_e4\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_add1_e4\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"c_t_ba\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"c_t_ql1\",onnx.TensorProto.INT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"c_t_dql1\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"c_t_ql2\",onnx.TensorProto.INT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"c_t_dql2\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"f_t\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"i_t\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"o_t\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"c_t_partial\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_el_mul1_e5\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_el_mul2_e5\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"fifth_ql1\",onnx.TensorProto.INT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"fifth_dql1\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"fifth_ql2\",onnx.TensorProto.INT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"fifth_dql2\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"h_t_ql\",onnx.TensorProto.INT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"h_t_dql\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_tanh_e6\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"sixth_ql1\",onnx.TensorProto.INT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"sixth_dql1\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"sixth_ql2\",onnx.TensorProto.INT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"h_t_inter\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"ql_wf_out\", onnx.TensorProto.INT8, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"dql_wf_out\",onnx.TensorProto.FLOAT, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"ql_wi_out\", onnx.TensorProto.INT8, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"dql_wi_out\",onnx.TensorProto.FLOAT, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"ql_wc_out\", onnx.TensorProto.INT8, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"dql_wc_out\",onnx.TensorProto.FLOAT, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"ql_wo_out\", onnx.TensorProto.INT8, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"dql_wo_out\",onnx.TensorProto.FLOAT, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"ql_uf_out\",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),\n", + " make_tensor_value_info(\"dql_uf_out\",onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells]),\n", + " make_tensor_value_info(\"ql_ui_out\",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),\n", + " make_tensor_value_info(\"dql_ui_out\",onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells]),\n", + " make_tensor_value_info(\"ql_uc_out\",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),\n", + " make_tensor_value_info(\"dql_uc_out\",onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells]),\n", + " make_tensor_value_info(\"ql_uo_out\",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),\n", + " make_tensor_value_info(\"dql_uo_out\",onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells]),\n", + " make_tensor_value_info(\"clp_wf\",onnx.TensorProto.INT8, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"clp_wi\",onnx.TensorProto.INT8, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"clp_wc\",onnx.TensorProto.INT8, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"clp_wo\",onnx.TensorProto.INT8, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"clp_uf\",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]), \n", + " make_tensor_value_info(\"clp_ui\",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),\n", + " make_tensor_value_info(\"clp_uc\",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),\n", + " make_tensor_value_info(\"clp_uo\",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),\n", + " ],\n", + " initializer=[\n", + " # Initializing the weight and recurrecne matrices\n", + " make_tensor('W_f',onnx.TensorProto.FLOAT, [num_hidden_cells,num_features], (Wf_val)),\n", + " make_tensor('U_f',onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells], (Uf_val)),\n", + " make_tensor('b_f',onnx.TensorProto.FLOAT, [num_hidden_cells,1], (bf_val)),\n", + " make_tensor('W_i',onnx.TensorProto.FLOAT, [num_hidden_cells,num_features], (Wi_val)),\n", + " make_tensor('U_i',onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells], (Ui_val)),\n", + " make_tensor('b_i',onnx.TensorProto.FLOAT, [num_hidden_cells,1], (bi_val)),\n", + " make_tensor('W_o',onnx.TensorProto.FLOAT, [num_hidden_cells,num_features], (Wo_val)),\n", + " make_tensor('U_o',onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells], (Uo_val)),\n", + " make_tensor('b_o',onnx.TensorProto.FLOAT, [num_hidden_cells,1], (bo_val)),\n", + " make_tensor('W_c',onnx.TensorProto.FLOAT, [num_hidden_cells,num_features], (Wc_val)),\n", + " make_tensor('U_c',onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells], (Uc_val)),\n", + " make_tensor('b_c',onnx.TensorProto.FLOAT, [num_hidden_cells,1], (bc_val)),\n", + " # Input scale value\n", + " make_tensor('inp_scale',onnx.TensorProto.FLOAT, [],[inp_scale_val]),\n", + " # Scale weight values\n", + " make_tensor('scale_i',onnx.TensorProto.FLOAT, [],[w1_scale_val]),\n", + " make_tensor('scale_c',onnx.TensorProto.FLOAT, [],[w2_scale_val]),\n", + " make_tensor('scale_o',onnx.TensorProto.FLOAT, [],[w3_scale_val]),\n", + " make_tensor('scale_f',onnx.TensorProto.FLOAT, [],[w4_scale_val]),\n", + " # Scale values for the six equations\n", + " make_tensor('scale_1',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),\n", + " make_tensor('scale_2',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]), \n", + " make_tensor('scale_3',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),\n", + " make_tensor('scale_test',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),\n", + " make_tensor('scale_4',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),\n", + " make_tensor('scale_5',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),\n", + " make_tensor('scale_6',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),\n", + " make_tensor('scale_7',onnx.TensorProto.FLOAT, [],[eq_scale_val_2]), \n", + " make_tensor('scale_8',onnx.TensorProto.FLOAT, [],[eq_scale_val_2]),\n", + " make_tensor('scale_9',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),\n", + " make_tensor('scale_10',onnx.TensorProto.FLOAT, [],[eq_scale_val_2]),\n", + " make_tensor('scale_11',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),\n", + " # Scales for zero-points : Zero-point datatype defines the dataype of the output for that quantization\n", + " make_tensor('zero_point_signed',onnx.TensorProto.INT8,[],[zero_point_signed_val]),\n", + " make_tensor('zero_point_unsigned',onnx.TensorProto.UINT8,[],[zero_point_unsigned_val]),\n", + " # Introducing scalars for the clip operators.\n", + " make_tensor('min', onnx.TensorProto.INT8, [], [min_clip_val]),\n", + " make_tensor('max', onnx.TensorProto.INT8, [], [max_clip_val]),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b1b16751", + "metadata": {}, + "source": [ + "The above created graph can now be converted into a qonnx model with the `qonnx_make_model` utility. We save the model with `onnx.save` utility and then view it in Netron with the help of `showInNetron` utility. \n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c6ec7b2a-456d-4452-97ec-df9a471d5391", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Serving './lstm_full_graph.onnx' at http://localhost:8080\n" + ] + }, + { + "data": { + "text/plain": [ + "('localhost', 8080)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "lstm_model = qonnx_make_model(lstm_body, producer_name=\"QuantizeLSTM_scan\")\n", + "onnx.save(lstm_model, './lstm_full_graph.onnx')\n", + "netron.start('./lstm_full_graph.onnx')" + ] + }, + { + "cell_type": "markdown", + "id": "40b49257", + "metadata": {}, + "source": [ + "In this block of code we execute the onnx graph to check that it can execute without any errors. We perform it's functional verification in the later part of the notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "db5892bc-ac8d-4972-afcf-20bf880f5e86", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[array([[ 0.1484375],\n", + " [-0.0078125],\n", + " [ 0.0390625],\n", + " [ 0.140625 ],\n", + " [ 0.015625 ],\n", + " [ 0. ],\n", + " [ 0.1015625],\n", + " [-0.1015625],\n", + " [ 0.0390625],\n", + " [-0.0625 ],\n", + " [ 0.015625 ],\n", + " [-0.125 ],\n", + " [ 0.1015625],\n", + " [ 0.03125 ],\n", + " [ 0.1640625],\n", + " [-0.015625 ],\n", + " [-0.0234375],\n", + " [-0.015625 ],\n", + " [-0.046875 ],\n", + " [ 0.0078125]], dtype=float32), array([[ 0.2421875],\n", + " [-0.0078125],\n", + " [ 0.0625 ],\n", + " [ 0.2421875],\n", + " [ 0.03125 ],\n", + " [ 0.0078125],\n", + " [ 0.2265625],\n", + " [-0.234375 ],\n", + " [ 0.0859375],\n", + " [-0.1328125],\n", + " [ 0.0390625],\n", + " [-0.2421875],\n", + " [ 0.1875 ],\n", + " [ 0.0546875],\n", + " [ 0.296875 ],\n", + " [-0.03125 ],\n", + " [-0.0546875],\n", + " [-0.03125 ],\n", + " [-0.109375 ],\n", + " [ 0.0234375]], dtype=float32), array([[ 0.1484375],\n", + " [-0.0078125],\n", + " [ 0.0390625],\n", + " [ 0.140625 ],\n", + " [ 0.015625 ],\n", + " [ 0. ],\n", + " [ 0.1015625],\n", + " [-0.1015625],\n", + " [ 0.0390625],\n", + " [-0.0625 ],\n", + " [ 0.015625 ],\n", + " [-0.125 ],\n", + " [ 0.1015625],\n", + " [ 0.03125 ],\n", + " [ 0.1640625],\n", + " [-0.015625 ],\n", + " [-0.0234375],\n", + " [-0.015625 ],\n", + " [-0.046875 ],\n", + " [ 0.0078125]], dtype=float32)]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-10-20 11:07:46.350885612 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'scale_test'. It is not used by any node and should be removed from the model.\n", + "2023-10-20 11:07:46.370978980 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'scale_test'. It is not used by any node and should be removed from the model.\n" + ] + } + ], + "source": [ + "# Before the model can be executed, it'd opset version needs to be set to a minimum of '14' to accomodate clip nodes with INT8 and UINT8 input. Otherwise ONNX cannot create an execution session and we get errors.\n", + "lstm_model.opset_import[0].version = 14\n", + "\n", + "# Creating the inference session here for the updated model here\n", + "sess = rt.InferenceSession(lstm_model.SerializeToString())\n", + "\n", + "# Defining dummy inputs and the model parameters for dummy execution\n", + "X_inp = np.empty([num_features,1],dtype=np.float32).reshape([num_features,1])\n", + "X_inp.fill(0.8)\n", + "hidden_state_input = np.zeros((num_hidden_cells, 1)).astype(np.float32)\n", + "cell_state_input = np.zeros((num_hidden_cells, 1)).astype(np.float32)\n", + "\n", + "# Assigning the above defined values to the input dictionary of the ONNX model.\n", + "input_dict = {}\n", + "input_dict[\"inp\"] = X_inp\n", + "input_dict[\"h_t-1\"] = hidden_state_input\n", + "input_dict[\"c_t-1\"] = cell_state_input \n", + "\n", + "# Setting up the inference session and executing the onnx model here.\n", + "sess = rt.InferenceSession(lstm_model.SerializeToString())\n", + "output = sess.run(None, input_dict)\n", + "print(output)" + ] + }, + { + "cell_type": "markdown", + "id": "5d2b5a1e-654e-46a5-9d4f-8708611a6d1e", + "metadata": {}, + "source": [ + "# SCAN Operation Integration" + ] + }, + { + "cell_type": "markdown", + "id": "7365329a-f3d2-4f74-8e2f-9076771e07a7", + "metadata": {}, + "source": [ + "### Introduction to ONNX Scan operation\n", + "Observations regarding the `Scan` operator in ONNX:\n", + "\n", + "1. `Scan` can be used to iterate over one or more scan input tensors constructing zero or more scan output tensors. It combines ideas from general recurrences, functional programming cnostructs such as scan, fold, map and zip.\n", + "2. The attribute `body` in the node must be a graph specifying the computation to be performed in every iteration.\n", + "3. Input is the current values of the `state variables` and the current `iterated element` of the scan input. Returns values of the `state variables` and the `scan output element tensors`. (Can be greater than 1)\n", + "4. The values of the scan output tensors are concatenated over all the iterations to produce the scan output values of the scan construct.\n", + "5. The properties that make a scan node unique and different from a normal compute node are:\n", + "* Allows update of state variable after each input computation; to be used in the processing of the next input.\n", + "* It needs to scan your inputs row by row or column by column; then keep computing the output with the updated hidden state for every input; while storing all the intermediate outputs in the form of hidden states.\n", + "\n", + "More information regarding this op can be found in these links:\n", + "\n", + "* https://github.com/onnx/onnx/blob/main/docs/Operators.md#Scan\n", + "* https://onnx.ai/onnx/intro/python.html#scan" + ] + }, + { + "cell_type": "markdown", + "id": "17f247f7", + "metadata": {}, + "source": [ + "The `Scan` operation is essentially a container operator which will consume the LSTM graph that we created above in it's body.\n", + "To create it, we need to define separate input and output value info tensors just for the Scan operator. We will then follow the same steps as the `QCDQ-LSTM` graph creation to convert the above graph into an executable ONNX model.\n", + "

\n", + "We start by defining the input and output value info tensors for the `scan_graph` creation. These tensors act as the wrapper to the previously defined graph.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "700a93a8-f757-4fa1-88dd-47a3f2a7f171", + "metadata": {}, + "outputs": [], + "source": [ + "# Inputs\n", + "scan_input = make_tensor_value_info(\"scan_input\",onnx.TensorProto.FLOAT, [None,num_features,1])#X ; scan input. Here None defines the varibale number of inputs that can be supplied for input processing.\n", + "scan_hidden_state = make_tensor_value_info(\"scan_hidden_state\",onnx.TensorProto.FLOAT, [num_hidden_cells,1])# h_t-1\n", + "scan_cell_state = make_tensor_value_info(\"scan_cell_state\",onnx.TensorProto.FLOAT, [num_hidden_cells,1])# c_t-1\n", + "\n", + "# Outputs\n", + "scan_out_hidden_state = make_tensor_value_info(\"scan_out_hidden_state\", onnx.TensorProto.FLOAT, [num_hidden_cells,1])#h_t\n", + "scan_out_cell_state = make_tensor_value_info(\"scan_out_cell_state\", onnx.TensorProto.FLOAT, [num_hidden_cells,1])#c_t\n", + "scan_out_hidden_state_concat = make_tensor_value_info(\"scan_out_hidden_state_concat\", onnx.TensorProto.FLOAT, [None,num_hidden_cells,1])" + ] + }, + { + "cell_type": "markdown", + "id": "572f191e", + "metadata": {}, + "source": [ + "We will now create the scan operator here now utilizing the `make_node` utility from ONNX.\n", + "Note, in the body of the operation we have included the `lstm_body` graph we created in the above steps." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "111fdce4-464f-40c1-ac4d-3022b05f153e", + "metadata": {}, + "outputs": [], + "source": [ + "scan_node_lstm = make_node(\n", + " \"Scan\", \n", + " inputs=[\"scan_hidden_state\",\"scan_cell_state\",\"scan_input\"], \n", + " outputs=[\"scan_out_hidden_state\",\"scan_out_cell_state\",\"scan_out_hidden_state_concat\"], \n", + " num_scan_inputs=1,\n", + " body=lstm_body, domain=''\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "ea8a05d9", + "metadata": {}, + "source": [ + "We can now define the graph for the scan operator utilizing the `make_graph` utility." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "4668cf2b-524e-4768-8dc8-9d619f6273da", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Serving './lstm_scan_node_model.onnx' at http://localhost:8081\n", + "[]\n" + ] + } + ], + "source": [ + "scan_lstm_node_graph = make_graph(\n", + " nodes = [scan_node_lstm],\n", + " name=\"lstm-scan-node\",\n", + " inputs=[scan_hidden_state,scan_cell_state,scan_input],#h_t-1, c_t-1, X\n", + " outputs=[scan_out_hidden_state,scan_out_cell_state,scan_out_hidden_state_concat]#h_t,c_t,h_t_concat\n", + ")\n", + "\n", + "# Creating the model from the above created graph and saving it.\n", + "lstm_scan_node_model = qonnx_make_model(scan_lstm_node_graph, producer_name=\"scan-lstm\")\n", + "onnx.save(lstm_scan_node_model, './lstm_scan_node_model.onnx')\n", + "netron.start('./lstm_scan_node_model.onnx')\n", + "\n", + "#Checking the model for any errors\n", + "onnx.checker.check_model(lstm_scan_node_model)\n", + "print(lstm_scan_node_model.graph.value_info)\n", + "\n", + "#Conversion to version 14 of onnx to accomodate clip nodes as done for the LSTM graph also.\n", + "lstm_scan_node_model.opset_import[0].version = 14" + ] + }, + { + "cell_type": "markdown", + "id": "0673e335", + "metadata": {}, + "source": [ + "Now that we have the SCAN based quantized LSTM model ready, we can now go forward and test it with the same sets of inputs we used for the testing of the brevitas model.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "818d2a81-686f-4a4a-8e78-17dbf75d8451", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Final Hidden State [[ 0.25 ]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.0546875]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.140625 ]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "------------------------\n", + "Final Cell State [[ 0.421875 ]\n", + " [-0.078125 ]\n", + " [ 0.0234375]\n", + " [ 0.4921875]\n", + " [ 0.1484375]\n", + " [-0.09375 ]\n", + " [ 0.75 ]\n", + " [-0.59375 ]\n", + " [ 0.1171875]\n", + " [-0.3125 ]\n", + " [ 0.0390625]\n", + " [-0.421875 ]\n", + " [ 0.3984375]\n", + " [ 0.2578125]\n", + " [ 0.828125 ]\n", + " [ 0.0625 ]\n", + " [-0.0703125]\n", + " [-0.109375 ]\n", + " [-0.1484375]\n", + " [ 0.0234375]]\n", + "------------------------\n", + "All Hidden States [[[ 0.1484375]\n", + " [-0.0078125]\n", + " [ 0.0390625]\n", + " [ 0.140625 ]\n", + " [ 0.015625 ]\n", + " [ 0. ]\n", + " [ 0.1015625]\n", + " [-0.1015625]\n", + " [ 0.0390625]\n", + " [-0.0625 ]\n", + " [ 0.015625 ]\n", + " [-0.125 ]\n", + " [ 0.1015625]\n", + " [ 0.03125 ]\n", + " [ 0.1640625]\n", + " [-0.015625 ]\n", + " [-0.0234375]\n", + " [-0.015625 ]\n", + " [-0.046875 ]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.203125 ]\n", + " [-0.0234375]\n", + " [ 0.03125 ]\n", + " [ 0.2109375]\n", + " [ 0.0234375]\n", + " [-0.015625 ]\n", + " [ 0.1875 ]\n", + " [-0.1484375]\n", + " [ 0.046875 ]\n", + " [-0.09375 ]\n", + " [ 0.0234375]\n", + " [-0.1640625]\n", + " [ 0.1484375]\n", + " [ 0.0703125]\n", + " [ 0.2578125]\n", + " [-0.015625 ]\n", + " [-0.03125 ]\n", + " [-0.0234375]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.2265625]\n", + " [-0.03125 ]\n", + " [ 0.015625 ]\n", + " [ 0.2421875]\n", + " [ 0.03125 ]\n", + " [-0.0234375]\n", + " [ 0.234375 ]\n", + " [-0.1796875]\n", + " [ 0.0546875]\n", + " [-0.109375 ]\n", + " [ 0.0234375]\n", + " [-0.1875 ]\n", + " [ 0.1796875]\n", + " [ 0.09375 ]\n", + " [ 0.2734375]\n", + " [ 0. ]\n", + " [-0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.0703125]\n", + " [ 0.015625 ]]\n", + "\n", + " [[ 0.234375 ]\n", + " [-0.0390625]\n", + " [ 0.015625 ]\n", + " [ 0.2578125]\n", + " [ 0.0390625]\n", + " [-0.03125 ]\n", + " [ 0.25 ]\n", + " [-0.1875 ]\n", + " [ 0.0546875]\n", + " [-0.125 ]\n", + " [ 0.015625 ]\n", + " [-0.1953125]\n", + " [ 0.1953125]\n", + " [ 0.1171875]\n", + " [ 0.2734375]\n", + " [ 0.015625 ]\n", + " [-0.03125 ]\n", + " [-0.0390625]\n", + " [-0.078125 ]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.2421875]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.0390625]\n", + " [-0.03125 ]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.1328125]\n", + " [ 0.015625 ]\n", + " [-0.1953125]\n", + " [ 0.203125 ]\n", + " [ 0.1328125]\n", + " [ 0.2734375]\n", + " [ 0.0234375]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.078125 ]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.2421875]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.046875 ]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.1328125]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.2421875]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.046875 ]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.140625 ]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.2421875]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.0546875]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.140625 ]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.2421875]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.0546875]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.140625 ]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.25 ]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.0546875]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.140625 ]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.25 ]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.0546875]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.140625 ]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.25 ]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.0546875]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.140625 ]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.25 ]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.0546875]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.140625 ]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.25 ]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.0546875]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.140625 ]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.25 ]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.0546875]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.140625 ]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.25 ]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.0546875]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.140625 ]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.25 ]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.0546875]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.140625 ]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.25 ]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.0546875]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.140625 ]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.25 ]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.0546875]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.140625 ]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.25 ]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.0546875]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.140625 ]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.25 ]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.0546875]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.140625 ]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.25 ]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.0546875]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.140625 ]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.25 ]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.0546875]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.140625 ]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.25 ]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.0546875]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.140625 ]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]\n", + "\n", + " [[ 0.25 ]\n", + " [-0.046875 ]\n", + " [ 0.015625 ]\n", + " [ 0.2734375]\n", + " [ 0.0546875]\n", + " [-0.0390625]\n", + " [ 0.25 ]\n", + " [-0.1953125]\n", + " [ 0.0546875]\n", + " [-0.140625 ]\n", + " [ 0.015625 ]\n", + " [-0.203125 ]\n", + " [ 0.203125 ]\n", + " [ 0.140625 ]\n", + " [ 0.2734375]\n", + " [ 0.03125 ]\n", + " [-0.03125 ]\n", + " [-0.046875 ]\n", + " [-0.0703125]\n", + " [ 0.0078125]]]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-10-20 10:50:38.892379706 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'scale_test'. It is not used by any node and should be removed from the model.\n", + "2023-10-20 10:50:38.894726380 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_uo_out'. It is not used by any node and should be removed from the model.\n", + "2023-10-20 10:50:38.894741924 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_wf_out'. It is not used by any node and should be removed from the model.\n", + "2023-10-20 10:50:38.894750521 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_ui_out'. It is not used by any node and should be removed from the model.\n", + "2023-10-20 10:50:38.894758793 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'max'. It is not used by any node and should be removed from the model.\n", + "2023-10-20 10:50:38.894767212 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'W_c'. It is not used by any node and should be removed from the model.\n", + "2023-10-20 10:50:38.894775093 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'U_c'. It is not used by any node and should be removed from the model.\n", + "2023-10-20 10:50:38.894782542 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'U_i'. It is not used by any node and should be removed from the model.\n", + "2023-10-20 10:50:38.894790413 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_uc_out'. It is not used by any node and should be removed from the model.\n", + "2023-10-20 10:50:38.894797986 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'W_i'. It is not used by any node and should be removed from the model.\n", + "2023-10-20 10:50:38.894805922 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_wi_out'. It is not used by any node and should be removed from the model.\n", + "2023-10-20 10:50:38.894813725 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'U_o'. It is not used by any node and should be removed from the model.\n", + "2023-10-20 10:50:38.894821378 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'W_f'. It is not used by any node and should be removed from the model.\n", + "2023-10-20 10:50:38.894829187 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'W_o'. It is not used by any node and should be removed from the model.\n", + "2023-10-20 10:50:38.894837744 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_uf_out'. It is not used by any node and should be removed from the model.\n", + "2023-10-20 10:50:38.894845343 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_wc_out'. It is not used by any node and should be removed from the model.\n", + "2023-10-20 10:50:38.894852862 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'U_f'. It is not used by any node and should be removed from the model.\n", + "2023-10-20 10:50:38.894861070 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_wo_out'. It is not used by any node and should be removed from the model.\n", + "2023-10-20 10:50:38.894868719 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'min'. It is not used by any node and should be removed from the model.\n" + ] + } + ], + "source": [ + "# Defining the values of the varibales to test the execution of the scan model\n", + "num_inputs = 25\n", + "\n", + "#Initializing the initial values of the hidden state and the cell state. \n", + "# Also assigning the same input as the one used for the brevitas execution.\n", + "\n", + "hidden_state_inp = np.zeros((num_hidden_cells, 1)).astype(np.float32)#'h_t-1'\n", + "cell_state_inp = np.zeros((num_hidden_cells, 1)).astype(np.float32)#'c_t-1'\n", + "scan_inp = np.empty([num_inputs,num_features,1],dtype=np.float32).reshape([num_inputs,num_features,1])\n", + "scan_inp.fill(0.8)\n", + "\n", + "# Assigning the defined input values to the input dictionary of the scan model\n", + "input_dict = {}\n", + "input_dict[\"scan_hidden_state\"] = hidden_state_inp\n", + "input_dict[\"scan_cell_state\"] = cell_state_inp\n", + "input_dict[\"scan_input\"] = scan_inp\n", + "\n", + "# We can now set up the inference session and execute the scan onnx model here. \n", + "# The execution session gives some warnings which can be ignored.\n", + "\n", + "sess = rt.InferenceSession(lstm_scan_node_model.SerializeToString())\n", + "scan_output = sess.run(None, input_dict)\n", + "print('Final Hidden State',scan_output[0])\n", + "print(\"------------------------\")\n", + "print('Final Cell State',scan_output[1])\n", + "print(\"------------------------\")\n", + "print('All Hidden States',scan_output[2])" + ] + }, + { + "cell_type": "markdown", + "id": "907d2ff9-f605-4aec-891e-0c77a1a92346", + "metadata": {}, + "source": [ + "# Functional Verification" + ] + }, + { + "cell_type": "markdown", + "id": "b6bb6c60", + "metadata": {}, + "source": [ + "In the final part of the notebook, we compare the output of the 8-bit quantized `(QCDQ)-LSTM` implementation with the `QuantLSTM` brevitas model.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "2fe07395-6cf9-4c99-a0d3-a27aa6a326b5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Brevitas Output shape : (25, 1, 20)\n", + "SCAN-QCDQ-LSTM output shape : (25, 1, 20)\n", + "-----------------------------------\n", + "Brevitas Output = [[[ 0.1484375 -0.0078125 0.0390625 0.140625 0.0078125 0.\n", + " 0.109375 -0.09375 0.0390625 -0.0625 0.015625 -0.1171875\n", + " 0.1015625 0.03125 0.1640625 -0.015625 -0.0234375 -0.015625\n", + " -0.046875 0.0078125]]\n", + "\n", + " [[ 0.2109375 -0.0234375 0.03125 0.2109375 0.0234375 -0.015625\n", + " 0.1875 -0.1484375 0.046875 -0.09375 0.0234375 -0.1640625\n", + " 0.1484375 0.0625 0.2578125 -0.015625 -0.03125 -0.0234375\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2421875 -0.0390625 0.015625 0.25 0.03125 -0.0234375\n", + " 0.234375 -0.1796875 0.0546875 -0.109375 0.015625 -0.1875\n", + " 0.1796875 0.09375 0.3125 0. -0.03125 -0.03125\n", + " -0.078125 0.0078125]]\n", + "\n", + " [[ 0.25 -0.0390625 0.015625 0.265625 0.0390625 -0.03125\n", + " 0.265625 -0.1875 0.0546875 -0.125 0.015625 -0.1953125\n", + " 0.1953125 0.1171875 0.3359375 0.015625 -0.03125 -0.0390625\n", + " -0.078125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.015625 0.2734375 0.046875 -0.0390625\n", + " 0.2890625 -0.1953125 0.0546875 -0.125 0.015625 -0.203125\n", + " 0.203125 0.125 0.359375 0.0234375 -0.03125 -0.046875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.2734375 0.046875 -0.0390625\n", + " 0.296875 -0.1953125 0.0546875 -0.1328125 0.015625 -0.203125\n", + " 0.2109375 0.1328125 0.3671875 0.03125 -0.0234375 -0.046875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.015625 0.28125 0.0546875 -0.046875\n", + " 0.3046875 -0.1953125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.2109375 0.140625 0.375 0.0390625 -0.0234375 -0.046875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.0546875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.2109375 0.140625 0.3828125 0.0390625 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.2109375 0.1484375 0.390625 0.0390625 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.0390625 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", + " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", + " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", + " -0.0703125 0.015625 ]]]\n", + "-----------------------------------\n", + "SCAN-QCDQ-LSTM output [[[ 0.1484375 -0.0078125 0.0390625 0.140625 0.015625 0.\n", + " 0.1015625 -0.1015625 0.0390625 -0.0625 0.015625 -0.125\n", + " 0.1015625 0.03125 0.1640625 -0.015625 -0.0234375 -0.015625\n", + " -0.046875 0.0078125]]\n", + "\n", + " [[ 0.203125 -0.0234375 0.03125 0.2109375 0.0234375 -0.015625\n", + " 0.1875 -0.1484375 0.046875 -0.09375 0.0234375 -0.1640625\n", + " 0.1484375 0.0703125 0.2578125 -0.015625 -0.03125 -0.0234375\n", + " -0.0703125 0.0078125]]\n", + "\n", + " [[ 0.2265625 -0.03125 0.015625 0.2421875 0.03125 -0.0234375\n", + " 0.234375 -0.1796875 0.0546875 -0.109375 0.0234375 -0.1875\n", + " 0.1796875 0.09375 0.2734375 0. -0.03125 -0.03125\n", + " -0.0703125 0.015625 ]]\n", + "\n", + " [[ 0.234375 -0.0390625 0.015625 0.2578125 0.0390625 -0.03125\n", + " 0.25 -0.1875 0.0546875 -0.125 0.015625 -0.1953125\n", + " 0.1953125 0.1171875 0.2734375 0.015625 -0.03125 -0.0390625\n", + " -0.078125 0.0078125]]\n", + "\n", + " [[ 0.2421875 -0.046875 0.015625 0.2734375 0.0390625 -0.03125\n", + " 0.25 -0.1953125 0.0546875 -0.1328125 0.015625 -0.1953125\n", + " 0.203125 0.1328125 0.2734375 0.0234375 -0.03125 -0.046875\n", + " -0.078125 0.0078125]]\n", + "\n", + " [[ 0.2421875 -0.046875 0.015625 0.2734375 0.046875 -0.0390625\n", + " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", + " 0.203125 0.1328125 0.2734375 0.03125 -0.03125 -0.046875\n", + " -0.0703125 0.0078125]]\n", + "\n", + " [[ 0.2421875 -0.046875 0.015625 0.2734375 0.046875 -0.0390625\n", + " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", + " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", + " -0.0703125 0.0078125]]\n", + "\n", + " [[ 0.2421875 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", + " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", + " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", + " -0.0703125 0.0078125]]\n", + "\n", + " [[ 0.2421875 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", + " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", + " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", + " -0.0703125 0.0078125]]\n", + "\n", + " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", + " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", + " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", + " -0.0703125 0.0078125]]\n", + "\n", + " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", + " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", + " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", + " -0.0703125 0.0078125]]\n", + "\n", + " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", + " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", + " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", + " -0.0703125 0.0078125]]\n", + "\n", + " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", + " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", + " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", + " -0.0703125 0.0078125]]\n", + "\n", + " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", + " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", + " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", + " -0.0703125 0.0078125]]\n", + "\n", + " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", + " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", + " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", + " -0.0703125 0.0078125]]\n", + "\n", + " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", + " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", + " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", + " -0.0703125 0.0078125]]\n", + "\n", + " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", + " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", + " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", + " -0.0703125 0.0078125]]\n", + "\n", + " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", + " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", + " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", + " -0.0703125 0.0078125]]\n", + "\n", + " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", + " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", + " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", + " -0.0703125 0.0078125]]\n", + "\n", + " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", + " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", + " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", + " -0.0703125 0.0078125]]\n", + "\n", + " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", + " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", + " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", + " -0.0703125 0.0078125]]\n", + "\n", + " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", + " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", + " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", + " -0.0703125 0.0078125]]\n", + "\n", + " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", + " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", + " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", + " -0.0703125 0.0078125]]\n", + "\n", + " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", + " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", + " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", + " -0.0703125 0.0078125]]\n", + "\n", + " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", + " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", + " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", + " -0.0703125 0.0078125]]]\n", + "-----------------------------------\n", + "[[[ 0. 0. 0. 0. 1. 0. -1. -1. 0. 0. 0. -1. 0. 0.\n", + " 0. 0. 0. 0. 0. 0.]]\n", + "\n", + " [[ -1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.\n", + " 0. 0. 0. 0. 0. -1.]]\n", + "\n", + " [[ -2. 1. 0. -1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.\n", + " -5. 0. 0. 0. 1. 1.]]\n", + "\n", + " [[ -2. 0. 0. -1. 0. 0. -2. 0. 0. 0. 0. 0. 0. 0.\n", + " -8. 0. 0. 0. 0. -1.]]\n", + "\n", + " [[ -2. 0. 0. 0. -1. 1. -5. 0. 0. -1. 0. 1. 0. 1.\n", + " -11. 0. 0. 0. -1. -1.]]\n", + "\n", + " [[ -2. 0. 1. 0. 0. 0. -6. 0. 0. -1. 0. 0. -1. 0.\n", + " -12. 0. -1. 0. 0. -1.]]\n", + "\n", + " [[ -2. 0. 0. -1. -1. 1. -7. 0. 0. 0. 1. 0. -1. 0.\n", + " -13. -1. -1. 0. 0. -1.]]\n", + "\n", + " [[ -2. 1. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -1. 0.\n", + " -14. -1. -2. 1. 0. -1.]]\n", + "\n", + " [[ -2. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -1. -1.\n", + " -15. -1. -2. 1. 0. -1.]]\n", + "\n", + " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", + " -15. -1. -2. 1. 0. -1.]]\n", + "\n", + " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", + " -15. -2. -2. 1. 0. -1.]]\n", + "\n", + " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", + " -15. -2. -2. 1. 0. -1.]]\n", + "\n", + " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", + " -15. -2. -2. 1. 0. -1.]]\n", + "\n", + " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", + " -15. -2. -2. 1. 0. -1.]]\n", + "\n", + " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", + " -15. -2. -2. 1. 0. -1.]]\n", + "\n", + " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", + " -15. -2. -2. 1. 0. -1.]]\n", + "\n", + " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", + " -15. -2. -2. 1. 0. -1.]]\n", + "\n", + " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", + " -15. -2. -2. 1. 0. -1.]]\n", + "\n", + " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", + " -15. -2. -2. 1. 0. -1.]]\n", + "\n", + " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", + " -15. -2. -2. 1. 0. -1.]]\n", + "\n", + " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", + " -15. -2. -2. 1. 0. -1.]]\n", + "\n", + " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", + " -15. -2. -2. 1. 0. -1.]]\n", + "\n", + " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", + " -15. -2. -2. 1. 0. -1.]]\n", + "\n", + " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", + " -15. -2. -2. 1. 0. -1.]]\n", + "\n", + " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", + " -15. -2. -2. 1. 0. -1.]]]\n" + ] + } + ], + "source": [ + "# We first match the shape of both the outputs to perform the functional verification correctly\n", + "\n", + "print('Brevitas Output shape : ', brevitas_output.shape)\n", + "all_hidden_states = np.array(scan_output[2])\n", + "all_hidden_states = all_hidden_states.reshape([num_inputs,1,num_hidden_cells])\n", + "print('SCAN-QCDQ-LSTM output shape :', all_hidden_states.shape)\n", + "print('-----------------------------------')\n", + "print('Brevitas Output = ',brevitas_output)\n", + "print('-----------------------------------')\n", + "print('SCAN-QCDQ-LSTM output',all_hidden_states)\n", + "print('-----------------------------------')\n", + "\n", + "# Comparison between the 'Scan-LSTM output' and the brevitas 'QuantLSTM' ouptut\n", + "# Since the outputs from both models are floating-point, to get a better understanding of the differences we scale the outputs to INT8 precision and then compare their differences.\n", + "# The scale used to do that is the last scale of the LSTM graph.\n", + "\n", + "scale = inp_scale_val #The scale value is equal to the value of the inp_scale_val\n", + "all_hidden_states = np.array(scan_output[2])\n", + "all_hidden_states = all_hidden_states.reshape([num_inputs,1,num_hidden_cells])\n", + "all_hidden_state_diff = (all_hidden_states - brevitas_output)\n", + "print(all_hidden_state_diff/scale)" + ] + }, + { + "cell_type": "markdown", + "id": "7bcca933", + "metadata": {}, + "source": [ + "Note the difference in outputs increases as we progress with processing the inputs. The first two outputs are very close to one another, but as we get the outputs for more inputs we see for some values differ from the brevitas output by a considerable amount.\n", + "This behaviour can be attributed to some values being slightly different in the first few outputs (which are not visible) which eventually cause an increase in differences between both values as more inputs are processed." + ] + }, + { + "cell_type": "markdown", + "id": "81c6d531", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From e2e07371c5e1a03de0dc0dc68e9ee55df0270a95 Mon Sep 17 00:00:00 2001 From: shashwat1198 Date: Sun, 22 Oct 2023 10:29:44 +0100 Subject: [PATCH 02/28] Clean QuantLSTM --- notebooks/4_quant_lstm.ipynb | 2040 +--------------------------------- 1 file changed, 22 insertions(+), 2018 deletions(-) diff --git a/notebooks/4_quant_lstm.ipynb b/notebooks/4_quant_lstm.ipynb index 72cac7e9..186be984 100644 --- a/notebooks/4_quant_lstm.ipynb +++ b/notebooks/4_quant_lstm.ipynb @@ -89,18 +89,10 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "84d66548-365d-46a5-9eaa-bb767085f9aa", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'\n" - ] - } - ], + "outputs": [], "source": [ "# We import the required libraries to execute different functions in the notebook.\n", "# The first four imports are required to build the QuantLSTM model in brevitas. \n", @@ -126,291 +118,10 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "23a7682c", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "quant_input_supplied to brevitas = tensor([[-1.0000, -0.5000, -1.0000, 0.5156, -1.0000, 0.9922, -0.8047, -1.0000,\n", - " 0.2188, 0.9922]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[-0.7266, -0.9531, 0.9922, 0.9922, -1.0000, 0.9922, -0.7734, -1.0000,\n", - " -0.0859, 0.6250]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[-0.6719, -1.0000, 0.0547, -0.5234, -0.0000, 0.1250, -1.0000, 0.3047,\n", - " -0.0312, -1.0000]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[-1.0000, -0.1797, 0.3516, -0.1328, -1.0000, -1.0000, 0.8750, -0.2812,\n", - " 0.4844, -0.3203]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[ 0.6719, -0.1484, 0.5078, 0.5312, -0.2969, 0.1719, -1.0000, 0.4688,\n", - " -0.2500, 0.8672]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[ 0.3125, 0.9922, 0.8281, -0.4297, -1.0000, 0.9922, -1.0000, 0.9922,\n", - " -1.0000, 0.2578]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[-0.3125, -1.0000, -0.4688, 0.2656, -1.0000, -1.0000, -1.0000, -0.7266,\n", - " 0.9922, 0.8984]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[-0.5625, 0.8359, -1.0000, 0.1875, -1.0000, -1.0000, 0.1562, 0.3438,\n", - " 0.6172, -1.0000]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[-1.0000, -0.0781, 0.3203, 0.1797, -1.0000, -0.1875, 0.9219, -0.4609,\n", - " -0.3125, 0.2031]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[ 0.8750, -1.0000, 0.6016, -1.0000, -0.7656, -0.1484, 0.9922, 0.6406,\n", - " -1.0000, 0.9922]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[-0.9922, -1.0000, 0.5078, -1.0000, -1.0000, 0.4453, -1.0000, 0.6719,\n", - " -1.0000, -1.0000]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[ 0.0703, -1.0000, -0.6797, -1.0000, -1.0000, -0.8750, -0.6797, 0.3672,\n", - " -0.5938, -0.2031]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[-0.6641, 0.9922, 0.1641, 0.9922, 0.9922, -1.0000, -1.0000, 0.9922,\n", - " 0.3438, 0.4688]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[-0.1875, 0.0000, -0.2812, -1.0000, -1.0000, -0.0391, 0.0781, 0.9922,\n", - " -0.2188, 0.9922]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[ 0.2578, 0.9922, -1.0000, 0.4297, -0.7500, 0.2891, -1.0000, -1.0000,\n", - " 0.6484, 0.3828]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[ 0.3594, -0.0000, -1.0000, 0.4688, -0.2734, -1.0000, -0.2969, 0.9922,\n", - " 0.9922, 0.9062]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[-0.0938, -1.0000, 0.1016, -0.7109, -0.3203, 0.7578, 0.9922, 0.3359,\n", - " 0.1328, 0.4062]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[ 0.4141, -0.6328, -0.7422, 0.9609, -0.9062, -0.4297, 0.7031, 0.9922,\n", - " -1.0000, -0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[ 0.3203, -1.0000, -0.7109, 0.3281, 0.6016, -0.2031, -0.6172, 0.7031,\n", - " -0.5078, -1.0000]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[-1.0000, -0.2500, -0.9766, -1.0000, 0.3984, -0.6484, -1.0000, 0.7188,\n", - " 0.9922, 0.9453]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[-0.5234, 0.9922, -0.3984, 0.1328, -0.0625, -0.8047, -0.1562, -0.1250,\n", - " -0.1172, 0.6328]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[ 0.0547, 0.0156, 0.0703, -0.8750, -1.0000, 0.5156, -0.0938, -0.2969,\n", - " -0.9922, 0.9922]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[ 0.9922, -1.0000, 0.3438, 0.9922, 0.1328, 0.2891, 0.0469, -0.3438,\n", - " -0.9531, -1.0000]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[ 0.2969, -1.0000, 0.1250, -1.0000, -0.5469, -1.0000, 0.5000, 0.7344,\n", - " -1.0000, 0.7109]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[ 0.4219, 0.4922, 0.7266, 0.0078, 0.0469, 0.9844, -0.5391, -0.0781,\n", - " 0.9922, -1.0000]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "quant_input_supplied to brevitas = tensor([[0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969, 0.7969,\n", - " 0.7969]])\n", - "----------------------------\n", - "[[[ 0.1484375 -0.0078125 0.0390625 0.140625 0.0078125 0.\n", - " 0.109375 -0.09375 0.0390625 -0.0625 0.015625 -0.1171875\n", - " 0.1015625 0.03125 0.1640625 -0.015625 -0.0234375 -0.015625\n", - " -0.046875 0.0078125]]\n", - "\n", - " [[ 0.2109375 -0.0234375 0.03125 0.2109375 0.0234375 -0.015625\n", - " 0.1875 -0.1484375 0.046875 -0.09375 0.0234375 -0.1640625\n", - " 0.1484375 0.0625 0.2578125 -0.015625 -0.03125 -0.0234375\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2421875 -0.0390625 0.015625 0.25 0.03125 -0.0234375\n", - " 0.234375 -0.1796875 0.0546875 -0.109375 0.015625 -0.1875\n", - " 0.1796875 0.09375 0.3125 0. -0.03125 -0.03125\n", - " -0.078125 0.0078125]]\n", - "\n", - " [[ 0.25 -0.0390625 0.015625 0.265625 0.0390625 -0.03125\n", - " 0.265625 -0.1875 0.0546875 -0.125 0.015625 -0.1953125\n", - " 0.1953125 0.1171875 0.3359375 0.015625 -0.03125 -0.0390625\n", - " -0.078125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.015625 0.2734375 0.046875 -0.0390625\n", - " 0.2890625 -0.1953125 0.0546875 -0.125 0.015625 -0.203125\n", - " 0.203125 0.125 0.359375 0.0234375 -0.03125 -0.046875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.2734375 0.046875 -0.0390625\n", - " 0.296875 -0.1953125 0.0546875 -0.1328125 0.015625 -0.203125\n", - " 0.2109375 0.1328125 0.3671875 0.03125 -0.0234375 -0.046875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.015625 0.28125 0.0546875 -0.046875\n", - " 0.3046875 -0.1953125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.2109375 0.140625 0.375 0.0390625 -0.0234375 -0.046875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.0546875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.2109375 0.140625 0.3828125 0.0390625 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.2109375 0.1484375 0.390625 0.0390625 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.0390625 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]]\n" - ] - } - ], + "outputs": [], "source": [ "# In this block of code we will create the QuantLSTM model using the brevitas layer\n", "torch.manual_seed(0) #Setting the manual seeds to 0 for consistency in outputs.\n", @@ -454,685 +165,10 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "0bfbf5a3-8556-4190-a28f-4fe9859c55a9", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.layers.0.0.input_gate_params.bias\n", - "(20,)\n", - "[-0.02587563 -0.18425222 -0.18189065 0.02914573 -0.21827428 0.0595416\n", - " -0.20598626 -0.15559138 -0.04639753 -0.2133838 0.18059207 0.18321364\n", - " -0.11679631 0.04684116 0.11439164 0.07105622 -0.02995344 -0.21090843\n", - " 0.1625932 -0.19612479] , 0\n", - "-------------------------\n", - "0.layers.0.0.input_gate_params.input_weight.weight\n", - "(20, 10)\n", - "[[-4.14119214e-02 1.38706667e-02 -7.36431107e-02 -8.17852393e-02\n", - " -1.93256751e-01 1.23205660e-02 -2.53894478e-02 1.94940954e-01\n", - " -7.36160800e-02 1.72829047e-01]\n", - " [ 1.05855539e-02 -1.00462548e-01 -5.31778559e-02 -2.53751595e-02\n", - " 2.31616711e-03 -3.68398018e-02 6.63604736e-02 1.84143797e-01\n", - " 3.51473056e-02 8.09932351e-02]\n", - " [ 1.38081744e-01 4.81988601e-02 1.03076197e-01 1.17293097e-01\n", - " 2.09298924e-01 -2.04075590e-01 7.65163079e-02 -1.01319486e-02\n", - " -4.01576199e-02 -8.62098187e-02]\n", - " [ 1.34432539e-01 2.04552680e-01 -1.82483241e-01 1.20810278e-01\n", - " 1.54187992e-01 3.90806384e-02 2.63404008e-03 1.72071218e-01\n", - " 6.62961556e-03 -5.57729751e-02]\n", - " [-1.65121444e-02 7.17408881e-02 5.59775345e-02 -1.20642958e-02\n", - " 7.05851838e-02 6.02219440e-02 -1.81134686e-01 5.57176135e-02\n", - " 1.36812523e-01 2.56436393e-02]\n", - " [-2.04101056e-02 1.71289816e-01 -1.95361048e-01 -1.02062307e-01\n", - " -1.01068199e-01 1.93207934e-01 -2.16277346e-01 2.21768115e-02\n", - " -2.16605455e-01 -7.35303294e-03]\n", - " [ 8.33466202e-02 -5.22914641e-02 2.17063010e-01 7.11822009e-04\n", - " -1.14001475e-01 5.76605424e-02 1.16289847e-01 -4.44249017e-04\n", - " 1.91289768e-01 -1.41524345e-01]\n", - " [ 9.54081938e-02 1.26971915e-01 1.11063533e-01 -8.20205314e-05\n", - " 6.38317242e-02 -1.75422058e-01 -1.75476715e-01 -1.38986288e-02\n", - " -2.80253254e-02 1.66033790e-01]\n", - " [ 1.62366882e-01 1.51616067e-01 -1.02419287e-01 -1.75539613e-01\n", - " -2.09742919e-01 8.09257179e-02 -2.01488122e-01 -2.23217383e-01\n", - " -1.13006435e-01 -1.88792080e-01]\n", - " [-8.81207064e-02 -1.40770882e-01 -1.14718042e-01 2.12588429e-01\n", - " -4.21379767e-02 1.85490459e-01 4.96126944e-03 -2.87544206e-02\n", - " -6.54680878e-02 -1.59840211e-01]\n", - " [-1.79656431e-01 1.54830217e-01 -6.89065754e-02 -2.18012080e-01\n", - " 2.05210581e-01 4.14780807e-03 -1.49626598e-01 -1.75766915e-01\n", - " -1.87781662e-01 -1.96070760e-01]\n", - " [ 2.02346548e-01 1.54175445e-01 1.82888191e-02 -1.90574318e-01\n", - " -5.84847443e-02 -2.10055038e-01 7.70593956e-02 -5.93719892e-02\n", - " -4.78506237e-02 -6.97683394e-02]\n", - " [ 1.04838371e-01 1.21036001e-01 4.89832126e-02 -2.80011501e-02\n", - " -2.20977236e-02 -3.90723767e-03 -1.66511953e-01 2.18188778e-01\n", - " -9.64377001e-02 1.30095944e-01]\n", - " [-1.25353500e-01 1.50110642e-03 7.65467212e-02 -2.05311388e-01\n", - " 1.02568395e-01 -1.71158642e-01 3.12034953e-02 -4.43410687e-02\n", - " 1.28176615e-01 2.17323676e-01]\n", - " [ 5.03933132e-02 -6.38488680e-03 -1.10784821e-01 8.33686888e-02\n", - " -1.07626989e-01 9.23645869e-02 -9.69173536e-02 1.51675642e-01\n", - " 1.71514452e-01 1.37112319e-01]\n", - " [ 2.23987759e-03 1.03696242e-01 -2.03757793e-01 1.81339085e-01\n", - " -5.80957830e-02 8.15173239e-02 -3.78652588e-02 -7.50842392e-02\n", - " -1.05006970e-01 1.44231498e-01]\n", - " [-1.21653110e-01 -3.94320451e-02 1.12798467e-01 2.25366149e-02\n", - " -1.88142627e-01 -2.22348958e-01 -1.08711593e-01 2.06236228e-01\n", - " -1.58990204e-01 1.23237595e-01]\n", - " [ 1.60061240e-01 -9.26844329e-02 -9.87462699e-02 -1.60870835e-01\n", - " 3.48785594e-02 -3.12594734e-02 1.08638955e-02 9.69918296e-02\n", - " 9.38790441e-02 -7.05472827e-02]\n", - " [ 1.53575651e-02 5.31169996e-02 4.75974986e-03 4.47460003e-02\n", - " -9.05808210e-02 1.83284596e-01 -2.29354147e-02 -2.86094397e-02\n", - " -2.00689927e-01 -1.62085444e-01]\n", - " [ 6.95567206e-03 -3.45815569e-02 -1.12424992e-01 1.17047116e-01\n", - " -2.00185552e-02 7.86398575e-02 1.88336477e-01 -1.02802545e-01\n", - " -1.10053055e-01 -4.49331515e-02]] , 1\n", - "-------------------------\n", - "0.layers.0.0.input_gate_params.hidden_weight.weight\n", - "(20, 20)\n", - "[[-1.89352538e-02 -1.11839756e-01 -5.36844507e-02 -6.44523604e-03\n", - " 1.00301303e-01 2.06872717e-01 1.65582791e-01 2.36654170e-02\n", - " -1.40909785e-02 5.72774969e-02 -9.12800338e-03 -2.93454379e-02\n", - " 7.68917575e-02 -1.81926534e-01 -1.90163419e-01 9.05744440e-04\n", - " -6.77747875e-02 -1.10600702e-01 -2.08165124e-01 1.49785221e-01]\n", - " [-8.90937075e-03 -1.20138384e-01 -9.10849124e-02 5.87869175e-02\n", - " -1.62167445e-01 1.43613769e-02 -2.75748386e-03 7.61744976e-02\n", - " 8.87038633e-02 -1.46100059e-01 9.65513662e-02 1.68849513e-01\n", - " 1.43956831e-02 1.13917463e-01 -8.46547335e-02 4.44148518e-02\n", - " 6.53375536e-02 -1.03280008e-01 1.38058737e-01 -2.11419612e-01]\n", - " [-8.39947835e-02 -1.31567493e-01 -1.32741287e-01 -1.35494858e-01\n", - " -2.10702628e-01 3.83746810e-02 -4.42331657e-02 -1.88279316e-01\n", - " -9.19632221e-05 -3.72487307e-02 9.22437534e-02 -1.75148100e-01\n", - " -6.29062578e-02 4.60259691e-02 9.47839618e-02 1.69158224e-02\n", - " 6.05970472e-02 2.23524958e-01 -7.74600878e-02 1.52398065e-01]\n", - " [ 1.92612275e-01 -1.97806209e-01 5.40891960e-02 1.26661941e-01\n", - " -3.48797850e-02 1.23408221e-01 7.60573195e-03 1.70228094e-01\n", - " 4.81458148e-03 -1.43158093e-01 1.69815615e-01 6.65016174e-02\n", - " 1.90237820e-01 5.55088967e-02 1.18736811e-01 1.39421389e-01\n", - " 3.76524106e-02 -5.19809462e-02 4.61825170e-02 -1.55909836e-01]\n", - " [ 7.63913197e-03 -7.18704611e-02 1.41373863e-02 -1.77042618e-01\n", - " 1.36628836e-01 -2.06302434e-01 9.57576782e-02 1.47258580e-01\n", - " -2.04934716e-01 2.02031001e-01 -1.66225716e-01 -4.39088680e-02\n", - " 1.15872569e-01 -7.09063411e-02 1.99275032e-01 -9.86447409e-02\n", - " -2.99374424e-02 -1.46168455e-01 -1.03737742e-01 2.18205780e-01]\n", - " [ 1.68166518e-01 1.64642967e-02 1.83855016e-02 -1.89751670e-01\n", - " 1.68811426e-01 -3.35250199e-02 -9.32650268e-02 -1.77951321e-01\n", - " 1.83845311e-01 1.06031545e-01 1.34684831e-01 2.31534615e-02\n", - " -1.51732951e-01 9.15970504e-02 2.57883817e-02 7.50367939e-02\n", - " -5.56799732e-02 -1.05523452e-01 1.83565930e-01 7.49567226e-02]\n", - " [-9.07528847e-02 1.99678559e-02 -4.86066155e-02 -1.91221125e-02\n", - " 1.25389591e-01 -1.77972749e-01 2.02371553e-01 1.50499865e-01\n", - " 1.92136504e-04 -9.14627835e-02 4.55915295e-02 -1.48007214e-01\n", - " 1.45243973e-01 -1.18256845e-01 4.27256078e-02 -2.19991282e-01\n", - " 1.07079633e-01 1.51370272e-01 1.67834863e-01 1.82519276e-02]\n", - " [ 1.32025823e-01 7.62412176e-02 1.49954304e-01 1.26183063e-01\n", - " -1.95639879e-01 2.35728398e-02 -7.62314126e-02 -1.06771380e-01\n", - " 1.56516239e-01 -3.20035741e-02 3.47357877e-02 1.40789405e-01\n", - " 1.50514722e-01 1.19332708e-01 -3.90392952e-02 -1.99321926e-01\n", - " -2.14659125e-01 7.02862144e-02 -2.65357876e-03 -1.41277447e-01]\n", - " [ 9.76564139e-02 2.02965632e-01 1.29328549e-01 -3.15438919e-02\n", - " 3.02148778e-02 -1.42630830e-01 1.05540812e-01 -1.73283800e-01\n", - " 1.54376432e-01 -1.02132224e-01 -8.86853859e-02 -1.87295631e-01\n", - " -5.40727489e-02 -2.16292981e-02 -1.03067294e-01 1.59174219e-01\n", - " 1.28328785e-01 -1.97347268e-01 -2.23675612e-02 7.51795396e-02]\n", - " [ 2.15735227e-01 -5.34672327e-02 1.37278914e-01 -1.25270970e-02\n", - " -8.57628211e-02 1.36838645e-01 -1.99253812e-01 1.87337860e-01\n", - " 2.23344907e-01 -6.10500947e-02 8.83295834e-02 2.22981662e-01\n", - " 6.74140528e-02 8.74451399e-02 8.21070075e-02 -9.14832279e-02\n", - " 5.45820408e-02 -1.19176529e-01 1.90940976e-01 -9.58186984e-02]\n", - " [ 5.11176400e-02 -6.47741258e-02 1.11825228e-01 3.68577940e-03\n", - " 1.22950912e-01 -6.05489872e-02 -1.31215081e-01 8.57292935e-02\n", - " -1.25841707e-01 -1.83588028e-01 8.63927826e-02 -1.34484172e-01\n", - " -8.40481222e-02 -5.58335669e-02 1.58777572e-02 -7.74438009e-02\n", - " -8.04765150e-02 -5.62009923e-02 1.56701818e-01 6.69540018e-02]\n", - " [-1.07652791e-01 -1.54563770e-01 5.18102152e-03 7.16358349e-02\n", - " -4.67919558e-03 1.30897254e-01 1.88077956e-01 6.55371249e-02\n", - " 7.37451240e-02 1.29728526e-01 -7.66031295e-02 3.96637134e-02\n", - " 1.80782616e-01 -1.07077263e-01 1.74031202e-02 -8.74211192e-02\n", - " -1.71936572e-01 1.18438050e-01 1.78673968e-01 -1.20800309e-01]\n", - " [ 8.38049129e-02 6.85676187e-02 8.73105526e-02 1.23087496e-01\n", - " 2.08757341e-01 1.69717655e-01 -1.95658267e-01 -8.76599625e-02\n", - " 1.18758187e-01 -1.27650708e-01 4.39067073e-02 -9.58611295e-02\n", - " 4.44106422e-02 1.09106824e-01 7.02822655e-02 1.62435979e-01\n", - " -2.69077457e-02 1.21389672e-01 7.22895712e-02 -7.04701096e-02]\n", - " [-1.57925934e-01 2.04573229e-01 -6.66687265e-02 1.68426275e-01\n", - " 1.40947536e-01 -9.00426600e-03 -1.84701070e-01 1.80013608e-02\n", - " -1.08096078e-01 5.81858531e-02 -8.88810679e-02 1.72345534e-01\n", - " -2.01746121e-01 -6.01959564e-02 3.52624580e-02 2.13314164e-02\n", - " 1.83701098e-01 -7.06517771e-02 -1.78495154e-01 1.48046315e-01]\n", - " [ 6.24824539e-02 1.47299409e-01 -1.32342920e-01 -1.31334439e-01\n", - " -9.03252959e-02 1.58978552e-02 7.57712200e-02 -1.28496692e-01\n", - " -2.10528076e-02 -3.86467576e-02 2.04027027e-01 -8.06416422e-02\n", - " 2.16690734e-01 -1.37144789e-01 -9.21397135e-02 -1.68184295e-01\n", - " 1.64731190e-01 -1.53769597e-01 9.25582647e-02 -8.21671411e-02]\n", - " [ 2.22826257e-01 3.15412283e-02 -1.94183901e-01 3.84835452e-02\n", - " 2.71859560e-02 -2.16274336e-01 4.48757894e-02 2.13342309e-01\n", - " 6.43487200e-02 -1.18915108e-03 -4.63541821e-02 5.94213046e-02\n", - " -9.96202976e-02 2.20200241e-01 1.93298727e-01 1.04461670e-01\n", - " -8.32887441e-02 -2.09956676e-01 -1.28724366e-01 2.17411697e-01]\n", - " [-2.05243871e-01 -2.13502616e-01 -1.61161683e-02 7.11405650e-02\n", - " -2.22554103e-01 -2.07601383e-01 1.21570053e-03 -7.50053376e-02\n", - " 1.55782372e-01 6.41999543e-02 -1.94095746e-01 -2.01538876e-01\n", - " 1.53562352e-01 -3.96501981e-02 -9.78184044e-02 7.04318583e-02\n", - " -4.39465865e-02 1.06939368e-01 5.67044728e-02 -9.68158469e-02]\n", - " [-1.79218486e-01 1.21047780e-01 -1.34345368e-01 -2.47318167e-02\n", - " 3.05733737e-02 -1.30131751e-01 1.21804118e-01 -1.57282248e-01\n", - " 5.49192652e-02 2.39149425e-02 8.20437744e-02 -2.19451547e-01\n", - " 1.29167549e-02 1.09009661e-01 -1.43156886e-01 5.53317666e-02\n", - " 8.76156322e-04 1.89696804e-01 -4.73480262e-02 1.52765575e-03]\n", - " [-9.72549468e-02 -5.51085509e-02 6.40134960e-02 -2.15656430e-01\n", - " 1.69629768e-01 1.60795882e-01 9.46965069e-02 1.67391464e-01\n", - " -6.96057901e-02 5.09320870e-02 1.13759311e-02 -1.54622883e-01\n", - " -8.59646648e-02 -7.93827102e-02 -5.52875437e-02 -1.98549107e-01\n", - " -1.57260388e-01 -2.12343093e-02 -3.40157561e-02 -2.02978238e-01]\n", - " [ 4.77774814e-02 1.21752672e-01 1.86222807e-01 1.88188314e-01\n", - " -1.56248853e-01 -7.16619864e-02 -1.06078379e-01 4.10118401e-02\n", - " 5.99195063e-02 4.97494638e-02 1.30669191e-01 1.17969945e-01\n", - " -1.20020248e-01 1.53502032e-01 1.50838137e-01 2.95910202e-02\n", - " -1.94543302e-01 -1.37143746e-01 6.23138808e-02 7.73103088e-02]] , 2\n", - "-------------------------\n", - "0.layers.0.0.forget_gate_params.bias\n", - "(20,)\n", - "[ 0.20850217 0.11380532 0.08104482 -0.00762655 0.15247074 -0.08138975\n", - " 0.0910454 -0.10650107 -0.00208706 0.13215044 0.10260209 -0.05017841\n", - " -0.00283135 -0.12413156 0.10357434 0.15046087 0.07697045 -0.21637587\n", - " -0.16006967 0.14969489] , 3\n", - "-------------------------\n", - "0.layers.0.0.forget_gate_params.input_weight.weight\n", - "(20, 10)\n", - "[[-0.03201701 0.13732338 0.16482215 -0.06550063 -0.13119501 -0.2103679\n", - " 0.08553377 0.11468438 -0.0387658 -0.21708311]\n", - " [-0.14402747 -0.01204806 0.10205487 -0.07492673 -0.14435105 -0.15566948\n", - " 0.2000676 0.08097311 -0.1815501 -0.13809344]\n", - " [-0.18981868 0.03235186 -0.09079897 -0.00075695 -0.0353742 -0.1957324\n", - " -0.19982079 -0.17343585 -0.09364887 0.03477862]\n", - " [-0.10515709 -0.00797041 -0.02678433 0.20449734 -0.10193561 0.21008612\n", - " -0.17165995 -0.18656294 0.07271551 -0.13013807]\n", - " [ 0.11469334 -0.12370986 0.17608246 0.21651667 0.01431521 0.04778921\n", - " 0.20847315 0.13255776 -0.19520605 -0.00715788]\n", - " [-0.20184483 0.17081025 -0.04095714 -0.00155866 -0.13738167 -0.12158713\n", - " 0.02901981 0.18449156 -0.1123966 0.02112942]\n", - " [ 0.20241037 0.20039941 -0.04371644 0.20957804 0.08143061 0.20365277\n", - " 0.00663433 -0.1895056 -0.06086665 0.06706649]\n", - " [ 0.1192437 -0.22275887 0.17393245 -0.20059223 0.13101582 0.22062524\n", - " 0.05510434 -0.0422016 0.12311912 -0.06636703]\n", - " [-0.16563286 -0.15869099 0.10513588 0.1707739 0.00905446 -0.2168069\n", - " -0.21971782 -0.05049207 0.12070725 -0.1490105 ]\n", - " [ 0.06027115 -0.12221678 0.18192975 -0.05859193 -0.04659947 -0.19612114\n", - " -0.20028274 0.01511241 0.03615525 0.12080745]\n", - " [-0.19552828 0.03918052 -0.03230212 0.1311668 -0.1016731 0.06661848\n", - " 0.09010674 0.11232612 -0.07669472 0.07195909]\n", - " [-0.04382298 0.06021269 -0.13749652 -0.17768005 -0.18290731 -0.1405653\n", - " -0.09463658 0.03328432 -0.04891114 -0.12729394]\n", - " [ 0.00187842 -0.07061429 0.13783802 -0.18416376 -0.08253521 -0.1436971\n", - " 0.02759105 0.01219904 -0.0128632 0.22186181]\n", - " [-0.08530237 -0.03213883 0.05777045 0.18662488 0.16948868 0.02554451\n", - " -0.08459641 0.07345897 0.14069013 -0.00477207]\n", - " [ 0.12276765 0.18300453 -0.11980148 -0.04943415 -0.20131664 0.05132969\n", - " 0.15936238 -0.04342245 0.03568069 0.07144996]\n", - " [-0.00476937 0.17384104 0.0325843 -0.21979333 -0.18465139 -0.22154187\n", - " 0.00921626 0.12087465 -0.02950055 0.20104776]\n", - " [-0.04022751 0.04571649 0.20163535 0.11316557 -0.00713371 0.2153832\n", - " -0.1335971 0.08328808 0.14121595 -0.13845547]\n", - " [-0.21004361 0.07152335 -0.08483391 -0.1128413 0.04447659 -0.16221067\n", - " 0.2011128 -0.02007227 -0.07161061 0.18693109]\n", - " [ 0.06226142 0.04260208 -0.10691333 0.21311398 -0.06810362 0.18598051\n", - " -0.016437 0.11216957 0.15722302 -0.1664758 ]\n", - " [-0.14903465 -0.22111452 0.16127922 0.19229865 -0.08172148 -0.10951796\n", - " 0.03742959 0.12038527 0.05519409 -0.04660187]] , 4\n", - "-------------------------\n", - "0.layers.0.0.forget_gate_params.hidden_weight.weight\n", - "(20, 20)\n", - "[[-0.14223064 0.19124371 -0.14481081 -0.21607104 -0.08928006 0.04458899\n", - " 0.0831126 0.08646142 -0.12953514 -0.08581803 -0.09943341 -0.10828371\n", - " -0.18833804 0.04577223 -0.06502874 -0.2152229 -0.13056786 -0.13428617\n", - " -0.09645564 -0.13816758]\n", - " [-0.03877772 0.08013236 -0.18096809 -0.01915519 -0.06435173 -0.11432081\n", - " -0.0496515 -0.09477154 0.00718846 -0.16141057 0.04240454 0.20530063\n", - " 0.18528308 -0.10025615 0.06892193 -0.21135406 0.18826427 -0.22283866\n", - " -0.19982089 -0.20071597]\n", - " [-0.20765333 0.03028304 -0.05912894 0.05351972 -0.01383548 -0.00480333\n", - " -0.08078498 -0.13266474 -0.18721604 0.11282834 -0.11529152 -0.04547688\n", - " 0.10860465 -0.05537887 -0.05637903 -0.14906646 -0.19131811 0.10732386\n", - " -0.05044974 0.14060505]\n", - " [ 0.01471702 -0.00028402 -0.20187245 0.0049368 -0.0505344 -0.12759772\n", - " -0.05175107 0.01168989 -0.16848378 0.03718214 0.15558895 0.04417289\n", - " 0.21344449 0.10434435 -0.17634727 -0.08801483 -0.05380939 0.06689031\n", - " -0.00637761 0.17993565]\n", - " [ 0.02597556 -0.14161254 -0.08197778 -0.18603216 -0.061655 0.10993782\n", - " 0.00215927 -0.21323241 -0.19348647 0.08106777 -0.19626026 -0.1783532\n", - " -0.1333177 0.21312374 -0.06358164 -0.09219337 -0.15098219 0.14304285\n", - " -0.03610551 0.04311918]\n", - " [ 0.05341741 0.06306308 0.14312816 0.01160373 0.02312934 -0.01452105\n", - " -0.17375752 -0.05117204 0.21281871 -0.15847513 -0.14112028 -0.22188812\n", - " 0.013559 -0.20914444 -0.11453009 0.20604049 0.09261008 0.11913135\n", - " 0.03828845 -0.19001652]\n", - " [-0.10404866 -0.18102278 -0.13826925 0.076148 -0.06201827 0.2185227\n", - " -0.16299975 -0.19082828 0.2207899 -0.19316407 0.19027402 0.06021235\n", - " -0.20380671 0.1947569 -0.06087566 -0.09220145 -0.17443547 -0.1891369\n", - " 0.04978558 -0.21964009]\n", - " [ 0.09188584 -0.05525529 0.0784739 -0.05474811 0.07732737 -0.00610806\n", - " 0.06572182 -0.09097287 -0.15380703 0.02847747 -0.14272346 -0.13861606\n", - " -0.21501313 -0.07127416 -0.14941145 0.17413448 0.1611419 0.05305404\n", - " 0.18168166 0.10766982]\n", - " [-0.21064265 -0.022373 -0.03629636 -0.13576584 0.06368566 -0.06979065\n", - " -0.10692404 -0.00260666 -0.14866948 0.18506847 0.14149404 0.21166477\n", - " -0.03960523 0.07302888 -0.00899392 -0.18503006 0.10116354 -0.15618756\n", - " -0.08071785 -0.10013654]\n", - " [-0.21814388 0.00802042 0.03663212 -0.01662389 0.1644524 0.01072139\n", - " -0.0407296 -0.12196475 -0.13280123 -0.03179033 -0.1312358 -0.14750735\n", - " -0.02957479 -0.03948133 -0.13649467 0.13065115 0.18963577 -0.15246144\n", - " 0.09794185 -0.10375587]\n", - " [-0.02321799 0.20873794 0.02861272 -0.21320319 0.20555921 -0.00946067\n", - " -0.11196752 -0.11808899 0.19175017 0.00377388 0.12350584 0.14696068\n", - " -0.08678884 0.01897924 -0.14464125 0.18672368 -0.11824197 0.14852415\n", - " 0.05665502 0.1379358 ]\n", - " [-0.1575466 -0.00695391 0.11586404 -0.00892534 -0.0032084 0.10896464\n", - " -0.16712412 -0.04483069 0.10185106 0.10966767 0.20768207 -0.04423303\n", - " 0.05298113 -0.11002054 -0.03752897 -0.11225442 0.16570821 0.0013621\n", - " 0.09096613 0.12299404]\n", - " [ 0.04166875 0.02379598 -0.01636612 -0.1894117 0.03602695 -0.04953878\n", - " -0.18794785 0.20833082 -0.02383836 -0.11159918 -0.21768506 -0.20595226\n", - " 0.08515022 -0.1020775 -0.09659212 -0.12938367 0.18049696 -0.05375253\n", - " 0.14493793 0.17751718]\n", - " [-0.17336273 0.16682073 -0.04269946 0.21416363 0.11421449 -0.21660405\n", - " 0.04154139 0.07860353 -0.08111839 0.16956337 -0.1851744 -0.07095176\n", - " 0.2130592 0.21838497 0.11170101 -0.13348123 -0.19239157 -0.1818077\n", - " -0.05589887 0.12667239]\n", - " [ 0.07079396 -0.02715501 0.20110089 0.17559125 -0.10450983 -0.09683432\n", - " -0.00262346 0.04640241 -0.00160075 0.08632647 0.15427703 -0.04031902\n", - " 0.10981148 0.03041176 0.08583194 0.09205452 -0.05976621 -0.09969731\n", - " 0.09557738 -0.14316456]\n", - " [ 0.1173941 -0.1434708 0.15340208 0.08971985 -0.05478028 0.12781222\n", - " -0.07363954 0.04763815 0.06583516 0.02283663 0.04490386 -0.00443905\n", - " -0.0645991 0.1247524 0.08819748 0.08340425 0.15096036 -0.11699554\n", - " -0.0519524 -0.00637345]\n", - " [ 0.18044722 -0.1780605 -0.12826072 -0.05326315 -0.19100511 -0.17666493\n", - " 0.15317535 0.01043098 -0.17988645 -0.03692174 -0.00735149 -0.07949581\n", - " -0.18703558 0.12169496 -0.02761802 0.21831468 -0.17125311 -0.12275734\n", - " -0.01161703 -0.15571442]\n", - " [ 0.16295849 0.17292082 0.2025731 -0.14115438 0.15909635 0.15525764\n", - " -0.08897205 0.02453648 0.10655329 0.16001071 -0.20884806 0.2226173\n", - " -0.05621968 0.09110746 -0.13887972 -0.17207511 -0.15143432 0.13178375\n", - " -0.11029776 0.12998497]\n", - " [ 0.0675995 0.08894558 -0.04973555 -0.07073203 -0.10462123 -0.12498911\n", - " 0.20617247 -0.01215215 -0.09589054 -0.20804486 0.0097276 -0.22196051\n", - " -0.00263305 0.14118703 -0.12879056 0.12285849 -0.07132839 -0.1719783\n", - " -0.22146888 0.11108326]\n", - " [-0.1710799 0.10918202 0.03201576 0.12152903 -0.16808327 0.19554281\n", - " -0.22271936 -0.16972543 0.13409424 0.00759949 -0.12556304 -0.04690479\n", - " -0.19899549 -0.194607 -0.04797396 0.17057896 0.06677905 0.04216573\n", - " -0.05926214 0.20352075]] , 5\n", - "-------------------------\n", - "0.layers.0.0.cell_gate_params.bias\n", - "(20,)\n", - "[ 0.00214154 0.07550146 0.00355405 0.03489293 0.07456551 0.17159154\n", - " 0.12870987 0.0286169 0.08939798 -0.06724557 0.15284362 0.06277069\n", - " 0.16875166 -0.03491265 -0.18256952 0.04417255 0.09094475 0.18067895\n", - " 0.08666804 0.08261736] , 6\n", - "-------------------------\n", - "0.layers.0.0.cell_gate_params.input_weight.weight\n", - "(20, 10)\n", - "[[ 0.17794745 -0.07684495 0.19742867 0.11464191 0.14933479 0.15947415\n", - " -0.18268393 0.11646748 0.20825341 -0.15708849]\n", - " [-0.01916463 -0.1364658 -0.05399449 0.03332363 0.11960924 -0.06491657\n", - " -0.21173826 0.12073942 0.12545025 -0.04053707]\n", - " [ 0.19142465 0.17237733 -0.04928424 0.00863487 0.03938841 -0.04381773\n", - " -0.05508858 -0.10093604 -0.12716216 0.11167222]\n", - " [-0.06639788 -0.10727276 0.19697405 0.03575112 0.16133724 0.2037714\n", - " -0.03149954 0.03335407 0.20731461 -0.15384933]\n", - " [-0.06704343 0.03181893 -0.01517017 0.05953267 0.11757869 -0.09199598\n", - " 0.01741112 0.20230028 -0.1265286 -0.15163381]\n", - " [-0.17148444 0.13366292 -0.20509928 -0.1087402 0.15102275 -0.13404797\n", - " 0.1818403 -0.10452814 0.03537463 0.02927051]\n", - " [-0.00548471 0.13927223 0.18991414 -0.13961166 0.12540615 0.0597448\n", - " -0.00416681 -0.15634763 0.06633033 0.1623022 ]\n", - " [-0.19193047 -0.20651296 -0.21982425 0.05166686 -0.06424998 -0.06945844\n", - " 0.20821334 -0.05703437 -0.14200093 0.02011372]\n", - " [-0.12272914 -0.06551553 0.11811562 0.05160707 -0.1534436 0.21288224\n", - " 0.15128401 -0.15242937 0.09739923 0.09188432]\n", - " [-0.16044928 -0.1571494 -0.18515183 0.09960561 0.03895786 0.09450045\n", - " -0.09821384 0.1681353 0.02855213 -0.17842196]\n", - " [-0.056282 0.11411482 0.04916727 -0.03420792 -0.15622441 -0.13909872\n", - " 0.19286813 -0.12808998 0.15845725 -0.07484471]\n", - " [ 0.00223508 -0.21774605 -0.07268656 0.18849593 -0.20075409 0.11251042\n", - " -0.188184 0.03261365 -0.20273004 -0.17701481]\n", - " [-0.18051723 -0.07753571 0.03044572 -0.16394225 0.05667006 0.13467607\n", - " 0.18228398 0.19799176 0.14722027 -0.06584404]\n", - " [-0.02060739 0.19784163 0.11123517 -0.05929887 0.16882291 -0.19541554\n", - " 0.1913779 0.12510933 -0.16400692 -0.18237662]\n", - " [ 0.17486629 0.22059093 0.01951262 -0.08737109 0.12732458 0.1008788\n", - " -0.0279066 0.17902343 0.14493623 0.05574536]\n", - " [ 0.11610299 -0.20945168 -0.10473937 0.02451142 0.06080827 -0.03056943\n", - " 0.08443112 0.06811719 -0.20665829 0.07052966]\n", - " [-0.01818041 -0.15387398 0.00754629 -0.05499369 -0.11874414 -0.20375879\n", - " 0.18706112 -0.13579562 0.0300329 0.17913137]\n", - " [-0.02817055 -0.14655502 -0.21633011 0.03715306 -0.11219743 0.01630673\n", - " 0.07142475 -0.06335549 0.1516163 -0.02909804]\n", - " [-0.08923855 -0.14784832 0.06784268 -0.13824603 0.04700406 -0.02822138\n", - " 0.1536749 -0.10962173 -0.11015368 -0.02889775]\n", - " [-0.13657494 0.08524874 -0.08190698 0.09174035 0.12977527 0.13057181\n", - " -0.04105001 0.12203032 -0.11840606 -0.22279048]] , 7\n", - "-------------------------\n", - "0.layers.0.0.cell_gate_params.hidden_weight.weight\n", - "(20, 20)\n", - "[[-2.12806370e-02 -1.62129834e-01 -1.73234463e-01 5.68399914e-02\n", - " 1.91077381e-01 -8.79967287e-02 -1.26489419e-02 -1.62001878e-01\n", - " 3.90813835e-02 6.37496263e-02 -3.43248062e-02 1.70126632e-01\n", - " -1.79964885e-01 -3.00010163e-02 -1.24117516e-01 1.96340203e-01\n", - " 1.89398184e-01 2.19951704e-01 2.05728129e-01 8.85609612e-02]\n", - " [-1.71218976e-01 -1.51676044e-01 5.36037646e-02 -1.99636862e-01\n", - " 1.41561761e-01 9.72114205e-02 5.33513576e-02 -1.95168942e-01\n", - " 1.62662312e-01 -2.36655492e-02 -9.38338637e-02 1.16747312e-01\n", - " 1.88960433e-02 -9.94693190e-02 5.23358434e-02 -1.49113968e-01\n", - " 2.07823291e-01 1.95990741e-01 1.03123404e-01 1.18294187e-01]\n", - " [-2.22277910e-01 -1.24300212e-01 -2.15169474e-01 -1.16545178e-01\n", - " -1.85386583e-01 1.64590582e-01 1.20638609e-01 1.31684974e-01\n", - " -9.92668644e-02 1.70430213e-01 -3.23111340e-02 -5.79339787e-02\n", - " 1.20397158e-01 1.48079455e-01 -1.60713032e-01 2.12880254e-01\n", - " -2.25685220e-02 5.95554635e-02 -2.22653463e-01 2.48931386e-02]\n", - " [-1.10666625e-01 -1.40009314e-01 -9.33616757e-02 -1.04158348e-03\n", - " -6.37013763e-02 -1.43241197e-01 1.60099015e-01 6.65228367e-02\n", - " -2.08098441e-01 4.69054580e-02 5.49288094e-02 8.21655430e-03\n", - " 5.42974621e-02 -1.87213402e-02 9.77927893e-02 -1.57414630e-01\n", - " -9.53418463e-02 1.67505234e-01 -1.38533488e-01 1.09708525e-01]\n", - " [ 2.06897184e-01 -2.04468444e-01 -9.79631692e-02 1.90820277e-01\n", - " -1.35208331e-02 4.41430137e-02 3.18236202e-02 9.21481624e-02\n", - " -9.21330750e-02 2.90291384e-02 1.52316689e-01 -1.88640561e-02\n", - " -2.05149427e-01 7.72908777e-02 -5.70836812e-02 -4.71739881e-02\n", - " 1.16618834e-01 3.91878746e-02 -1.35271400e-01 -1.03187911e-01]\n", - " [-3.39903794e-02 -5.52454554e-02 -4.73374985e-02 -1.52837262e-01\n", - " 1.61986634e-01 1.15967356e-01 4.41279002e-02 5.06293550e-02\n", - " 2.61772387e-02 1.67198420e-01 5.05979806e-02 3.40624861e-02\n", - " -1.22919112e-01 7.45933205e-02 -2.09194586e-01 7.05230013e-02\n", - " -1.93819985e-01 -9.25445408e-02 1.18050657e-01 -1.33182898e-01]\n", - " [ 1.78052112e-01 -1.23547316e-01 2.11798310e-01 6.89183101e-02\n", - " -9.69009325e-02 1.36373073e-01 -1.98024541e-01 -1.41652852e-01\n", - " -1.40091866e-01 2.94355899e-02 2.19678022e-02 -1.92325816e-01\n", - " 2.15771765e-01 -2.13701205e-04 -1.19405292e-01 5.34111727e-03\n", - " -9.59839672e-02 6.16913289e-02 8.09477344e-02 -6.34285584e-02]\n", - " [ 1.30358534e-02 1.33047834e-01 -1.45440847e-01 -4.98616323e-02\n", - " -3.29875015e-02 -1.47941127e-01 1.82121564e-02 8.21812730e-03\n", - " -1.80613607e-01 4.58700024e-02 2.13425189e-01 1.18935056e-01\n", - " -1.21292830e-01 2.04682201e-01 -1.53705969e-01 -1.13691926e-01\n", - " 9.86314118e-02 1.77888468e-01 2.13384852e-01 1.92508563e-01]\n", - " [-1.23128124e-01 5.11671938e-02 -1.40405849e-01 4.93797194e-03\n", - " 1.85259327e-01 1.10102132e-01 -2.06472665e-01 -9.62342396e-02\n", - " -1.88666239e-01 1.05334759e-01 -2.83857696e-02 -1.63461700e-01\n", - " -7.14522004e-02 7.33797774e-02 2.07014289e-02 2.09811881e-01\n", - " -2.96870619e-03 7.03370497e-02 -6.77365363e-02 2.66825557e-02]\n", - " [ 8.01036973e-03 1.92074046e-01 9.36935991e-02 -1.27431735e-01\n", - " -1.98687479e-01 -2.12748200e-01 -8.12046453e-02 2.89045740e-02\n", - " 2.10361689e-01 -2.19703875e-02 8.74281824e-02 1.13642633e-01\n", - " -1.71282887e-01 -1.84971020e-01 8.47281963e-02 1.04225203e-01\n", - " -1.04119189e-01 3.50410007e-02 -2.18935862e-01 2.81849946e-03]\n", - " [ 5.48111200e-02 2.11656699e-03 -3.54930870e-02 9.30717662e-02\n", - " -6.14620335e-02 1.66451484e-01 -1.92599118e-01 -1.27790585e-01\n", - " -1.86674312e-01 -2.02230543e-01 1.65771663e-01 -5.53366169e-02\n", - " -1.75649151e-01 4.63781990e-02 -1.69327542e-01 1.15589779e-02\n", - " 1.06298663e-01 -4.72831465e-02 1.14950888e-01 4.58941013e-02]\n", - " [-1.79431096e-01 4.40098420e-02 1.44146204e-01 -5.18364720e-02\n", - " 2.11329088e-02 2.85264328e-02 1.92284174e-02 5.81263304e-02\n", - " -2.14094386e-01 1.69653893e-01 9.75249708e-02 2.76133306e-02\n", - " 4.06875163e-02 -1.80331707e-01 -6.38444126e-02 -9.72616393e-03\n", - " 5.31534106e-02 -1.22661509e-01 2.37256587e-02 -6.93958476e-02]\n", - " [ 1.62758812e-01 -1.91935405e-01 2.33742520e-02 1.51492402e-01\n", - " -1.73671409e-01 -6.40887721e-03 1.03327051e-01 9.02309865e-02\n", - " 2.62962040e-02 9.03898776e-02 -1.55875593e-01 1.86238810e-01\n", - " 4.98715229e-03 1.44541100e-01 4.94662710e-02 -2.48756800e-02\n", - " 9.57791656e-02 2.12270051e-01 2.20569506e-01 -1.88220173e-01]\n", - " [ 1.35616167e-02 -1.60633817e-01 1.30284145e-01 1.60526067e-01\n", - " -1.57016143e-01 -1.29234986e-02 1.54731110e-01 1.47872686e-01\n", - " -1.68123141e-01 1.50136366e-01 -3.95872369e-02 -1.90171361e-01\n", - " 4.45422679e-02 1.04169942e-01 1.34101674e-01 -1.52035385e-01\n", - " -1.61954522e-01 -1.50239438e-01 1.26720712e-01 -1.95428118e-01]\n", - " [-1.88556593e-03 -6.57092705e-02 9.76277590e-02 4.39127870e-02\n", - " -1.12915963e-01 3.90566476e-02 2.05778107e-01 3.68154384e-02\n", - " -1.10807024e-01 7.48633966e-03 -2.05102757e-01 -1.43465236e-01\n", - " -4.15345095e-02 -1.39340952e-01 1.89353585e-01 4.34043780e-02\n", - " 1.73192978e-01 -5.09172641e-02 -3.10981516e-02 5.64037636e-02]\n", - " [-6.64871484e-02 -7.62214959e-02 -2.19352797e-01 1.68453470e-01\n", - " 2.02370644e-01 -2.21398085e-01 -7.39822015e-02 -1.69133484e-01\n", - " -9.07677040e-02 1.70234248e-01 1.19611956e-01 -1.73501018e-02\n", - " 9.55028459e-02 6.67780936e-02 1.22115597e-01 -1.79690495e-01\n", - " 6.91184700e-02 -2.11776465e-01 -1.47058472e-01 -8.33279863e-02]\n", - " [-2.17858739e-02 -2.11018786e-01 5.56494808e-03 3.57002839e-02\n", - " -8.87419507e-02 7.25275800e-02 1.95392817e-01 -3.81953120e-02\n", - " -1.19088188e-01 -1.98077247e-01 -1.63278311e-01 -1.23674117e-01\n", - " -1.65306747e-01 -8.79110843e-02 1.23181596e-01 6.99715093e-02\n", - " 2.01542184e-01 2.22007304e-01 -8.05223361e-02 -8.75686854e-02]\n", - " [ 3.05994693e-02 -1.78054109e-01 1.21623978e-01 -4.02442813e-02\n", - " -1.87232435e-01 -1.68819025e-01 -1.54080361e-01 6.14588112e-02\n", - " 1.71410367e-01 1.77153081e-01 -6.15712442e-02 -1.29883334e-01\n", - " -9.92444977e-02 -1.52750149e-01 -5.76506779e-02 -2.01948732e-01\n", - " 1.19517274e-01 -2.10457653e-01 -1.39095634e-01 1.50062576e-01]\n", - " [-1.67259946e-01 5.34564890e-02 1.67486787e-01 2.20412284e-01\n", - " 1.13142729e-01 -6.00084551e-02 1.27776846e-01 -7.37963570e-03\n", - " -6.89469650e-02 7.28242099e-04 5.01570366e-02 1.49932787e-01\n", - " 9.38621163e-02 1.06770106e-01 3.34510244e-02 -1.12544857e-02\n", - " 9.38917845e-02 5.37824407e-02 -2.13967159e-01 3.61516774e-02]\n", - " [-9.93019715e-02 -1.18578210e-01 8.64755288e-02 4.57250476e-02\n", - " 3.78663242e-02 -1.06075369e-01 1.03322893e-01 2.09839717e-01\n", - " 2.73554083e-02 9.19082835e-02 -1.96176514e-01 1.32933155e-01\n", - " 7.76783228e-02 1.00741126e-01 9.32467878e-02 -5.88140823e-02\n", - " -1.34220198e-02 2.16287613e-01 1.63621128e-01 -1.60278752e-01]] , 8\n", - "-------------------------\n", - "0.layers.0.0.output_gate_params.bias\n", - "(20,)\n", - "[ 0.17741492 0.22254053 0.02940683 -0.17445402 0.04334408 -0.04515981\n", - " 0.16077036 -0.21483785 0.05722176 -0.00262266 0.01760296 0.15381731\n", - " 0.0040394 -0.18002152 -0.13043821 -0.08953302 0.02384774 0.08628984\n", - " -0.04173774 -0.08825271] , 9\n", - "-------------------------\n", - "0.layers.0.0.output_gate_params.input_weight.weight\n", - "(20, 10)\n", - "[[ 9.81200710e-02 -2.17414662e-01 1.56252235e-01 -2.59936582e-02\n", - " 1.55592158e-01 1.68960407e-01 2.38872208e-02 7.07329437e-02\n", - " -1.26473457e-01 1.60210714e-01]\n", - " [ 1.30875960e-01 -3.51194218e-02 8.71568248e-02 -1.25249382e-02\n", - " 1.74701765e-01 9.20466036e-02 1.63019851e-01 -2.03253865e-01\n", - " 2.17866078e-01 8.33117217e-02]\n", - " [ 1.08713590e-01 4.98261265e-02 1.46862045e-01 2.10508242e-01\n", - " -1.90491565e-02 -1.83473915e-01 2.05329910e-01 -4.71567698e-02\n", - " -1.07840233e-01 1.37649149e-01]\n", - " [ 1.24790154e-01 2.99369618e-02 -1.40363071e-02 -4.27761748e-02\n", - " 2.05027208e-01 1.36240214e-01 1.33165866e-01 1.42589167e-01\n", - " -1.17026694e-01 4.66880240e-02]\n", - " [-1.93439931e-01 1.29910931e-01 -2.21640781e-01 -2.23473564e-01\n", - " -2.21031293e-01 1.37891039e-01 2.32707467e-02 5.08490019e-04\n", - " 3.55657227e-02 -8.46242681e-02]\n", - " [-6.79011941e-02 -1.50619775e-01 -5.46085611e-02 -1.37593433e-01\n", - " 5.88322058e-03 1.75689265e-01 -1.84854001e-01 1.09963417e-01\n", - " -1.66318297e-01 -9.26456451e-02]\n", - " [ 4.37250473e-02 3.84753868e-02 1.83374569e-01 -8.36465479e-05\n", - " -8.51647705e-02 -9.24766734e-02 6.55569835e-03 -1.67666823e-01\n", - " -1.75320774e-01 -9.56731290e-02]\n", - " [ 5.74407633e-03 -1.51010871e-01 -1.27642184e-01 1.59654185e-01\n", - " 2.06639260e-01 -7.00415373e-02 -1.91840678e-01 -8.56086463e-02\n", - " 9.02482048e-02 7.25704432e-02]\n", - " [-6.93180412e-02 -1.96934849e-01 -6.72358871e-02 -4.99973148e-02\n", - " 1.28766835e-01 -1.10879898e-01 1.34200945e-01 3.10183968e-02\n", - " -3.74761075e-02 -1.99273914e-01]\n", - " [ 2.20759660e-01 -3.98728549e-02 1.40693069e-01 -1.15664735e-01\n", - " -2.17755169e-01 -1.78237423e-01 -1.14595190e-01 -7.12116584e-02\n", - " -3.15762796e-02 1.86491266e-01]\n", - " [-2.06223264e-01 1.11605875e-01 1.88149154e-01 1.43918453e-03\n", - " -1.39450610e-01 7.15188682e-03 5.30482270e-02 9.89372358e-02\n", - " -6.79695681e-02 -7.67354444e-02]\n", - " [-1.05491146e-01 -2.16275647e-01 7.85326734e-02 -1.69050053e-01\n", - " -1.07421041e-01 -2.30107992e-03 1.72379389e-01 1.98816836e-01\n", - " -1.62642673e-01 1.93931282e-01]\n", - " [ 2.00302720e-01 1.80637628e-01 1.94676816e-02 1.79588884e-01\n", - " 1.08642928e-01 -1.60451204e-01 -1.17858045e-01 4.20530513e-03\n", - " -1.58465564e-01 -7.36296773e-02]\n", - " [ 1.80281103e-01 1.04106739e-01 1.94734529e-01 1.71422120e-03\n", - " -1.14017285e-01 1.47993699e-01 1.64847951e-02 3.76562215e-02\n", - " -9.47417393e-02 9.18511599e-02]\n", - " [-1.65143967e-01 1.78432971e-01 1.95620790e-01 8.06822702e-02\n", - " 1.74128443e-01 1.35722205e-01 -8.53993148e-02 -1.93941638e-01\n", - " 2.94244476e-02 1.40397370e-01]\n", - " [-2.28753053e-02 1.88145563e-02 1.65735826e-01 9.23255607e-02\n", - " 1.67166159e-01 3.28338295e-02 2.50651501e-02 -5.34861833e-02\n", - " -3.77333388e-02 -1.18839331e-01]\n", - " [ 1.49498299e-01 2.03940362e-01 8.29838291e-02 6.35351241e-03\n", - " -7.38137364e-02 -2.20774114e-01 -4.14042696e-02 -1.58739850e-01\n", - " -1.65080443e-01 -4.42778133e-02]\n", - " [-4.39881422e-02 4.51072417e-02 -1.62074581e-01 1.60696968e-01\n", - " -2.03583151e-01 -1.05898663e-01 -8.48927200e-02 1.37860607e-02\n", - " 9.24347416e-02 -5.89275286e-02]\n", - " [ 3.48980725e-02 -5.29355779e-02 -8.79468024e-02 -3.12774107e-02\n", - " 4.50214110e-02 -2.17200696e-01 -1.55640006e-01 1.74693078e-01\n", - " 1.01111621e-01 -5.97870257e-03]\n", - " [ 7.06157601e-03 3.08655780e-02 5.19711897e-02 -1.52664930e-01\n", - " -6.09524250e-02 -2.05220923e-01 -1.75796479e-01 -4.20728028e-02\n", - " -2.95243543e-02 1.74893185e-01]] , 10\n", - "-------------------------\n", - "0.layers.0.0.output_gate_params.hidden_weight.weight\n", - "(20, 20)\n", - "[[ 0.03851524 -0.03625689 -0.00619491 0.12488268 -0.06773603 -0.0418019\n", - " -0.04485707 -0.18031046 -0.03125188 -0.20671144 -0.12019279 -0.14232881\n", - " 0.16657048 -0.20598304 0.21545227 0.08384079 -0.15111198 0.18525589\n", - " -0.0492739 -0.18939163]\n", - " [-0.03105276 0.11050874 -0.21741039 -0.01675669 0.09098183 -0.08714523\n", - " 0.02036562 -0.0876366 -0.15001732 0.17511557 -0.1587715 -0.00262151\n", - " 0.07447443 -0.12496222 0.10796666 -0.18569624 0.21355589 0.09958527\n", - " -0.03165689 -0.18600492]\n", - " [ 0.00689578 0.0793154 -0.12144296 -0.02816021 -0.22284126 -0.22354037\n", - " -0.02428471 0.187102 -0.01052416 0.07010341 -0.08937916 -0.07301357\n", - " -0.02457852 -0.11304034 0.13682817 0.13944101 -0.17383203 0.06858449\n", - " -0.09237309 -0.12858376]\n", - " [-0.02727968 -0.0693544 -0.12731954 0.03295429 0.12762886 -0.03450404\n", - " -0.01564156 0.01682661 -0.09610138 0.11838 0.2063172 -0.02043679\n", - " 0.01520035 0.18016809 0.18314716 -0.16634111 -0.10355289 -0.21934243\n", - " 0.13695723 0.17452586]\n", - " [-0.08138426 0.07172713 0.05416519 -0.19238184 0.0892937 0.10971964\n", - " 0.00491766 0.02293088 0.05196048 0.16108814 0.19757238 0.03213832\n", - " 0.09531388 -0.05850127 0.13331535 -0.08795608 -0.18431664 0.1049106\n", - " 0.08293276 0.0492176 ]\n", - " [ 0.09513766 0.02660845 0.0761021 0.09111597 -0.12062387 -0.01198089\n", - " 0.03369791 -0.03394864 -0.188005 0.02121117 0.13665509 -0.11958458\n", - " 0.21953909 0.0509951 0.09510146 -0.08634473 -0.18291326 -0.08321758\n", - " 0.00683159 -0.10189173]\n", - " [ 0.19913672 -0.14311586 -0.15060481 -0.0793146 0.20060927 -0.10224532\n", - " 0.20686573 0.10745841 -0.03397548 0.11565119 0.10630453 -0.11381406\n", - " -0.04603498 0.21659105 0.12819836 -0.10921414 -0.0601254 0.12532982\n", - " 0.11351746 0.01772486]\n", - " [-0.14387828 -0.16492477 -0.04719649 0.08221286 -0.02383876 -0.18695372\n", - " -0.05480145 0.22319667 -0.18481532 -0.17354017 0.14056584 0.22249034\n", - " -0.21510145 -0.20223859 -0.06991865 0.22294378 -0.1269095 0.01911828\n", - " 0.18253623 -0.0791588 ]\n", - " [-0.06857247 -0.15009233 0.0085855 0.20870976 0.0914357 0.157171\n", - " -0.01481424 -0.03551737 -0.03994827 0.12753342 -0.02932107 -0.19100396\n", - " -0.07851914 0.08750965 0.21801063 -0.04065894 -0.19468635 -0.16464569\n", - " -0.1759353 0.09013668]\n", - " [ 0.16482699 0.06612828 0.07709847 0.14567545 0.15288451 0.13352284\n", - " 0.12504087 0.06050573 0.11541758 -0.1534312 -0.14473058 0.06013739\n", - " 0.03479816 -0.19657765 -0.16289718 -0.17800786 0.17759389 0.14619377\n", - " -0.11769552 0.033738 ]\n", - " [-0.05143119 0.19438726 -0.20252845 -0.16313015 -0.18616724 0.13013433\n", - " -0.11177826 0.13318242 0.07558636 -0.10929734 -0.06023749 -0.09048979\n", - " 0.09864956 -0.08967353 0.07588523 0.01597441 -0.17857382 -0.1405619\n", - " -0.1550431 0.1171688 ]\n", - " [ 0.0484514 -0.00562237 -0.1331447 -0.22155127 -0.07913139 -0.17113578\n", - " -0.22241357 -0.21326728 -0.14605871 -0.21737726 0.069704 0.08366753\n", - " 0.0901287 -0.22259942 0.13826938 0.04359518 0.11433873 -0.05495736\n", - " 0.10737925 -0.21207204]\n", - " [ 0.0761621 0.17731208 0.09399657 -0.21077465 -0.06277167 -0.02776839\n", - " 0.11715963 -0.08461329 0.03216063 -0.07849736 -0.03552182 -0.00445118\n", - " -0.1283987 -0.15520401 0.1845957 0.18787426 -0.00676964 0.19354711\n", - " 0.17230819 -0.14084579]\n", - " [-0.08885217 -0.15358365 0.07229424 0.00565505 -0.03066478 0.16602065\n", - " -0.08740129 -0.12237797 -0.15895672 -0.11375529 0.21551864 -0.10871551\n", - " -0.06152614 0.10078279 -0.17173737 -0.13572007 0.16457646 -0.08576282\n", - " -0.1160312 -0.02892987]\n", - " [-0.03186222 0.04086494 0.08197901 -0.17241116 0.2032053 -0.21259488\n", - " 0.07573222 -0.06309208 -0.09442816 0.20916638 -0.2154794 0.01527144\n", - " 0.1432838 0.19990316 -0.18904059 0.02694101 0.22123207 -0.21902935\n", - " 0.0546164 -0.14010552]\n", - " [ 0.03629959 -0.20227122 0.11001531 -0.04960475 0.13363701 -0.0033625\n", - " -0.03187283 -0.05428797 -0.2047436 -0.09497944 0.00742607 -0.1729926\n", - " 0.19623755 -0.14542621 -0.08711543 -0.02990268 -0.1811355 -0.00176668\n", - " -0.10767633 -0.1871676 ]\n", - " [ 0.00548474 0.19795649 0.05506302 0.18442854 -0.0021867 -0.07804751\n", - " 0.1802177 -0.11907462 -0.20685978 0.0489392 0.11143997 -0.13366425\n", - " 0.07870162 -0.07933193 -0.02713096 -0.04951058 -0.04782786 -0.18194063\n", - " 0.05480235 -0.05881837]\n", - " [ 0.17097771 0.03732251 -0.18287036 -0.17010981 -0.11653572 0.10708019\n", - " -0.14437075 -0.10229405 0.04059571 -0.15502611 -0.11010965 0.20276332\n", - " -0.11821949 -0.07449946 0.1599237 0.05010674 0.17550889 -0.19699533\n", - " 0.11176885 -0.03420243]\n", - " [-0.14325288 -0.09576999 -0.21628909 0.15468563 -0.04290593 -0.2192564\n", - " 0.19123225 0.14483131 0.09245753 0.21885075 0.20192903 0.20897363\n", - " 0.2002456 0.18172018 0.05853782 -0.01872608 0.00850361 -0.09292599\n", - " 0.10506337 0.00647802]\n", - " [ 0.05275466 -0.14403579 -0.08419433 0.16763861 0.02174832 0.07716487\n", - " -0.1952104 -0.09575427 -0.00569092 -0.0234643 0.14273825 -0.06748112\n", - " 0.18662164 -0.04324729 0.08697162 -0.15742545 0.03795354 -0.21800253\n", - " -0.19185208 -0.14310952]] , 11\n", - "-------------------------\n", - "/0/layers.0.0/output_quant/export_handler/Constant_output_0\n", - "()\n", - "0.0078125 , 12\n", - "-------------------------\n", - "/0/layers.0.0/output_quant/export_handler/Constant_1_output_0\n", - "()\n", - "0 , 13\n", - "-------------------------\n", - "/0/layers.0.0/output_quant/export_handler/Constant_2_output_0\n", - "()\n", - "8.0 , 14\n", - "-------------------------\n", - "/0/layers.0.0/input_weight/weight_quant/export_handler/Constant_output_0\n", - "()\n", - "0.001760039 , 15\n", - "-------------------------\n", - "/0/layers.0.0/input_weight/weight_quant/export_handler/Constant_1_output_0\n", - "()\n", - "-127 , 16\n", - "-------------------------\n", - "/0/layers.0.0/input_weight/weight_quant/export_handler/Constant_2_output_0\n", - "()\n", - "127 , 17\n", - "-------------------------\n", - "/0/layers.0.0/input_weight/weight_quant/export_handler_1/Constant_output_0\n", - "()\n", - "0.0017542557 , 18\n", - "-------------------------\n", - "/0/layers.0.0/input_weight/weight_quant/export_handler_2/Constant_output_0\n", - "()\n", - "0.0017601603 , 19\n", - "-------------------------\n", - "/0/layers.0.0/input_weight/weight_quant/export_handler_3/Constant_output_0\n", - "()\n", - "0.0017546351 , 20\n", - "-------------------------\n", - "onnx.brevitas::QuantLSTMCell_48\n", - "(1, 20)\n", - "[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]] , 21\n", - "-------------------------\n", - "/0/layers.0.0/export_handler/Constant_output_0\n", - "()\n", - "0.003921569 , 22\n", - "-------------------------\n", - "/0/layers.0.0/export_handler/Constant_1_output_0\n", - "()\n", - "0 , 23\n", - "-------------------------\n", - "/0/layers.0.0/Constant_output_0\n", - "(1,)\n", - "[0] , 24\n", - "-------------------------\n", - "/0/layers.0.0/Constant_1_output_0\n", - "(1,)\n", - "[0] , 25\n", - "-------------------------\n" - ] - } - ], + "outputs": [], "source": [ "# In this block of code we store all the parameters (weight matrices, recurrence matrices, biases, scales and zero-points) that we will need to import in the QCDQ implementation.\n", "# Importing the exported quantized model from brevitas\n", @@ -1190,7 +226,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "02fe4d94-af24-4d5e-a809-7d8c49e7fd90", "metadata": {}, "outputs": [], @@ -1239,7 +275,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "02761646-4c6d-440f-8e90-4935beebab56", "metadata": {}, "outputs": [], @@ -1257,7 +293,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "c08e5a23-ef2e-4bca-9293-c800350c2c62", "metadata": {}, "outputs": [], @@ -1384,7 +420,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "79839558-8752-4fc8-9b0e-8fed47c91701", "metadata": {}, "outputs": [], @@ -1604,28 +640,10 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "c6ec7b2a-456d-4452-97ec-df9a471d5391", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Serving './lstm_full_graph.onnx' at http://localhost:8080\n" - ] - }, - { - "data": { - "text/plain": [ - "('localhost', 8080)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "lstm_model = qonnx_make_model(lstm_body, producer_name=\"QuantizeLSTM_scan\")\n", "onnx.save(lstm_model, './lstm_full_graph.onnx')\n", @@ -1642,83 +660,10 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "db5892bc-ac8d-4972-afcf-20bf880f5e86", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[array([[ 0.1484375],\n", - " [-0.0078125],\n", - " [ 0.0390625],\n", - " [ 0.140625 ],\n", - " [ 0.015625 ],\n", - " [ 0. ],\n", - " [ 0.1015625],\n", - " [-0.1015625],\n", - " [ 0.0390625],\n", - " [-0.0625 ],\n", - " [ 0.015625 ],\n", - " [-0.125 ],\n", - " [ 0.1015625],\n", - " [ 0.03125 ],\n", - " [ 0.1640625],\n", - " [-0.015625 ],\n", - " [-0.0234375],\n", - " [-0.015625 ],\n", - " [-0.046875 ],\n", - " [ 0.0078125]], dtype=float32), array([[ 0.2421875],\n", - " [-0.0078125],\n", - " [ 0.0625 ],\n", - " [ 0.2421875],\n", - " [ 0.03125 ],\n", - " [ 0.0078125],\n", - " [ 0.2265625],\n", - " [-0.234375 ],\n", - " [ 0.0859375],\n", - " [-0.1328125],\n", - " [ 0.0390625],\n", - " [-0.2421875],\n", - " [ 0.1875 ],\n", - " [ 0.0546875],\n", - " [ 0.296875 ],\n", - " [-0.03125 ],\n", - " [-0.0546875],\n", - " [-0.03125 ],\n", - " [-0.109375 ],\n", - " [ 0.0234375]], dtype=float32), array([[ 0.1484375],\n", - " [-0.0078125],\n", - " [ 0.0390625],\n", - " [ 0.140625 ],\n", - " [ 0.015625 ],\n", - " [ 0. ],\n", - " [ 0.1015625],\n", - " [-0.1015625],\n", - " [ 0.0390625],\n", - " [-0.0625 ],\n", - " [ 0.015625 ],\n", - " [-0.125 ],\n", - " [ 0.1015625],\n", - " [ 0.03125 ],\n", - " [ 0.1640625],\n", - " [-0.015625 ],\n", - " [-0.0234375],\n", - " [-0.015625 ],\n", - " [-0.046875 ],\n", - " [ 0.0078125]], dtype=float32)]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-10-20 11:07:46.350885612 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'scale_test'. It is not used by any node and should be removed from the model.\n", - "2023-10-20 11:07:46.370978980 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'scale_test'. It is not used by any node and should be removed from the model.\n" - ] - } - ], + "outputs": [], "source": [ "# Before the model can be executed, it'd opset version needs to be set to a minimum of '14' to accomodate clip nodes with INT8 and UINT8 input. Otherwise ONNX cannot create an execution session and we get errors.\n", "lstm_model.opset_import[0].version = 14\n", @@ -1787,7 +732,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "700a93a8-f757-4fa1-88dd-47a3f2a7f171", "metadata": {}, "outputs": [], @@ -1814,7 +759,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "111fdce4-464f-40c1-ac4d-3022b05f153e", "metadata": {}, "outputs": [], @@ -1838,19 +783,10 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "4668cf2b-524e-4768-8dc8-9d619f6273da", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Serving './lstm_scan_node_model.onnx' at http://localhost:8081\n", - "[]\n" - ] - } - ], + "outputs": [], "source": [ "scan_lstm_node_graph = make_graph(\n", " nodes = [scan_node_lstm],\n", @@ -1882,608 +818,10 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "818d2a81-686f-4a4a-8e78-17dbf75d8451", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Final Hidden State [[ 0.25 ]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.0546875]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.140625 ]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "------------------------\n", - "Final Cell State [[ 0.421875 ]\n", - " [-0.078125 ]\n", - " [ 0.0234375]\n", - " [ 0.4921875]\n", - " [ 0.1484375]\n", - " [-0.09375 ]\n", - " [ 0.75 ]\n", - " [-0.59375 ]\n", - " [ 0.1171875]\n", - " [-0.3125 ]\n", - " [ 0.0390625]\n", - " [-0.421875 ]\n", - " [ 0.3984375]\n", - " [ 0.2578125]\n", - " [ 0.828125 ]\n", - " [ 0.0625 ]\n", - " [-0.0703125]\n", - " [-0.109375 ]\n", - " [-0.1484375]\n", - " [ 0.0234375]]\n", - "------------------------\n", - "All Hidden States [[[ 0.1484375]\n", - " [-0.0078125]\n", - " [ 0.0390625]\n", - " [ 0.140625 ]\n", - " [ 0.015625 ]\n", - " [ 0. ]\n", - " [ 0.1015625]\n", - " [-0.1015625]\n", - " [ 0.0390625]\n", - " [-0.0625 ]\n", - " [ 0.015625 ]\n", - " [-0.125 ]\n", - " [ 0.1015625]\n", - " [ 0.03125 ]\n", - " [ 0.1640625]\n", - " [-0.015625 ]\n", - " [-0.0234375]\n", - " [-0.015625 ]\n", - " [-0.046875 ]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.203125 ]\n", - " [-0.0234375]\n", - " [ 0.03125 ]\n", - " [ 0.2109375]\n", - " [ 0.0234375]\n", - " [-0.015625 ]\n", - " [ 0.1875 ]\n", - " [-0.1484375]\n", - " [ 0.046875 ]\n", - " [-0.09375 ]\n", - " [ 0.0234375]\n", - " [-0.1640625]\n", - " [ 0.1484375]\n", - " [ 0.0703125]\n", - " [ 0.2578125]\n", - " [-0.015625 ]\n", - " [-0.03125 ]\n", - " [-0.0234375]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.2265625]\n", - " [-0.03125 ]\n", - " [ 0.015625 ]\n", - " [ 0.2421875]\n", - " [ 0.03125 ]\n", - " [-0.0234375]\n", - " [ 0.234375 ]\n", - " [-0.1796875]\n", - " [ 0.0546875]\n", - " [-0.109375 ]\n", - " [ 0.0234375]\n", - " [-0.1875 ]\n", - " [ 0.1796875]\n", - " [ 0.09375 ]\n", - " [ 0.2734375]\n", - " [ 0. ]\n", - " [-0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.0703125]\n", - " [ 0.015625 ]]\n", - "\n", - " [[ 0.234375 ]\n", - " [-0.0390625]\n", - " [ 0.015625 ]\n", - " [ 0.2578125]\n", - " [ 0.0390625]\n", - " [-0.03125 ]\n", - " [ 0.25 ]\n", - " [-0.1875 ]\n", - " [ 0.0546875]\n", - " [-0.125 ]\n", - " [ 0.015625 ]\n", - " [-0.1953125]\n", - " [ 0.1953125]\n", - " [ 0.1171875]\n", - " [ 0.2734375]\n", - " [ 0.015625 ]\n", - " [-0.03125 ]\n", - " [-0.0390625]\n", - " [-0.078125 ]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.2421875]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.0390625]\n", - " [-0.03125 ]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.1328125]\n", - " [ 0.015625 ]\n", - " [-0.1953125]\n", - " [ 0.203125 ]\n", - " [ 0.1328125]\n", - " [ 0.2734375]\n", - " [ 0.0234375]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.078125 ]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.2421875]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.046875 ]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.1328125]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.2421875]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.046875 ]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.140625 ]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.2421875]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.0546875]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.140625 ]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.2421875]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.0546875]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.140625 ]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.25 ]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.0546875]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.140625 ]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.25 ]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.0546875]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.140625 ]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.25 ]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.0546875]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.140625 ]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.25 ]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.0546875]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.140625 ]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.25 ]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.0546875]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.140625 ]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.25 ]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.0546875]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.140625 ]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.25 ]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.0546875]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.140625 ]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.25 ]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.0546875]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.140625 ]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.25 ]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.0546875]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.140625 ]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.25 ]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.0546875]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.140625 ]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.25 ]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.0546875]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.140625 ]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.25 ]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.0546875]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.140625 ]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.25 ]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.0546875]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.140625 ]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.25 ]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.0546875]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.140625 ]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.25 ]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.0546875]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.140625 ]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]\n", - "\n", - " [[ 0.25 ]\n", - " [-0.046875 ]\n", - " [ 0.015625 ]\n", - " [ 0.2734375]\n", - " [ 0.0546875]\n", - " [-0.0390625]\n", - " [ 0.25 ]\n", - " [-0.1953125]\n", - " [ 0.0546875]\n", - " [-0.140625 ]\n", - " [ 0.015625 ]\n", - " [-0.203125 ]\n", - " [ 0.203125 ]\n", - " [ 0.140625 ]\n", - " [ 0.2734375]\n", - " [ 0.03125 ]\n", - " [-0.03125 ]\n", - " [-0.046875 ]\n", - " [-0.0703125]\n", - " [ 0.0078125]]]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-10-20 10:50:38.892379706 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'scale_test'. It is not used by any node and should be removed from the model.\n", - "2023-10-20 10:50:38.894726380 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_uo_out'. It is not used by any node and should be removed from the model.\n", - "2023-10-20 10:50:38.894741924 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_wf_out'. It is not used by any node and should be removed from the model.\n", - "2023-10-20 10:50:38.894750521 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_ui_out'. It is not used by any node and should be removed from the model.\n", - "2023-10-20 10:50:38.894758793 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'max'. It is not used by any node and should be removed from the model.\n", - "2023-10-20 10:50:38.894767212 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'W_c'. It is not used by any node and should be removed from the model.\n", - "2023-10-20 10:50:38.894775093 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'U_c'. It is not used by any node and should be removed from the model.\n", - "2023-10-20 10:50:38.894782542 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'U_i'. It is not used by any node and should be removed from the model.\n", - "2023-10-20 10:50:38.894790413 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_uc_out'. It is not used by any node and should be removed from the model.\n", - "2023-10-20 10:50:38.894797986 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'W_i'. It is not used by any node and should be removed from the model.\n", - "2023-10-20 10:50:38.894805922 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_wi_out'. It is not used by any node and should be removed from the model.\n", - "2023-10-20 10:50:38.894813725 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'U_o'. It is not used by any node and should be removed from the model.\n", - "2023-10-20 10:50:38.894821378 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'W_f'. It is not used by any node and should be removed from the model.\n", - "2023-10-20 10:50:38.894829187 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'W_o'. It is not used by any node and should be removed from the model.\n", - "2023-10-20 10:50:38.894837744 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_uf_out'. It is not used by any node and should be removed from the model.\n", - "2023-10-20 10:50:38.894845343 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_wc_out'. It is not used by any node and should be removed from the model.\n", - "2023-10-20 10:50:38.894852862 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'U_f'. It is not used by any node and should be removed from the model.\n", - "2023-10-20 10:50:38.894861070 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_wo_out'. It is not used by any node and should be removed from the model.\n", - "2023-10-20 10:50:38.894868719 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'min'. It is not used by any node and should be removed from the model.\n" - ] - } - ], + "outputs": [], "source": [ "# Defining the values of the varibales to test the execution of the scan model\n", "num_inputs = 25\n", @@ -2532,344 +870,10 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "id": "2fe07395-6cf9-4c99-a0d3-a27aa6a326b5", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Brevitas Output shape : (25, 1, 20)\n", - "SCAN-QCDQ-LSTM output shape : (25, 1, 20)\n", - "-----------------------------------\n", - "Brevitas Output = [[[ 0.1484375 -0.0078125 0.0390625 0.140625 0.0078125 0.\n", - " 0.109375 -0.09375 0.0390625 -0.0625 0.015625 -0.1171875\n", - " 0.1015625 0.03125 0.1640625 -0.015625 -0.0234375 -0.015625\n", - " -0.046875 0.0078125]]\n", - "\n", - " [[ 0.2109375 -0.0234375 0.03125 0.2109375 0.0234375 -0.015625\n", - " 0.1875 -0.1484375 0.046875 -0.09375 0.0234375 -0.1640625\n", - " 0.1484375 0.0625 0.2578125 -0.015625 -0.03125 -0.0234375\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2421875 -0.0390625 0.015625 0.25 0.03125 -0.0234375\n", - " 0.234375 -0.1796875 0.0546875 -0.109375 0.015625 -0.1875\n", - " 0.1796875 0.09375 0.3125 0. -0.03125 -0.03125\n", - " -0.078125 0.0078125]]\n", - "\n", - " [[ 0.25 -0.0390625 0.015625 0.265625 0.0390625 -0.03125\n", - " 0.265625 -0.1875 0.0546875 -0.125 0.015625 -0.1953125\n", - " 0.1953125 0.1171875 0.3359375 0.015625 -0.03125 -0.0390625\n", - " -0.078125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.015625 0.2734375 0.046875 -0.0390625\n", - " 0.2890625 -0.1953125 0.0546875 -0.125 0.015625 -0.203125\n", - " 0.203125 0.125 0.359375 0.0234375 -0.03125 -0.046875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.2734375 0.046875 -0.0390625\n", - " 0.296875 -0.1953125 0.0546875 -0.1328125 0.015625 -0.203125\n", - " 0.2109375 0.1328125 0.3671875 0.03125 -0.0234375 -0.046875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.015625 0.28125 0.0546875 -0.046875\n", - " 0.3046875 -0.1953125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.2109375 0.140625 0.375 0.0390625 -0.0234375 -0.046875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.0546875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.2109375 0.140625 0.3828125 0.0390625 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.2109375 0.1484375 0.390625 0.0390625 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.0390625 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.2578125 -0.046875 0.0078125 0.28125 0.0546875 -0.046875\n", - " 0.3125 -0.203125 0.0546875 -0.140625 0.0078125 -0.203125\n", - " 0.21875 0.1484375 0.390625 0.046875 -0.015625 -0.0546875\n", - " -0.0703125 0.015625 ]]]\n", - "-----------------------------------\n", - "SCAN-QCDQ-LSTM output [[[ 0.1484375 -0.0078125 0.0390625 0.140625 0.015625 0.\n", - " 0.1015625 -0.1015625 0.0390625 -0.0625 0.015625 -0.125\n", - " 0.1015625 0.03125 0.1640625 -0.015625 -0.0234375 -0.015625\n", - " -0.046875 0.0078125]]\n", - "\n", - " [[ 0.203125 -0.0234375 0.03125 0.2109375 0.0234375 -0.015625\n", - " 0.1875 -0.1484375 0.046875 -0.09375 0.0234375 -0.1640625\n", - " 0.1484375 0.0703125 0.2578125 -0.015625 -0.03125 -0.0234375\n", - " -0.0703125 0.0078125]]\n", - "\n", - " [[ 0.2265625 -0.03125 0.015625 0.2421875 0.03125 -0.0234375\n", - " 0.234375 -0.1796875 0.0546875 -0.109375 0.0234375 -0.1875\n", - " 0.1796875 0.09375 0.2734375 0. -0.03125 -0.03125\n", - " -0.0703125 0.015625 ]]\n", - "\n", - " [[ 0.234375 -0.0390625 0.015625 0.2578125 0.0390625 -0.03125\n", - " 0.25 -0.1875 0.0546875 -0.125 0.015625 -0.1953125\n", - " 0.1953125 0.1171875 0.2734375 0.015625 -0.03125 -0.0390625\n", - " -0.078125 0.0078125]]\n", - "\n", - " [[ 0.2421875 -0.046875 0.015625 0.2734375 0.0390625 -0.03125\n", - " 0.25 -0.1953125 0.0546875 -0.1328125 0.015625 -0.1953125\n", - " 0.203125 0.1328125 0.2734375 0.0234375 -0.03125 -0.046875\n", - " -0.078125 0.0078125]]\n", - "\n", - " [[ 0.2421875 -0.046875 0.015625 0.2734375 0.046875 -0.0390625\n", - " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", - " 0.203125 0.1328125 0.2734375 0.03125 -0.03125 -0.046875\n", - " -0.0703125 0.0078125]]\n", - "\n", - " [[ 0.2421875 -0.046875 0.015625 0.2734375 0.046875 -0.0390625\n", - " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", - " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", - " -0.0703125 0.0078125]]\n", - "\n", - " [[ 0.2421875 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", - " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", - " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", - " -0.0703125 0.0078125]]\n", - "\n", - " [[ 0.2421875 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", - " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", - " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", - " -0.0703125 0.0078125]]\n", - "\n", - " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", - " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", - " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", - " -0.0703125 0.0078125]]\n", - "\n", - " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", - " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", - " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", - " -0.0703125 0.0078125]]\n", - "\n", - " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", - " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", - " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", - " -0.0703125 0.0078125]]\n", - "\n", - " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", - " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", - " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", - " -0.0703125 0.0078125]]\n", - "\n", - " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", - " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", - " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", - " -0.0703125 0.0078125]]\n", - "\n", - " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", - " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", - " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", - " -0.0703125 0.0078125]]\n", - "\n", - " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", - " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", - " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", - " -0.0703125 0.0078125]]\n", - "\n", - " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", - " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", - " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", - " -0.0703125 0.0078125]]\n", - "\n", - " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", - " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", - " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", - " -0.0703125 0.0078125]]\n", - "\n", - " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", - " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", - " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", - " -0.0703125 0.0078125]]\n", - "\n", - " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", - " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", - " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", - " -0.0703125 0.0078125]]\n", - "\n", - " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", - " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", - " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", - " -0.0703125 0.0078125]]\n", - "\n", - " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", - " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", - " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", - " -0.0703125 0.0078125]]\n", - "\n", - " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", - " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", - " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", - " -0.0703125 0.0078125]]\n", - "\n", - " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", - " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", - " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", - " -0.0703125 0.0078125]]\n", - "\n", - " [[ 0.25 -0.046875 0.015625 0.2734375 0.0546875 -0.0390625\n", - " 0.25 -0.1953125 0.0546875 -0.140625 0.015625 -0.203125\n", - " 0.203125 0.140625 0.2734375 0.03125 -0.03125 -0.046875\n", - " -0.0703125 0.0078125]]]\n", - "-----------------------------------\n", - "[[[ 0. 0. 0. 0. 1. 0. -1. -1. 0. 0. 0. -1. 0. 0.\n", - " 0. 0. 0. 0. 0. 0.]]\n", - "\n", - " [[ -1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.\n", - " 0. 0. 0. 0. 0. -1.]]\n", - "\n", - " [[ -2. 1. 0. -1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.\n", - " -5. 0. 0. 0. 1. 1.]]\n", - "\n", - " [[ -2. 0. 0. -1. 0. 0. -2. 0. 0. 0. 0. 0. 0. 0.\n", - " -8. 0. 0. 0. 0. -1.]]\n", - "\n", - " [[ -2. 0. 0. 0. -1. 1. -5. 0. 0. -1. 0. 1. 0. 1.\n", - " -11. 0. 0. 0. -1. -1.]]\n", - "\n", - " [[ -2. 0. 1. 0. 0. 0. -6. 0. 0. -1. 0. 0. -1. 0.\n", - " -12. 0. -1. 0. 0. -1.]]\n", - "\n", - " [[ -2. 0. 0. -1. -1. 1. -7. 0. 0. 0. 1. 0. -1. 0.\n", - " -13. -1. -1. 0. 0. -1.]]\n", - "\n", - " [[ -2. 1. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -1. 0.\n", - " -14. -1. -2. 1. 0. -1.]]\n", - "\n", - " [[ -2. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -1. -1.\n", - " -15. -1. -2. 1. 0. -1.]]\n", - "\n", - " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", - " -15. -1. -2. 1. 0. -1.]]\n", - "\n", - " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", - " -15. -2. -2. 1. 0. -1.]]\n", - "\n", - " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", - " -15. -2. -2. 1. 0. -1.]]\n", - "\n", - " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", - " -15. -2. -2. 1. 0. -1.]]\n", - "\n", - " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", - " -15. -2. -2. 1. 0. -1.]]\n", - "\n", - " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", - " -15. -2. -2. 1. 0. -1.]]\n", - "\n", - " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", - " -15. -2. -2. 1. 0. -1.]]\n", - "\n", - " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", - " -15. -2. -2. 1. 0. -1.]]\n", - "\n", - " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", - " -15. -2. -2. 1. 0. -1.]]\n", - "\n", - " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", - " -15. -2. -2. 1. 0. -1.]]\n", - "\n", - " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", - " -15. -2. -2. 1. 0. -1.]]\n", - "\n", - " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", - " -15. -2. -2. 1. 0. -1.]]\n", - "\n", - " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", - " -15. -2. -2. 1. 0. -1.]]\n", - "\n", - " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", - " -15. -2. -2. 1. 0. -1.]]\n", - "\n", - " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", - " -15. -2. -2. 1. 0. -1.]]\n", - "\n", - " [[ -1. 0. 1. -1. 0. 1. -8. 1. 0. 0. 1. 0. -2. -1.\n", - " -15. -2. -2. 1. 0. -1.]]]\n" - ] - } - ], + "outputs": [], "source": [ "# We first match the shape of both the outputs to perform the functional verification correctly\n", "\n", From 7719a3e3ec2a78148dd021dd6b47bfd7eff182ec Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Tue, 14 Nov 2023 14:21:13 +0100 Subject: [PATCH 03/28] Add cleanup transformation sorting inputs of commutative operations --- src/qonnx/core/modelwrapper.py | 8 +++- src/qonnx/transformation/general.py | 57 +++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index f78e1334..ce621743 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -38,7 +38,12 @@ import qonnx.util.onnx as onnxutil from qonnx.core.datatype import DataType from qonnx.transformation.double_to_single_float import DoubleToSingleFloat -from qonnx.transformation.general import RemoveStaticGraphInputs, RemoveUnusedTensors, SortGraph +from qonnx.transformation.general import ( + RemoveStaticGraphInputs, + RemoveUnusedTensors, + SortGraph, + SortCommutativeInputsInitializerLast +) class ModelWrapper: @@ -149,6 +154,7 @@ def cleanup(self): RemoveUnusedTensors(), RemoveStaticGraphInputs(), SortGraph(), + SortCommutativeInputsInitializerLast(), ] for trn in cleanup_transforms: transformed_model = transformed_model.transform(trn, cleanup=False, make_deepcopy=False) diff --git a/src/qonnx/transformation/general.py b/src/qonnx/transformation/general.py index 5153e616..686bf17b 100644 --- a/src/qonnx/transformation/general.py +++ b/src/qonnx/transformation/general.py @@ -35,6 +35,9 @@ import qonnx.util.basic as util from qonnx.transformation.base import Transformation +# Protobuf onnx graph node type +from onnx import NodeProto # noqa + class MovePadAttributeToTensor(Transformation): "Move padding info from attribute into input tensor for Pad nodes." @@ -359,3 +362,57 @@ def apply(self, model): # one iteration is enough return (model, False) + + +# Groups inputs by categories, i.e., groups dynamic inputs first, followed by +# initializers. Keeps order of inputs in each category. +def group_inputs_by_category(node: NodeProto, model): # noqa + # Select all dynamic inputs, which are those without initializer tensor + dynamics = [i for i in node.input if model.get_initializer(i) is None] + # Select all input which are initializers, which, by exclusion, are all + # those not among the dynamic inputs + initializers = [i for i in node.input if i not in dynamics] + # Return lists of dynamic anc initializer inputs + return dynamics, initializers + + +# Tidy-Up transformation sorting the inputs to all commutative operations to +# have initializer inputs last +class SortCommutativeInputsInitializerLast(Transformation): + """ + Sorts inputs of nodes describing commutative operations to have initializer + inputs last. This order of inputs is assumed by many other transformations. + """ + + # Set of supported commutative operations + # TODO: There might be more valid operations + SUPPORTED_COMMUTATIVE_OPS = {"Add", "Mul", "And", "Or", "Xor", "Sum"} + + # Applies the transform to a whole model graph + def apply(self, model): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Check whether this node is among the supported + if node.op_type in self.SUPPORTED_COMMUTATIVE_OPS: + # Group node inputs by category + dynamics, initializers = group_inputs_by_category(node, model) + # Flatten the grouped input list + inputs = [*dynamics, *initializers] + # Length of sorted and original input list must match + assert len(inputs) == len(node.input) + # Reassigned inputs from sorted categories + # Note: ONNX does not allow direct assignment to node.input + for i, name in enumerate(inputs): + # The graph has been modified if any input is reordered + if node.input[i] != name: + # Note: This is never reset back to False + graph_modified = True + # Reassign input name at the new index + node.input[i] = name + # Return the transformed model and indicate whether the graph actually + # has been transformed + return model, graph_modified From c0f5b4626118c275a8588a5a64393aa319044f5d Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Tue, 14 Nov 2023 14:26:29 +0100 Subject: [PATCH 04/28] Address some linting issues --- src/qonnx/core/modelwrapper.py | 2 +- src/qonnx/transformation/general.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index ce621743..f7cf1d19 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -41,8 +41,8 @@ from qonnx.transformation.general import ( RemoveStaticGraphInputs, RemoveUnusedTensors, + SortCommutativeInputsInitializerLast, SortGraph, - SortCommutativeInputsInitializerLast ) diff --git a/src/qonnx/transformation/general.py b/src/qonnx/transformation/general.py index 686bf17b..b5ed0fca 100644 --- a/src/qonnx/transformation/general.py +++ b/src/qonnx/transformation/general.py @@ -29,15 +29,15 @@ import json import numpy as np import warnings + +# Protobuf onnx graph node type +from onnx import NodeProto # noqa from onnx import mapping from toposort import toposort_flatten import qonnx.util.basic as util from qonnx.transformation.base import Transformation -# Protobuf onnx graph node type -from onnx import NodeProto # noqa - class MovePadAttributeToTensor(Transformation): "Move padding info from attribute into input tensor for Pad nodes." From 8902694106de98c827e38e04ffbf3f0d8dfc9675 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Wed, 15 Nov 2023 09:41:18 +0100 Subject: [PATCH 05/28] Fix RemoveIdentityOps not correctly handling ops following fork-nodes --- src/qonnx/transformation/remove.py | 45 ++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/src/qonnx/transformation/remove.py b/src/qonnx/transformation/remove.py index e745f0f0..2fc888cb 100644 --- a/src/qonnx/transformation/remove.py +++ b/src/qonnx/transformation/remove.py @@ -25,9 +25,8 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - import numpy as np +import warnings from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.base import Transformation @@ -58,21 +57,45 @@ def apply(self, model: ModelWrapper): def remove_node_and_rewire(model, node): + # Currently cannot remove and rewire join-nodes, probably not necessary to + # support this + if model.is_join_node(node): + # Log this as a warning, so the user is aware of this, there might be + # somthing wrong or some checks missing at the caller site + warnings.warn( + "Tried to remove join-node operation: Currently not supported" + ) + # Exit the function here without doing anything + return + # We already know that node is not a join-node, thus to rewire, we only need + # to check the single producer producer = model.find_producer(node.input[0]) - if producer is not None: - # wire output tensor to - # output of producer node + # If there is a producer which is not a fork-node, rewiring is simple + if producer is not None and not model.is_fork_node(producer): + # Rewire by skipping the node, letting the producer directly feed the + # nodes output. + # TODO: Check whether this already covers fork-node identities? producer.output[0] = node.output[0] + # If there is no producer or the producer forks, rewiring is a bit more + # complicated else: - # node is first in graph + # Now it depends on the successor nodes to rewire their inputs successors = model.find_direct_successors(node) + # Singular node detached from the rest of the graph? assert successors is not None, "Whole graph is one node." - for succ in successors: - for i, s_inp in enumerate(succ.input): + # We need to rewire the input of each successor to not detach parts of + # the graph + for successor in successors: + # Find the inputs of the successor which are produced by the node to + # be removed + for i, s_inp in enumerate(successor.input): + # Note: This might happen multiple times? if s_inp == node.output[0]: - # rewire successor's input directly to graph input - succ.input[i] = node.input[0] - # remove node + # Rewire successor's input directly to nodes input + # Note: Node may not be a join-node, but there is probably + # no such thing as join-node identity anyway + successor.input[i] = node.input[0] + # Remove node model.graph.node.remove(node) From c7b359062dee8b979bc22741885ac812da8fe7ce Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Wed, 15 Nov 2023 09:50:20 +0100 Subject: [PATCH 06/28] Change error message to address some linting issue --- src/qonnx/transformation/remove.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/qonnx/transformation/remove.py b/src/qonnx/transformation/remove.py index 2fc888cb..980e80c1 100644 --- a/src/qonnx/transformation/remove.py +++ b/src/qonnx/transformation/remove.py @@ -62,9 +62,7 @@ def remove_node_and_rewire(model, node): if model.is_join_node(node): # Log this as a warning, so the user is aware of this, there might be # somthing wrong or some checks missing at the caller site - warnings.warn( - "Tried to remove join-node operation: Currently not supported" - ) + warnings.warn("Removing join-node operation is currently not supported") # Exit the function here without doing anything return # We already know that node is not a join-node, thus to rewire, we only need From 59a7ca002c992d5ff8cd5ab1086825574e58dd22 Mon Sep 17 00:00:00 2001 From: shashwat1198 Date: Fri, 1 Mar 2024 14:45:25 +0000 Subject: [PATCH 07/28] package installations added --- notebooks/4_quant_lstm.ipynb | 81 +++--- notebooks/4_quant_lstm_helper/function.py | 340 ++++++++++++++++++++++ notebooks/4_quant_lstm_helper/handler.py | 140 +++++++++ 3 files changed, 518 insertions(+), 43 deletions(-) create mode 100644 notebooks/4_quant_lstm_helper/function.py create mode 100644 notebooks/4_quant_lstm_helper/handler.py diff --git a/notebooks/4_quant_lstm.ipynb b/notebooks/4_quant_lstm.ipynb index 186be984..bc2b5e2e 100644 --- a/notebooks/4_quant_lstm.ipynb +++ b/notebooks/4_quant_lstm.ipynb @@ -2,7 +2,6 @@ "cells": [ { "cell_type": "markdown", - "id": "5ef5f772-f48a-4bb1-bb68-4e8e9236fd2e", "metadata": {}, "source": [ "# QuantLSTM - ONNX (QCDQ) representation" @@ -10,11 +9,12 @@ }, { "cell_type": "markdown", - "id": "e5a747f9-fd74-4ebc-8d74-17bf06ff2d48", "metadata": {}, "source": [ - "This notebook is divided into `five` parts:\n", + "This notebook is divided into `six` parts:\n", "\n", + "
Part 0 : Package Installations.\n", + "
\n", "
Part 1 : Introduction to LSTMs.\n", "
\n", "
Part 2 : Model creation with brevitas QuantLSTM layer. \n", @@ -28,16 +28,45 @@ }, { "cell_type": "markdown", - "id": "69ae7154-8cf3-4ee7-88c3-3bec0550008a", + "metadata": {}, + "source": [ + "# Package Installations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#Required package installations, This cell only needs to be executed once at the start\n", + "!pip install torch==1.13.1\n", + "!pip install brevitas==0.9.1\n", + "!pip install onnx==1.13.0\n", + "!pip install onnxoptimizer==0.3.13\n", + "!pip install onnxruntime==1.11.1\n", + "!pip install netron==7.2.5\n", + "!pip install qonnx==0.2.0\n", + "!pip install IPython\n", + "!pip install ipykernel\n", + "!ipython kernel install --user --name=venv\n", + "\n", + "#The below location can change depending on your installation of the 'venv' virtual environment\n", + "!cp ./4_quant_lstm_helper/function.py ../venv/lib/python3.8/site-packages/brevitas/export/onnx/standard/\n", + "!cp ./4_quant_lstm_helper/handler.py ../venv/lib/python3.8/site-packages/brevitas/export/onnx/standard/qcdq/\n", + "\n", + "#NOTE : Make sure to chnage the kernel to from \"Python 3\" to \"venv\" before running the below commands" + ] + }, + { + "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction to LSTM's " ] }, { - "attachments": {}, "cell_type": "markdown", - "id": "e7a903ef-1680-4a20-8c61-267884b76c96", "metadata": {}, "source": [ "`LSTM’s (Long Short-Term Memory)` are sequential neural networks that are capable of learning long term dependencies especially in sequence prediction problems. They are deployed in machine translation, speech recognition, image captioning and especially used for time-series analysis applications.\n", @@ -73,7 +102,6 @@ }, { "cell_type": "markdown", - "id": "70d052c8-e5cd-4eb1-89e5-f8ae956cb853", "metadata": {}, "source": [ "# QuantLSTM model creation" @@ -81,7 +109,6 @@ }, { "cell_type": "markdown", - "id": "6a64be7c", "metadata": {}, "source": [ "In the 2nd part of the notebook, we will create a single layer `QuantLSTM` model in brevitas. We will evaluate with a given set of inputs. We then export this model to `QONNX` so that the same parameters (weights/biases/scales) can be extracted and used in the `QCDQ-LSTM` implementation." @@ -90,7 +117,6 @@ { "cell_type": "code", "execution_count": null, - "id": "84d66548-365d-46a5-9eaa-bb767085f9aa", "metadata": {}, "outputs": [], "source": [ @@ -119,7 +145,6 @@ { "cell_type": "code", "execution_count": null, - "id": "23a7682c", "metadata": {}, "outputs": [], "source": [ @@ -153,7 +178,6 @@ }, { "cell_type": "markdown", - "id": "347ef1f5-36e8-4103-9b13-efa7fe93eb5e", "metadata": {}, "source": [ "`Abbreviations` : Short-forms defined in the next code block can be referenced here for definitions.\n", @@ -166,7 +190,6 @@ { "cell_type": "code", "execution_count": null, - "id": "0bfbf5a3-8556-4190-a28f-4fe9859c55a9", "metadata": {}, "outputs": [], "source": [ @@ -210,7 +233,6 @@ }, { "cell_type": "markdown", - "id": "10237589-f84e-423a-829e-3e2c2e806ed7", "metadata": {}, "source": [ "# LSTM ONNX model" @@ -218,7 +240,6 @@ }, { "cell_type": "markdown", - "id": "367547b8", "metadata": {}, "source": [ "In the 3rd part of the notebook, we will construct the `QCDQ-LSTM` model with standard ONNX operators. After loading all the parameters in the above block we can now start building our ONNX model with QCDQ quantization to represent the LSTM computations described in part-1.\n" @@ -227,7 +248,6 @@ { "cell_type": "code", "execution_count": null, - "id": "02fe4d94-af24-4d5e-a809-7d8c49e7fd90", "metadata": {}, "outputs": [], "source": [ @@ -249,7 +269,6 @@ }, { "cell_type": "markdown", - "id": "15098a9e-4187-4987-82cc-275eba650923", "metadata": {}, "source": [ "`Abbreviations` : These describe different short-forms used in the next two blocks.\n", @@ -265,7 +284,6 @@ }, { "cell_type": "markdown", - "id": "f2edc0cc", "metadata": {}, "source": [ "We start defining the model by defining the `inputs` and `outputs` defined as value_info tensors in ONNX.\n", @@ -276,7 +294,6 @@ { "cell_type": "code", "execution_count": null, - "id": "02761646-4c6d-440f-8e90-4935beebab56", "metadata": {}, "outputs": [], "source": [ @@ -294,7 +311,6 @@ { "cell_type": "code", "execution_count": null, - "id": "c08e5a23-ef2e-4bca-9293-c800350c2c62", "metadata": {}, "outputs": [], "source": [ @@ -412,7 +428,6 @@ }, { "cell_type": "markdown", - "id": "3d10867f", "metadata": {}, "source": [ "After defining the above operations we now connect them and create a graph with the help of onnx.helper `make_graph` utility function" @@ -421,7 +436,6 @@ { "cell_type": "code", "execution_count": null, - "id": "79839558-8752-4fc8-9b0e-8fed47c91701", "metadata": {}, "outputs": [], "source": [ @@ -632,7 +646,6 @@ }, { "cell_type": "markdown", - "id": "b1b16751", "metadata": {}, "source": [ "The above created graph can now be converted into a qonnx model with the `qonnx_make_model` utility. We save the model with `onnx.save` utility and then view it in Netron with the help of `showInNetron` utility. \n" @@ -641,7 +654,6 @@ { "cell_type": "code", "execution_count": null, - "id": "c6ec7b2a-456d-4452-97ec-df9a471d5391", "metadata": {}, "outputs": [], "source": [ @@ -652,7 +664,6 @@ }, { "cell_type": "markdown", - "id": "40b49257", "metadata": {}, "source": [ "In this block of code we execute the onnx graph to check that it can execute without any errors. We perform it's functional verification in the later part of the notebook." @@ -661,7 +672,6 @@ { "cell_type": "code", "execution_count": null, - "id": "db5892bc-ac8d-4972-afcf-20bf880f5e86", "metadata": {}, "outputs": [], "source": [ @@ -691,7 +701,6 @@ }, { "cell_type": "markdown", - "id": "5d2b5a1e-654e-46a5-9d4f-8708611a6d1e", "metadata": {}, "source": [ "# SCAN Operation Integration" @@ -699,7 +708,6 @@ }, { "cell_type": "markdown", - "id": "7365329a-f3d2-4f74-8e2f-9076771e07a7", "metadata": {}, "source": [ "### Introduction to ONNX Scan operation\n", @@ -721,7 +729,6 @@ }, { "cell_type": "markdown", - "id": "17f247f7", "metadata": {}, "source": [ "The `Scan` operation is essentially a container operator which will consume the LSTM graph that we created above in it's body.\n", @@ -733,7 +740,6 @@ { "cell_type": "code", "execution_count": null, - "id": "700a93a8-f757-4fa1-88dd-47a3f2a7f171", "metadata": {}, "outputs": [], "source": [ @@ -750,7 +756,6 @@ }, { "cell_type": "markdown", - "id": "572f191e", "metadata": {}, "source": [ "We will now create the scan operator here now utilizing the `make_node` utility from ONNX.\n", @@ -760,7 +765,6 @@ { "cell_type": "code", "execution_count": null, - "id": "111fdce4-464f-40c1-ac4d-3022b05f153e", "metadata": {}, "outputs": [], "source": [ @@ -775,7 +779,6 @@ }, { "cell_type": "markdown", - "id": "ea8a05d9", "metadata": {}, "source": [ "We can now define the graph for the scan operator utilizing the `make_graph` utility." @@ -784,7 +787,6 @@ { "cell_type": "code", "execution_count": null, - "id": "4668cf2b-524e-4768-8dc8-9d619f6273da", "metadata": {}, "outputs": [], "source": [ @@ -810,7 +812,6 @@ }, { "cell_type": "markdown", - "id": "0673e335", "metadata": {}, "source": [ "Now that we have the SCAN based quantized LSTM model ready, we can now go forward and test it with the same sets of inputs we used for the testing of the brevitas model.\n" @@ -819,7 +820,6 @@ { "cell_type": "code", "execution_count": null, - "id": "818d2a81-686f-4a4a-8e78-17dbf75d8451", "metadata": {}, "outputs": [], "source": [ @@ -854,7 +854,6 @@ }, { "cell_type": "markdown", - "id": "907d2ff9-f605-4aec-891e-0c77a1a92346", "metadata": {}, "source": [ "# Functional Verification" @@ -862,7 +861,6 @@ }, { "cell_type": "markdown", - "id": "b6bb6c60", "metadata": {}, "source": [ "In the final part of the notebook, we compare the output of the 8-bit quantized `(QCDQ)-LSTM` implementation with the `QuantLSTM` brevitas model.\n" @@ -871,7 +869,6 @@ { "cell_type": "code", "execution_count": null, - "id": "2fe07395-6cf9-4c99-a0d3-a27aa6a326b5", "metadata": {}, "outputs": [], "source": [ @@ -900,7 +897,6 @@ }, { "cell_type": "markdown", - "id": "7bcca933", "metadata": {}, "source": [ "Note the difference in outputs increases as we progress with processing the inputs. The first two outputs are very close to one another, but as we get the outputs for more inputs we see for some values differ from the brevitas output by a considerable amount.\n", @@ -909,16 +905,15 @@ }, { "cell_type": "markdown", - "id": "81c6d531", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "venv", "language": "python", - "name": "python3" + "name": "venv" }, "language_info": { "codemirror_mode": { @@ -930,7 +925,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.8.0" } }, "nbformat": 4, diff --git a/notebooks/4_quant_lstm_helper/function.py b/notebooks/4_quant_lstm_helper/function.py new file mode 100644 index 00000000..6ba2e9dd --- /dev/null +++ b/notebooks/4_quant_lstm_helper/function.py @@ -0,0 +1,340 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import torch +from torch.autograd import Function + +from brevitas.export.onnx import onnx_export_opset + +AXIS_OPSET = 13 +DOMAIN_STRING = "onnx.brevitas" + + +class DequantizeLinearFn(Function): + + @staticmethod + def symbolic(g, x, input_scale, input_zero_point, input_axis): + opset_version = onnx_export_opset() + + if input_axis is not None and opset_version < AXIS_OPSET: + raise RuntimeError('ONNX Opset 13 is required for per-channel quantization') + elif input_axis is not None and opset_version >= AXIS_OPSET: + ret = g.op('DequantizeLinear', x, input_scale, input_zero_point, axis_i=input_axis) + else: + ret = g.op('DequantizeLinear', x, input_scale, input_zero_point) + return ret + + @staticmethod + def forward(ctx, int_x, input_scale, input_zero_point, input_axis): + return int_x.float() + + +class IntClipFn(Function): + + @staticmethod + def symbolic(g, int_x, min_int_val, max_int_val): + ret = g.op('Clip', int_x, min_int_val, max_int_val) + return ret + + @staticmethod + def forward(ctx, int_x, min_int_val, max_int_val): + return int_x + + +class QuantizeLinearFn(Function): + + @staticmethod + def symbolic(g, x, output_scale, ouput_zero_point, output_dtype, output_axis): + opset_version = onnx_export_opset() + + if output_axis is not None and opset_version < AXIS_OPSET: + raise RuntimeError('ONNX Opset 13 is required for per-channel quantization') + elif output_axis is not None and opset_version >= AXIS_OPSET: + ret = g.op('QuantizeLinear', x, output_scale, ouput_zero_point, axis_i=output_axis) + else: + ret = g.op('QuantizeLinear', x, output_scale, ouput_zero_point) + return ret + + @staticmethod + def forward(ctx, x, output_scale, ouput_zero_point, output_dtype, output_axis): + return x.type(output_dtype) + +class BrevitasQuantLSTMCellFn(Function): + + + @staticmethod + def symbolic( + g, # args and kwargs passed from _QuantLSTMLayer + quant_input, + quant_hidden_state, + quant_cell_state, + quant_weight_ii, + quant_weight_if, + quant_weight_ic, + quant_weight_io, + quant_weight_hi, + quant_weight_hf, + quant_weight_hc, + quant_weight_ho, + quant_bias_input, + quant_bias_forget, + quant_bias_cell, + quant_bias_output, # Symbolic kwargs passed from BrevitasQuantLSTMLayerHandler + batch_first, + reverse_input, + cifg, # Output quant + output_scale, + output_zero_point, + output_bit_width, + output_narrow_range, + output_signed, + output_rounding_mode, # Cell state quant + cell_state_scale, + cell_state_zero_point, + cell_state_bit_width, + cell_state_narrow_range, + cell_state_signed, + cell_state_rounding_mode, # Input gate accumulator quant + input_acc_scale, + input_acc_zero_point, + input_acc_bit_width, + input_acc_narrow_range, + input_acc_signed, + input_acc_rounding_mode, # Forget gate accumulator quant + forget_acc_scale, + forget_acc_zero_point, + forget_acc_bit_width, + forget_acc_narrow_range, + forget_acc_signed, + forget_acc_rounding_mode, # Cell gate accumulator quant + cell_acc_scale, + cell_acc_zero_point, + cell_acc_bit_width, + cell_acc_narrow_range, + cell_acc_signed, + cell_acc_rounding_mode, # Output gate accumulator quant + output_acc_scale, + output_acc_zero_point, + output_acc_bit_width, + output_acc_narrow_range, + output_acc_signed, + output_acc_rounding_mode, # Input gate sigmoid quant + input_sigmoid_scale, + input_sigmoid_zero_point, + input_sigmoid_bit_width, + input_sigmoid_narrow_range, + input_sigmoid_signed, + input_sigmoid_rounding_mode, # Forget gate sigmoid quant + forget_sigmoid_scale, + forget_sigmoid_zero_point, + forget_sigmoid_bit_width, + forget_sigmoid_narrow_range, + forget_sigmoid_signed, + forget_sigmoid_rounding_mode, # Cell gate tanh quant + cell_tanh_scale, + cell_tanh_zero_point, + cell_tanh_bit_width, + cell_tanh_narrow_range, + cell_tanh_signed, + cell_tanh_rounding_mode, # Output gate sigmoid quant + output_sigmoid_scale, + output_sigmoid_zero_point, + output_sigmoid_bit_width, + output_sigmoid_narrow_range, + output_sigmoid_signed, + output_sigmoid_rounding_mode, # Hidden state tanh quant + hidden_state_tanh_scale, + hidden_state_tanh_zero_point, + hidden_state_tanh_bit_width, + hidden_state_tanh_narrow_range, + hidden_state_tanh_signed, + hidden_state_tanh_rounding_mode): + return g.op( + f'{DOMAIN_STRING}::QuantLSTMCell', # Tensors + ## Input values + quant_input, + quant_hidden_state, + quant_cell_state, + quant_weight_ii, + quant_weight_if, + quant_weight_ic, + quant_weight_io, + quant_weight_hi, + quant_weight_hf, + quant_weight_hc, + quant_weight_ho, + quant_bias_input, + quant_bias_forget, + quant_bias_cell, + quant_bias_output, ## Output quant + output_scale, + output_zero_point, + output_bit_width, ## Cell state quant + cell_state_scale, + cell_state_zero_point, + cell_state_bit_width, ## Input gate accumulator quant + input_acc_scale, + input_acc_zero_point, + input_acc_bit_width, ## Forget gate accumulator quant + forget_acc_scale, + forget_acc_zero_point, + forget_acc_bit_width, ## Cell gate accumulator quant + cell_acc_scale, + cell_acc_zero_point, + cell_acc_bit_width, ## Output gate accumulator quant + output_acc_scale, + output_acc_zero_point, + output_acc_bit_width, ## Input gate sigmoid quant + input_sigmoid_scale, + input_sigmoid_zero_point, + input_sigmoid_bit_width, ## Forget gate sigmoid quant + forget_sigmoid_scale, + forget_sigmoid_zero_point, + forget_sigmoid_bit_width, ## Cell gate tanh quant + cell_tanh_scale, + cell_tanh_zero_point, + cell_tanh_bit_width, ## Output gate sigmoid quant + output_sigmoid_scale, + output_sigmoid_zero_point, + output_sigmoid_bit_width, ## Hidden state tanh quant + hidden_state_tanh_scale, + hidden_state_tanh_zero_point, + hidden_state_tanh_bit_width, + # Attributes + batch_first_i=batch_first, + reverse_input_i=reverse_input, + cifg_i=cifg, + output_narrow_i=output_narrow_range, + output_signed_i=output_signed, + output_rounding_mode_s=output_rounding_mode, + cell_state_narrow_i=cell_state_narrow_range, + cell_state_signed_i=cell_state_signed, + cell_state_rounding_mode_s=cell_state_rounding_mode, + input_acc_narrow_i=input_acc_narrow_range, + input_acc_signed_i=input_acc_signed, + input_acc_rounding_mode_s=input_acc_rounding_mode, + forget_acc_narrow_i=forget_acc_narrow_range, + forget_acc_signed_i=forget_acc_signed, + forget_acc_rounding_mode_s=forget_acc_rounding_mode, + cell_acc_narrow_i=cell_acc_narrow_range, + cell_acc_signed_i=cell_acc_signed, + cell_acc_rounding_mode_s=cell_acc_rounding_mode, + output_acc_narrow_i=output_acc_narrow_range, + output_acc_signed_i=output_acc_signed, + output_acc_rounding_mode_s=output_acc_rounding_mode, + input_sigmoid_narrow_i=input_sigmoid_narrow_range, + input_sigmoid_signed_i=input_sigmoid_signed, + input_sigmoid_rounding_mode_s=input_sigmoid_rounding_mode, + forget_sigmoid_narrow_i=forget_sigmoid_narrow_range, + forget_sigmoid_signed_i=forget_sigmoid_signed, + forget_sigmoid_rounding_mode_s=forget_sigmoid_rounding_mode, + cell_tanh_narrow_i=cell_tanh_narrow_range, + cell_tanh_signed_i=cell_tanh_signed, + cell_tanh_rounding_mode_s=cell_tanh_rounding_mode, + output_sigmoid_narrow_range_i=output_sigmoid_narrow_range, + output_sigmoid_signed_i=output_sigmoid_signed, + output_sigmoid_rounding_mode_s=output_sigmoid_rounding_mode, + hidden_state_tanh_narrow_i=hidden_state_tanh_narrow_range, + hidden_state_tanh_signed_i=hidden_state_tanh_signed, + hidden_state_tanh_rounding_mode_s=hidden_state_tanh_rounding_mode, + # PyTorch requires to specify the number of outputs manually + outputs=3) + + + @staticmethod + def forward( + ctx, # args and kwargs passed from _QuantLSTMLayer + quant_input, + quant_hidden_state, + quant_cell_state, + quant_weight_ii, + quant_weight_if, + quant_weight_ic, + quant_weight_io, + quant_weight_hi, + quant_weight_hf, + quant_weight_hc, + quant_weight_ho, + quant_bias_input, + quant_bias_forget, + quant_bias_cell, + quant_bias_output, # Symbolic kwargs passed from BrevitasQuantLSTMLayerHandler + batch_first, + reverse_input, + cifg, # Output quant + output_scale, + output_zero_point, + output_bit_width, + output_narrow_range, + output_signed, + output_rounding_mode, # Cell state quant + cell_state_scale, + cell_state_zero_point, + cell_state_bit_width, + cell_state_narrow_range, + cell_state_signed, + cell_state_rounding_mode, # Input gate accumulator quant + input_acc_scale, + input_acc_zero_point, + input_acc_bit_width, + input_acc_narrow_range, + input_acc_signed, + input_acc_rounding_mode, # Forget gate accumulator quant + forget_acc_scale, + forget_acc_zero_point, + forget_acc_bit_width, + forget_acc_narrow_range, + forget_acc_signed, + forget_acc_rounding_mode, # Cell gate accumulator quant + cell_acc_scale, + cell_acc_zero_point, + cell_acc_bit_width, + cell_acc_narrow_range, + cell_acc_signed, + cell_acc_rounding_mode, # Output gate accumulator quant + output_acc_scale, + output_acc_zero_point, + output_acc_bit_width, + output_acc_narrow_range, + output_acc_signed, + output_acc_rounding_mode, # Input gate sigmoid quant + input_sigmoid_scale, + input_sigmoid_zero_point, + input_sigmoid_bit_width, + input_sigmoid_narrow_range, + input_sigmoid_signed, + input_sigmoid_rounding_mode, # Forget gate sigmoid quant + forget_sigmoid_scale, + forget_sigmoid_zero_point, + forget_sigmoid_bit_width, + forget_sigmoid_narrow_range, + forget_sigmoid_signed, + forget_sigmoid_rounding_mode, # Cell gate tanh quant + cell_tanh_scale, + cell_tanh_zero_point, + cell_tanh_bit_width, + cell_tanh_narrow_range, + cell_tanh_signed, + cell_tanh_rounding_mode, # Output gate sigmoid quant + output_sigmoid_scale, + output_sigmoid_zero_point, + output_sigmoid_bit_width, + output_sigmoid_narrow_range, + output_sigmoid_signed, + output_sigmoid_rounding_mode, # Hidden state tanh quant + hidden_state_tanh_scale, + hidden_state_tanh_zero_point, + hidden_state_tanh_bit_width, + hidden_state_tanh_narrow_range, + hidden_state_tanh_signed, + hidden_state_tanh_rounding_mode): + # Tp simplify things, here we are returning the outputs + # as if they were already concatenated. Scale/zp/bw are avoided too. + # This preserves output shapes but not values. + # See _QuantLSTMCell for the actual implementation. + quant_outputs = torch.zeros( + quant_input.size(0), + quant_input.size(1), + quant_hidden_state.size(1), + device=quant_hidden_state.device) + return quant_outputs, quant_hidden_state, quant_cell_state diff --git a/notebooks/4_quant_lstm_helper/handler.py b/notebooks/4_quant_lstm_helper/handler.py new file mode 100644 index 00000000..948eb647 --- /dev/null +++ b/notebooks/4_quant_lstm_helper/handler.py @@ -0,0 +1,140 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from abc import ABC +from copy import copy + +import torch +from torch import Tensor + +from brevitas.export.common.handler.base import QuantAxisMixin +from brevitas.export.common.handler.qcdq import DQMixin +from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQMixin +from brevitas.export.common.handler.qcdq import QCDQTruncQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQWeightQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import ZeroPointHandlerMixin +from brevitas.export.onnx.handler import ONNXBaseHandler +from brevitas.export.onnx.handler import QuantLSTMLayerHandler + +from ..function import DequantizeLinearFn +from ..function import IntClipFn +from ..function import QuantizeLinearFn +from ..function import BrevitasQuantLSTMCellFn + + +class StdDQONNXMixin(DQMixin, ABC): + + def dequantize_fn(self, x, scale, zero_point, axis): + return DequantizeLinearFn.apply(x, scale, zero_point, axis) + + @property + def flatten_dequantize_params(self): + return True + + @property + def itemize_quantize_scalar_params(self): + return False + + +class StdQCDQONNXMixin(QCDQMixin, StdDQONNXMixin, ABC): + + @property + def clip_over_integers(self): + return True + + @classmethod + def int8_dtype(cls): + return torch.int8 + + @classmethod + def uint8_dtype(cls): + return torch.uint8 + + @classmethod + def int32_dtype(cls): + return torch.int32 + + def validate(self, module): + self.validate_8b_bit_width(module.bit_width(), le_then=True) + assert module.bit_width() > 1., 'Binary quant not supported' + assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported' + + def quantize_fn(self, x, scale, zero_point, dtype, axis): + return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis) + + def clip_fn(self, x, min_val, max_val): + return IntClipFn.apply(x, min_val, max_val) + + +class StdQCDQONNXWeightQuantProxyHandler(StdQCDQONNXMixin, + QCDQWeightQuantProxyHandlerMixin, + ONNXBaseHandler): + pass + + +class StdQCDQONNXDecoupledWeightQuantProxyHandler(StdQCDQONNXMixin, + QCDQDecoupledWeightQuantProxyHandlerMixin, + ONNXBaseHandler): + pass + + +class StdQCDQONNXActQuantProxyHandler(StdQCDQONNXMixin, + QCDQActQuantProxyHandlerMixin, + ONNXBaseHandler): + pass + + +class StdQCDQONNXBiasQuantProxyHandler(StdDQONNXMixin, + QCDQBiasQuantProxyHandlerMixin, + ONNXBaseHandler): + pass + + +class StdQCDQONNXTruncQuantProxyHandler(StdQCDQONNXMixin, + QCDQTruncQuantProxyHandlerMixin, + ONNXBaseHandler): + pass + + +class StdQCDQONNXQuantLSTMLayerHandler(QuantLSTMLayerHandler): + + def quantized_cell_symbolic_execution( + self, + quant_input, + quant_hidden_state, + quant_cell_state, + quant_weight_ii, + quant_weight_if, + quant_weight_ic, + quant_weight_io, + quant_weight_hi, + quant_weight_hf, + quant_weight_hc, + quant_weight_ho, + quant_bias_input, + quant_bias_forget, + quant_bias_cell, + quant_bias_output): + return BrevitasQuantLSTMCellFn.apply( + quant_input, + quant_hidden_state, + quant_cell_state, + quant_weight_ii, + quant_weight_if, + quant_weight_ic, + quant_weight_io, + quant_weight_hi, + quant_weight_hf, + quant_weight_hc, + quant_weight_ho, + quant_bias_input, + quant_bias_forget, + quant_bias_cell, + quant_bias_output, + *self.symbolic_kwargs.values()) + # raise RuntimeError( + # "Quantized LSTM cell is not supported for ONNX QCDQ " + # "(weights only quantization is). Use export_qonnx.") From 7ebbeac73c5d253a4b76639945554aacd241d13f Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Thu, 25 Apr 2024 16:19:13 +0200 Subject: [PATCH 08/28] Add unit test for SortCommutativeInputsInitializerLast transformation --- ...ort_commutative_inputs_initializer_last.py | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 tests/transformation/test_sort_commutative_inputs_initializer_last.py diff --git a/tests/transformation/test_sort_commutative_inputs_initializer_last.py b/tests/transformation/test_sort_commutative_inputs_initializer_last.py new file mode 100644 index 00000000..134cb89e --- /dev/null +++ b/tests/transformation/test_sort_commutative_inputs_initializer_last.py @@ -0,0 +1,78 @@ +# Set pytest parameters +import pytest +# Numpy for handling simulation of tensor operations +import numpy as np +# Helper for creating ONNX nodes +from onnx import TensorProto +from onnx import helper as oh +# QONNX wrapper of ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper +# QONNX utility for creating models from ONNX graphs +from qonnx.util.basic import qonnx_make_model +# Execute QONNX model graphs +from qonnx.core.onnx_exec import execute_onnx +# Graph transformation to be tested: Sorts the input list of commutative +# operations to have all dynamic inputs first followed by all initializer inputs +from qonnx.transformation.general import SortCommutativeInputsInitializerLast + + +# Specify how many inputs the test should cover +@pytest.mark.parametrize("num_inputs", [4, 5, 6]) +# Specify which inputs should be turned into initializers +@pytest.mark.parametrize( + "initializers", [[], [0], [1], [0, 1], [0, 3], [0, 1, 2, 3]] +) +# Tests the SortCommutativeInputsInitializerLast transformation +def test_sort_commutative_inputs_initializer_last(num_inputs, initializers): + # Generate the input tensor names + inputs = [f"in{i}" for i in range(num_inputs)] + # We will use the Sum ONNX operation to test this behavior, as it allows for + # arbitrary many inputs + node = oh.make_node( + op_type="Sum", inputs=inputs, outputs=["out"], name="Sum" + ) + # Create value infos for all input and the output tensor + inputs = [ + oh.make_tensor_value_info(i, TensorProto.FLOAT, (16,)) for i in inputs + ] + out = oh.make_tensor_value_info("out", TensorProto.FLOAT, (16,)) + # Make a graph comprising the Sum node and value infos for all inputs and + # the output + graph = oh.make_graph([node], inputs=inputs, outputs=[out], name="Sum") + # Wrap the graph in an QONNX model wrapper + model = ModelWrapper(qonnx_make_model(graph, producer_name="qonnx-tests")) + # Prepare the execution context + context = { + f"in{i}": np.random.rand(16) for i in range(num_inputs) + } + # Make sure all inputs are of type float32 + context = {key: value.astype(np.float32) for key, value in context.items()} + # Turn selected inputs into initializers + for i in initializers: + model.set_initializer(f"in{i}", context[f"in{i}"]) + + # Execute the ONNX model before transforming + out_expected = execute_onnx(model, context)["out"] + # Apply the transformation to be tested + # Note: No cleanup, as the tested transformation is part of the cleanup, and + # we want to test this in isolation + model = model.transform( + SortCommutativeInputsInitializerLast(), cleanup=False + ) + # Execute the ONNX model after transforming + out_produced = execute_onnx(model, context)["out"] + + # Start with no initializer input seen so far + seen_initializer = False + # Verify that no "dynamic" input follows an initializer input + for i in model.graph.node[0].input: + # Keep track of when an initializer has been seen + if model.get_initializer(i) is not None: + seen_initializer = True + # If there has already been an initializer, this input must be an + # initializer as well + assert not seen_initializer or model.get_initializer(i) is not None, \ + "Non-initializer input following initializer after sorting" + + # Outputs before and after must match + assert np.allclose(out_produced, out_expected) From 38df9fbe3b3dce582eae26b02d1a9b5ee91ebbad Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Thu, 25 Apr 2024 16:26:22 +0200 Subject: [PATCH 09/28] Address some linting issues --- src/qonnx/transformation/general.py | 1 - ...ort_commutative_inputs_initializer_last.py | 27 ++++++++++++++----- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/qonnx/transformation/general.py b/src/qonnx/transformation/general.py index b5ed0fca..d69cee5a 100644 --- a/src/qonnx/transformation/general.py +++ b/src/qonnx/transformation/general.py @@ -405,7 +405,6 @@ def apply(self, model): # noqa # Length of sorted and original input list must match assert len(inputs) == len(node.input) # Reassigned inputs from sorted categories - # Note: ONNX does not allow direct assignment to node.input for i, name in enumerate(inputs): # The graph has been modified if any input is reordered if node.input[i] != name: diff --git a/tests/transformation/test_sort_commutative_inputs_initializer_last.py b/tests/transformation/test_sort_commutative_inputs_initializer_last.py index 134cb89e..1cd1eb72 100644 --- a/tests/transformation/test_sort_commutative_inputs_initializer_last.py +++ b/tests/transformation/test_sort_commutative_inputs_initializer_last.py @@ -1,26 +1,34 @@ # Set pytest parameters import pytest + # Numpy for handling simulation of tensor operations import numpy as np + # Helper for creating ONNX nodes from onnx import TensorProto from onnx import helper as oh + # QONNX wrapper of ONNX model graphs from qonnx.core.modelwrapper import ModelWrapper -# QONNX utility for creating models from ONNX graphs -from qonnx.util.basic import qonnx_make_model + # Execute QONNX model graphs from qonnx.core.onnx_exec import execute_onnx + # Graph transformation to be tested: Sorts the input list of commutative # operations to have all dynamic inputs first followed by all initializer inputs from qonnx.transformation.general import SortCommutativeInputsInitializerLast +# QONNX utility for creating models from ONNX graphs +from qonnx.util.basic import qonnx_make_model + # Specify how many inputs the test should cover @pytest.mark.parametrize("num_inputs", [4, 5, 6]) # Specify which inputs should be turned into initializers @pytest.mark.parametrize( + # fmt: off "initializers", [[], [0], [1], [0, 1], [0, 3], [0, 1, 2, 3]] + # fmt: on ) # Tests the SortCommutativeInputsInitializerLast transformation def test_sort_commutative_inputs_initializer_last(num_inputs, initializers): @@ -29,11 +37,15 @@ def test_sort_commutative_inputs_initializer_last(num_inputs, initializers): # We will use the Sum ONNX operation to test this behavior, as it allows for # arbitrary many inputs node = oh.make_node( + # fmt: off op_type="Sum", inputs=inputs, outputs=["out"], name="Sum" + # fmt: on ) # Create value infos for all input and the output tensor inputs = [ + # fmt: off oh.make_tensor_value_info(i, TensorProto.FLOAT, (16,)) for i in inputs + # fmt: on ] out = oh.make_tensor_value_info("out", TensorProto.FLOAT, (16,)) # Make a graph comprising the Sum node and value infos for all inputs and @@ -42,9 +54,7 @@ def test_sort_commutative_inputs_initializer_last(num_inputs, initializers): # Wrap the graph in an QONNX model wrapper model = ModelWrapper(qonnx_make_model(graph, producer_name="qonnx-tests")) # Prepare the execution context - context = { - f"in{i}": np.random.rand(16) for i in range(num_inputs) - } + context = {f"in{i}": np.random.rand(16) for i in range(num_inputs)} # Make sure all inputs are of type float32 context = {key: value.astype(np.float32) for key, value in context.items()} # Turn selected inputs into initializers @@ -57,7 +67,9 @@ def test_sort_commutative_inputs_initializer_last(num_inputs, initializers): # Note: No cleanup, as the tested transformation is part of the cleanup, and # we want to test this in isolation model = model.transform( + # fmt: off SortCommutativeInputsInitializerLast(), cleanup=False + # fmt: on ) # Execute the ONNX model after transforming out_produced = execute_onnx(model, context)["out"] @@ -71,8 +83,9 @@ def test_sort_commutative_inputs_initializer_last(num_inputs, initializers): seen_initializer = True # If there has already been an initializer, this input must be an # initializer as well - assert not seen_initializer or model.get_initializer(i) is not None, \ - "Non-initializer input following initializer after sorting" + assert ( + not seen_initializer or model.get_initializer(i) is not None + ), "Non-initializer input following initializer after sorting" # Outputs before and after must match assert np.allclose(out_produced, out_expected) From 57d0d9d6a8b7f61e68fea32581e0aec3031a3293 Mon Sep 17 00:00:00 2001 From: Tim Paine <3105306+timkpaine@users.noreply.github.com> Date: Sat, 11 May 2024 16:00:04 -0400 Subject: [PATCH 10/28] Remove some commented debug code --- src/qonnx/core/onnx_exec.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/qonnx/core/onnx_exec.py b/src/qonnx/core/onnx_exec.py index a5be9dee..a8f4774c 100644 --- a/src/qonnx/core/onnx_exec.py +++ b/src/qonnx/core/onnx_exec.py @@ -208,7 +208,6 @@ def execute_onnx_and_make_model(model, input_dict): new_model.set_initializer(i, execution_context[i]) for vi in new_model.graph.value_info: new_model.graph.output.append(vi) - # import pdb; pdb.set_trace() return new_model From 1b2774c0635476add3ce05bc50f4849480282bdb Mon Sep 17 00:00:00 2001 From: makoeppel Date: Mon, 17 Jun 2024 09:28:36 +0200 Subject: [PATCH 11/28] refactor LowerConvsToMatMul class, increase rtol in test_conv_lowering_convmnist() --- README.md | 1 + docs/index.rst | 3 + .../transformation/lower_convs_to_matmul.py | 292 ++++++++---------- tests/transformation/test_conv_lowering.py | 2 +- 4 files changed, 139 insertions(+), 159 deletions(-) diff --git a/README.md b/README.md index dd9b6c66..69c28b3c 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,7 @@ source venv/bin/activate pip install -e .[qkeras,testing] ``` +### Test suite Run entire test suite, parallelized across CPU cores: ``` pytest -n auto --verbose diff --git a/docs/index.rst b/docs/index.rst index f07ba086..53b9c159 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -63,6 +63,9 @@ Install in editable mode in a venv: pip install -e .[testing, docs, notebooks] +Test suite +++++++++++ + Run entire test suite, parallelized across CPU cores: :: diff --git a/src/qonnx/transformation/lower_convs_to_matmul.py b/src/qonnx/transformation/lower_convs_to_matmul.py index bf95d537..c5964cf4 100644 --- a/src/qonnx/transformation/lower_convs_to_matmul.py +++ b/src/qonnx/transformation/lower_convs_to_matmul.py @@ -42,167 +42,143 @@ class LowerConvsToMatMul(Transformation): def apply(self, model): model = model.transform(ExtractBiasFromConv()) graph = model.graph - node_ind = 0 graph_modified = False - for n in graph.node: - node_ind += 1 - if n.op_type == "Conv": - if len(n.input) == 3: - warnings.warn("Found Conv node with bias, skipping") - continue - cnv_input = n.input[0] - cnv_output = n.output[0] - idt = model.get_tensor_datatype(cnv_input) - odt = model.get_tensor_datatype(cnv_output) - # extract conv parameters - k = get_by_name(n.attribute, "kernel_shape").ints - k_h = k[0] - k_w = k[1] - stride_h = get_by_name(n.attribute, "strides").ints[0] - stride_w = get_by_name(n.attribute, "strides").ints[1] - group = get_by_name(n.attribute, "group").i - weight_name = n.input[1] - W_conv = model.get_initializer(weight_name) - ifm_ch = model.get_tensor_shape(n.input[0])[1] # assume NCHW - ofm_ch = model.get_tensor_shape(n.output[0])[1] # assume NCHW - ifm_dim_h = model.get_tensor_shape(n.input[0])[2] # assume NCHW - ifm_dim_w = model.get_tensor_shape(n.input[0])[3] - ofm_dim_h = model.get_tensor_shape(n.output[0])[2] # assume NCHW - ofm_dim_w = model.get_tensor_shape(n.output[0])[3] - dilation_attr = get_by_name(n.attribute, "dilations") - if dilation_attr is not None: - dilation = dilation_attr.ints - else: - dilation = [1, 1] # default value - # handle both auto_pad and explicit padding - auto_pad = get_by_name(n.attribute, "auto_pad") - if auto_pad is not None: - # find equivalent specified padding - auto_pad = auto_pad.s.decode("utf-8") - if auto_pad == "NOTSET": - # use specified padding - pad = get_by_name(n.attribute, "pads").ints - else: - pad = auto_pad_to_explicit_padding( - auto_pad, - ifm_dim_h, - ifm_dim_w, - k_h, - k_w, - stride_h, - stride_w, - len(model.get_tensor_shape(n.input[0])) - 2, - ) - else: - # use specified padding - pad = get_by_name(n.attribute, "pads").ints - - # If len(pad) == 2, assume no padding for other dimension - if len(pad) == 2: # only one dimension should be padded - assert ifm_dim_h == 1 or ifm_dim_w == 1, "Padding is assumed to be 1D, image is 2D" - - # if depthwise conv create sparse matrix and variable "dw" - # to store as attribute in Im2Col that indicates that the created + for node_ind, node in enumerate(graph.node, start=1): + if node.op_type != "Conv": + continue + + if len(node.input) == 3: + warnings.warn("Found Conv node with bias, skipping") + continue + + # extract parameters of node + (cnv_input, cnv_output, cnv_input_datatype, cnv_output_datatype, + k_h, k_w, stride_h, stride_w, group, weight_name, W_conv, ifm_ch, + ofm_ch, ifm_dim_h, ifm_dim_w, ofm_dim_h, ofm_dim_w, dilation, pad) =\ + self.extract_conv_params(model, node) + + # if depthwise conv create sparse matrix and variable "dw" + # to store as attribute in Im2Col that indicates that the created + # Im2Col node belongs to a depthwise convolution + dw = False + if group == ifm_ch and ofm_ch == ifm_ch: + W_sparse = np.zeros((ofm_ch, ifm_ch, k_h, k_w)) # (OFM, IFM, k_H, k_W) + for ch in range(ifm_ch): + W_sparse[ch][ch] = W_conv[ch][0] # W_conv = [OFM, IFM, k_H, k_W] + W_conv = W_sparse.astype(np.float32) + # we need to store information of the + # sparsity of the weight matrix. For this + # we use the sparsity annotation of the + # weight tensor + sparsity = {"dw": {"kernel_shape": [k_h, k_w]}} + model.set_tensor_sparsity(weight_name, sparsity) + # additionally create variable "dw" to store + # as attribute in Im2Col that indicates that the created # Im2Col node belongs to a depthwise convolution - dw = False - if group == ifm_ch and ofm_ch == ifm_ch: - W_sparse = np.zeros((ofm_ch, ifm_ch, k_h, k_w)) # (OFM, IFM, k_H, k_W) - for ch in range(ifm_ch): - W_sparse[ch][ch] = W_conv[ch][0] # W_conv = [OFM, IFM, k_H, k_W] - W_conv = W_sparse.astype(np.float32) - # we need to store information of the - # sparsity of the weight matrix. For this - # we use the sparsity annotation of the - # weight tensor - sparsity = {"dw": {"kernel_shape": [k_h, k_w]}} - model.set_tensor_sparsity(weight_name, sparsity) - # additionally create variable "dw" to store - # as attribute in Im2Col that indicates that the created - # Im2Col node belongs to a depthwise convolution - dw = True - - # reuse conv weights for new matmul weights - # conv weights are [OFM][IFM][k][k] - # first convert to [OFM][k][k][IFM] (to remain compatible with - # finn-hlslib and how it does im2col/sliding window) - W_matmul = W_conv.transpose(0, 2, 3, 1) # W_conv = [OFM, IFM, k_H, k_W] - # reshape into [OFM][k*k*IFM] matrix - W_matmul = W_matmul.reshape(ofm_ch, ifm_ch * k_h * k_w) - # transpose to get ONNX-compatible [k*k*IFM][OFM] matrix - W_matmul = W_matmul.T - model.set_initializer(weight_name, W_matmul) - - # create new intermediate values - inp_trans_out = helper.make_tensor_value_info( - model.make_new_valueinfo_name(), - TensorProto.FLOAT, - (1, ifm_dim_h, ifm_dim_w, ifm_ch), # NHWC + dw = True + + # reuse conv weights for new matmul weights + # conv weights are [OFM][IFM][k][k] + # first convert to [OFM][k_h][k_w][IFM] (to remain compatible with + # finn-hlslib and how it does im2col/sliding window) + W_matmul = W_conv.transpose(0, 2, 3, 1) # W_conv = [OFM, IFM, k_H, k_W] + # reshape into [OFM][k_h*k_w*IFM] matrix + W_matmul = W_matmul.reshape(ofm_ch, ifm_ch * k_h * k_w) + # transpose to get ONNX-compatible [k_h*k_w*IFM][OFM] matrix + W_matmul = W_matmul.T + model.set_initializer(weight_name, W_matmul) + + # create new intermediate values + inp_trans_out = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), + TensorProto.FLOAT, + (1, ifm_dim_h, ifm_dim_w, ifm_ch), # NHWC + ) + graph.value_info.append(inp_trans_out) + inp_trans_out = inp_trans_out.name + model.set_tensor_datatype(inp_trans_out, cnv_input_datatype) + + # k_h=k_w==1: pointwise convolution, thus no im2col needed + need_im2col = any(p != 0 for p in pad) or k_h != 1 or k_w != 1 or stride_h != 1 or stride_w != 1 + + # create new intermediate values + matmul_out = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), TensorProto.FLOAT, (1, ofm_dim_h, ofm_dim_w, ofm_ch) + ) + graph.value_info.append(matmul_out) + matmul_out = matmul_out.name + model.set_tensor_datatype(matmul_out, cnv_output_datatype) + + # create new nodes + # NCHW -> NHWC + inp_trans_node = helper.make_node("Transpose", [cnv_input], [inp_trans_out], perm=[0, 2, 3, 1]) + nodes_to_insert = [inp_trans_node] + + if need_im2col: + im2col_out = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), TensorProto.FLOAT, (1, ofm_dim_h, ofm_dim_w, ifm_ch * k_h * k_w) ) - graph.value_info.append(inp_trans_out) - inp_trans_out = inp_trans_out.name - model.set_tensor_datatype(inp_trans_out, idt) - - need_im2col = True - if all(p == 0 for p in pad): - padding = 0 - - # k_h=k_w==1: pointwise convolution, thus no im2col needed - if k_h == 1 and k_w == 1 and padding == 0 and stride_h == 1 and stride_w == 1: - need_im2col = False - - if need_im2col: - im2col_out = helper.make_tensor_value_info( - model.make_new_valueinfo_name(), - TensorProto.FLOAT, - (1, ofm_dim_h, ofm_dim_w, ifm_ch * k_h * k_w), - ) - graph.value_info.append(im2col_out) - im2col_out = im2col_out.name - model.set_tensor_datatype(im2col_out, idt) - - matmul_out = helper.make_tensor_value_info( - model.make_new_valueinfo_name(), - TensorProto.FLOAT, - (1, ofm_dim_h, ofm_dim_w, ofm_ch), + graph.value_info.append(im2col_out) + im2col_out = im2col_out.name + model.set_tensor_datatype(im2col_out, cnv_input_datatype) + im2col_node = helper.make_node( + "Im2Col", [inp_trans_out], [im2col_out], domain="qonnx.custom_op.general", + stride=[stride_h, stride_w], kernel_size=[k_h, k_w], pad_amount=pad, + input_shape="(1,{},{},{})".format(ifm_dim_h, ifm_dim_w, ifm_ch), depthwise=dw, dilations=dilation ) - graph.value_info.append(matmul_out) - matmul_out = matmul_out.name - model.set_tensor_datatype(matmul_out, odt) - - # create new nodes - # NCHW -> NHWC - inp_trans_node = helper.make_node("Transpose", [cnv_input], [inp_trans_out], perm=[0, 2, 3, 1]) - # lower input tensor - matmul_input = inp_trans_out - if need_im2col: - matmul_input = im2col_out - im2col_node = helper.make_node( - "Im2Col", - [inp_trans_out], - [im2col_out], - domain="qonnx.custom_op.general", - stride=[stride_h, stride_w], - kernel_size=[k_h, k_w], - pad_amount=pad, - input_shape="(1,{},{},{})".format(ifm_dim_h, ifm_dim_w, ifm_ch), - depthwise=dw, - dilations=dilation, - ) - - # do matmul - matmul_node = helper.make_node("MatMul", [matmul_input, weight_name], [matmul_out]) - # NHWC -> NCHW - out_trans_node = helper.make_node("Transpose", [matmul_out], [cnv_output], perm=[0, 3, 1, 2]) - # insert nodes where the conv is to preserve topological ordering - graph.node.insert(node_ind, inp_trans_node) - if need_im2col: - graph.node.insert(node_ind + 1, im2col_node) - graph.node.insert(node_ind + 2, matmul_node) - graph.node.insert(node_ind + 3, out_trans_node) - else: - graph.node.insert(node_ind + 1, matmul_node) - graph.node.insert(node_ind + 2, out_trans_node) - # remove old nodes - graph.node.remove(n) + nodes_to_insert.append(im2col_node) + + matmul_input = im2col_out if need_im2col else inp_trans_out + # do matmul + matmul_node = helper.make_node("MatMul", [matmul_input, weight_name], [matmul_out]) + # NHWC -> NCHW + out_trans_node = helper.make_node("Transpose", [matmul_out], [cnv_output], perm=[0, 3, 1, 2]) + + nodes_to_insert.extend([matmul_node, out_trans_node]) + + # insert nodes where the conv is to preserve topological ordering + for i, insert_node in enumerate(nodes_to_insert): + graph.node.insert(node_ind + i, insert_node) + graph.node.remove(node) return (model, graph_modified) + + def extract_conv_params(self, model, node): + + cnv_input = node.input[0] + cnv_output = node.output[0] + cnv_input_datatype = model.get_tensor_datatype(cnv_input) + cnv_output_datatype = model.get_tensor_datatype(cnv_output) + k_h = get_by_name(node.attribute, "kernel_shape").ints[0] + k_w = get_by_name(node.attribute, "kernel_shape").ints[1] + stride_h = get_by_name(node.attribute, "strides").ints[0] + stride_w = get_by_name(node.attribute, "strides").ints[1] + group = get_by_name(node.attribute, "group").i + weight_name = node.input[1] + W_conv = model.get_initializer(weight_name) + ifm_ch = model.get_tensor_shape(cnv_input)[1] # assume NCHW + ofm_ch = model.get_tensor_shape(cnv_output)[1] # assume NCHW + ifm_dim_h = model.get_tensor_shape(cnv_input)[2] # assume NCHW + ifm_dim_w = model.get_tensor_shape(cnv_input)[3] # assume NCHW + ofm_dim_h = model.get_tensor_shape(cnv_output)[2] # assume NCHW + ofm_dim_w = model.get_tensor_shape(cnv_output)[3] # assume NCHW + dilation_attr = get_by_name(node.attribute, "dilations") + dilation = dilation_attr.ints if dilation_attr is not None else [1, 1] # default value + auto_pad = get_by_name(node.attribute, "auto_pad") + if auto_pad is not None: + auto_pad = auto_pad.s.decode("utf-8") + if auto_pad == "NOTSET": + pad = get_by_name(node.attribute, "pads").ints + else: + pad = auto_pad_to_explicit_padding( + auto_pad, ifm_dim_h, ifm_dim_w, k_h, k_w, stride_h, stride_w, len(model.get_tensor_shape(cnv_input)) - 2 + ) + else: + pad = get_by_name(node.attribute, "pads").ints + + if len(pad) == 2: # only one dimension should be padded + assert ifm_dim_h == 1 or ifm_dim_w == 1, "Padding is assumed to be 1D, image is 2D" + + return (cnv_input, cnv_output, cnv_input_datatype, cnv_output_datatype, k_h, k_w, stride_h, + stride_w, group, weight_name, W_conv, ifm_ch, ofm_ch, ifm_dim_h, ifm_dim_w, ofm_dim_h, + ofm_dim_w, dilation, pad) diff --git a/tests/transformation/test_conv_lowering.py b/tests/transformation/test_conv_lowering.py index 78da6213..044da1b2 100644 --- a/tests/transformation/test_conv_lowering.py +++ b/tests/transformation/test_conv_lowering.py @@ -65,7 +65,7 @@ def test_conv_lowering_convmnist(): model = model.transform(InferShapes()) output_dict_p = oxe.execute_onnx(model, input_dict) produced = output_dict_p[output_name] - assert np.isclose(produced, expected).all() + assert np.isclose(produced, expected, rtol=1.e-4).all() def run_conv_lowering_test(idt, k_h, k_w, ifm_dim_h, ifm_dim_w, ifm_ch, stride, padding, dilations, dw, bias): From 75f8f8c887f613f8f41bdfd84d29997a6db5b8dd Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Fri, 16 Aug 2024 15:05:27 +0200 Subject: [PATCH 12/28] run pre-commit on all files --- docs/license.rst | 4 +- notebooks/4_quant_lstm_helper/function.py | 401 +++++++++--------- notebooks/4_quant_lstm_helper/handler.py | 97 ++--- .../transformation/lower_convs_to_matmul.py | 63 ++- tests/transformation/test_conv_lowering.py | 2 +- 5 files changed, 293 insertions(+), 274 deletions(-) diff --git a/docs/license.rst b/docs/license.rst index e647e180..a5103f77 100644 --- a/docs/license.rst +++ b/docs/license.rst @@ -1,7 +1,7 @@ .. _license: -======= +======== License -======= +======== .. include:: ../LICENSE diff --git a/notebooks/4_quant_lstm_helper/function.py b/notebooks/4_quant_lstm_helper/function.py index 6ba2e9dd..935bf78a 100644 --- a/notebooks/4_quant_lstm_helper/function.py +++ b/notebooks/4_quant_lstm_helper/function.py @@ -2,26 +2,24 @@ # SPDX-License-Identifier: BSD-3-Clause import torch -from torch.autograd import Function - from brevitas.export.onnx import onnx_export_opset +from torch.autograd import Function AXIS_OPSET = 13 DOMAIN_STRING = "onnx.brevitas" class DequantizeLinearFn(Function): - @staticmethod def symbolic(g, x, input_scale, input_zero_point, input_axis): opset_version = onnx_export_opset() if input_axis is not None and opset_version < AXIS_OPSET: - raise RuntimeError('ONNX Opset 13 is required for per-channel quantization') + raise RuntimeError("ONNX Opset 13 is required for per-channel quantization") elif input_axis is not None and opset_version >= AXIS_OPSET: - ret = g.op('DequantizeLinear', x, input_scale, input_zero_point, axis_i=input_axis) + ret = g.op("DequantizeLinear", x, input_scale, input_zero_point, axis_i=input_axis) else: - ret = g.op('DequantizeLinear', x, input_scale, input_zero_point) + ret = g.op("DequantizeLinear", x, input_scale, input_zero_point) return ret @staticmethod @@ -30,10 +28,9 @@ def forward(ctx, int_x, input_scale, input_zero_point, input_axis): class IntClipFn(Function): - @staticmethod def symbolic(g, int_x, min_int_val, max_int_val): - ret = g.op('Clip', int_x, min_int_val, max_int_val) + ret = g.op("Clip", int_x, min_int_val, max_int_val) return ret @staticmethod @@ -42,116 +39,115 @@ def forward(ctx, int_x, min_int_val, max_int_val): class QuantizeLinearFn(Function): - @staticmethod def symbolic(g, x, output_scale, ouput_zero_point, output_dtype, output_axis): opset_version = onnx_export_opset() if output_axis is not None and opset_version < AXIS_OPSET: - raise RuntimeError('ONNX Opset 13 is required for per-channel quantization') + raise RuntimeError("ONNX Opset 13 is required for per-channel quantization") elif output_axis is not None and opset_version >= AXIS_OPSET: - ret = g.op('QuantizeLinear', x, output_scale, ouput_zero_point, axis_i=output_axis) + ret = g.op("QuantizeLinear", x, output_scale, ouput_zero_point, axis_i=output_axis) else: - ret = g.op('QuantizeLinear', x, output_scale, ouput_zero_point) + ret = g.op("QuantizeLinear", x, output_scale, ouput_zero_point) return ret @staticmethod def forward(ctx, x, output_scale, ouput_zero_point, output_dtype, output_axis): return x.type(output_dtype) -class BrevitasQuantLSTMCellFn(Function): - +class BrevitasQuantLSTMCellFn(Function): @staticmethod def symbolic( - g, # args and kwargs passed from _QuantLSTMLayer - quant_input, - quant_hidden_state, - quant_cell_state, - quant_weight_ii, - quant_weight_if, - quant_weight_ic, - quant_weight_io, - quant_weight_hi, - quant_weight_hf, - quant_weight_hc, - quant_weight_ho, - quant_bias_input, - quant_bias_forget, - quant_bias_cell, - quant_bias_output, # Symbolic kwargs passed from BrevitasQuantLSTMLayerHandler - batch_first, - reverse_input, - cifg, # Output quant - output_scale, - output_zero_point, - output_bit_width, - output_narrow_range, - output_signed, - output_rounding_mode, # Cell state quant - cell_state_scale, - cell_state_zero_point, - cell_state_bit_width, - cell_state_narrow_range, - cell_state_signed, - cell_state_rounding_mode, # Input gate accumulator quant - input_acc_scale, - input_acc_zero_point, - input_acc_bit_width, - input_acc_narrow_range, - input_acc_signed, - input_acc_rounding_mode, # Forget gate accumulator quant - forget_acc_scale, - forget_acc_zero_point, - forget_acc_bit_width, - forget_acc_narrow_range, - forget_acc_signed, - forget_acc_rounding_mode, # Cell gate accumulator quant - cell_acc_scale, - cell_acc_zero_point, - cell_acc_bit_width, - cell_acc_narrow_range, - cell_acc_signed, - cell_acc_rounding_mode, # Output gate accumulator quant - output_acc_scale, - output_acc_zero_point, - output_acc_bit_width, - output_acc_narrow_range, - output_acc_signed, - output_acc_rounding_mode, # Input gate sigmoid quant - input_sigmoid_scale, - input_sigmoid_zero_point, - input_sigmoid_bit_width, - input_sigmoid_narrow_range, - input_sigmoid_signed, - input_sigmoid_rounding_mode, # Forget gate sigmoid quant - forget_sigmoid_scale, - forget_sigmoid_zero_point, - forget_sigmoid_bit_width, - forget_sigmoid_narrow_range, - forget_sigmoid_signed, - forget_sigmoid_rounding_mode, # Cell gate tanh quant - cell_tanh_scale, - cell_tanh_zero_point, - cell_tanh_bit_width, - cell_tanh_narrow_range, - cell_tanh_signed, - cell_tanh_rounding_mode, # Output gate sigmoid quant - output_sigmoid_scale, - output_sigmoid_zero_point, - output_sigmoid_bit_width, - output_sigmoid_narrow_range, - output_sigmoid_signed, - output_sigmoid_rounding_mode, # Hidden state tanh quant - hidden_state_tanh_scale, - hidden_state_tanh_zero_point, - hidden_state_tanh_bit_width, - hidden_state_tanh_narrow_range, - hidden_state_tanh_signed, - hidden_state_tanh_rounding_mode): + g, # args and kwargs passed from _QuantLSTMLayer + quant_input, + quant_hidden_state, + quant_cell_state, + quant_weight_ii, + quant_weight_if, + quant_weight_ic, + quant_weight_io, + quant_weight_hi, + quant_weight_hf, + quant_weight_hc, + quant_weight_ho, + quant_bias_input, + quant_bias_forget, + quant_bias_cell, + quant_bias_output, # Symbolic kwargs passed from BrevitasQuantLSTMLayerHandler + batch_first, + reverse_input, + cifg, # Output quant + output_scale, + output_zero_point, + output_bit_width, + output_narrow_range, + output_signed, + output_rounding_mode, # Cell state quant + cell_state_scale, + cell_state_zero_point, + cell_state_bit_width, + cell_state_narrow_range, + cell_state_signed, + cell_state_rounding_mode, # Input gate accumulator quant + input_acc_scale, + input_acc_zero_point, + input_acc_bit_width, + input_acc_narrow_range, + input_acc_signed, + input_acc_rounding_mode, # Forget gate accumulator quant + forget_acc_scale, + forget_acc_zero_point, + forget_acc_bit_width, + forget_acc_narrow_range, + forget_acc_signed, + forget_acc_rounding_mode, # Cell gate accumulator quant + cell_acc_scale, + cell_acc_zero_point, + cell_acc_bit_width, + cell_acc_narrow_range, + cell_acc_signed, + cell_acc_rounding_mode, # Output gate accumulator quant + output_acc_scale, + output_acc_zero_point, + output_acc_bit_width, + output_acc_narrow_range, + output_acc_signed, + output_acc_rounding_mode, # Input gate sigmoid quant + input_sigmoid_scale, + input_sigmoid_zero_point, + input_sigmoid_bit_width, + input_sigmoid_narrow_range, + input_sigmoid_signed, + input_sigmoid_rounding_mode, # Forget gate sigmoid quant + forget_sigmoid_scale, + forget_sigmoid_zero_point, + forget_sigmoid_bit_width, + forget_sigmoid_narrow_range, + forget_sigmoid_signed, + forget_sigmoid_rounding_mode, # Cell gate tanh quant + cell_tanh_scale, + cell_tanh_zero_point, + cell_tanh_bit_width, + cell_tanh_narrow_range, + cell_tanh_signed, + cell_tanh_rounding_mode, # Output gate sigmoid quant + output_sigmoid_scale, + output_sigmoid_zero_point, + output_sigmoid_bit_width, + output_sigmoid_narrow_range, + output_sigmoid_signed, + output_sigmoid_rounding_mode, # Hidden state tanh quant + hidden_state_tanh_scale, + hidden_state_tanh_zero_point, + hidden_state_tanh_bit_width, + hidden_state_tanh_narrow_range, + hidden_state_tanh_signed, + hidden_state_tanh_rounding_mode, + ): return g.op( - f'{DOMAIN_STRING}::QuantLSTMCell', # Tensors - ## Input values + f"{DOMAIN_STRING}::QuantLSTMCell", # Tensors + # Input values quant_input, quant_hidden_state, quant_cell_state, @@ -166,37 +162,37 @@ def symbolic( quant_bias_input, quant_bias_forget, quant_bias_cell, - quant_bias_output, ## Output quant + quant_bias_output, # Output quant output_scale, output_zero_point, - output_bit_width, ## Cell state quant + output_bit_width, # Cell state quant cell_state_scale, cell_state_zero_point, - cell_state_bit_width, ## Input gate accumulator quant + cell_state_bit_width, # Input gate accumulator quant input_acc_scale, input_acc_zero_point, - input_acc_bit_width, ## Forget gate accumulator quant + input_acc_bit_width, # Forget gate accumulator quant forget_acc_scale, forget_acc_zero_point, - forget_acc_bit_width, ## Cell gate accumulator quant + forget_acc_bit_width, # Cell gate accumulator quant cell_acc_scale, cell_acc_zero_point, - cell_acc_bit_width, ## Output gate accumulator quant + cell_acc_bit_width, # Output gate accumulator quant output_acc_scale, output_acc_zero_point, - output_acc_bit_width, ## Input gate sigmoid quant + output_acc_bit_width, # Input gate sigmoid quant input_sigmoid_scale, input_sigmoid_zero_point, - input_sigmoid_bit_width, ## Forget gate sigmoid quant + input_sigmoid_bit_width, # Forget gate sigmoid quant forget_sigmoid_scale, forget_sigmoid_zero_point, - forget_sigmoid_bit_width, ## Cell gate tanh quant + forget_sigmoid_bit_width, # Cell gate tanh quant cell_tanh_scale, cell_tanh_zero_point, - cell_tanh_bit_width, ## Output gate sigmoid quant + cell_tanh_bit_width, # Output gate sigmoid quant output_sigmoid_scale, output_sigmoid_zero_point, - output_sigmoid_bit_width, ## Hidden state tanh quant + output_sigmoid_bit_width, # Hidden state tanh quant hidden_state_tanh_scale, hidden_state_tanh_zero_point, hidden_state_tanh_bit_width, @@ -238,103 +234,102 @@ def symbolic( hidden_state_tanh_signed_i=hidden_state_tanh_signed, hidden_state_tanh_rounding_mode_s=hidden_state_tanh_rounding_mode, # PyTorch requires to specify the number of outputs manually - outputs=3) - + outputs=3, + ) @staticmethod def forward( - ctx, # args and kwargs passed from _QuantLSTMLayer - quant_input, - quant_hidden_state, - quant_cell_state, - quant_weight_ii, - quant_weight_if, - quant_weight_ic, - quant_weight_io, - quant_weight_hi, - quant_weight_hf, - quant_weight_hc, - quant_weight_ho, - quant_bias_input, - quant_bias_forget, - quant_bias_cell, - quant_bias_output, # Symbolic kwargs passed from BrevitasQuantLSTMLayerHandler - batch_first, - reverse_input, - cifg, # Output quant - output_scale, - output_zero_point, - output_bit_width, - output_narrow_range, - output_signed, - output_rounding_mode, # Cell state quant - cell_state_scale, - cell_state_zero_point, - cell_state_bit_width, - cell_state_narrow_range, - cell_state_signed, - cell_state_rounding_mode, # Input gate accumulator quant - input_acc_scale, - input_acc_zero_point, - input_acc_bit_width, - input_acc_narrow_range, - input_acc_signed, - input_acc_rounding_mode, # Forget gate accumulator quant - forget_acc_scale, - forget_acc_zero_point, - forget_acc_bit_width, - forget_acc_narrow_range, - forget_acc_signed, - forget_acc_rounding_mode, # Cell gate accumulator quant - cell_acc_scale, - cell_acc_zero_point, - cell_acc_bit_width, - cell_acc_narrow_range, - cell_acc_signed, - cell_acc_rounding_mode, # Output gate accumulator quant - output_acc_scale, - output_acc_zero_point, - output_acc_bit_width, - output_acc_narrow_range, - output_acc_signed, - output_acc_rounding_mode, # Input gate sigmoid quant - input_sigmoid_scale, - input_sigmoid_zero_point, - input_sigmoid_bit_width, - input_sigmoid_narrow_range, - input_sigmoid_signed, - input_sigmoid_rounding_mode, # Forget gate sigmoid quant - forget_sigmoid_scale, - forget_sigmoid_zero_point, - forget_sigmoid_bit_width, - forget_sigmoid_narrow_range, - forget_sigmoid_signed, - forget_sigmoid_rounding_mode, # Cell gate tanh quant - cell_tanh_scale, - cell_tanh_zero_point, - cell_tanh_bit_width, - cell_tanh_narrow_range, - cell_tanh_signed, - cell_tanh_rounding_mode, # Output gate sigmoid quant - output_sigmoid_scale, - output_sigmoid_zero_point, - output_sigmoid_bit_width, - output_sigmoid_narrow_range, - output_sigmoid_signed, - output_sigmoid_rounding_mode, # Hidden state tanh quant - hidden_state_tanh_scale, - hidden_state_tanh_zero_point, - hidden_state_tanh_bit_width, - hidden_state_tanh_narrow_range, - hidden_state_tanh_signed, - hidden_state_tanh_rounding_mode): + ctx, # args and kwargs passed from _QuantLSTMLayer + quant_input, + quant_hidden_state, + quant_cell_state, + quant_weight_ii, + quant_weight_if, + quant_weight_ic, + quant_weight_io, + quant_weight_hi, + quant_weight_hf, + quant_weight_hc, + quant_weight_ho, + quant_bias_input, + quant_bias_forget, + quant_bias_cell, + quant_bias_output, # Symbolic kwargs passed from BrevitasQuantLSTMLayerHandler + batch_first, + reverse_input, + cifg, # Output quant + output_scale, + output_zero_point, + output_bit_width, + output_narrow_range, + output_signed, + output_rounding_mode, # Cell state quant + cell_state_scale, + cell_state_zero_point, + cell_state_bit_width, + cell_state_narrow_range, + cell_state_signed, + cell_state_rounding_mode, # Input gate accumulator quant + input_acc_scale, + input_acc_zero_point, + input_acc_bit_width, + input_acc_narrow_range, + input_acc_signed, + input_acc_rounding_mode, # Forget gate accumulator quant + forget_acc_scale, + forget_acc_zero_point, + forget_acc_bit_width, + forget_acc_narrow_range, + forget_acc_signed, + forget_acc_rounding_mode, # Cell gate accumulator quant + cell_acc_scale, + cell_acc_zero_point, + cell_acc_bit_width, + cell_acc_narrow_range, + cell_acc_signed, + cell_acc_rounding_mode, # Output gate accumulator quant + output_acc_scale, + output_acc_zero_point, + output_acc_bit_width, + output_acc_narrow_range, + output_acc_signed, + output_acc_rounding_mode, # Input gate sigmoid quant + input_sigmoid_scale, + input_sigmoid_zero_point, + input_sigmoid_bit_width, + input_sigmoid_narrow_range, + input_sigmoid_signed, + input_sigmoid_rounding_mode, # Forget gate sigmoid quant + forget_sigmoid_scale, + forget_sigmoid_zero_point, + forget_sigmoid_bit_width, + forget_sigmoid_narrow_range, + forget_sigmoid_signed, + forget_sigmoid_rounding_mode, # Cell gate tanh quant + cell_tanh_scale, + cell_tanh_zero_point, + cell_tanh_bit_width, + cell_tanh_narrow_range, + cell_tanh_signed, + cell_tanh_rounding_mode, # Output gate sigmoid quant + output_sigmoid_scale, + output_sigmoid_zero_point, + output_sigmoid_bit_width, + output_sigmoid_narrow_range, + output_sigmoid_signed, + output_sigmoid_rounding_mode, # Hidden state tanh quant + hidden_state_tanh_scale, + hidden_state_tanh_zero_point, + hidden_state_tanh_bit_width, + hidden_state_tanh_narrow_range, + hidden_state_tanh_signed, + hidden_state_tanh_rounding_mode, + ): # Tp simplify things, here we are returning the outputs # as if they were already concatenated. Scale/zp/bw are avoided too. # This preserves output shapes but not values. # See _QuantLSTMCell for the actual implementation. quant_outputs = torch.zeros( - quant_input.size(0), - quant_input.size(1), - quant_hidden_state.size(1), - device=quant_hidden_state.device) + quant_input.size(0), quant_input.size(1), quant_hidden_state.size(1), device=quant_hidden_state.device + ) return quant_outputs, quant_hidden_state, quant_cell_state diff --git a/notebooks/4_quant_lstm_helper/handler.py b/notebooks/4_quant_lstm_helper/handler.py index 948eb647..71cbdeb1 100644 --- a/notebooks/4_quant_lstm_helper/handler.py +++ b/notebooks/4_quant_lstm_helper/handler.py @@ -1,32 +1,23 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import torch from abc import ABC -from copy import copy +from brevitas.export.common.handler.qcdq import ( + DQMixin, + QCDQActQuantProxyHandlerMixin, + QCDQBiasQuantProxyHandlerMixin, + QCDQDecoupledWeightQuantProxyHandlerMixin, + QCDQMixin, + QCDQTruncQuantProxyHandlerMixin, + QCDQWeightQuantProxyHandlerMixin, +) +from brevitas.export.onnx.handler import ONNXBaseHandler, QuantLSTMLayerHandler -import torch -from torch import Tensor - -from brevitas.export.common.handler.base import QuantAxisMixin -from brevitas.export.common.handler.qcdq import DQMixin -from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQMixin -from brevitas.export.common.handler.qcdq import QCDQTruncQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQWeightQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import ZeroPointHandlerMixin -from brevitas.export.onnx.handler import ONNXBaseHandler -from brevitas.export.onnx.handler import QuantLSTMLayerHandler - -from ..function import DequantizeLinearFn -from ..function import IntClipFn -from ..function import QuantizeLinearFn -from ..function import BrevitasQuantLSTMCellFn +from ..function import BrevitasQuantLSTMCellFn, DequantizeLinearFn, IntClipFn, QuantizeLinearFn class StdDQONNXMixin(DQMixin, ABC): - def dequantize_fn(self, x, scale, zero_point, axis): return DequantizeLinearFn.apply(x, scale, zero_point, axis) @@ -40,7 +31,6 @@ def itemize_quantize_scalar_params(self): class StdQCDQONNXMixin(QCDQMixin, StdDQONNXMixin, ABC): - @property def clip_over_integers(self): return True @@ -59,8 +49,8 @@ def int32_dtype(cls): def validate(self, module): self.validate_8b_bit_width(module.bit_width(), le_then=True) - assert module.bit_width() > 1., 'Binary quant not supported' - assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported' + assert module.bit_width() > 1.0, "Binary quant not supported" + assert module.rounding_mode.upper() == "ROUND", "Only round to nearest even supported" def quantize_fn(self, x, scale, zero_point, dtype, axis): return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis) @@ -69,55 +59,47 @@ def clip_fn(self, x, min_val, max_val): return IntClipFn.apply(x, min_val, max_val) -class StdQCDQONNXWeightQuantProxyHandler(StdQCDQONNXMixin, - QCDQWeightQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdQCDQONNXWeightQuantProxyHandler(StdQCDQONNXMixin, QCDQWeightQuantProxyHandlerMixin, ONNXBaseHandler): pass -class StdQCDQONNXDecoupledWeightQuantProxyHandler(StdQCDQONNXMixin, - QCDQDecoupledWeightQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdQCDQONNXDecoupledWeightQuantProxyHandler( + StdQCDQONNXMixin, QCDQDecoupledWeightQuantProxyHandlerMixin, ONNXBaseHandler +): pass -class StdQCDQONNXActQuantProxyHandler(StdQCDQONNXMixin, - QCDQActQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdQCDQONNXActQuantProxyHandler(StdQCDQONNXMixin, QCDQActQuantProxyHandlerMixin, ONNXBaseHandler): pass -class StdQCDQONNXBiasQuantProxyHandler(StdDQONNXMixin, - QCDQBiasQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdQCDQONNXBiasQuantProxyHandler(StdDQONNXMixin, QCDQBiasQuantProxyHandlerMixin, ONNXBaseHandler): pass -class StdQCDQONNXTruncQuantProxyHandler(StdQCDQONNXMixin, - QCDQTruncQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdQCDQONNXTruncQuantProxyHandler(StdQCDQONNXMixin, QCDQTruncQuantProxyHandlerMixin, ONNXBaseHandler): pass class StdQCDQONNXQuantLSTMLayerHandler(QuantLSTMLayerHandler): - def quantized_cell_symbolic_execution( - self, - quant_input, - quant_hidden_state, - quant_cell_state, - quant_weight_ii, - quant_weight_if, - quant_weight_ic, - quant_weight_io, - quant_weight_hi, - quant_weight_hf, - quant_weight_hc, - quant_weight_ho, - quant_bias_input, - quant_bias_forget, - quant_bias_cell, - quant_bias_output): + self, + quant_input, + quant_hidden_state, + quant_cell_state, + quant_weight_ii, + quant_weight_if, + quant_weight_ic, + quant_weight_io, + quant_weight_hi, + quant_weight_hf, + quant_weight_hc, + quant_weight_ho, + quant_bias_input, + quant_bias_forget, + quant_bias_cell, + quant_bias_output, + ): return BrevitasQuantLSTMCellFn.apply( quant_input, quant_hidden_state, @@ -134,7 +116,8 @@ def quantized_cell_symbolic_execution( quant_bias_forget, quant_bias_cell, quant_bias_output, - *self.symbolic_kwargs.values()) + *self.symbolic_kwargs.values() + ) # raise RuntimeError( # "Quantized LSTM cell is not supported for ONNX QCDQ " # "(weights only quantization is). Use export_qonnx.") diff --git a/src/qonnx/transformation/lower_convs_to_matmul.py b/src/qonnx/transformation/lower_convs_to_matmul.py index c5964cf4..49700cd7 100644 --- a/src/qonnx/transformation/lower_convs_to_matmul.py +++ b/src/qonnx/transformation/lower_convs_to_matmul.py @@ -52,10 +52,27 @@ def apply(self, model): continue # extract parameters of node - (cnv_input, cnv_output, cnv_input_datatype, cnv_output_datatype, - k_h, k_w, stride_h, stride_w, group, weight_name, W_conv, ifm_ch, - ofm_ch, ifm_dim_h, ifm_dim_w, ofm_dim_h, ofm_dim_w, dilation, pad) =\ - self.extract_conv_params(model, node) + ( + cnv_input, + cnv_output, + cnv_input_datatype, + cnv_output_datatype, + k_h, + k_w, + stride_h, + stride_w, + group, + weight_name, + W_conv, + ifm_ch, + ofm_ch, + ifm_dim_h, + ifm_dim_w, + ofm_dim_h, + ofm_dim_w, + dilation, + pad, + ) = self.extract_conv_params(model, node) # if depthwise conv create sparse matrix and variable "dw" # to store as attribute in Im2Col that indicates that the created @@ -122,9 +139,16 @@ def apply(self, model): im2col_out = im2col_out.name model.set_tensor_datatype(im2col_out, cnv_input_datatype) im2col_node = helper.make_node( - "Im2Col", [inp_trans_out], [im2col_out], domain="qonnx.custom_op.general", - stride=[stride_h, stride_w], kernel_size=[k_h, k_w], pad_amount=pad, - input_shape="(1,{},{},{})".format(ifm_dim_h, ifm_dim_w, ifm_ch), depthwise=dw, dilations=dilation + "Im2Col", + [inp_trans_out], + [im2col_out], + domain="qonnx.custom_op.general", + stride=[stride_h, stride_w], + kernel_size=[k_h, k_w], + pad_amount=pad, + input_shape="(1,{},{},{})".format(ifm_dim_h, ifm_dim_w, ifm_ch), + depthwise=dw, + dilations=dilation, ) nodes_to_insert.append(im2col_node) @@ -144,7 +168,6 @@ def apply(self, model): return (model, graph_modified) def extract_conv_params(self, model, node): - cnv_input = node.input[0] cnv_output = node.output[0] cnv_input_datatype = model.get_tensor_datatype(cnv_input) @@ -179,6 +202,24 @@ def extract_conv_params(self, model, node): if len(pad) == 2: # only one dimension should be padded assert ifm_dim_h == 1 or ifm_dim_w == 1, "Padding is assumed to be 1D, image is 2D" - return (cnv_input, cnv_output, cnv_input_datatype, cnv_output_datatype, k_h, k_w, stride_h, - stride_w, group, weight_name, W_conv, ifm_ch, ofm_ch, ifm_dim_h, ifm_dim_w, ofm_dim_h, - ofm_dim_w, dilation, pad) + return ( + cnv_input, + cnv_output, + cnv_input_datatype, + cnv_output_datatype, + k_h, + k_w, + stride_h, + stride_w, + group, + weight_name, + W_conv, + ifm_ch, + ofm_ch, + ifm_dim_h, + ifm_dim_w, + ofm_dim_h, + ofm_dim_w, + dilation, + pad, + ) diff --git a/tests/transformation/test_conv_lowering.py b/tests/transformation/test_conv_lowering.py index 044da1b2..788d6993 100644 --- a/tests/transformation/test_conv_lowering.py +++ b/tests/transformation/test_conv_lowering.py @@ -65,7 +65,7 @@ def test_conv_lowering_convmnist(): model = model.transform(InferShapes()) output_dict_p = oxe.execute_onnx(model, input_dict) produced = output_dict_p[output_name] - assert np.isclose(produced, expected, rtol=1.e-4).all() + assert np.isclose(produced, expected, rtol=1.0e-4).all() def run_conv_lowering_test(idt, k_h, k_w, ifm_dim_h, ifm_dim_w, ifm_ch, stride, padding, dilations, dw, bias): From 8f6661524c57fc4f54fe7758399428a3f1268624 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Fri, 16 Aug 2024 15:37:52 +0200 Subject: [PATCH 13/28] [LowerConv] skip convs with non-initialized weights --- src/qonnx/transformation/lower_convs_to_matmul.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/qonnx/transformation/lower_convs_to_matmul.py b/src/qonnx/transformation/lower_convs_to_matmul.py index 49700cd7..59ddbce6 100644 --- a/src/qonnx/transformation/lower_convs_to_matmul.py +++ b/src/qonnx/transformation/lower_convs_to_matmul.py @@ -51,6 +51,10 @@ def apply(self, model): warnings.warn("Found Conv node with bias, skipping") continue + if model.get_initializer(node.input[1]) is None: + warnings.warn("Found Conv node with non-initialized weight, skipping") + continue + # extract parameters of node ( cnv_input, From a92093c32268eae06e09aa0da65a56ddb4bee217 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Fri, 16 Aug 2024 15:38:33 +0200 Subject: [PATCH 14/28] [Test] add (failing) quant weight conv testcase for lowering --- tests/transformation/test_conv_lowering.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/transformation/test_conv_lowering.py b/tests/transformation/test_conv_lowering.py index 788d6993..eea53c55 100644 --- a/tests/transformation/test_conv_lowering.py +++ b/tests/transformation/test_conv_lowering.py @@ -43,6 +43,14 @@ from qonnx.transformation.infer_shapes import InferShapes from qonnx.transformation.lower_convs_to_matmul import LowerConvsToMatMul from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model +from qonnx.util.test import download_model + + +def test_conv_lowering_quant_weights(): + model_name = "FINN-CNV_W2A2" + model = download_model(model_name, return_modelwrapper=True, do_cleanup=True) + model = model.transform(LowerConvsToMatMul()) + assert model.get_nodes_by_op_type("Conv") == [] def test_conv_lowering_convmnist(): From 5e5bb5523137df04632b73edf7f7828cabc84ded Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Fri, 16 Aug 2024 16:16:48 +0200 Subject: [PATCH 15/28] [LowerConv] support lowering Conv with Quant node on weights --- .../transformation/lower_convs_to_matmul.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/qonnx/transformation/lower_convs_to_matmul.py b/src/qonnx/transformation/lower_convs_to_matmul.py index 59ddbce6..30ed85ca 100644 --- a/src/qonnx/transformation/lower_convs_to_matmul.py +++ b/src/qonnx/transformation/lower_convs_to_matmul.py @@ -51,10 +51,6 @@ def apply(self, model): warnings.warn("Found Conv node with bias, skipping") continue - if model.get_initializer(node.input[1]) is None: - warnings.warn("Found Conv node with non-initialized weight, skipping") - continue - # extract parameters of node ( cnv_input, @@ -67,6 +63,7 @@ def apply(self, model): stride_w, group, weight_name, + conv_weight_inp_name, W_conv, ifm_ch, ofm_ch, @@ -78,6 +75,10 @@ def apply(self, model): pad, ) = self.extract_conv_params(model, node) + if W_conv is None: + warnings.warn("Found Conv node with non-initialized weight, skipping") + continue + # if depthwise conv create sparse matrix and variable "dw" # to store as attribute in Im2Col that indicates that the created # Im2Col node belongs to a depthwise convolution @@ -108,6 +109,8 @@ def apply(self, model): # transpose to get ONNX-compatible [k_h*k_w*IFM][OFM] matrix W_matmul = W_matmul.T model.set_initializer(weight_name, W_matmul) + if weight_name != conv_weight_inp_name: + model.set_tensor_shape(conv_weight_inp_name, W_matmul.shape) # create new intermediate values inp_trans_out = helper.make_tensor_value_info( @@ -158,7 +161,7 @@ def apply(self, model): matmul_input = im2col_out if need_im2col else inp_trans_out # do matmul - matmul_node = helper.make_node("MatMul", [matmul_input, weight_name], [matmul_out]) + matmul_node = helper.make_node("MatMul", [matmul_input, conv_weight_inp_name], [matmul_out]) # NHWC -> NCHW out_trans_node = helper.make_node("Transpose", [matmul_out], [cnv_output], perm=[0, 3, 1, 2]) @@ -182,7 +185,14 @@ def extract_conv_params(self, model, node): stride_w = get_by_name(node.attribute, "strides").ints[1] group = get_by_name(node.attribute, "group").i weight_name = node.input[1] + conv_weight_inp_name = node.input[1] W_conv = model.get_initializer(weight_name) + if W_conv is None: + # check to see if there is an immediate quantizer node feeding the weight input + w_producer = model.find_producer(weight_name) + if not (w_producer is None) and w_producer.op_type == "Quant": + W_conv = model.get_initializer(w_producer.input[0]) + weight_name = w_producer.input[0] ifm_ch = model.get_tensor_shape(cnv_input)[1] # assume NCHW ofm_ch = model.get_tensor_shape(cnv_output)[1] # assume NCHW ifm_dim_h = model.get_tensor_shape(cnv_input)[2] # assume NCHW @@ -217,6 +227,7 @@ def extract_conv_params(self, model, node): stride_w, group, weight_name, + conv_weight_inp_name, W_conv, ifm_ch, ofm_ch, From c54f142e2c3ec6b9a8c9deffa5072f56cdff5f1b Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Fri, 16 Aug 2024 16:17:15 +0200 Subject: [PATCH 16/28] [Test] extend quant weight conv testcase, now passing --- tests/transformation/test_conv_lowering.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/transformation/test_conv_lowering.py b/tests/transformation/test_conv_lowering.py index eea53c55..b2003a77 100644 --- a/tests/transformation/test_conv_lowering.py +++ b/tests/transformation/test_conv_lowering.py @@ -43,7 +43,7 @@ from qonnx.transformation.infer_shapes import InferShapes from qonnx.transformation.lower_convs_to_matmul import LowerConvsToMatMul from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model -from qonnx.util.test import download_model +from qonnx.util.test import download_model, get_golden_in_and_output def test_conv_lowering_quant_weights(): @@ -51,6 +51,11 @@ def test_conv_lowering_quant_weights(): model = download_model(model_name, return_modelwrapper=True, do_cleanup=True) model = model.transform(LowerConvsToMatMul()) assert model.get_nodes_by_op_type("Conv") == [] + input_t, golden_t = get_golden_in_and_output(model_name) + input_dict = {model.graph.input[0].name: input_t} + prod_dict = oxe.execute_onnx(model, input_dict) + prod_t = prod_dict[model.graph.output[0].name] + assert (prod_t == golden_t).all() def test_conv_lowering_convmnist(): From 8d1ee1d8fed3e93f5f5e7b6a7e4a260577350e90 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Fri, 16 Aug 2024 16:42:46 +0200 Subject: [PATCH 17/28] [LowerConv] support reshaping quant conv weight scales --- .../transformation/lower_convs_to_matmul.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/qonnx/transformation/lower_convs_to_matmul.py b/src/qonnx/transformation/lower_convs_to_matmul.py index 30ed85ca..89c08eae 100644 --- a/src/qonnx/transformation/lower_convs_to_matmul.py +++ b/src/qonnx/transformation/lower_convs_to_matmul.py @@ -64,6 +64,7 @@ def apply(self, model): group, weight_name, conv_weight_inp_name, + conv_weight_q_scale_name, W_conv, ifm_ch, ofm_ch, @@ -110,7 +111,19 @@ def apply(self, model): W_matmul = W_matmul.T model.set_initializer(weight_name, W_matmul) if weight_name != conv_weight_inp_name: + # required for convs with quantized weights model.set_tensor_shape(conv_weight_inp_name, W_matmul.shape) + if conv_weight_q_scale_name is not None: + # required for convs with quantized weights + scale_weight_q = model.get_initializer(conv_weight_q_scale_name) + # scale shape is originally [OFM, IFM, k_H, k_W] + # transpose into [OFM, k_H, k_W, IFM] + scale_weight_q = scale_weight_q.transpose(0, 2, 3, 1) + # reshape into [OFM][k_h*k_w*IFM] matrix + scale_weight_q = scale_weight_q.reshape(ofm_ch, -1) + # transpose to be shape-compatible with weight matrix + scale_weight_q = scale_weight_q.T + model.set_initializer(conv_weight_q_scale_name, scale_weight_q) # create new intermediate values inp_trans_out = helper.make_tensor_value_info( @@ -186,6 +199,7 @@ def extract_conv_params(self, model, node): group = get_by_name(node.attribute, "group").i weight_name = node.input[1] conv_weight_inp_name = node.input[1] + conv_weight_q_scale_name = None W_conv = model.get_initializer(weight_name) if W_conv is None: # check to see if there is an immediate quantizer node feeding the weight input @@ -193,6 +207,7 @@ def extract_conv_params(self, model, node): if not (w_producer is None) and w_producer.op_type == "Quant": W_conv = model.get_initializer(w_producer.input[0]) weight_name = w_producer.input[0] + conv_weight_q_scale_name = w_producer.input[1] ifm_ch = model.get_tensor_shape(cnv_input)[1] # assume NCHW ofm_ch = model.get_tensor_shape(cnv_output)[1] # assume NCHW ifm_dim_h = model.get_tensor_shape(cnv_input)[2] # assume NCHW @@ -228,6 +243,7 @@ def extract_conv_params(self, model, node): group, weight_name, conv_weight_inp_name, + conv_weight_q_scale_name, W_conv, ifm_ch, ofm_ch, From 55c2f6faa70f98583aa38175a92ecd88aec960f1 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Fri, 16 Aug 2024 16:43:31 +0200 Subject: [PATCH 18/28] [Test] add MNv1 for quant conv lowering test --- tests/transformation/test_conv_lowering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/transformation/test_conv_lowering.py b/tests/transformation/test_conv_lowering.py index b2003a77..091619e3 100644 --- a/tests/transformation/test_conv_lowering.py +++ b/tests/transformation/test_conv_lowering.py @@ -46,8 +46,8 @@ from qonnx.util.test import download_model, get_golden_in_and_output -def test_conv_lowering_quant_weights(): - model_name = "FINN-CNV_W2A2" +@pytest.mark.parametrize("model_name", ["FINN-CNV_W2A2", "MobileNetv1-w4a4"]) +def test_conv_lowering_quant_weights(model_name): model = download_model(model_name, return_modelwrapper=True, do_cleanup=True) model = model.transform(LowerConvsToMatMul()) assert model.get_nodes_by_op_type("Conv") == [] From a3451c5ef64f8eac40b0fc21247f0f781a907d20 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Sun, 18 Aug 2024 22:31:15 +0200 Subject: [PATCH 19/28] [Test] use np.isclose instead of equals for test condition --- tests/transformation/test_conv_lowering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/transformation/test_conv_lowering.py b/tests/transformation/test_conv_lowering.py index 091619e3..c4470e93 100644 --- a/tests/transformation/test_conv_lowering.py +++ b/tests/transformation/test_conv_lowering.py @@ -55,7 +55,7 @@ def test_conv_lowering_quant_weights(model_name): input_dict = {model.graph.input[0].name: input_t} prod_dict = oxe.execute_onnx(model, input_dict) prod_t = prod_dict[model.graph.output[0].name] - assert (prod_t == golden_t).all() + assert np.isclose(prod_t, golden_t).all() def test_conv_lowering_convmnist(): From 100bfdef896c9ca7c31f8a3b681beb66d109f43d Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 22 Aug 2024 09:50:33 +0200 Subject: [PATCH 20/28] [Util] break out test input generation function & allow seed setting --- src/qonnx/util/test.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/qonnx/util/test.py b/src/qonnx/util/test.py index f18e437e..ff0fcb15 100644 --- a/src/qonnx/util/test.py +++ b/src/qonnx/util/test.py @@ -145,15 +145,20 @@ def qonnx_download_model(): clize.run(download_model) -def get_golden_in_and_output(test_model): - model = download_model(test_model, do_cleanup=True, return_modelwrapper=True) - rng = np.random.RandomState(42) +def get_random_input(test_model, seed=42): + rng = np.random.RandomState(seed) input_shape = test_model_details[test_model]["input_shape"] (low, high) = test_model_details[test_model]["input_range"] size = np.prod(np.asarray(input_shape)) input_tensor = rng.uniform(low=low, high=high, size=size) input_tensor = input_tensor.astype(np.float32) input_tensor = input_tensor.reshape(input_shape) + return input_tensor + + +def get_golden_in_and_output(test_model, seed=42): + model = download_model(test_model, do_cleanup=True, return_modelwrapper=True) + input_tensor = get_random_input(test_model, seed=seed) input_dict = {model.graph.input[0].name: input_tensor} golden_output_dict = oxe.execute_onnx(model, input_dict) golden_result = golden_output_dict[model.graph.output[0].name] From 032681c5137848531a6c26ee7f05a0b2a8241d68 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 22 Aug 2024 09:51:08 +0200 Subject: [PATCH 21/28] [Lower] fix quant scale conversion, adjust seed random input generated with seed=42 was causing a major difference in Conv_13_out0 for no apparent reason (probably float / numerical related) --- .../transformation/lower_convs_to_matmul.py | 19 +++++++++++-------- tests/transformation/test_conv_lowering.py | 6 +++--- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/qonnx/transformation/lower_convs_to_matmul.py b/src/qonnx/transformation/lower_convs_to_matmul.py index 89c08eae..81f0b713 100644 --- a/src/qonnx/transformation/lower_convs_to_matmul.py +++ b/src/qonnx/transformation/lower_convs_to_matmul.py @@ -86,6 +86,8 @@ def apply(self, model): dw = False if group == ifm_ch and ofm_ch == ifm_ch: W_sparse = np.zeros((ofm_ch, ifm_ch, k_h, k_w)) # (OFM, IFM, k_H, k_W) + # TODO: if the convolution is quantized with a non-zero zeropoint we + # should be using the zeropoint value here instead of np.zeros for ch in range(ifm_ch): W_sparse[ch][ch] = W_conv[ch][0] # W_conv = [OFM, IFM, k_H, k_W] W_conv = W_sparse.astype(np.float32) @@ -116,14 +118,15 @@ def apply(self, model): if conv_weight_q_scale_name is not None: # required for convs with quantized weights scale_weight_q = model.get_initializer(conv_weight_q_scale_name) - # scale shape is originally [OFM, IFM, k_H, k_W] - # transpose into [OFM, k_H, k_W, IFM] - scale_weight_q = scale_weight_q.transpose(0, 2, 3, 1) - # reshape into [OFM][k_h*k_w*IFM] matrix - scale_weight_q = scale_weight_q.reshape(ofm_ch, -1) - # transpose to be shape-compatible with weight matrix - scale_weight_q = scale_weight_q.T - model.set_initializer(conv_weight_q_scale_name, scale_weight_q) + if scale_weight_q.ndim > 0: + # scale shape is originally [OFM, IFM, k_H, k_W] + # transpose into [OFM, k_H, k_W, IFM] + scale_weight_q = scale_weight_q.transpose(0, 2, 3, 1) + # reshape into [OFM][k_h*k_w*IFM] matrix + scale_weight_q = scale_weight_q.reshape(ofm_ch, -1) + # transpose to be shape-compatible with weight matrix + scale_weight_q = scale_weight_q.T + model.set_initializer(conv_weight_q_scale_name, scale_weight_q) # create new intermediate values inp_trans_out = helper.make_tensor_value_info( diff --git a/tests/transformation/test_conv_lowering.py b/tests/transformation/test_conv_lowering.py index c4470e93..0da57ea3 100644 --- a/tests/transformation/test_conv_lowering.py +++ b/tests/transformation/test_conv_lowering.py @@ -49,13 +49,13 @@ @pytest.mark.parametrize("model_name", ["FINN-CNV_W2A2", "MobileNetv1-w4a4"]) def test_conv_lowering_quant_weights(model_name): model = download_model(model_name, return_modelwrapper=True, do_cleanup=True) + input_t, golden_t = get_golden_in_and_output(model_name, seed=0) + input_dict = {model.graph.input[0].name: input_t} model = model.transform(LowerConvsToMatMul()) assert model.get_nodes_by_op_type("Conv") == [] - input_t, golden_t = get_golden_in_and_output(model_name) - input_dict = {model.graph.input[0].name: input_t} prod_dict = oxe.execute_onnx(model, input_dict) prod_t = prod_dict[model.graph.output[0].name] - assert np.isclose(prod_t, golden_t).all() + assert np.isclose(golden_t, prod_t, atol=1e-04).all() def test_conv_lowering_convmnist(): From 3c870d698bb6ae9d18ee5d8f875ad9dcd95c3a9c Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Mon, 9 Sep 2024 11:38:59 +0300 Subject: [PATCH 22/28] [Util] add accumulator-aware quantized (A2Q) CIFAR-10 models --- src/qonnx/util/test.py | 71 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/src/qonnx/util/test.py b/src/qonnx/util/test.py index ff0fcb15..84f83134 100644 --- a/src/qonnx/util/test.py +++ b/src/qonnx/util/test.py @@ -37,6 +37,76 @@ # utility functions to fetch models and data for # testing various qonnx transformations +a2q_rn18_preproc_mean = np.asarray([0.491, 0.482, 0.447], dtype=np.float32) +a2q_rn18_preproc_std = np.asarray([0.247, 0.243, 0.262], dtype=np.float32) +a2q_rn18_int_range = (0, 255) +a2q_rn18_iscale = 1 / 255 +a2q_rn18_rmin = (a2q_rn18_int_range[0] * a2q_rn18_iscale - a2q_rn18_preproc_mean) / a2q_rn18_preproc_std +a2q_rn18_rmax = (a2q_rn18_int_range[1] * a2q_rn18_iscale - a2q_rn18_preproc_mean) / a2q_rn18_preproc_std +a2q_rn18_scale = (1 / a2q_rn18_preproc_std) * a2q_rn18_iscale +a2q_rn18_bias = -a2q_rn18_preproc_mean * a2q_rn18_preproc_std +a2q_rn18_common = { + "input_shape": (1, 3, 32, 32), + "input_range": (a2q_rn18_rmin, a2q_rn18_rmax), + "int_range": a2q_rn18_int_range, + "scale": a2q_rn18_scale, + "bias": a2q_rn18_bias, +} +a2q_rn18_urlbase = "https://github.com/fastmachinelearning/qonnx_model_zoo/releases/download/a2q-20240905/" + +a2q_model_details = { + "rn18_w4a4_a2q_16b": { + "description": "4-bit ResNet-18 on CIFAR-10, A2Q 16-bit accumulators", + "url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_16b-d4bfa990.onnx", + **a2q_rn18_common, + }, + "rn18_w4a4_a2q_15b": { + "description": "4-bit ResNet-18 on CIFAR-10, A2Q 15-bit accumulators", + "url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_15b-eeca8ac2.onnx", + **a2q_rn18_common, + }, + "rn18_w4a4_a2q_14b": { + "description": "4-bit ResNet-18 on CIFAR-10, A2Q 14-bit accumulators", + "url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_14b-563cf426.onnx", + **a2q_rn18_common, + }, + "rn18_w4a4_a2q_13b": { + "description": "4-bit ResNet-18 on CIFAR-10, A2Q 13-bit accumulators", + "url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_13b-d3cae293.onnx", + **a2q_rn18_common, + }, + "rn18_w4a4_a2q_12b": { + "description": "4-bit ResNet-18 on CIFAR-10, A2Q 12-bit accumulators", + "url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_12b-fb3a0f8a.onnx", + **a2q_rn18_common, + }, + "rn18_w4a4_a2q_plus_16b": { + "description": "4-bit ResNet-18 on CIFAR-10, A2Q+ 16-bit accumulators", + "url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_plus_16b-09e47feb.onnx", + **a2q_rn18_common, + }, + "rn18_w4a4_a2q_plus_15b": { + "description": "4-bit ResNet-18 on CIFAR-10, A2Q+ 15-bit accumulators", + "url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_plus_15b-10e7bc83.onnx", + **a2q_rn18_common, + }, + "rn18_w4a4_a2q_plus_14b": { + "description": "4-bit ResNet-18 on CIFAR-10, A2Q+ 14-bit accumulators", + "url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_plus_14b-8db8c78c.onnx", + **a2q_rn18_common, + }, + "rn18_w4a4_a2q_plus_13b": { + "description": "4-bit ResNet-18 on CIFAR-10, A2Q+ 13-bit accumulators", + "url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_plus_13b-f57b05ce.onnx", + **a2q_rn18_common, + }, + "rn18_w4a4_a2q_plus_12b": { + "description": "4-bit ResNet-18 on CIFAR-10, A2Q+ 12-bit accumulators", + "url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_plus_12b-1e2aca29.onnx", + **a2q_rn18_common, + }, +} + test_model_details = { "FINN-CNV_W2A2": { "description": "2-bit VGG-10-like CNN on CIFAR-10", @@ -116,6 +186,7 @@ "input_shape": (1, 3, 224, 224), "input_range": (0, 1), }, + **a2q_model_details, } From ee7464f4c8d68a01acec3617fb3758828867c30b Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Mon, 9 Sep 2024 12:22:09 +0300 Subject: [PATCH 23/28] [Test] correctly handle multi-channel input ranges in change_batchsize --- tests/transformation/test_change_batchsize.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/transformation/test_change_batchsize.py b/tests/transformation/test_change_batchsize.py index 08d7c20f..e6c76da1 100644 --- a/tests/transformation/test_change_batchsize.py +++ b/tests/transformation/test_change_batchsize.py @@ -45,6 +45,11 @@ def test_change_batchsize(test_model): batch_size = 10 old_ishape = test_details["input_shape"] imin, imax = test_details["input_range"] + # some models spec per-channel ranges, be conservative for those + if isinstance(imin, np.ndarray): + imin = imin.max() + if isinstance(imax, np.ndarray): + imax = imax.min() model = download_model(test_model=test_model, do_cleanup=True, return_modelwrapper=True) iname = model.graph.input[0].name oname = model.graph.output[0].name From 8694a6de703e432dcedd67a4e15e88d914da8c08 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Mon, 9 Sep 2024 12:26:09 +0300 Subject: [PATCH 24/28] [Util] handle per-channel ranges in get_random_input --- src/qonnx/util/test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/qonnx/util/test.py b/src/qonnx/util/test.py index 84f83134..47b4378f 100644 --- a/src/qonnx/util/test.py +++ b/src/qonnx/util/test.py @@ -220,6 +220,11 @@ def get_random_input(test_model, seed=42): rng = np.random.RandomState(seed) input_shape = test_model_details[test_model]["input_shape"] (low, high) = test_model_details[test_model]["input_range"] + # some models spec per-channel ranges, be conservative for those + if isinstance(low, np.ndarray): + low = low.max() + if isinstance(high, np.ndarray): + high = high.min() size = np.prod(np.asarray(input_shape)) input_tensor = rng.uniform(low=low, high=high, size=size) input_tensor = input_tensor.astype(np.float32) From 8bad7e71806d6c611c68fe00ac6007b076b08b5f Mon Sep 17 00:00:00 2001 From: jvreca Date: Thu, 22 Aug 2024 17:06:23 +0200 Subject: [PATCH 25/28] Added Identity node to the removal list --- src/qonnx/transformation/remove.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/qonnx/transformation/remove.py b/src/qonnx/transformation/remove.py index 980e80c1..0f7f38f7 100644 --- a/src/qonnx/transformation/remove.py +++ b/src/qonnx/transformation/remove.py @@ -138,5 +138,9 @@ def apply(self, model): remove_node_and_rewire(model, n) graph_modified = True break + elif n.op_type == "Identity": + remove_node_and_rewire(model, n) + graph_modified = True + break model = model.transform(InferShapes()) return (model, graph_modified) From 71ee78062ebdb5ae58dfbcc644d97d07dff3beb1 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 12 Sep 2024 10:42:11 +0300 Subject: [PATCH 26/28] [Test] add Identity op case to test_remove_identity_ops --- tests/transformation/test_remove_identity_ops.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/transformation/test_remove_identity_ops.py b/tests/transformation/test_remove_identity_ops.py index ed34ffe6..d9e92c73 100644 --- a/tests/transformation/test_remove_identity_ops.py +++ b/tests/transformation/test_remove_identity_ops.py @@ -51,25 +51,30 @@ def insert_identity_op(model, op, as_first_node, approx): val = np.asarray([zero_val], dtype=np.float32) elif op in ["Mul", "Div"]: val = np.asarray([one_val], dtype=np.float32) + elif op in ["Identity"]: + val = None else: return graph = model.graph + if val is None: + inplist = ["inp" if as_first_node else "div_out"] + else: + model.set_initializer("value", val) + inplist = ["inp" if as_first_node else "div_out", "value"] + identity_node = helper.make_node(op, inplist, ["ident_out"]) if as_first_node: - identity_node = helper.make_node(op, ["inp", "value"], ["ident_out"]) graph.node.insert(0, identity_node) graph.node[1].input[0] = "ident_out" else: - identity_node = helper.make_node(op, ["div_out", "value"], ["ident_out"]) graph.node.insert(3, identity_node) graph.node[-1].input[0] = "ident_out" - model.set_initializer("value", val) return model # identity operations to be inserted -@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div"]) +@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div", "Identity"]) @pytest.mark.parametrize("approx", [False, True]) @pytest.mark.parametrize("as_first_node", [False, True]) def test_remove_identity_ops(op, as_first_node, approx): From 0a4d5c5315082582d3a646e9504fe129b4ff0fd6 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 12 Sep 2024 12:01:18 +0300 Subject: [PATCH 27/28] [ModelWrapper] add top-level checks for fork/join checks --- src/qonnx/core/modelwrapper.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index b95c6a33..779bb8f2 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -429,14 +429,24 @@ def is_fork_node(self, node): """Checks if the given node is a fork, that is, the node has multiple direct successors""" direct_successors = self.find_direct_successors(node) - is_fork = False if direct_successors is None else (len(direct_successors) > 1) + # if the node output is also wired to a top-level output, it is still + # a fork with only 1 direct successor + if node.output[0] in [x.name for x in self.graph.output]: + is_fork = False if direct_successors is None else (len(direct_successors) > 0) + else: + is_fork = False if direct_successors is None else (len(direct_successors) > 1) return is_fork def is_join_node(self, node): """Checks if the given node is a join, that is, the node has multiple direct predecessors""" direct_predecessors = self.find_direct_predecessors(node) - is_join = False if direct_predecessors is None else (len(direct_predecessors) > 1) + # if the node input is also wired to a top-level input, it is still + # a fork with only 1 direct predecessor + if node.input[0] in [x.name for x in self.graph.input]: + is_join = False if direct_predecessors is None else (len(direct_predecessors) > 0) + else: + is_join = False if direct_predecessors is None else (len(direct_predecessors) > 1) return is_join def get_all_tensor_names(self): From 2d0934111ad24928aa3a613f7262b835a0d135c3 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 12 Sep 2024 12:01:54 +0300 Subject: [PATCH 28/28] [Test] add fork cases to RemoveIdentityOps test --- .../transformation/test_remove_identity_ops.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/transformation/test_remove_identity_ops.py b/tests/transformation/test_remove_identity_ops.py index d9e92c73..cfe01a82 100644 --- a/tests/transformation/test_remove_identity_ops.py +++ b/tests/transformation/test_remove_identity_ops.py @@ -77,7 +77,8 @@ def insert_identity_op(model, op, as_first_node, approx): @pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div", "Identity"]) @pytest.mark.parametrize("approx", [False, True]) @pytest.mark.parametrize("as_first_node", [False, True]) -def test_remove_identity_ops(op, as_first_node, approx): +@pytest.mark.parametrize("fork_before_id", [False, True]) +def test_remove_identity_ops(op, as_first_node, approx, fork_before_id): # set up onnx model inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 4, 1, 1]) mul = helper.make_tensor_value_info("mul", TensorProto.FLOAT, []) @@ -114,14 +115,16 @@ def test_remove_identity_ops(op, as_first_node, approx): model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) idict = {"inp": inp_values} - odict = oxe.execute_onnx(model, idict) - out_before = odict["outp"] + odict_before = oxe.execute_onnx(model, idict) num_of_nodes_before = len(model.graph.node) - + if fork_before_id and not as_first_node: + divout_vi = model.get_tensor_valueinfo("div_out") + model.graph.output.append(divout_vi) + model.graph.value_info.remove(divout_vi) model = model.transform(RemoveIdentityOps()) num_of_nodes_after = len(model.graph.node) assert num_of_nodes_before - 1 == num_of_nodes_after - odict = oxe.execute_onnx(model, idict) - out_after = odict["outp"] - assert np.isclose(out_before, out_after, atol=1e-3).all() + odict_after = oxe.execute_onnx(model, idict) + outputs_same = [np.isclose(odict_before[tname], odict_after[tname], atol=1e-3).all() for tname in odict_before.keys()] + assert all(outputs_same)