diff --git a/notebooks/4_quant_lstm.ipynb b/notebooks/4_quant_lstm.ipynb new file mode 100644 index 00000000..bc2b5e2e --- /dev/null +++ b/notebooks/4_quant_lstm.ipynb @@ -0,0 +1,933 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# QuantLSTM - ONNX (QCDQ) representation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook is divided into `six` parts:\n", + "\n", + "
Part 0 : Package Installations.\n", + "
\n", + "
Part 1 : Introduction to LSTMs.\n", + "
\n", + "
Part 2 : Model creation with brevitas QuantLSTM layer. \n", + "
\n", + "
Part 3 : Build ONNX model representing the LSTM computation used to process a single input with `QCDQ quantization` (weights/inputs/activations) \n", + "
\n", + "
Part 4 : Integration of the QCDQ-LSTM graph with the `SCAN` operator. This operator allows cyclic computations (required for state updates in recurrent neural networks) that are currently not supported in ONNX.\n", + "
\n", + "
Part 5 : Functional verification of the `QCDQ-LSTM` model with brevitas `QuantLSTM` model output." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Package Installations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#Required package installations, This cell only needs to be executed once at the start\n", + "!pip install torch==1.13.1\n", + "!pip install brevitas==0.9.1\n", + "!pip install onnx==1.13.0\n", + "!pip install onnxoptimizer==0.3.13\n", + "!pip install onnxruntime==1.11.1\n", + "!pip install netron==7.2.5\n", + "!pip install qonnx==0.2.0\n", + "!pip install IPython\n", + "!pip install ipykernel\n", + "!ipython kernel install --user --name=venv\n", + "\n", + "#The below location can change depending on your installation of the 'venv' virtual environment\n", + "!cp ./4_quant_lstm_helper/function.py ../venv/lib/python3.8/site-packages/brevitas/export/onnx/standard/\n", + "!cp ./4_quant_lstm_helper/handler.py ../venv/lib/python3.8/site-packages/brevitas/export/onnx/standard/qcdq/\n", + "\n", + "#NOTE : Make sure to chnage the kernel to from \"Python 3\" to \"venv\" before running the below commands" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Introduction to LSTM's " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`LSTM’s (Long Short-Term Memory)` are sequential neural networks that are capable of learning long term dependencies especially in sequence prediction problems. They are deployed in machine translation, speech recognition, image captioning and especially used for time-series analysis applications.\n", + "\n", + "LSTM's have `feedback connections`, unlike conventional feed-forward neural networks (where the compute path goes only in the forward direction). This makes them capable of processing time-series data like vide streams or analyzing network traffic patterns.\n", + "Such feedback connections though also make their hardware implementations compiliacted as they require state updates unlike feed-forward neural networks.\n", + "
\n", + "
\n", + "The LSTM compute requires the following six compute equations:\n", + "$$\n", + " f_t = \\sigma (W_f * x_t + U_f * H_{t-1} + b_f) \n", + "$$\n", + "$$\n", + " i_t = \\sigma (W_i * x_t + U_i * H_{t-1} + b_i)\n", + "$$\n", + "$$\n", + " \\tilde{C_t} = tanh(W_c * x_t + U_c * H_{t-1} + b_c)\n", + "$$\n", + "$$\n", + " o_t = \\sigma (W_o * x_t + U_o * H_{t-1} + b_o)\n", + "$$\n", + "$$\n", + " C_t = f_t \\odot C_{t-1} + i_t \\odot \\tilde{C_t}\n", + "$$\n", + "$$\n", + " H_t = tanh(C_t) \\odot o_t \n", + "$$\n", + "\n", + "The first four equations represent the `gate computations`.\n", + "We compute the `cell state` and the `hidden state` in the last two equations respectively. \n", + "These two states are then fed back into the LSTM cell for the computation of the next input." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# QuantLSTM model creation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the 2nd part of the notebook, we will create a single layer `QuantLSTM` model in brevitas. We will evaluate with a given set of inputs. We then export this model to `QONNX` so that the same parameters (weights/biases/scales) can be extracted and used in the `QCDQ-LSTM` implementation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We import the required libraries to execute different functions in the notebook.\n", + "# The first four imports are required to build the QuantLSTM model in brevitas. \n", + "# The model created will then be exported to QONNX and it's parameters used in the QCDQ implementation.\n", + "\n", + "import torch\n", + "from torch import nn\n", + "from brevitas.nn import QuantLSTM\n", + "from brevitas.export import export_onnx_qcdq\n", + "\n", + "#We need the onnx and onnx helper nodes to build the onnx graph for the LSTM compute.\n", + "import onnx\n", + "from onnx import numpy_helper\n", + "from onnx.helper import make_tensor_value_info, make_node, make_graph, make_model, make_tensor\n", + "#onnxruntime will be used to execute our onnx model.\n", + "import onnxruntime as rt \n", + "from qonnx.util.basic import qonnx_make_model\n", + "#numpy allows us to manipulate outputs from the brevitas and the ONNX model\n", + "import numpy as np \n", + "# Netron visualization tool will help us view interactable graphs\n", + "import netron" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# In this block of code we will create the QuantLSTM model using the brevitas layer\n", + "torch.manual_seed(0) #Setting the manual seeds to 0 for consistency in outputs.\n", + "\n", + "# Initializing attributes that can be changed accordingly depending on users requirements.\n", + "\n", + "num_inputs = 25 #Defining the number of inputs \n", + "num_features_brevitas = 10 #This attribute defines number of features each input comprises of\n", + "num_hidden_cells_brevitas = 20 #This attribute defines the number of hidden cells in the QuantLSTM layer\n", + "\n", + "# Creating a sequential model\n", + "\n", + "model_lstm = nn.Sequential( \n", + " QuantLSTM(input_size = num_features_brevitas, hidden_size = num_hidden_cells_brevitas, bias_quant=None) \n", + " ) #No other feature described here implies quantization of inputs/parametersers/activations to 8-bits.\n", + "model_lstm.eval() #Setting the model to eval mode to make sure all the parameters and scales are frozen and not updated on runtime.\n", + "export_path = './quant_lstm_quantization_qcdq.onnx' #Setting export path for the model\n", + "export_onnx_qcdq(model_lstm,(torch.randn(num_inputs, 1, num_features_brevitas)), opset_version=14, export_path=export_path) #Exporting the model to QCDQ representation. \n", + "\n", + "# Creating a test input to execute the above created model\n", + "\n", + "in_qcdq_node = np.empty([num_inputs,1,num_features_brevitas],dtype=np.float32).reshape([num_inputs,1,num_features_brevitas])\n", + "in_qcdq_node.fill(0.8) #Using the fill function to fill the numpy array with a value of 0.8\n", + "test_input = torch.from_numpy(in_qcdq_node) #Converting the array to a torch tensor\n", + "brevitas_output = model_lstm(test_input) #Executing the model with the set input\n", + "brevitas_output = brevitas_output[0].detach().numpy()\n", + "print(brevitas_output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`Abbreviations` : Short-forms defined in the next code block can be referenced here for definitions.\n", + "\n", + "* Wi = \"Weight matrix for the input gate\" (Similarily for the other three gates)\n", + "* Ui = \"Recurrence matrix for the input gate\" (Similarily for the other three gates)\n", + "* bi = \"Bias for the input gate\" (Similarily for the other three gates)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# In this block of code we store all the parameters (weight matrices, recurrence matrices, biases, scales and zero-points) that we will need to import in the QCDQ implementation.\n", + "# Importing the exported quantized model from brevitas\n", + "brevitas_lstm_export = onnx.load(\"./quant_lstm_quantization_qcdq.onnx\")\n", + "parameters = brevitas_lstm_export.graph.initializer #Extracting all the parameters from the loaded graph\n", + "\n", + "# In this loop we will be printing all the parameters to correctly import the parameters values to the right variables\n", + "for i in range(len(parameters)):\n", + " w = numpy_helper.to_array(parameters[i])\n", + " print (brevitas_lstm_export.graph.initializer[i].name)\n", + " print(w.shape)\n", + " print(w,',',i)\n", + " print(\"-------------------------\")\n", + " \n", + "# Storing the extracted parameters (weights/biases/scales) to the right variables depending on the order in which they are exported. \n", + "# The abbreviation described in the above block will help in understanding what each variable denotes\n", + "\n", + "bi_val = numpy_helper.to_array(parameters[0])\n", + "Wi_val = numpy_helper.to_array(parameters[1])\n", + "Ui_val = numpy_helper.to_array(parameters[2])\n", + "bf_val = numpy_helper.to_array(parameters[3])\n", + "Wf_val = numpy_helper.to_array(parameters[4])\n", + "Uf_val = numpy_helper.to_array(parameters[5])\n", + "bc_val = numpy_helper.to_array(parameters[6])\n", + "Wc_val = numpy_helper.to_array(parameters[7])\n", + "Uc_val = numpy_helper.to_array(parameters[8])\n", + "bo_val = numpy_helper.to_array(parameters[9])\n", + "Wo_val = numpy_helper.to_array(parameters[10])\n", + "Uo_val = numpy_helper.to_array(parameters[11])\n", + "# Scalar values can either be int or float\n", + "inp_scale_val = float(numpy_helper.to_array(parameters[12])) \n", + "w1_scale_val = float(numpy_helper.to_array(parameters[15]))\n", + "w2_scale_val = float(numpy_helper.to_array(parameters[18]))\n", + "w3_scale_val = float(numpy_helper.to_array(parameters[19]))\n", + "w4_scale_val = float(numpy_helper.to_array(parameters[20]))\n", + "eq_scale_val_1 = float(numpy_helper.to_array(parameters[12]))\n", + "eq_scale_val_2 = float(numpy_helper.to_array(parameters[22]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# LSTM ONNX model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the 3rd part of the notebook, we will construct the `QCDQ-LSTM` model with standard ONNX operators. After loading all the parameters in the above block we can now start building our ONNX model with QCDQ quantization to represent the LSTM computations described in part-1.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Setting parameters : Matching the input output lengths exported from brevitas\n", + "num_features = 10\n", + "num_hidden_cells = 20\n", + "activation_bit_width = 8\n", + "\n", + "# The below two parameters are for the 'Clip' operation. \n", + "# Clip node parameters\n", + "max_clip_val = (2 ** (activation_bit_width -1) - 1)\n", + "min_clip_val = -(2 ** (activation_bit_width -1) - 1)\n", + "\n", + "# Zero-point datatype decides the datatype of the output tensor for the quantization operations hence we defined two. One for signed and other for unsigned.\n", + "# Zero point values for quantization\n", + "zero_point_signed_val = 0\n", + "zero_point_unsigned_val = 0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`Abbreviations` : These describe different short-forms used in the next two blocks.\n", + "\n", + "* ql = \"QuantizeLinear\"\n", + "* dql = \"DequantizeLinear\"\n", + "* clp = \"Clip\"\n", + "* id = \"Identity\"\n", + "* matmul = \"Matrix Multiplication\"\n", + "* el_mul = \"Elementwise Multiplication\"\n", + "* sig = \"Sigmoid\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start defining the model by defining the `inputs` and `outputs` defined as value_info tensors in ONNX.\n", + "For LSTMs we need three inputs : `inputs`, `previous hidden state` and `previous cell state`. \n", + "We get three outputs : `hidden_state`, `cell_state` and `concatenated_hidden_states`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Defining the inputs 'value info' tensors for the compute graph.\n", + "hidden_state = make_tensor_value_info(\"h_t-1\",onnx.TensorProto.FLOAT, [num_hidden_cells,1])\n", + "cell_state = make_tensor_value_info(\"c_t-1\", onnx.TensorProto.FLOAT, [num_hidden_cells,1])\n", + "inputs = make_tensor_value_info(\"inp\",onnx.TensorProto.FLOAT, [num_features,1])\n", + "\n", + "#Output value info tensor definitions\n", + "out_hidden_state = make_tensor_value_info(\"h_t\", onnx.TensorProto.FLOAT, [num_hidden_cells,1])\n", + "out_cell_state = make_tensor_value_info(\"c_t\", onnx.TensorProto.FLOAT, [num_hidden_cells,1])\n", + "out_hidden_state_concat = make_tensor_value_info(\"h_t_concat\", onnx.TensorProto.FLOAT, [num_hidden_cells,1])#maybe this will have one more dimension" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Once we have defined the inputs and outputs, we will now start defining the operations in the LSTM compute graph.\n", + "# We start by quantizing the input with the standard QDQ operation which is 8-bit quantization. \n", + "# Note: For quantization to lower bit-width's we can use the clip node.\n", + "\n", + "# Input quantization\n", + "ql_input = make_node(\"QuantizeLinear\", inputs=[\"inp\",\"inp_scale\",\"zero_point_signed\"], outputs=[\"ql_input_out\"],name=\"ql_input\")\n", + "id_0_input = make_node(\"Identity\", inputs=[\"ql_input_out\"], outputs=[\"first_input_out\"], name=\"id_0_input\")\n", + "dql_input = make_node(\"DequantizeLinear\", inputs=[\"ql_input_out\", 'inp_scale', \"zero_point_signed\"], outputs=[\"dql_input_out\"],name=\"dql_input\")\n", + "\n", + "# Quantization of the four weight matrices showing QCDQ quantization\n", + "ql_w1 = make_node(\"QuantizeLinear\", inputs=[\"W_f\",\"scale_f\",\"zero_point_signed\"], outputs=[\"ql_wf_out\"], name=\"ql_w1\")\n", + "clp_w1 = make_node(\"Clip\", inputs=[\"ql_wf_out\",\"min\",\"max\"], outputs=[\"clp_wf\"], name=\"clp_w1\")\n", + "dql_w1 = make_node(\"DequantizeLinear\", inputs=[\"clp_wf\",\"scale_f\",\"zero_point_signed\"], outputs=[\"dql_wf_out\"], name=\"dql_w1\")\n", + "\n", + "ql_w2 = make_node(\"QuantizeLinear\", inputs=[\"W_i\",\"scale_i\",\"zero_point_signed\"], outputs=[\"ql_wi_out\"], name=\"ql_w2\")\n", + "clp_w2 = make_node(\"Clip\", inputs=[\"ql_wi_out\",\"min\",\"max\"], outputs=[\"clp_wi\"], name=\"clp_w2\")\n", + "dql_w2 = make_node(\"DequantizeLinear\", inputs=[\"clp_wi\",\"scale_i\",\"zero_point_signed\"], outputs=[\"dql_wi_out\"], name=\"dql_w2\")\n", + "\n", + "ql_w3 = make_node(\"QuantizeLinear\", inputs=[\"W_c\",\"scale_c\",\"zero_point_signed\"], outputs=[\"ql_wc_out\"], name=\"ql_w3\")\n", + "clp_w3 = make_node(\"Clip\", inputs=[\"ql_wc_out\",\"min\",\"max\"], outputs=[\"clp_wc\"], name=\"clp_w3\")\n", + "dql_w3 = make_node(\"DequantizeLinear\", inputs=[\"clp_wc\",\"scale_c\",\"zero_point_signed\"], outputs=[\"dql_wc_out\"], name=\"dql_w3\")\n", + "\n", + "ql_w4 = make_node(\"QuantizeLinear\", inputs=[\"W_o\",\"scale_o\",\"zero_point_signed\"], outputs=[\"ql_wo_out\"], name=\"ql_w4\")\n", + "clp_w4 = make_node(\"Clip\", inputs=[\"ql_wo_out\",\"min\",\"max\"], outputs=[\"clp_wo\"], name=\"clp_w4\")\n", + "dql_w4 = make_node(\"DequantizeLinear\", inputs=[\"clp_wo\",\"scale_o\",\"zero_point_signed\"], outputs=[\"dql_wo_out\"], name=\"dql_w4\")\n", + "\n", + "# Quantizations for the four recurrence weight matrices showing QCDQ quantization\n", + "ql_u1 = make_node(\"QuantizeLinear\", inputs=[\"U_f\",\"scale_f\",\"zero_point_signed\"], outputs=[\"ql_uf_out\"], name=\"ql_u1\")\n", + "clp_u1 = make_node(\"Clip\", inputs=[\"ql_uf_out\",\"min\",\"max\"], outputs=[\"clp_uf\"], name=\"clp_u1\")\n", + "dql_u1 = make_node(\"DequantizeLinear\", inputs=[\"clp_uf\",\"scale_f\",\"zero_point_signed\"], outputs=[\"dql_uf_out\"], name=\"dql_u1\")\n", + "\n", + "ql_u2 = make_node(\"QuantizeLinear\", inputs=[\"U_i\",\"scale_i\",\"zero_point_signed\"], outputs=[\"ql_ui_out\"], name=\"ql_u2\")\n", + "clp_u2 = make_node(\"Clip\", inputs=[\"ql_ui_out\",\"min\",\"max\"], outputs=[\"clp_ui\"], name=\"clp_u2\")\n", + "dql_u2 = make_node(\"DequantizeLinear\", inputs=[\"clp_ui\",\"scale_i\",\"zero_point_signed\"], outputs=[\"dql_ui_out\"], name=\"dql_u2\")\n", + "\n", + "ql_u3 = make_node(\"QuantizeLinear\", inputs=[\"U_c\",\"scale_c\",\"zero_point_signed\"], outputs=[\"ql_uc_out\"], name=\"ql_u3\")\n", + "clp_u3 = make_node(\"Clip\", inputs=[\"ql_uc_out\",\"min\",\"max\"], outputs=[\"clp_uc\"], name=\"clp_u3\")\n", + "dql_u3 = make_node(\"DequantizeLinear\", inputs=[\"clp_uc\",\"scale_c\",\"zero_point_signed\"], outputs=[\"dql_uc_out\"], name=\"dql_u3\")\n", + "\n", + "ql_u4 = make_node(\"QuantizeLinear\", inputs=[\"U_o\",\"scale_o\",\"zero_point_signed\"], outputs=[\"ql_uo_out\"], name=\"ql_u4\")\n", + "clp_u4 = make_node(\"Clip\", inputs=[\"ql_uo_out\",\"min\",\"max\"], outputs=[\"clp_uo\"], name=\"clp_u4\")\n", + "dql_u4 = make_node(\"DequantizeLinear\", inputs=[\"clp_uo\",\"scale_o\",\"zero_point_signed\"], outputs=[\"dql_uo_out\"], name=\"dql_u4\")\n", + "\n", + "# Once we have quantized the weights and inputs we can now start defining the operations for the 6 LSTM equations.\n", + "# The first four gate equations have a very similar compute structure. We define the first four gate computations in this order : Forget, Input, Output, Cell \n", + "\n", + "# 1st Equation : Forget gate\n", + "matmul_1_e1 = make_node(\"MatMul\", inputs=[\"dql_wf_out\",\"dql_input_out\"], outputs=[\"out_m1_e1\"], name=\"matmul_1_e1\")\n", + "matmul_2_e1 = make_node(\"MatMul\", inputs=[\"dql_uf_out\",\"h_t-1\"], outputs=[\"out_m2_e1\"],name=\"matmul_2_e1\")\n", + "add_1_e1 = make_node(\"Add\", inputs=[\"out_m1_e1\",\"out_m2_e1\"], outputs=[\"out_add1_e1\"],name=\"add_1_e1\")\n", + "add_2_e1 = make_node(\"Add\", inputs=[\"out_add1_e1\",\"b_f\"], outputs=[\"f_t_ba\"],name=\"add_2_e1\")\n", + "ql_1_e1 = make_node(\"QuantizeLinear\", inputs=[\"f_t_ba\",\"scale_3\",\"zero_point_signed\"], outputs=[\"f_t_ql1\"],name=\"ql_1_e1\")\n", + "dql_1_e1 = make_node(\"DequantizeLinear\", inputs=[\"f_t_ql1\", \"scale_4\", \"zero_point_signed\"], outputs=[\"f_t_dql1\"], name=\"dql_1_e1\")\n", + "sig_f_e1 = make_node(\"Sigmoid\", inputs=[\"f_t_dql1\"], outputs=[\"f_t\"],name=\"sig_f_e1\")\n", + "ql_2_e1 = make_node(\"QuantizeLinear\", inputs=[\"f_t\",\"scale_4\",\"zero_point_unsigned\"], outputs=[\"f_t_ql2\"],name=\"ql_2_e1\")\n", + "dql_2_e1 = make_node(\"DequantizeLinear\", inputs=[\"f_t_ql2\", \"scale_4\", \"zero_point_unsigned\"], outputs=[\"f_t_dql2\"], name=\"dql_2_e1\")\n", + "\n", + "# 2nd Equation : Input gate\n", + "matmul_1_e2 = make_node(\"MatMul\", inputs=[\"dql_wi_out\",\"dql_input_out\"], outputs=[\"out_m1_e2\"], name=\"matmul_1_e2\")\n", + "matmul_2_e2 = make_node(\"MatMul\", inputs=[\"dql_ui_out\",\"h_t-1\"], outputs=[\"out_m2_e2\"],name=\"matmul_2_e2\")\n", + "add_1_e2 = make_node(\"Add\", inputs=[\"out_m1_e2\",\"out_m2_e2\"], outputs=[\"out_add1_e2\"],name=\"add_1_e2\")\n", + "add_2_e2 = make_node(\"Add\", inputs=[\"out_add1_e2\",\"b_i\"], outputs=[\"i_t_ba\"],name=\"add_2_e2\")\n", + "ql_1_e2 = make_node(\"QuantizeLinear\", inputs=[\"i_t_ba\",\"scale_1\",\"zero_point_signed\"], outputs=[\"i_t_ql1\"],name=\"ql_1_e2\")\n", + "dql_1_e2 = make_node(\"DequantizeLinear\", inputs=[\"i_t_ql1\",\"scale_1\", \"zero_point_signed\"], outputs=[\"i_t_dql1\"], name=\"dql_1_e2\")\n", + "sig_i_e2 = make_node(\"Sigmoid\", inputs=[\"i_t_dql1\"], outputs=[\"i_t\"],name=\"sig_i_e2\")\n", + "ql_2_e2 = make_node(\"QuantizeLinear\", inputs=[\"i_t\",\"scale_2\",\"zero_point_unsigned\"], outputs=[\"i_t_ql2\"],name=\"ql_2_e2\")\n", + "dql_2_e2 = make_node(\"DequantizeLinear\", inputs=[\"i_t_ql2\", \"scale_2\", \"zero_point_unsigned\"], outputs=[\"i_t_dql2\"], name=\"dql_2_e2\")\n", + "\n", + "# 3rd Equation : Output gate\n", + "matmul_1_e3 = make_node(\"MatMul\", inputs=[\"dql_wo_out\",\"dql_input_out\"], outputs=[\"out_m1_e3\"], name=\"matmul_1_e3\")\n", + "matmul_2_e3 = make_node(\"MatMul\", inputs=[\"dql_uo_out\",\"h_t-1\"], outputs=[\"out_m2_e3\"],name=\"matmul_2_e3\")\n", + "add_1_e3 = make_node(\"Add\", inputs=[\"out_m1_e3\",\"out_m2_e3\"], outputs=[\"out_add1_e3\"],name=\"add_1_e3\")\n", + "add_2_e3 = make_node(\"Add\", inputs=[\"out_add1_e3\",\"b_o\"], outputs=[\"o_t_ba\"],name=\"add_2_e3\" )\n", + "ql_1_e3 = make_node(\"QuantizeLinear\", inputs=[\"o_t_ba\",\"scale_7\",\"zero_point_signed\"], outputs=[\"o_t_ql1\"],name=\"ql_1_e3\")\n", + "dql_1_e3 = make_node(\"DequantizeLinear\", inputs=[\"o_t_ql1\",\"scale_7\", \"zero_point_signed\"], outputs=[\"o_t_dql1\"], name=\"dql_1_e3\")\n", + "sig_o_e3 = make_node(\"Sigmoid\", inputs=[\"o_t_dql1\"], outputs=[\"o_t\"],name=\"sig_o_e3\")\n", + "ql_2_e3 = make_node(\"QuantizeLinear\", inputs=[\"o_t\",\"scale_8\",\"zero_point_unsigned\"], outputs=[\"o_t_ql2\"],name=\"ql_2_e3\")\n", + "dql_2_e3 = make_node(\"DequantizeLinear\", inputs=[\"o_t_ql2\", \"scale_8\", \"zero_point_unsigned\"], outputs=[\"o_t_dql2\"], name=\"dql_2_e3\")\n", + "\n", + "# 4th Equation : Cell gate\n", + "matmul_1_e4 = make_node(\"MatMul\", inputs=[\"dql_wc_out\",\"dql_input_out\"], outputs=[\"out_m1_e4\"], name=\"matmul_1_e4\")\n", + "matmul_2_e4 = make_node(\"MatMul\", inputs=[\"dql_uc_out\",\"h_t-1\"], outputs=[\"out_m2_e4\"],name=\"matmul_2_e4\")\n", + "add_1_e4 = make_node(\"Add\", inputs=[\"out_m1_e4\",\"out_m2_e4\"], outputs=[\"out_add1_e4\"],name=\"add_1_e4\")\n", + "add_2_e4 = make_node(\"Add\", inputs=[\"out_add1_e4\",\"b_c\"], outputs=[\"c_t_ba\"],name=\"add_2_e4\")\n", + "ql_1_e4 = make_node(\"QuantizeLinear\", inputs=[\"c_t_ba\",\"scale_5\",\"zero_point_signed\"], outputs=[\"c_t_ql1\"],name=\"ql_1_e4\")\n", + "dql_1_e4 = make_node(\"DequantizeLinear\", inputs=[\"c_t_ql1\",\"scale_5\", \"zero_point_signed\"], outputs=[\"c_t_dql1\"], name=\"dql_1_e4\")\n", + "tanh_c_e4 = make_node(\"Tanh\", inputs=[\"c_t_dql1\"], outputs=[\"c_t_partial\"],name=\"tanh_c_e4\")\n", + "ql_2_e4 = make_node(\"QuantizeLinear\", inputs=[\"c_t_partial\",\"scale_6\",\"zero_point_signed\"], outputs=[\"c_t_ql2\"],name=\"ql_2_e4\")\n", + "dql_2_e4 = make_node(\"DequantizeLinear\", inputs=[\"c_t_ql2\", \"scale_6\", \"zero_point_signed\"], outputs=[\"c_t_dql2\"], name=\"dql_2_e4\")\n", + "\n", + "# Once we have the first four gate computations we can procedd with the computation of the cell_state and the hidden_state in the 5th and the 6th equations.\n", + "# 5th Equation : Cell state compute\n", + "el_mul_1_e5 = make_node(\"Mul\", inputs=[\"f_t_dql2\",\"c_t-1\"], outputs=[\"out_el_mul1_e5\"],name=\"el_mul_1_e5\")\n", + "ql_1_e5 = make_node(\"QuantizeLinear\", inputs=[\"out_el_mul1_e5\",\"scale_9\",\"zero_point_signed\"], outputs=[\"fifth_ql1\"],name=\"ql_1_e5\")\n", + "dql_1_e5 = make_node(\"DequantizeLinear\", inputs=[\"fifth_ql1\",\"scale_9\", \"zero_point_signed\"], outputs=[\"fifth_dql1\"], name=\"dql_1_e5\")\n", + "el_mul_2_e5 = make_node(\"Mul\", inputs=[\"i_t_dql2\",\"c_t_dql2\"], outputs=[\"out_el_mul2_e5\"], name=\"el_mul_2_e5\") \n", + "ql_2_e5 = make_node(\"QuantizeLinear\", inputs=[\"out_el_mul2_e5\",\"scale_9\",\"zero_point_signed\"], outputs=[\"fifth_ql2\"],name=\"ql_2_e5\")\n", + "dql_2_e5 = make_node(\"DequantizeLinear\", inputs=[\"fifth_ql2\",\"scale_9\", \"zero_point_signed\"], outputs=[\"fifth_dql2\"], name=\"dql_2_e5\")\n", + "add_1_e5 = make_node(\"Add\", inputs=[\"fifth_dql1\",\"fifth_dql2\"], outputs=[\"c_t\"], name=\"add_1_e5\") #-----------------> The first output is computed here.\n", + "ql_3_e5 = make_node(\"QuantizeLinear\", inputs=[\"c_t\",\"scale_9\",\"zero_point_signed\"], outputs=[\"h_t_ql\"], name=\"ql_3_e5\")\n", + "dql_3_e5 = make_node(\"DequantizeLinear\", inputs=[\"h_t_ql\",\"scale_9\",\"zero_point_signed\"], outputs=[\"h_t_dql\"], name=\"dql_3_e5\")\n", + "\n", + "# 6th Equation : Hidden state compute\n", + "tanh_node_e6 = make_node(\"Tanh\", inputs=[\"h_t_dql\"], outputs=[\"out_tanh_e6\"], name=\"tanh_node_e6\") \n", + "ql_1_e6 = make_node(\"QuantizeLinear\", inputs=[\"out_tanh_e6\",\"scale_10\",\"zero_point_signed\"], outputs=[\"sixth_ql1\"], name=\"ql_1_e6\")\n", + "dql_1_e6 = make_node(\"DequantizeLinear\", inputs=[\"sixth_ql1\",\"scale_10\",\"zero_point_signed\"], outputs=[\"sixth_dql1\"], name=\"dql_1_e6\")\n", + "el_mul_1_e6 = make_node(\"Mul\", inputs=[\"sixth_dql1\",\"o_t_dql2\"], outputs=[\"h_t_inter\"], name=\"el_mul_1_e6\")#h_t_inter\n", + "ql_2_e6 = make_node(\"QuantizeLinear\", inputs=[\"h_t_inter\",\"scale_11\",\"zero_point_signed\"], outputs=[\"sixth_ql2\"], name=\"ql_2_e6\")\n", + "dql_2_e6 = make_node(\"DequantizeLinear\", inputs=[\"sixth_ql2\",\"scale_11\",\"zero_point_signed\"], outputs=[\"h_t\"], name=\"dql_2_e6\") #-----------------> The second output is computed here.\n", + "id_1_e6 = make_node(\"Identity\", inputs=[\"h_t\"], outputs=[\"h_t_concat\"], name=\"id_1_e6\") #-----------------> The third output is computed here." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After defining the above operations we now connect them and create a graph with the help of onnx.helper `make_graph` utility function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lstm_body = make_graph(\n", + " nodes=[\n", + " ql_input,\n", + " dql_input, \n", + " ql_w1,\n", + " clp_w1, \n", + " dql_w1,\n", + " ql_w2,\n", + " clp_w2, \n", + " dql_w2,\n", + " ql_w3,\n", + " clp_w3, \n", + " dql_w3,\n", + " ql_w4,\n", + " clp_w4, \n", + " dql_w4,\n", + " ql_u1,\n", + " clp_u1, \n", + " dql_u1,\n", + " ql_u2,\n", + " clp_u2,\n", + " dql_u2, \n", + " ql_u3,\n", + " clp_u3,\n", + " dql_u3, \n", + " ql_u4,\n", + " clp_u4,\n", + " dql_u4, \n", + " matmul_1_e1,\n", + " matmul_2_e1, \n", + " add_1_e1, \n", + " add_2_e1,\n", + " ql_1_e1,\n", + " dql_1_e1,\n", + " sig_f_e1,\n", + " ql_2_e1, \n", + " dql_2_e1, \n", + " matmul_1_e2,\n", + " matmul_2_e2, \n", + " add_1_e2, \n", + " add_2_e2,\n", + " ql_1_e2,\n", + " dql_1_e2,\n", + " sig_i_e2,\n", + " ql_2_e2, \n", + " dql_2_e2, \n", + " matmul_1_e3,\n", + " matmul_2_e3, \n", + " add_1_e3, \n", + " add_2_e3,\n", + " ql_1_e3,\n", + " dql_1_e3,\n", + " sig_o_e3,\n", + " ql_2_e3, \n", + " dql_2_e3, \n", + " matmul_1_e4,\n", + " matmul_2_e4, \n", + " add_1_e4, \n", + " add_2_e4,\n", + " ql_1_e4,\n", + " dql_1_e4,\n", + " tanh_c_e4,\n", + " ql_2_e4, \n", + " dql_2_e4, \n", + " el_mul_1_e5,\n", + " ql_1_e5, \n", + " dql_1_e5,\n", + " el_mul_2_e5,\n", + " ql_2_e5,\n", + " dql_2_e5,\n", + " add_1_e5,\n", + " ql_3_e5, \n", + " dql_3_e5,\n", + " tanh_node_e6,\n", + " ql_1_e6, \n", + " dql_1_e6,\n", + " el_mul_1_e6,\n", + " ql_2_e6,\n", + " dql_2_e6, \n", + " id_1_e6\n", + " ],\n", + " name = \"qcdq-lsmt-body\",\n", + " inputs=[hidden_state,cell_state,inputs], #The order in which the inputs are defined here should match the input order when the scan node is defined.\n", + " outputs = [out_hidden_state, out_cell_state, out_hidden_state_concat],\n", + " value_info=[\n", + " make_tensor_value_info(\"ql_input_out\",onnx.TensorProto.INT8, [num_features,1]),\n", + " make_tensor_value_info(\"dql_input_out\",onnx.TensorProto.FLOAT, [num_features,1]),\n", + " make_tensor_value_info(\"out_m1_e1\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_m2_e1\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_add1_e1\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"f_t_ba\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"f_t_ql1\",onnx.TensorProto.INT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"f_t_dql1\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"f_t_ql2\",onnx.TensorProto.UINT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"f_t_dql2\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_m1_e2\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_m2_e2\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_add1_e2\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"i_t_ba\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"i_t_ql1\",onnx.TensorProto.INT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"i_t_dql1\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"i_t_ql2\",onnx.TensorProto.UINT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"i_t_dql2\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_m1_e3\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_m2_e3\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_add1_e3\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"o_t_ba\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"o_t_ql1\",onnx.TensorProto.INT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"o_t_dql1\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"o_t_ql2\",onnx.TensorProto.UINT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"o_t_dql2\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_m1_e4\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_m2_e4\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_add1_e4\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"c_t_ba\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"c_t_ql1\",onnx.TensorProto.INT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"c_t_dql1\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"c_t_ql2\",onnx.TensorProto.INT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"c_t_dql2\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"f_t\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"i_t\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"o_t\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"c_t_partial\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_el_mul1_e5\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_el_mul2_e5\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"fifth_ql1\",onnx.TensorProto.INT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"fifth_dql1\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"fifth_ql2\",onnx.TensorProto.INT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"fifth_dql2\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"h_t_ql\",onnx.TensorProto.INT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"h_t_dql\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"out_tanh_e6\",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"sixth_ql1\",onnx.TensorProto.INT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"sixth_dql1\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"sixth_ql2\",onnx.TensorProto.INT8, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"h_t_inter\", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),\n", + " make_tensor_value_info(\"ql_wf_out\", onnx.TensorProto.INT8, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"dql_wf_out\",onnx.TensorProto.FLOAT, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"ql_wi_out\", onnx.TensorProto.INT8, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"dql_wi_out\",onnx.TensorProto.FLOAT, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"ql_wc_out\", onnx.TensorProto.INT8, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"dql_wc_out\",onnx.TensorProto.FLOAT, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"ql_wo_out\", onnx.TensorProto.INT8, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"dql_wo_out\",onnx.TensorProto.FLOAT, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"ql_uf_out\",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),\n", + " make_tensor_value_info(\"dql_uf_out\",onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells]),\n", + " make_tensor_value_info(\"ql_ui_out\",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),\n", + " make_tensor_value_info(\"dql_ui_out\",onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells]),\n", + " make_tensor_value_info(\"ql_uc_out\",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),\n", + " make_tensor_value_info(\"dql_uc_out\",onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells]),\n", + " make_tensor_value_info(\"ql_uo_out\",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),\n", + " make_tensor_value_info(\"dql_uo_out\",onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells]),\n", + " make_tensor_value_info(\"clp_wf\",onnx.TensorProto.INT8, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"clp_wi\",onnx.TensorProto.INT8, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"clp_wc\",onnx.TensorProto.INT8, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"clp_wo\",onnx.TensorProto.INT8, [num_hidden_cells,num_features]),\n", + " make_tensor_value_info(\"clp_uf\",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]), \n", + " make_tensor_value_info(\"clp_ui\",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),\n", + " make_tensor_value_info(\"clp_uc\",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),\n", + " make_tensor_value_info(\"clp_uo\",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),\n", + " ],\n", + " initializer=[\n", + " # Initializing the weight and recurrecne matrices\n", + " make_tensor('W_f',onnx.TensorProto.FLOAT, [num_hidden_cells,num_features], (Wf_val)),\n", + " make_tensor('U_f',onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells], (Uf_val)),\n", + " make_tensor('b_f',onnx.TensorProto.FLOAT, [num_hidden_cells,1], (bf_val)),\n", + " make_tensor('W_i',onnx.TensorProto.FLOAT, [num_hidden_cells,num_features], (Wi_val)),\n", + " make_tensor('U_i',onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells], (Ui_val)),\n", + " make_tensor('b_i',onnx.TensorProto.FLOAT, [num_hidden_cells,1], (bi_val)),\n", + " make_tensor('W_o',onnx.TensorProto.FLOAT, [num_hidden_cells,num_features], (Wo_val)),\n", + " make_tensor('U_o',onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells], (Uo_val)),\n", + " make_tensor('b_o',onnx.TensorProto.FLOAT, [num_hidden_cells,1], (bo_val)),\n", + " make_tensor('W_c',onnx.TensorProto.FLOAT, [num_hidden_cells,num_features], (Wc_val)),\n", + " make_tensor('U_c',onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells], (Uc_val)),\n", + " make_tensor('b_c',onnx.TensorProto.FLOAT, [num_hidden_cells,1], (bc_val)),\n", + " # Input scale value\n", + " make_tensor('inp_scale',onnx.TensorProto.FLOAT, [],[inp_scale_val]),\n", + " # Scale weight values\n", + " make_tensor('scale_i',onnx.TensorProto.FLOAT, [],[w1_scale_val]),\n", + " make_tensor('scale_c',onnx.TensorProto.FLOAT, [],[w2_scale_val]),\n", + " make_tensor('scale_o',onnx.TensorProto.FLOAT, [],[w3_scale_val]),\n", + " make_tensor('scale_f',onnx.TensorProto.FLOAT, [],[w4_scale_val]),\n", + " # Scale values for the six equations\n", + " make_tensor('scale_1',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),\n", + " make_tensor('scale_2',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]), \n", + " make_tensor('scale_3',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),\n", + " make_tensor('scale_test',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),\n", + " make_tensor('scale_4',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),\n", + " make_tensor('scale_5',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),\n", + " make_tensor('scale_6',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),\n", + " make_tensor('scale_7',onnx.TensorProto.FLOAT, [],[eq_scale_val_2]), \n", + " make_tensor('scale_8',onnx.TensorProto.FLOAT, [],[eq_scale_val_2]),\n", + " make_tensor('scale_9',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),\n", + " make_tensor('scale_10',onnx.TensorProto.FLOAT, [],[eq_scale_val_2]),\n", + " make_tensor('scale_11',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),\n", + " # Scales for zero-points : Zero-point datatype defines the dataype of the output for that quantization\n", + " make_tensor('zero_point_signed',onnx.TensorProto.INT8,[],[zero_point_signed_val]),\n", + " make_tensor('zero_point_unsigned',onnx.TensorProto.UINT8,[],[zero_point_unsigned_val]),\n", + " # Introducing scalars for the clip operators.\n", + " make_tensor('min', onnx.TensorProto.INT8, [], [min_clip_val]),\n", + " make_tensor('max', onnx.TensorProto.INT8, [], [max_clip_val]),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The above created graph can now be converted into a qonnx model with the `qonnx_make_model` utility. We save the model with `onnx.save` utility and then view it in Netron with the help of `showInNetron` utility. \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lstm_model = qonnx_make_model(lstm_body, producer_name=\"QuantizeLSTM_scan\")\n", + "onnx.save(lstm_model, './lstm_full_graph.onnx')\n", + "netron.start('./lstm_full_graph.onnx')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this block of code we execute the onnx graph to check that it can execute without any errors. We perform it's functional verification in the later part of the notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Before the model can be executed, it'd opset version needs to be set to a minimum of '14' to accomodate clip nodes with INT8 and UINT8 input. Otherwise ONNX cannot create an execution session and we get errors.\n", + "lstm_model.opset_import[0].version = 14\n", + "\n", + "# Creating the inference session here for the updated model here\n", + "sess = rt.InferenceSession(lstm_model.SerializeToString())\n", + "\n", + "# Defining dummy inputs and the model parameters for dummy execution\n", + "X_inp = np.empty([num_features,1],dtype=np.float32).reshape([num_features,1])\n", + "X_inp.fill(0.8)\n", + "hidden_state_input = np.zeros((num_hidden_cells, 1)).astype(np.float32)\n", + "cell_state_input = np.zeros((num_hidden_cells, 1)).astype(np.float32)\n", + "\n", + "# Assigning the above defined values to the input dictionary of the ONNX model.\n", + "input_dict = {}\n", + "input_dict[\"inp\"] = X_inp\n", + "input_dict[\"h_t-1\"] = hidden_state_input\n", + "input_dict[\"c_t-1\"] = cell_state_input \n", + "\n", + "# Setting up the inference session and executing the onnx model here.\n", + "sess = rt.InferenceSession(lstm_model.SerializeToString())\n", + "output = sess.run(None, input_dict)\n", + "print(output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SCAN Operation Integration" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Introduction to ONNX Scan operation\n", + "Observations regarding the `Scan` operator in ONNX:\n", + "\n", + "1. `Scan` can be used to iterate over one or more scan input tensors constructing zero or more scan output tensors. It combines ideas from general recurrences, functional programming cnostructs such as scan, fold, map and zip.\n", + "2. The attribute `body` in the node must be a graph specifying the computation to be performed in every iteration.\n", + "3. Input is the current values of the `state variables` and the current `iterated element` of the scan input. Returns values of the `state variables` and the `scan output element tensors`. (Can be greater than 1)\n", + "4. The values of the scan output tensors are concatenated over all the iterations to produce the scan output values of the scan construct.\n", + "5. The properties that make a scan node unique and different from a normal compute node are:\n", + "* Allows update of state variable after each input computation; to be used in the processing of the next input.\n", + "* It needs to scan your inputs row by row or column by column; then keep computing the output with the updated hidden state for every input; while storing all the intermediate outputs in the form of hidden states.\n", + "\n", + "More information regarding this op can be found in these links:\n", + "\n", + "* https://github.com/onnx/onnx/blob/main/docs/Operators.md#Scan\n", + "* https://onnx.ai/onnx/intro/python.html#scan" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `Scan` operation is essentially a container operator which will consume the LSTM graph that we created above in it's body.\n", + "To create it, we need to define separate input and output value info tensors just for the Scan operator. We will then follow the same steps as the `QCDQ-LSTM` graph creation to convert the above graph into an executable ONNX model.\n", + "

\n", + "We start by defining the input and output value info tensors for the `scan_graph` creation. These tensors act as the wrapper to the previously defined graph.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Inputs\n", + "scan_input = make_tensor_value_info(\"scan_input\",onnx.TensorProto.FLOAT, [None,num_features,1])#X ; scan input. Here None defines the varibale number of inputs that can be supplied for input processing.\n", + "scan_hidden_state = make_tensor_value_info(\"scan_hidden_state\",onnx.TensorProto.FLOAT, [num_hidden_cells,1])# h_t-1\n", + "scan_cell_state = make_tensor_value_info(\"scan_cell_state\",onnx.TensorProto.FLOAT, [num_hidden_cells,1])# c_t-1\n", + "\n", + "# Outputs\n", + "scan_out_hidden_state = make_tensor_value_info(\"scan_out_hidden_state\", onnx.TensorProto.FLOAT, [num_hidden_cells,1])#h_t\n", + "scan_out_cell_state = make_tensor_value_info(\"scan_out_cell_state\", onnx.TensorProto.FLOAT, [num_hidden_cells,1])#c_t\n", + "scan_out_hidden_state_concat = make_tensor_value_info(\"scan_out_hidden_state_concat\", onnx.TensorProto.FLOAT, [None,num_hidden_cells,1])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will now create the scan operator here now utilizing the `make_node` utility from ONNX.\n", + "Note, in the body of the operation we have included the `lstm_body` graph we created in the above steps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "scan_node_lstm = make_node(\n", + " \"Scan\", \n", + " inputs=[\"scan_hidden_state\",\"scan_cell_state\",\"scan_input\"], \n", + " outputs=[\"scan_out_hidden_state\",\"scan_out_cell_state\",\"scan_out_hidden_state_concat\"], \n", + " num_scan_inputs=1,\n", + " body=lstm_body, domain=''\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now define the graph for the scan operator utilizing the `make_graph` utility." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "scan_lstm_node_graph = make_graph(\n", + " nodes = [scan_node_lstm],\n", + " name=\"lstm-scan-node\",\n", + " inputs=[scan_hidden_state,scan_cell_state,scan_input],#h_t-1, c_t-1, X\n", + " outputs=[scan_out_hidden_state,scan_out_cell_state,scan_out_hidden_state_concat]#h_t,c_t,h_t_concat\n", + ")\n", + "\n", + "# Creating the model from the above created graph and saving it.\n", + "lstm_scan_node_model = qonnx_make_model(scan_lstm_node_graph, producer_name=\"scan-lstm\")\n", + "onnx.save(lstm_scan_node_model, './lstm_scan_node_model.onnx')\n", + "netron.start('./lstm_scan_node_model.onnx')\n", + "\n", + "#Checking the model for any errors\n", + "onnx.checker.check_model(lstm_scan_node_model)\n", + "print(lstm_scan_node_model.graph.value_info)\n", + "\n", + "#Conversion to version 14 of onnx to accomodate clip nodes as done for the LSTM graph also.\n", + "lstm_scan_node_model.opset_import[0].version = 14" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have the SCAN based quantized LSTM model ready, we can now go forward and test it with the same sets of inputs we used for the testing of the brevitas model.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Defining the values of the varibales to test the execution of the scan model\n", + "num_inputs = 25\n", + "\n", + "#Initializing the initial values of the hidden state and the cell state. \n", + "# Also assigning the same input as the one used for the brevitas execution.\n", + "\n", + "hidden_state_inp = np.zeros((num_hidden_cells, 1)).astype(np.float32)#'h_t-1'\n", + "cell_state_inp = np.zeros((num_hidden_cells, 1)).astype(np.float32)#'c_t-1'\n", + "scan_inp = np.empty([num_inputs,num_features,1],dtype=np.float32).reshape([num_inputs,num_features,1])\n", + "scan_inp.fill(0.8)\n", + "\n", + "# Assigning the defined input values to the input dictionary of the scan model\n", + "input_dict = {}\n", + "input_dict[\"scan_hidden_state\"] = hidden_state_inp\n", + "input_dict[\"scan_cell_state\"] = cell_state_inp\n", + "input_dict[\"scan_input\"] = scan_inp\n", + "\n", + "# We can now set up the inference session and execute the scan onnx model here. \n", + "# The execution session gives some warnings which can be ignored.\n", + "\n", + "sess = rt.InferenceSession(lstm_scan_node_model.SerializeToString())\n", + "scan_output = sess.run(None, input_dict)\n", + "print('Final Hidden State',scan_output[0])\n", + "print(\"------------------------\")\n", + "print('Final Cell State',scan_output[1])\n", + "print(\"------------------------\")\n", + "print('All Hidden States',scan_output[2])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Functional Verification" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the final part of the notebook, we compare the output of the 8-bit quantized `(QCDQ)-LSTM` implementation with the `QuantLSTM` brevitas model.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We first match the shape of both the outputs to perform the functional verification correctly\n", + "\n", + "print('Brevitas Output shape : ', brevitas_output.shape)\n", + "all_hidden_states = np.array(scan_output[2])\n", + "all_hidden_states = all_hidden_states.reshape([num_inputs,1,num_hidden_cells])\n", + "print('SCAN-QCDQ-LSTM output shape :', all_hidden_states.shape)\n", + "print('-----------------------------------')\n", + "print('Brevitas Output = ',brevitas_output)\n", + "print('-----------------------------------')\n", + "print('SCAN-QCDQ-LSTM output',all_hidden_states)\n", + "print('-----------------------------------')\n", + "\n", + "# Comparison between the 'Scan-LSTM output' and the brevitas 'QuantLSTM' ouptut\n", + "# Since the outputs from both models are floating-point, to get a better understanding of the differences we scale the outputs to INT8 precision and then compare their differences.\n", + "# The scale used to do that is the last scale of the LSTM graph.\n", + "\n", + "scale = inp_scale_val #The scale value is equal to the value of the inp_scale_val\n", + "all_hidden_states = np.array(scan_output[2])\n", + "all_hidden_states = all_hidden_states.reshape([num_inputs,1,num_hidden_cells])\n", + "all_hidden_state_diff = (all_hidden_states - brevitas_output)\n", + "print(all_hidden_state_diff/scale)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note the difference in outputs increases as we progress with processing the inputs. The first two outputs are very close to one another, but as we get the outputs for more inputs we see for some values differ from the brevitas output by a considerable amount.\n", + "This behaviour can be attributed to some values being slightly different in the first few outputs (which are not visible) which eventually cause an increase in differences between both values as more inputs are processed." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/4_quant_lstm_helper/function.py b/notebooks/4_quant_lstm_helper/function.py new file mode 100644 index 00000000..6ba2e9dd --- /dev/null +++ b/notebooks/4_quant_lstm_helper/function.py @@ -0,0 +1,340 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import torch +from torch.autograd import Function + +from brevitas.export.onnx import onnx_export_opset + +AXIS_OPSET = 13 +DOMAIN_STRING = "onnx.brevitas" + + +class DequantizeLinearFn(Function): + + @staticmethod + def symbolic(g, x, input_scale, input_zero_point, input_axis): + opset_version = onnx_export_opset() + + if input_axis is not None and opset_version < AXIS_OPSET: + raise RuntimeError('ONNX Opset 13 is required for per-channel quantization') + elif input_axis is not None and opset_version >= AXIS_OPSET: + ret = g.op('DequantizeLinear', x, input_scale, input_zero_point, axis_i=input_axis) + else: + ret = g.op('DequantizeLinear', x, input_scale, input_zero_point) + return ret + + @staticmethod + def forward(ctx, int_x, input_scale, input_zero_point, input_axis): + return int_x.float() + + +class IntClipFn(Function): + + @staticmethod + def symbolic(g, int_x, min_int_val, max_int_val): + ret = g.op('Clip', int_x, min_int_val, max_int_val) + return ret + + @staticmethod + def forward(ctx, int_x, min_int_val, max_int_val): + return int_x + + +class QuantizeLinearFn(Function): + + @staticmethod + def symbolic(g, x, output_scale, ouput_zero_point, output_dtype, output_axis): + opset_version = onnx_export_opset() + + if output_axis is not None and opset_version < AXIS_OPSET: + raise RuntimeError('ONNX Opset 13 is required for per-channel quantization') + elif output_axis is not None and opset_version >= AXIS_OPSET: + ret = g.op('QuantizeLinear', x, output_scale, ouput_zero_point, axis_i=output_axis) + else: + ret = g.op('QuantizeLinear', x, output_scale, ouput_zero_point) + return ret + + @staticmethod + def forward(ctx, x, output_scale, ouput_zero_point, output_dtype, output_axis): + return x.type(output_dtype) + +class BrevitasQuantLSTMCellFn(Function): + + + @staticmethod + def symbolic( + g, # args and kwargs passed from _QuantLSTMLayer + quant_input, + quant_hidden_state, + quant_cell_state, + quant_weight_ii, + quant_weight_if, + quant_weight_ic, + quant_weight_io, + quant_weight_hi, + quant_weight_hf, + quant_weight_hc, + quant_weight_ho, + quant_bias_input, + quant_bias_forget, + quant_bias_cell, + quant_bias_output, # Symbolic kwargs passed from BrevitasQuantLSTMLayerHandler + batch_first, + reverse_input, + cifg, # Output quant + output_scale, + output_zero_point, + output_bit_width, + output_narrow_range, + output_signed, + output_rounding_mode, # Cell state quant + cell_state_scale, + cell_state_zero_point, + cell_state_bit_width, + cell_state_narrow_range, + cell_state_signed, + cell_state_rounding_mode, # Input gate accumulator quant + input_acc_scale, + input_acc_zero_point, + input_acc_bit_width, + input_acc_narrow_range, + input_acc_signed, + input_acc_rounding_mode, # Forget gate accumulator quant + forget_acc_scale, + forget_acc_zero_point, + forget_acc_bit_width, + forget_acc_narrow_range, + forget_acc_signed, + forget_acc_rounding_mode, # Cell gate accumulator quant + cell_acc_scale, + cell_acc_zero_point, + cell_acc_bit_width, + cell_acc_narrow_range, + cell_acc_signed, + cell_acc_rounding_mode, # Output gate accumulator quant + output_acc_scale, + output_acc_zero_point, + output_acc_bit_width, + output_acc_narrow_range, + output_acc_signed, + output_acc_rounding_mode, # Input gate sigmoid quant + input_sigmoid_scale, + input_sigmoid_zero_point, + input_sigmoid_bit_width, + input_sigmoid_narrow_range, + input_sigmoid_signed, + input_sigmoid_rounding_mode, # Forget gate sigmoid quant + forget_sigmoid_scale, + forget_sigmoid_zero_point, + forget_sigmoid_bit_width, + forget_sigmoid_narrow_range, + forget_sigmoid_signed, + forget_sigmoid_rounding_mode, # Cell gate tanh quant + cell_tanh_scale, + cell_tanh_zero_point, + cell_tanh_bit_width, + cell_tanh_narrow_range, + cell_tanh_signed, + cell_tanh_rounding_mode, # Output gate sigmoid quant + output_sigmoid_scale, + output_sigmoid_zero_point, + output_sigmoid_bit_width, + output_sigmoid_narrow_range, + output_sigmoid_signed, + output_sigmoid_rounding_mode, # Hidden state tanh quant + hidden_state_tanh_scale, + hidden_state_tanh_zero_point, + hidden_state_tanh_bit_width, + hidden_state_tanh_narrow_range, + hidden_state_tanh_signed, + hidden_state_tanh_rounding_mode): + return g.op( + f'{DOMAIN_STRING}::QuantLSTMCell', # Tensors + ## Input values + quant_input, + quant_hidden_state, + quant_cell_state, + quant_weight_ii, + quant_weight_if, + quant_weight_ic, + quant_weight_io, + quant_weight_hi, + quant_weight_hf, + quant_weight_hc, + quant_weight_ho, + quant_bias_input, + quant_bias_forget, + quant_bias_cell, + quant_bias_output, ## Output quant + output_scale, + output_zero_point, + output_bit_width, ## Cell state quant + cell_state_scale, + cell_state_zero_point, + cell_state_bit_width, ## Input gate accumulator quant + input_acc_scale, + input_acc_zero_point, + input_acc_bit_width, ## Forget gate accumulator quant + forget_acc_scale, + forget_acc_zero_point, + forget_acc_bit_width, ## Cell gate accumulator quant + cell_acc_scale, + cell_acc_zero_point, + cell_acc_bit_width, ## Output gate accumulator quant + output_acc_scale, + output_acc_zero_point, + output_acc_bit_width, ## Input gate sigmoid quant + input_sigmoid_scale, + input_sigmoid_zero_point, + input_sigmoid_bit_width, ## Forget gate sigmoid quant + forget_sigmoid_scale, + forget_sigmoid_zero_point, + forget_sigmoid_bit_width, ## Cell gate tanh quant + cell_tanh_scale, + cell_tanh_zero_point, + cell_tanh_bit_width, ## Output gate sigmoid quant + output_sigmoid_scale, + output_sigmoid_zero_point, + output_sigmoid_bit_width, ## Hidden state tanh quant + hidden_state_tanh_scale, + hidden_state_tanh_zero_point, + hidden_state_tanh_bit_width, + # Attributes + batch_first_i=batch_first, + reverse_input_i=reverse_input, + cifg_i=cifg, + output_narrow_i=output_narrow_range, + output_signed_i=output_signed, + output_rounding_mode_s=output_rounding_mode, + cell_state_narrow_i=cell_state_narrow_range, + cell_state_signed_i=cell_state_signed, + cell_state_rounding_mode_s=cell_state_rounding_mode, + input_acc_narrow_i=input_acc_narrow_range, + input_acc_signed_i=input_acc_signed, + input_acc_rounding_mode_s=input_acc_rounding_mode, + forget_acc_narrow_i=forget_acc_narrow_range, + forget_acc_signed_i=forget_acc_signed, + forget_acc_rounding_mode_s=forget_acc_rounding_mode, + cell_acc_narrow_i=cell_acc_narrow_range, + cell_acc_signed_i=cell_acc_signed, + cell_acc_rounding_mode_s=cell_acc_rounding_mode, + output_acc_narrow_i=output_acc_narrow_range, + output_acc_signed_i=output_acc_signed, + output_acc_rounding_mode_s=output_acc_rounding_mode, + input_sigmoid_narrow_i=input_sigmoid_narrow_range, + input_sigmoid_signed_i=input_sigmoid_signed, + input_sigmoid_rounding_mode_s=input_sigmoid_rounding_mode, + forget_sigmoid_narrow_i=forget_sigmoid_narrow_range, + forget_sigmoid_signed_i=forget_sigmoid_signed, + forget_sigmoid_rounding_mode_s=forget_sigmoid_rounding_mode, + cell_tanh_narrow_i=cell_tanh_narrow_range, + cell_tanh_signed_i=cell_tanh_signed, + cell_tanh_rounding_mode_s=cell_tanh_rounding_mode, + output_sigmoid_narrow_range_i=output_sigmoid_narrow_range, + output_sigmoid_signed_i=output_sigmoid_signed, + output_sigmoid_rounding_mode_s=output_sigmoid_rounding_mode, + hidden_state_tanh_narrow_i=hidden_state_tanh_narrow_range, + hidden_state_tanh_signed_i=hidden_state_tanh_signed, + hidden_state_tanh_rounding_mode_s=hidden_state_tanh_rounding_mode, + # PyTorch requires to specify the number of outputs manually + outputs=3) + + + @staticmethod + def forward( + ctx, # args and kwargs passed from _QuantLSTMLayer + quant_input, + quant_hidden_state, + quant_cell_state, + quant_weight_ii, + quant_weight_if, + quant_weight_ic, + quant_weight_io, + quant_weight_hi, + quant_weight_hf, + quant_weight_hc, + quant_weight_ho, + quant_bias_input, + quant_bias_forget, + quant_bias_cell, + quant_bias_output, # Symbolic kwargs passed from BrevitasQuantLSTMLayerHandler + batch_first, + reverse_input, + cifg, # Output quant + output_scale, + output_zero_point, + output_bit_width, + output_narrow_range, + output_signed, + output_rounding_mode, # Cell state quant + cell_state_scale, + cell_state_zero_point, + cell_state_bit_width, + cell_state_narrow_range, + cell_state_signed, + cell_state_rounding_mode, # Input gate accumulator quant + input_acc_scale, + input_acc_zero_point, + input_acc_bit_width, + input_acc_narrow_range, + input_acc_signed, + input_acc_rounding_mode, # Forget gate accumulator quant + forget_acc_scale, + forget_acc_zero_point, + forget_acc_bit_width, + forget_acc_narrow_range, + forget_acc_signed, + forget_acc_rounding_mode, # Cell gate accumulator quant + cell_acc_scale, + cell_acc_zero_point, + cell_acc_bit_width, + cell_acc_narrow_range, + cell_acc_signed, + cell_acc_rounding_mode, # Output gate accumulator quant + output_acc_scale, + output_acc_zero_point, + output_acc_bit_width, + output_acc_narrow_range, + output_acc_signed, + output_acc_rounding_mode, # Input gate sigmoid quant + input_sigmoid_scale, + input_sigmoid_zero_point, + input_sigmoid_bit_width, + input_sigmoid_narrow_range, + input_sigmoid_signed, + input_sigmoid_rounding_mode, # Forget gate sigmoid quant + forget_sigmoid_scale, + forget_sigmoid_zero_point, + forget_sigmoid_bit_width, + forget_sigmoid_narrow_range, + forget_sigmoid_signed, + forget_sigmoid_rounding_mode, # Cell gate tanh quant + cell_tanh_scale, + cell_tanh_zero_point, + cell_tanh_bit_width, + cell_tanh_narrow_range, + cell_tanh_signed, + cell_tanh_rounding_mode, # Output gate sigmoid quant + output_sigmoid_scale, + output_sigmoid_zero_point, + output_sigmoid_bit_width, + output_sigmoid_narrow_range, + output_sigmoid_signed, + output_sigmoid_rounding_mode, # Hidden state tanh quant + hidden_state_tanh_scale, + hidden_state_tanh_zero_point, + hidden_state_tanh_bit_width, + hidden_state_tanh_narrow_range, + hidden_state_tanh_signed, + hidden_state_tanh_rounding_mode): + # Tp simplify things, here we are returning the outputs + # as if they were already concatenated. Scale/zp/bw are avoided too. + # This preserves output shapes but not values. + # See _QuantLSTMCell for the actual implementation. + quant_outputs = torch.zeros( + quant_input.size(0), + quant_input.size(1), + quant_hidden_state.size(1), + device=quant_hidden_state.device) + return quant_outputs, quant_hidden_state, quant_cell_state diff --git a/notebooks/4_quant_lstm_helper/handler.py b/notebooks/4_quant_lstm_helper/handler.py new file mode 100644 index 00000000..948eb647 --- /dev/null +++ b/notebooks/4_quant_lstm_helper/handler.py @@ -0,0 +1,140 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from abc import ABC +from copy import copy + +import torch +from torch import Tensor + +from brevitas.export.common.handler.base import QuantAxisMixin +from brevitas.export.common.handler.qcdq import DQMixin +from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQMixin +from brevitas.export.common.handler.qcdq import QCDQTruncQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import QCDQWeightQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import ZeroPointHandlerMixin +from brevitas.export.onnx.handler import ONNXBaseHandler +from brevitas.export.onnx.handler import QuantLSTMLayerHandler + +from ..function import DequantizeLinearFn +from ..function import IntClipFn +from ..function import QuantizeLinearFn +from ..function import BrevitasQuantLSTMCellFn + + +class StdDQONNXMixin(DQMixin, ABC): + + def dequantize_fn(self, x, scale, zero_point, axis): + return DequantizeLinearFn.apply(x, scale, zero_point, axis) + + @property + def flatten_dequantize_params(self): + return True + + @property + def itemize_quantize_scalar_params(self): + return False + + +class StdQCDQONNXMixin(QCDQMixin, StdDQONNXMixin, ABC): + + @property + def clip_over_integers(self): + return True + + @classmethod + def int8_dtype(cls): + return torch.int8 + + @classmethod + def uint8_dtype(cls): + return torch.uint8 + + @classmethod + def int32_dtype(cls): + return torch.int32 + + def validate(self, module): + self.validate_8b_bit_width(module.bit_width(), le_then=True) + assert module.bit_width() > 1., 'Binary quant not supported' + assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported' + + def quantize_fn(self, x, scale, zero_point, dtype, axis): + return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis) + + def clip_fn(self, x, min_val, max_val): + return IntClipFn.apply(x, min_val, max_val) + + +class StdQCDQONNXWeightQuantProxyHandler(StdQCDQONNXMixin, + QCDQWeightQuantProxyHandlerMixin, + ONNXBaseHandler): + pass + + +class StdQCDQONNXDecoupledWeightQuantProxyHandler(StdQCDQONNXMixin, + QCDQDecoupledWeightQuantProxyHandlerMixin, + ONNXBaseHandler): + pass + + +class StdQCDQONNXActQuantProxyHandler(StdQCDQONNXMixin, + QCDQActQuantProxyHandlerMixin, + ONNXBaseHandler): + pass + + +class StdQCDQONNXBiasQuantProxyHandler(StdDQONNXMixin, + QCDQBiasQuantProxyHandlerMixin, + ONNXBaseHandler): + pass + + +class StdQCDQONNXTruncQuantProxyHandler(StdQCDQONNXMixin, + QCDQTruncQuantProxyHandlerMixin, + ONNXBaseHandler): + pass + + +class StdQCDQONNXQuantLSTMLayerHandler(QuantLSTMLayerHandler): + + def quantized_cell_symbolic_execution( + self, + quant_input, + quant_hidden_state, + quant_cell_state, + quant_weight_ii, + quant_weight_if, + quant_weight_ic, + quant_weight_io, + quant_weight_hi, + quant_weight_hf, + quant_weight_hc, + quant_weight_ho, + quant_bias_input, + quant_bias_forget, + quant_bias_cell, + quant_bias_output): + return BrevitasQuantLSTMCellFn.apply( + quant_input, + quant_hidden_state, + quant_cell_state, + quant_weight_ii, + quant_weight_if, + quant_weight_ic, + quant_weight_io, + quant_weight_hi, + quant_weight_hf, + quant_weight_hc, + quant_weight_ho, + quant_bias_input, + quant_bias_forget, + quant_bias_cell, + quant_bias_output, + *self.symbolic_kwargs.values()) + # raise RuntimeError( + # "Quantized LSTM cell is not supported for ONNX QCDQ " + # "(weights only quantization is). Use export_qonnx.")