From e1da07b801ba524b9b51d48b45deb1fec29ef45b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 28 Mar 2024 13:29:21 +0100 Subject: [PATCH] Feat!: remove quant metadata from quantlayer (#883) Breaking change: The interface to access quant metadata has changed and now everything is directly delegated to the underlying proxies. --- ...1_quant_tensor_quant_conv2d_overview.ipynb | 50 ++--- notebooks/03_anatomy_of_a_quantizer.ipynb | 16 +- notebooks/Brevitas_TVMCon2021.ipynb | 14 +- src/brevitas/export/common/handler/base.py | 21 -- src/brevitas/export/manager.py | 39 ++-- src/brevitas/export/onnx/manager.py | 8 +- src/brevitas/graph/gpfq.py | 4 +- src/brevitas/graph/gptq.py | 2 +- src/brevitas/graph/quantize.py | 4 +- src/brevitas/graph/quantize_impl.py | 4 +- src/brevitas/nn/mixin/act.py | 77 -------- src/brevitas/nn/mixin/base.py | 96 +--------- src/brevitas/nn/mixin/parameter.py | 125 +----------- src/brevitas/nn/quant_layer.py | 179 ++---------------- src/brevitas/proxy/parameter_quant.py | 73 +++++-- src/brevitas/proxy/runtime_quant.py | 76 +++++--- src/brevitas/utils/quant_utils.py | 27 +++ .../super_resolution/utils/evaluate.py | 4 +- tests/brevitas/graph/test_calibration.py | 2 +- tests/brevitas/nn/test_linear.py | 2 +- tests/brevitas/nn/test_wbiol.py | 40 ++-- tests/brevitas/proxy/test_act_scaling.py | 2 +- tests/brevitas/proxy/test_proxy.py | 4 +- tests/brevitas/proxy/test_weight_scaling.py | 10 +- 24 files changed, 241 insertions(+), 638 deletions(-) diff --git a/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb b/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb index 47a4ed48b..d2982f14b 100644 --- a/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb +++ b/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb @@ -142,10 +142,10 @@ } ], "source": [ - "print(f'Is weight quant enabled: {default_quant_conv.is_weight_quant_enabled}')\n", - "print(f'Is bias quant enabled: {default_quant_conv.is_bias_quant_enabled}')\n", - "print(f'Is input quant enabled: {default_quant_conv.is_input_quant_enabled}')\n", - "print(f'Is output quant enabled: {default_quant_conv.is_output_quant_enabled}')" + "print(f'Is weight quant enabled: {default_quant_conv.weight_quant.is_quant_enabled}')\n", + "print(f'Is bias quant enabled: {default_quant_conv.bias_quant.is_quant_enabled}')\n", + "print(f'Is input quant enabled: {default_quant_conv.input_quant.is_quant_enabled}')\n", + "print(f'Is output quant enabled: {default_quant_conv.output_quant.is_quant_enabled}')" ] }, { @@ -157,19 +157,11 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 12, "metadata": { "scrolled": true }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/scratch/fabian/miniforge3/envs/torchgpu/lib/python3.11/site-packages/torch/_tensor.py:1362: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1699449183005/work/c10/core/TensorImpl.h:1900.)\n", - " return super().rename(names)\n" - ] - }, { "data": { "text/plain": [ @@ -327,9 +319,9 @@ } ], "source": [ - "int_weight = default_quant_conv.int_weight()\n", - "zero_point = default_quant_conv.quant_weight_zero_point()\n", - "scale = default_quant_conv.quant_weight_scale()\n", + "int_weight = default_quant_conv.quant_weight().int()\n", + "zero_point = default_quant_conv.weight_quant.zero_point()\n", + "scale = default_quant_conv.weight_quant.scale()\n", "quant_weight_manually = (int_weight - zero_point) * scale\n", "\n", "assert_with_message(default_quant_conv.quant_weight().value.isclose(quant_weight_manually).all().item())" @@ -878,11 +870,11 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[23], line 6\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mbrevitas\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mquant\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mscaled_int\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Int8Bias\n\u001b[1;32m 3\u001b[0m bias_quant_conv \u001b[38;5;241m=\u001b[39m QuantConv2d(\n\u001b[1;32m 4\u001b[0m in_channels\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m, out_channels\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m, kernel_size\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m3\u001b[39m,\u001b[38;5;241m3\u001b[39m), bias\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 5\u001b[0m bias_quant\u001b[38;5;241m=\u001b[39mInt8Bias, return_quant_tensor\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m----> 6\u001b[0m bias_quant_conv(torch\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m5\u001b[39m, \u001b[38;5;241m5\u001b[39m))\n", - "File \u001b[0;32m/scratch/fabian/miniforge3/envs/torchgpu/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", - "File \u001b[0;32m/scratch/fabian/miniforge3/envs/torchgpu/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m/scratch/fabian/brevitas/src/brevitas/nn/quant_conv.py:198\u001b[0m, in \u001b[0;36mQuantConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Union[Tensor, QuantTensor]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[Tensor, QuantTensor]:\n\u001b[0;32m--> 198\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward_impl(\u001b[38;5;28minput\u001b[39m)\n", - "File \u001b[0;32m/scratch/fabian/brevitas/src/brevitas/nn/quant_layer.py:318\u001b[0m, in \u001b[0;36mQuantWeightBiasInputOutputLayer.forward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 314\u001b[0m compute_output_quant_tensor \u001b[38;5;241m=\u001b[39m \u001b[38;5;28misinstance\u001b[39m(quant_input, QuantTensor) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\n\u001b[1;32m 315\u001b[0m quant_weight, QuantTensor)\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (compute_output_quant_tensor \u001b[38;5;129;01mor\u001b[39;00m\n\u001b[1;32m 317\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mis_output_quant_enabled) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreturn_quant_tensor:\n\u001b[0;32m--> 318\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mQuantLayer is not correctly configured\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 320\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(quant_input, QuantTensor) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(quant_weight, QuantTensor):\n\u001b[1;32m 321\u001b[0m output_bit_width \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmax_acc_bit_width(quant_input\u001b[38;5;241m.\u001b[39mbit_width, quant_weight\u001b[38;5;241m.\u001b[39mbit_width)\n", + "Cell \u001b[0;32mIn[41], line 6\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mbrevitas\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mquant\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mscaled_int\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Int8Bias\n\u001b[1;32m 3\u001b[0m bias_quant_conv \u001b[38;5;241m=\u001b[39m QuantConv2d(\n\u001b[1;32m 4\u001b[0m in_channels\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m, out_channels\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m, kernel_size\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m3\u001b[39m,\u001b[38;5;241m3\u001b[39m), bias\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 5\u001b[0m bias_quant\u001b[38;5;241m=\u001b[39mInt8Bias, return_quant_tensor\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m----> 6\u001b[0m \u001b[43mbias_quant_conv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrandn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_conv.py:198\u001b[0m, in \u001b[0;36mQuantConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Union[Tensor, QuantTensor]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[Tensor, QuantTensor]:\n\u001b[0;32m--> 198\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:173\u001b[0m, in \u001b[0;36mQuantWeightBiasInputOutputLayer.forward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 169\u001b[0m compute_output_quant_tensor \u001b[38;5;241m=\u001b[39m \u001b[38;5;28misinstance\u001b[39m(quant_input, QuantTensor) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\n\u001b[1;32m 170\u001b[0m quant_weight, QuantTensor)\n\u001b[1;32m 171\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (compute_output_quant_tensor \u001b[38;5;129;01mor\u001b[39;00m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mis_output_quant_enabled) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreturn_quant_tensor:\n\u001b[0;32m--> 173\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mQuantLayer is not correctly configured\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 175\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(quant_input, QuantTensor) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(quant_weight, QuantTensor):\n\u001b[1;32m 176\u001b[0m output_bit_width \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmax_acc_bit_width(quant_input\u001b[38;5;241m.\u001b[39mbit_width, quant_weight\u001b[38;5;241m.\u001b[39mbit_width)\n", "\u001b[0;31mRuntimeError\u001b[0m: QuantLayer is not correctly configured" ] } @@ -1029,14 +1021,14 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[27], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m output_bias_quant_conv \u001b[38;5;241m=\u001b[39m QuantConv2d(\n\u001b[1;32m 2\u001b[0m in_channels\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m, out_channels\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m, kernel_size\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m3\u001b[39m,\u001b[38;5;241m3\u001b[39m), bias\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 3\u001b[0m output_quant\u001b[38;5;241m=\u001b[39mInt8ActPerTensorFloat, bias_quant\u001b[38;5;241m=\u001b[39mInt8Bias, return_quant_tensor\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m----> 4\u001b[0m output_bias_quant_conv(torch\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m5\u001b[39m, \u001b[38;5;241m5\u001b[39m))\n", - "File \u001b[0;32m/scratch/fabian/miniforge3/envs/torchgpu/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", - "File \u001b[0;32m/scratch/fabian/miniforge3/envs/torchgpu/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m/scratch/fabian/brevitas/src/brevitas/nn/quant_conv.py:198\u001b[0m, in \u001b[0;36mQuantConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Union[Tensor, QuantTensor]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[Tensor, QuantTensor]:\n\u001b[0;32m--> 198\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward_impl(\u001b[38;5;28minput\u001b[39m)\n", - "File \u001b[0;32m/scratch/fabian/brevitas/src/brevitas/nn/quant_layer.py:326\u001b[0m, in \u001b[0;36mQuantWeightBiasInputOutputLayer.forward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 323\u001b[0m output_signed \u001b[38;5;241m=\u001b[39m quant_input\u001b[38;5;241m.\u001b[39msigned \u001b[38;5;129;01mor\u001b[39;00m quant_weight\u001b[38;5;241m.\u001b[39msigned\n\u001b[1;32m 325\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbias \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 326\u001b[0m quant_bias \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbias_quant(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbias, output_scale)\n\u001b[1;32m 327\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcache_inference_quant_bias \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(quant_bias,\n\u001b[1;32m 328\u001b[0m QuantTensor):\n\u001b[1;32m 330\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_cached_bias \u001b[38;5;241m=\u001b[39m _CachedIO(quant_bias\u001b[38;5;241m.\u001b[39mdetach(), metadata_only\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n", - "File \u001b[0;32m/scratch/fabian/miniforge3/envs/torchgpu/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", - "File \u001b[0;32m/scratch/fabian/miniforge3/envs/torchgpu/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m/scratch/fabian/brevitas/src/brevitas/proxy/parameter_quant.py:193\u001b[0m, in \u001b[0;36mBiasQuantProxyFromInjector.forward\u001b[0;34m(self, x, input_scale)\u001b[0m\n\u001b[1;32m 191\u001b[0m impl \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mexport_handler \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mexport_mode \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtensor_quant\n\u001b[1;32m 192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrequires_input_scale \u001b[38;5;129;01mand\u001b[39;00m input_scale \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 193\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInput scale required\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 195\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrequires_input_scale:\n\u001b[1;32m 196\u001b[0m input_scale \u001b[38;5;241m=\u001b[39m input_scale\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n", + "Cell \u001b[0;32mIn[45], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m output_bias_quant_conv \u001b[38;5;241m=\u001b[39m QuantConv2d(\n\u001b[1;32m 2\u001b[0m in_channels\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m, out_channels\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m, kernel_size\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m3\u001b[39m,\u001b[38;5;241m3\u001b[39m), bias\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 3\u001b[0m output_quant\u001b[38;5;241m=\u001b[39mInt8ActPerTensorFloat, bias_quant\u001b[38;5;241m=\u001b[39mInt8Bias, return_quant_tensor\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m----> 4\u001b[0m \u001b[43moutput_bias_quant_conv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrandn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_conv.py:198\u001b[0m, in \u001b[0;36mQuantConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Union[Tensor, QuantTensor]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[Tensor, QuantTensor]:\n\u001b[0;32m--> 198\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:181\u001b[0m, in \u001b[0;36mQuantWeightBiasInputOutputLayer.forward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 178\u001b[0m output_signed \u001b[38;5;241m=\u001b[39m quant_input\u001b[38;5;241m.\u001b[39msigned \u001b[38;5;129;01mor\u001b[39;00m quant_weight\u001b[38;5;241m.\u001b[39msigned\n\u001b[1;32m 180\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbias \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 181\u001b[0m quant_bias \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias_quant\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_scale\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 182\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcache_inference_quant_bias \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(quant_bias,\n\u001b[1;32m 183\u001b[0m QuantTensor):\n\u001b[1;32m 184\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbias_quant\u001b[38;5;241m.\u001b[39m_cached_bias \u001b[38;5;241m=\u001b[39m _CachedIO(quant_bias\u001b[38;5;241m.\u001b[39mdetach(), metadata_only\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/parameter_quant.py:240\u001b[0m, in \u001b[0;36mBiasQuantProxyFromInjector.forward\u001b[0;34m(self, x, input_scale)\u001b[0m\n\u001b[1;32m 238\u001b[0m impl \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mexport_handler \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mexport_mode \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtensor_quant\n\u001b[1;32m 239\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrequires_input_scale \u001b[38;5;129;01mand\u001b[39;00m input_scale \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 240\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInput scale required\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 242\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrequires_input_scale:\n\u001b[1;32m 243\u001b[0m input_scale \u001b[38;5;241m=\u001b[39m input_scale\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n", "\u001b[0;31mRuntimeError\u001b[0m: Input scale required" ] } diff --git a/notebooks/03_anatomy_of_a_quantizer.ipynb b/notebooks/03_anatomy_of_a_quantizer.ipynb index 8a919dbe8..c4c9295e7 100644 --- a/notebooks/03_anatomy_of_a_quantizer.ipynb +++ b/notebooks/03_anatomy_of_a_quantizer.ipynb @@ -775,7 +775,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Indeed we can verify that `quant_weight_scale()` is equal to `weight.abs().max()`:" + "Indeed we can verify that `weight_quant.scale()` is equal to `weight.abs().max()`:" ] }, { @@ -792,7 +792,7 @@ } ], "source": [ - "assert_with_message((param_from_max_quant_conv.quant_weight_scale() == param_from_max_quant_conv.weight.abs().max()).item())" + "assert_with_message((param_from_max_quant_conv.weight_quant.scale() == param_from_max_quant_conv.weight.abs().max()).item())" ] }, { @@ -1024,7 +1024,7 @@ } ], "source": [ - "assert_with_message((quant_conv1.quant_weight_scale() == quant_conv2.quant_weight_scale()).item())" + "assert_with_message((quant_conv1.weight_quant.scale() == quant_conv2.weight_quant.scale()).item())" ] }, { @@ -1059,9 +1059,9 @@ " return module.weight.abs().mean()\n", " \n", "quant_conv1 = QuantConv2d(3, 2, (3, 3), weight_quant=SharedParamFromMeanWeightQuantizer)\n", - "old_quant_conv1_scale = quant_conv1.quant_weight_scale()\n", + "old_quant_conv1_scale = quant_conv1.weight_quant.scale()\n", "quant_conv2 = QuantConv2d(3, 2, (3, 3), weight_quant=quant_conv1.weight_quant)\n", - "new_quant_conv1_scale = quant_conv1.quant_weight_scale()\n", + "new_quant_conv1_scale = quant_conv1.weight_quant.scale()\n", "\n", "assert_with_message(not (old_quant_conv1_scale == new_quant_conv1_scale).item())" ] @@ -1080,7 +1080,7 @@ } ], "source": [ - "assert_with_message((new_quant_conv1_scale == quant_conv2.quant_weight_scale()).item())" + "assert_with_message((new_quant_conv1_scale == quant_conv2.weight_quant.scale()).item())" ] }, { @@ -1134,7 +1134,7 @@ "quant_conv_w_init = QuantConv2d(3, 2, (3, 3), weight_quant=ParamFromMaxWeightQuantizer)\n", "torch.nn.init.uniform_(quant_conv_w_init.weight)\n", "\n", - "assert_with_message(not (quant_conv_w_init.weight.abs().max() == quant_conv_w_init.quant_weight_scale()).item())" + "assert_with_message(not (quant_conv_w_init.weight.abs().max() == quant_conv_w_init.weight_quant.scale()).item())" ] }, { @@ -1160,7 +1160,7 @@ "source": [ "quant_conv_w_init.weight_quant.init_tensor_quant()\n", "\n", - "assert_with_message((quant_conv_w_init.weight.abs().max() == quant_conv_w_init.quant_weight_scale()).item())" + "assert_with_message((quant_conv_w_init.weight.abs().max() == quant_conv_w_init.weight_quant.scale()).item())" ] }, { diff --git a/notebooks/Brevitas_TVMCon2021.ipynb b/notebooks/Brevitas_TVMCon2021.ipynb index a6f24ea72..fc11303aa 100644 --- a/notebooks/Brevitas_TVMCon2021.ipynb +++ b/notebooks/Brevitas_TVMCon2021.ipynb @@ -631,11 +631,11 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[13], line 8\u001b[0m\n\u001b[1;32m 5\u001b[0m float_input \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m3\u001b[39m, \u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 6\u001b[0m quant_linear \u001b[38;5;241m=\u001b[39m QuantLinear(\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m4\u001b[39m, bias\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, bias_quant\u001b[38;5;241m=\u001b[39mInt16Bias, return_quant_tensor\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m----> 8\u001b[0m quant_output \u001b[38;5;241m=\u001b[39m quant_linear(float_input)\n", - "File \u001b[0;32m/scratch/fabian/miniforge3/envs/torchgpu/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", - "File \u001b[0;32m/scratch/fabian/miniforge3/envs/torchgpu/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m/scratch/fabian/brevitas/src/brevitas/nn/quant_linear.py:66\u001b[0m, in \u001b[0;36mQuantLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Union[Tensor, QuantTensor]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[Tensor, QuantTensor]:\n\u001b[0;32m---> 66\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward_impl(\u001b[38;5;28minput\u001b[39m)\n", - "File \u001b[0;32m/scratch/fabian/brevitas/src/brevitas/nn/quant_layer.py:318\u001b[0m, in \u001b[0;36mQuantWeightBiasInputOutputLayer.forward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 314\u001b[0m compute_output_quant_tensor \u001b[38;5;241m=\u001b[39m \u001b[38;5;28misinstance\u001b[39m(quant_input, QuantTensor) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\n\u001b[1;32m 315\u001b[0m quant_weight, QuantTensor)\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (compute_output_quant_tensor \u001b[38;5;129;01mor\u001b[39;00m\n\u001b[1;32m 317\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mis_output_quant_enabled) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreturn_quant_tensor:\n\u001b[0;32m--> 318\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mQuantLayer is not correctly configured\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 320\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(quant_input, QuantTensor) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(quant_weight, QuantTensor):\n\u001b[1;32m 321\u001b[0m output_bit_width \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmax_acc_bit_width(quant_input\u001b[38;5;241m.\u001b[39mbit_width, quant_weight\u001b[38;5;241m.\u001b[39mbit_width)\n", + "Cell \u001b[0;32mIn[13], line 8\u001b[0m\n\u001b[1;32m 5\u001b[0m float_input \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m3\u001b[39m, \u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 6\u001b[0m quant_linear \u001b[38;5;241m=\u001b[39m QuantLinear(\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m4\u001b[39m, bias\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, bias_quant\u001b[38;5;241m=\u001b[39mInt16Bias, return_quant_tensor\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m----> 8\u001b[0m quant_output \u001b[38;5;241m=\u001b[39m \u001b[43mquant_linear\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfloat_input\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_linear.py:66\u001b[0m, in \u001b[0;36mQuantLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Union[Tensor, QuantTensor]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[Tensor, QuantTensor]:\n\u001b[0;32m---> 66\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:173\u001b[0m, in \u001b[0;36mQuantWeightBiasInputOutputLayer.forward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 169\u001b[0m compute_output_quant_tensor \u001b[38;5;241m=\u001b[39m \u001b[38;5;28misinstance\u001b[39m(quant_input, QuantTensor) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\n\u001b[1;32m 170\u001b[0m quant_weight, QuantTensor)\n\u001b[1;32m 171\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (compute_output_quant_tensor \u001b[38;5;129;01mor\u001b[39;00m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mis_output_quant_enabled) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreturn_quant_tensor:\n\u001b[0;32m--> 173\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mQuantLayer is not correctly configured\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 175\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(quant_input, QuantTensor) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(quant_weight, QuantTensor):\n\u001b[1;32m 176\u001b[0m output_bit_width \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmax_acc_bit_width(quant_input\u001b[38;5;241m.\u001b[39mbit_width, quant_weight\u001b[38;5;241m.\u001b[39mbit_width)\n", "\u001b[0;31mRuntimeError\u001b[0m: QuantLayer is not correctly configured" ] } @@ -1432,8 +1432,8 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[29], line 12\u001b[0m\n\u001b[1;32m 5\u001b[0m float_linear \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mLinear(\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m4\u001b[39m, bias\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 6\u001b[0m quant_linear \u001b[38;5;241m=\u001b[39m QuantLinear(\n\u001b[1;32m 7\u001b[0m \u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m4\u001b[39m, \n\u001b[1;32m 8\u001b[0m input_quant\u001b[38;5;241m=\u001b[39mLearnedIntActPerTensorFloat,\n\u001b[1;32m 9\u001b[0m weight_quant\u001b[38;5;241m=\u001b[39mLearnedIntWeightPerChannelFloat, \n\u001b[1;32m 10\u001b[0m return_quant_tensor\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, bias\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m---> 12\u001b[0m quant_linear\u001b[38;5;241m.\u001b[39mload_state_dict(float_linear\u001b[38;5;241m.\u001b[39mstate_dict())\n", - "File \u001b[0;32m/scratch/fabian/miniforge3/envs/torchgpu/lib/python3.11/site-packages/torch/nn/modules/module.py:2152\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict, assign)\u001b[0m\n\u001b[1;32m 2147\u001b[0m error_msgs\u001b[38;5;241m.\u001b[39minsert(\n\u001b[1;32m 2148\u001b[0m \u001b[38;5;241m0\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mMissing key(s) in state_dict: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m. \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 2149\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m missing_keys)))\n\u001b[1;32m 2151\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(error_msgs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m-> 2152\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mError(s) in loading state_dict for \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 2153\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(error_msgs)))\n\u001b[1;32m 2154\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n", + "Cell \u001b[0;32mIn[29], line 12\u001b[0m\n\u001b[1;32m 5\u001b[0m float_linear \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mLinear(\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m4\u001b[39m, bias\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 6\u001b[0m quant_linear \u001b[38;5;241m=\u001b[39m QuantLinear(\n\u001b[1;32m 7\u001b[0m \u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m4\u001b[39m, \n\u001b[1;32m 8\u001b[0m input_quant\u001b[38;5;241m=\u001b[39mLearnedIntActPerTensorFloat,\n\u001b[1;32m 9\u001b[0m weight_quant\u001b[38;5;241m=\u001b[39mLearnedIntWeightPerChannelFloat, \n\u001b[1;32m 10\u001b[0m return_quant_tensor\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, bias\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m---> 12\u001b[0m \u001b[43mquant_linear\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_state_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfloat_linear\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstate_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:2153\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict, assign)\u001b[0m\n\u001b[1;32m 2148\u001b[0m error_msgs\u001b[38;5;241m.\u001b[39minsert(\n\u001b[1;32m 2149\u001b[0m \u001b[38;5;241m0\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mMissing key(s) in state_dict: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m. \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 2150\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m missing_keys)))\n\u001b[1;32m 2152\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(error_msgs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m-> 2153\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mError(s) in loading state_dict for \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 2154\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(error_msgs)))\n\u001b[1;32m 2155\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n", "\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for QuantLinear:\n\tMissing key(s) in state_dict: \"input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value\", \"input_quant.fused_activation_quant_proxy.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\", \"weight_quant.tensor_quant.scaling_impl.value\", \"weight_quant.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\". " ] } diff --git a/src/brevitas/export/common/handler/base.py b/src/brevitas/export/common/handler/base.py index f7706f450..6136a4cdc 100644 --- a/src/brevitas/export/common/handler/base.py +++ b/src/brevitas/export/common/handler/base.py @@ -127,24 +127,3 @@ def zero_point_with_dtype(cls, signed, bit_width, zero_point): return zero_point.type(torch.int8) else: return zero_point.type(torch.int32) - - @classmethod - def quant_input_zero_point(cls, module): - signed = module.is_quant_input_signed - zero_point = module.quant_input_zero_point() - bit_width = module.quant_input_bit_width() - return cls.zero_point_with_dtype(signed, bit_width, zero_point) - - @classmethod - def quant_weight_zero_point(cls, module): - signed = module.is_quant_weight_signed - zero_point = module.quant_weight_zero_point() - bit_width = module.quant_weight_bit_width() - return cls.zero_point_with_dtype(signed, bit_width, zero_point) - - @classmethod - def quant_output_zero_point(cls, module): - signed = module.is_quant_output_signed - zero_point = module.quant_output_zero_point() - bit_width = module.quant_output_bit_width() - return cls.zero_point_with_dtype(signed, bit_width, zero_point) diff --git a/src/brevitas/export/manager.py b/src/brevitas/export/manager.py index 8ad069642..2805c6174 100644 --- a/src/brevitas/export/manager.py +++ b/src/brevitas/export/manager.py @@ -14,13 +14,13 @@ from torch.nn import Module from brevitas import config -from brevitas.nn.mixin.base import _CachedIO from brevitas.nn.mixin.base import QuantLayerMixin from brevitas.nn.mixin.base import QuantRecurrentLayerMixin from brevitas.proxy.quant_proxy import QuantProxyProtocol from brevitas.quant_tensor import QuantTensor from brevitas.utils.jit_utils import clear_class_registry from brevitas.utils.python_utils import patch +from brevitas.utils.quant_utils import _CachedIO class _JitTraceExportWrapper(nn.Module): @@ -64,18 +64,11 @@ def _override_bias_caching_mode(m: Module, enabled: bool): m.cache_inference_quant_bias = enabled -def _override_inp_caching_mode(m: Module, enabled: bool): - if hasattr(m, 'cache_inference_quant_inp'): - if not hasattr(m, "cache_inference_quant_inp_backup"): - m.cache_inference_quant_inp_backup = m.cache_inference_quant_inp - m.cache_inference_quant_inp = enabled - - -def _override_out_caching_mode(m: Module, enabled: bool): - if hasattr(m, 'cache_inference_quant_out'): - if not hasattr(m, "cache_inference_quant_out_backup"): - m.cache_inference_quant_out_backup = m.cache_inference_quant_out - m.cache_inference_quant_out = enabled +def _override_act_caching_mode(m: Module, enabled: bool): + if hasattr(m, 'cache_inference_quant_act'): + if not hasattr(m, "cache_inference_quant_act_backup"): + m.cache_inference_quant_act_backup = m.cache_inference_quant_act + m.cache_inference_quant_act = enabled def _restore_quant_metadata_caching_mode(m: Module): @@ -90,16 +83,10 @@ def _restore_bias_caching_mode(m: Module): del m.cache_inference_quant_bias_backup -def _restore_inp_caching_mode(m: Module): - if hasattr(m, "cache_inference_quant_inp_backup"): - m.cache_inference_quant_inp = m.cache_inference_quant_inp_backup - del m.cache_inference_quant_inp_backup - - -def _restore_out_caching_mode(m: Module): - if hasattr(m, "cache_inference_quant_out_backup"): - m.cache_inference_quant_out = m.cache_inference_quant_out_backup - del m.cache_inference_quant_out_backup +def _restore_act_caching_mode(m: Module): + if hasattr(m, "cache_inference_quant_act_backup"): + m.cache_inference_quant_act = m.cache_inference_quant_act_backup + del m.cache_inference_quant_act_backup def _set_recurrent_layer_export_mode(model: Module, enabled: bool): @@ -202,14 +189,12 @@ def _cache_inp_out(cls, module, *args, **kwargs): # force enable caching module.apply(lambda m: _override_quant_metadata_caching_mode(m, enabled=True)) module.apply(lambda m: _override_bias_caching_mode(m, enabled=True)) - module.apply(lambda m: _override_inp_caching_mode(m, enabled=True)) - module.apply(lambda m: _override_out_caching_mode(m, enabled=True)) + module.apply(lambda m: _override_act_caching_mode(m, enabled=True)) _ = module.forward(*args, **kwargs) # Restore previous caching properties module.apply(lambda m: _restore_quant_metadata_caching_mode(m)) module.apply(lambda m: _restore_bias_caching_mode(m)) - module.apply(lambda m: _restore_inp_caching_mode(m)) - module.apply(lambda m: _restore_out_caching_mode(m)) + module.apply(lambda m: _restore_act_caching_mode(m)) @classmethod def jit_inference_trace( diff --git a/src/brevitas/export/onnx/manager.py b/src/brevitas/export/onnx/manager.py index 56d5bf753..1bacb461e 100644 --- a/src/brevitas/export/onnx/manager.py +++ b/src/brevitas/export/onnx/manager.py @@ -24,8 +24,8 @@ from brevitas import torch_version from brevitas.quant_tensor import QuantTensor -from ..manager import _override_inp_caching_mode -from ..manager import _restore_inp_caching_mode +from ..manager import _override_act_caching_mode +from ..manager import _restore_act_caching_mode from ..manager import BaseManager from ..manager import ExportContext @@ -120,7 +120,7 @@ def export_onnx( # enable export mode, this triggers collecting export values into handlers cls.set_export_mode(module, enabled=True) # temporarily disable input caching to avoid collectives empty debug values - module.apply(lambda m: _override_inp_caching_mode(m, enabled=False)) + module.apply(lambda m: _override_act_caching_mode(m, enabled=False)) # perform export pass if export_path is not None: export_target = export_path @@ -130,7 +130,7 @@ def export_onnx( torch.onnx.export(module, args, export_target, **onnx_export_kwargs) # restore the model to previous properties - module.apply(lambda m: _restore_inp_caching_mode(m)) + module.apply(lambda m: _restore_act_caching_mode(m)) cls.set_export_mode(module, enabled=False) module.train(training_state) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 73a08727b..8ba6d2912 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -236,7 +236,7 @@ def update_batch(self, module, input, current_layer): raise StopFwdException def single_layer_update(self): - assert not self.layer.weight_quant_requires_quant_input, "Error: GPFQ does not support weight quantizers that require quantized inputs." + assert not self.layer.weight_quant.requires_quant_input, "Error: GPFQ does not support weight quantizers that require quantized inputs." weight = self.layer.weight.data dev = weight.device dtype = weight.dtype @@ -360,7 +360,7 @@ def single_layer_update(self): input_is_signed = self.quant_input.signed T = get_upper_bound_on_l1_norm( torch.tensor(self.accumulator_bit_width), input_bit_width, input_is_signed) - s = self.layer.quant_weight_scale() + s = self.layer.weight_quant.scale() if s.ndim > 1: s = s.view(self.groups, -1) # [Groups, OC/Groups] diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 158e93628..31d31433b 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -194,7 +194,7 @@ def update_batch(self, module, input, current_layer): raise StopFwdException def single_layer_update(self, percdamp=.01): - assert not self.layer.weight_quant_requires_quant_input, "Error: GPTQ does not support weight quantizers that require quantized inputs." + assert not self.layer.weight_quant.requires_quant_input, "Error: GPTQ does not support weight quantizers that require quantized inputs." if hasattr(self.layer, 'allocate_params'): self.layer.allocate_params(self.layer) weight = self.layer.weight.data diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index b1b94b5da..c45e833d5 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -237,10 +237,10 @@ def align_input_quant( # If it is a QuantIdentity already, simply modify tensor_quant or the scaling implementations # based on whether we need to align the sign or not if isinstance(module, qnn.QuantIdentity): - if align_sign or module.is_quant_act_signed == shared_quant_identity.is_quant_act_signed: + if align_sign or module.input_quant.is_signed == shared_quant_identity.input_quant.is_signed: return shared_quant_identity else: - assert not module.is_quant_act_signed and shared_quant_identity.is_quant_act_signed + assert not module.input_quant.is_signed and shared_quant_identity.input_quant.is_signed quant_module_class, quant_module_kwargs = quant_identity_map['unsigned'] return ( quant_module_class, diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index e76ba2dea..42696efac 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -80,8 +80,8 @@ def are_inputs_unsigned(model, node, is_unsigned_list, quant_act_map, unsigned_a elif isinstance(inp_module, tuple(SIGN_PRESERVING_MODULES)): are_inputs_unsigned( model, inp_node, is_unsigned_list, quant_act_map, unsigned_act_tuple) - elif hasattr(inp_module, 'is_quant_act_signed'): - is_unsigned_list.append(not inp_module.is_quant_act_signed) + elif hasattr(inp_module, 'input_quant'): + is_unsigned_list.append(not inp_module.input_quant.is_signed) else: is_unsigned_list.append(False) elif inp_node.op == 'call_function': diff --git a/src/brevitas/nn/mixin/act.py b/src/brevitas/nn/mixin/act.py index d00b6f35e..eaa94c929 100644 --- a/src/brevitas/nn/mixin/act.py +++ b/src/brevitas/nn/mixin/act.py @@ -4,13 +4,11 @@ from abc import ABCMeta from abc import abstractmethod from typing import Optional, Type, Union -from warnings import warn from torch.nn import Module from brevitas.inject import ExtendedInjector from brevitas.inject import Injector -from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector from brevitas.proxy.runtime_quant import ActQuantProxyProtocol from brevitas.quant import NoneActQuant @@ -34,31 +32,6 @@ def __init__(self, input_quant: Optional[ActQuantType], **kwargs): input_passthrough_act=True, **kwargs) - @property - def is_input_quant_enabled(self): - return self.input_quant.is_quant_enabled - - @property - def is_quant_input_narrow_range(self): # TODO make abstract once narrow range can be cached - return self.input_quant.is_narrow_range - - @property - @abstractmethod - def is_quant_input_signed(self): - pass - - @abstractmethod - def quant_input_scale(self): - pass - - @abstractmethod - def quant_input_zero_point(self): - pass - - @abstractmethod - def quant_input_bit_width(self): - pass - class QuantOutputMixin(QuantProxyMixin): __metaclass__ = ABCMeta @@ -75,31 +48,6 @@ def __init__(self, output_quant: Optional[ActQuantType], **kwargs): output_passthrough_act=True, **kwargs) - @property - def is_output_quant_enabled(self): - return self.output_quant.is_quant_enabled - - @property - def is_quant_output_narrow_range(self): # TODO make abstract once narrow range can be cached - return self.output_quant.is_narrow_range - - @property - @abstractmethod - def is_quant_output_signed(self): - pass - - @abstractmethod - def quant_output_scale(self): - pass - - @abstractmethod - def quant_output_zero_point(self): - pass - - @abstractmethod - def quant_output_bit_width(self): - pass - class QuantNonLinearActMixin(QuantProxyMixin): __metaclass__ = ABCMeta @@ -124,28 +72,3 @@ def __init__( none_quant_injector=NoneActQuant, **prefixed_kwargs, **kwargs) - - @property - def is_act_quant_enabled(self): - return self.act_quant.is_quant_enabled - - @property - def is_quant_act_narrow_range(self): # TODO make abstract once narrow range can be cached - return self.act_quant.is_narrow_range - - @property - @abstractmethod - def is_quant_act_signed(self): - pass - - @abstractmethod - def quant_act_scale(self): - pass - - @abstractmethod - def quant_act_zero_point(self): - pass - - @abstractmethod - def quant_act_bit_width(self): - pass diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 55a5c8150..8327bc156 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -24,32 +24,6 @@ from .utils import filter_kwargs -class _CachedIO: - - def __init__(self, quant_tensor: QuantTensor, metadata_only: bool): - self.shape = quant_tensor.value.shape - if metadata_only: - self.quant_tensor = quant_tensor.set(value=None) - else: - self.quant_tensor = quant_tensor - - @property - def scale(self): - return self.quant_tensor.scale - - @property - def zero_point(self): - return self.quant_tensor.zero_point - - @property - def bit_width(self): - return self.quant_tensor.bit_width - - @property - def signed(self): - return self.quant_tensor.signed - - class QuantProxyMixin(object): __metaclass__ = ABCMeta @@ -82,79 +56,19 @@ def __init__( class QuantLayerMixin(ExportMixin): __metaclass__ = ABCMeta - def __init__( - self, - return_quant_tensor: bool, - cache_inference_quant_inp: bool = False, - cache_inference_quant_out: bool = False, - cache_quant_io_metadata_only: bool = True): + def __init__(self, return_quant_tensor: bool): ExportMixin.__init__(self) self.accept_quant_tensor = True self.return_quant_tensor = return_quant_tensor - self.cache_inference_quant_inp = cache_inference_quant_inp - self.cache_inference_quant_out = cache_inference_quant_out - self.cache_quant_io_metadata_only = cache_quant_io_metadata_only - self._cached_inp = None - self._cached_out = None @property @abstractmethod def channelwise_separable(self) -> bool: pass - @property - def is_quant_input_signed(self) -> Optional[bool]: # tri-valued logic output - if self._cached_inp is not None: - return self._cached_inp.signed - else: - return None - def _set_global_is_quant_layer(self, value): config._IS_INSIDE_QUANT_LAYER = value - def quant_input_scale(self): - if self._cached_inp is not None: - return self._cached_inp.scale - else: - return None - - def quant_input_zero_point(self): - if self._cached_inp is not None: - return self._cached_inp.zero_point - else: - return None - - def quant_input_bit_width(self): - if self._cached_inp is not None: - return self._cached_inp.bit_width - else: - return None - - @property - def is_quant_output_signed(self) -> Optional[bool]: # tri-valued logic output - if self._cached_out is not None: - return self._cached_out.signed - else: - return None - - def quant_output_scale(self): - if self._cached_out is not None: - return self._cached_out.scale - else: - return None - - def quant_output_zero_point(self): - if self._cached_out is not None: - return self._cached_out.zero_point - else: - return None - - def quant_output_bit_width(self): - if self._cached_out is not None: - return self._cached_out.bit_width - else: - return None - def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: self._set_global_is_quant_layer(True) # Hack to recognize a QuantTensor that has decayed to a tuple @@ -162,11 +76,6 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe if (torch._C._get_tracing_state() is not None and isinstance(inp, tuple) and len(inp) == len(QuantTensor._fields) and all([isinstance(t, Tensor) for t in inp])): inp = QuantTensor(*inp) - if isinstance(inp, QuantTensor): - # don't cache values during export pass - if not self.training and not self._export_mode and self.cache_inference_quant_inp: - cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) - self._cached_inp = cached_inp if not torch._C._get_tracing_state(): if isinstance(inp, QuantTensor): inp = inp.set(value=inp.value.rename(None)) @@ -175,9 +84,6 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe return inp def pack_output(self, quant_output: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: - if not self.training and self.cache_inference_quant_out and isinstance(quant_output, - QuantTensor): - self._cached_out = _CachedIO(quant_output.detach(), self.cache_quant_io_metadata_only) self._set_global_is_quant_layer(False) if self.return_quant_tensor: assert isinstance(quant_output, QuantTensor) diff --git a/src/brevitas/nn/mixin/parameter.py b/src/brevitas/nn/mixin/parameter.py index a752c35ec..11225a039 100644 --- a/src/brevitas/nn/mixin/parameter.py +++ b/src/brevitas/nn/mixin/parameter.py @@ -4,13 +4,10 @@ from abc import ABCMeta from abc import abstractmethod from typing import List, Optional, Tuple, Type, Union -from warnings import warn from brevitas.inject import ExtendedInjector from brevitas.inject import Injector -from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector from brevitas.proxy.parameter_quant import BiasQuantProxyProtocol -from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector from brevitas.proxy.parameter_quant import WeightQuantProxyProtocol from brevitas.quant import NoneBiasQuant from brevitas.quant import NoneWeightQuant @@ -41,22 +38,6 @@ def __init__(self, weight_quant: Optional[WeightQuantType], **kwargs): def output_channel_dim(self) -> int: pass - @property - def is_weight_quant_enabled(self): - return self.weight_quant.is_quant_enabled - - @property - def is_quant_weight_narrow_range(self): - return self.weight_quant.is_narrow_range - - @property - def is_quant_weight_signed(self): - return self.weight_quant.is_signed - - @property - def weight_quant_requires_quant_input(self): - return self.weight_quant.requires_quant_input - def quant_weight( self, quant_input: Optional[QuantTensor] = None, @@ -84,21 +65,8 @@ def quant_weight( slice(*s) if s is not None else slice(s) for s in subtensor_slice_list) else: weight_slice_tuple = slice(None) - if self.weight_quant_requires_quant_input: - if self.is_weight_quant_enabled: - if quant_input is None: - input_bit_width = self.quant_input_bit_width() - input_is_signed = self.is_quant_input_signed - else: - input_bit_width = quant_input.bit_width - input_is_signed = quant_input.signed - assert input_bit_width is not None, "Input bit-width needs to be specified." - assert input_is_signed is not None, "Input sign needs to be specified." - else: - input_bit_width = None - input_is_signed = None - out = self.weight_quant( - weights_to_quantize[weight_slice_tuple], input_bit_width, input_is_signed) + if self.weight_quant.requires_quant_input: + out = self.weight_quant(weights_to_quantize[weight_slice_tuple], quant_input) else: out = self.weight_quant(weights_to_quantize[weight_slice_tuple]) if subtensor_slice_list is not None: @@ -109,21 +77,6 @@ def quant_weight( m.subtensor_slice_list = [None] return out - def int_weight(self, float_datatype=False): - return self.quant_weight().int(float_datatype) - - def quant_weight_scale(self): - scale = self.quant_weight().scale - return scale - - def quant_weight_zero_point(self): - scale = self.quant_weight().zero_point - return scale - - def quant_weight_bit_width(self): - bit_width = self.quant_weight().bit_width - return bit_width - def register_parameter(self, name, value): super(QuantWeightMixin, self).register_parameter(name, value) if hasattr(self, 'weight_quant') and name == 'weight': @@ -147,85 +100,13 @@ def __init__( kwargs_prefix='bias_', proxy_prefix='bias_', **kwargs) - self.cache_inference_quant_bias = cache_inference_bias - self._cached_bias = None - - @property - def is_bias_quant_enabled(self): - return self.bias_quant.is_quant_enabled - - @property - def is_quant_bias_narrow_range(self): - if self.bias is None: - return None - return self.bias_quant.is_narrow_range - - @property - def is_quant_bias_signed(self): - if self.bias is None or not self.is_bias_quant_enabled: - return None - return self.bias_quant.is_signed - - def int_bias(self, float_datatype=False): - if self.bias is None or not self.is_bias_quant_enabled: - return None - quant_bias = self.quant_bias() - return quant_bias.int(float_datatype=float_datatype) def quant_bias(self): if self.bias is None: return None - scale = self.quant_bias_scale() - quant_bias = self.bias_quant(self.bias, scale) + quant_bias = self.bias_quant(self.bias) return quant_bias - def quant_bias_scale(self): - if self.bias is None or not self.is_bias_quant_enabled: - return None - if not self.bias_quant.requires_input_scale: - return self.bias_quant(self.bias).scale - else: - if self._cached_bias is None: - raise RuntimeError( - "No quant bias cache found, set cache_inference_quant_bias=True and run an " - "inference pass first") - if self.training: - warn("Cached quant bias scale is being used in training mode.") - return self._cached_bias.scale - - def quant_bias_zero_point(self): - if self.bias is None: - return None - - if not self.bias_quant.requires_input_scale: - bias_quant = self.bias_quant(self.bias) - if isinstance(bias_quant, QuantTensor): - return bias_quant.zero_point - else: - return None - else: - if self._cached_bias is None: - raise RuntimeError( - "No quant bias cache found, set cache_inference_quant_bias=True and run an " - "inference pass first") - if self.training: - warn("Cached quant bias zero-point is being used in training mode.") - return self._cached_bias.bit_width - - def quant_bias_bit_width(self): - if self.bias is None or not self.is_bias_quant_enabled: - return None - if not self.bias_quant.requires_input_scale: - return self.bias_quant(self.bias).bit_width - else: - if self._cached_bias is None: - raise RuntimeError( - "No quant bias cache found, set cache_inference_quant_bias=True and run an " - "inference pass first") - if self.training: - warn("Cached quant bias bit-width is being used in training mode.") - return self._cached_bias.bit_width - def register_parameter(self, name, value): super(QuantBiasMixin, self).register_parameter(name, value) if hasattr(self, 'bias_quant') and name == 'bias': diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 5cf29602f..8cd2e10a7 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -3,18 +3,16 @@ from abc import ABCMeta from abc import abstractmethod -from typing import Callable, Optional, Type, Union +from typing import Optional, Type, Union import torch from torch import Tensor from torch.nn import Module -from torch.nn import Parameter from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import QuantTensor from .mixin import * -from .mixin.base import _CachedIO from .utils import compute_channel_view_shape from .utils import merge_bn from .utils import rename_state_dict_by_prefix @@ -42,94 +40,7 @@ def channelwise_separable(self) -> bool: @property def requires_export_handler(self): - return self.is_input_quant_enabled or self.is_act_quant_enabled - - @property - def is_quant_input_signed(self) -> Optional[bool]: # tri-valued logic output - if self.is_input_quant_enabled: - return self.input_quant.is_signed - elif self._cached_inp is not None: - return self._cached_inp.signed - else: - return None - - @property - def is_quant_act_signed(self) -> Optional[bool]: # tri-valued logic output - if self.is_act_quant_enabled: - return self.act_quant.is_signed - elif self._cached_out is not None: - return self._cached_out.signed - else: - return None - - @property - def is_output_quant_enabled(self): - return self.is_act_quant_enabled - - @property - def is_quant_output_narrow_range(self): - return self.is_quant_act_narrow_range - - @property - def is_quant_output_signed(self): # overrides from QuantLayerMixin - return self.is_quant_act_signed - - def quant_input_scale(self): - if self.is_input_quant_enabled: - return self.input_quant.scale() - elif self._cached_inp is not None: - return self._cached_inp.scale - else: - return None - - def quant_act_scale(self): - if self.is_act_quant_enabled: - return self.act_quant.scale() - elif self._cached_out is not None: - return self._cached_out.scale - else: - return None - - def quant_output_scale(self): # overrides from QuantLayerMixin - return self.quant_act_scale() - - def quant_input_zero_point(self): - if self.is_input_quant_enabled: - return self.input_quant.zero_point() - elif self._cached_inp is not None: - return self._cached_inp.zero_point - else: - return None - - def quant_act_zero_point(self): - if self.is_act_quant_enabled: - return self.act_quant.zero_point() - elif self._cached_out is not None: - return self._cached_out.zero_point - else: - return None - - def quant_output_zero_point(self): # overrides from QuantLayerMixin - return self.quant_act_zero_point() - - def quant_input_bit_width(self): - if self.is_input_quant_enabled: - return self.input_quant.bit_width() - elif self._cached_inp is not None: - return self._cached_inp.bit_width - else: - return None - - def quant_act_bit_width(self): - if self.is_act_quant_enabled: - return self.act_quant.bit_width() - elif self._cached_out is not None: - return self._cached_out.bit_width - else: - return None - - def quant_output_bit_width(self): # overrides from QuantLayerMixin - return self.quant_act_bit_width() + return self.input_quant.is_quant_enabled or self.act_quant.is_quant_enabled def forward(self, input: Union[Tensor, QuantTensor]): input = self.unpack_input(input) @@ -167,82 +78,16 @@ def __init__( QuantOutputMixin.__init__(self, output_quant, **kwargs) # we have to account for quantization being enabled through kwargs if tie_input_output_quant: - if self.is_input_quant_enabled and self.is_output_quant_enabled: + if self.input_quant.is_quant_enabled and self.act_quant.is_quant_enabled: raise RuntimeError("Enable only input or output quant with tie_input_output=True") - if self.is_input_quant_enabled: + if self.input_quant.is_quant_enabled: self.output_quant = self.input_quant - if self.is_output_quant_enabled: + if self.act_quant.is_quant_enabled: self.input_quant = self.output_quant @property def requires_export_handler(self): - return self.is_input_quant_enabled or self.is_output_quant_enabled - - @property - def is_quant_input_signed(self) -> Optional[bool]: # tri-valued logic output - if self.is_input_quant_enabled: - return self.input_quant.is_signed - elif self._cached_inp is not None: - return self._cached_inp.signed - else: - return None - - @property - def is_quant_output_signed(self) -> Optional[bool]: # tri-valued logic output: - if self.is_output_quant_enabled: - return self.output_quant.is_signed - elif self._cached_out is not None: - return self._cached_out.signed - else: - return None - - def quant_input_scale(self): - if self.is_input_quant_enabled: - return self.input_quant.scale() - elif self._cached_inp is not None: - return self._cached_inp.scale - else: - return None - - def quant_output_scale(self): - if self.is_output_quant_enabled: - return self.output_quant.scale() - elif self._cached_out is not None: - return self._cached_out.scale - else: - return None - - def quant_input_zero_point(self): - if self.is_input_quant_enabled: - return self.input_quant.zero_point() - elif self._cached_inp is not None: - return self._cached_inp.zero_point - else: - return None - - def quant_output_zero_point(self): - if self.is_output_quant_enabled: - return self.output_quant.zero_point() - elif self._cached_out is not None: - return self._cached_out.zero_point - else: - return None - - def quant_input_bit_width(self): - if self.is_input_quant_enabled: - return self.input_quant.bit_width() - elif self._cached_inp is not None: - return self._cached_inp.bit_width - else: - return None - - def quant_output_bit_width(self): - if self.is_output_quant_enabled: - return self.output_quant.bit_width() - elif self._cached_out is not None: - return self._cached_out.bit_width - else: - return None + return self.input_quant.is_quant_enabled or self.act_quant.is_quant_enabled class QuantWeightBiasInputOutputLayer(QuantBiasMixin, QuantWeightMixin, QuantInputOutputLayer): @@ -284,8 +129,8 @@ def quant_output_scale_impl( @property def requires_export_handler(self): return ( - self.is_input_quant_enabled or self.is_weight_quant_enabled or - self.is_bias_quant_enabled or self.is_output_quant_enabled) + self.input_quant.is_quant_enabled or self.weight_quant.is_quant_enabled or + self.bias_quant.is_quant_enabled or self.output_quant.is_quant_enabled) @property def per_elem_ops(self): # optional, so concrete impl + error if not overridden @@ -309,12 +154,13 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe return out quant_input = self.input_quant(inp) + quant_weight = self.quant_weight(quant_input) compute_output_quant_tensor = isinstance(quant_input, QuantTensor) and isinstance( quant_weight, QuantTensor) if not (compute_output_quant_tensor or - self.is_output_quant_enabled) and self.return_quant_tensor: + self.output_quant.is_quant_enabled) and self.return_quant_tensor: raise RuntimeError("QuantLayer is not correctly configured") if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): @@ -324,10 +170,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe if self.bias is not None: quant_bias = self.bias_quant(self.bias, output_scale) - if not self.training and self.cache_inference_quant_bias and isinstance(quant_bias, - QuantTensor): - self._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False) output_tensor = self.inner_forward_impl( _unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), @@ -351,7 +194,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe output_tensor = self.inner_forward_impl( _unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), None) - if not self.is_output_quant_enabled and self.return_quant_tensor: + if not self.output_quant.is_quant_enabled and self.return_quant_tensor: if compute_output_quant_tensor: if (quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any(): raise RuntimeError( diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 79ad7e9ec..893ff5e30 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -4,15 +4,19 @@ from abc import ABCMeta from abc import abstractmethod from typing import Optional, Union +from warnings import warn import torch from torch import Tensor +import torch.nn as nn from typing_extensions import Protocol from typing_extensions import runtime_checkable from brevitas import config from brevitas.function import max_int +from brevitas.inject import BaseInjector as Injector from brevitas.quant_tensor import QuantTensor +from brevitas.utils.quant_utils import _CachedIO from .quant_proxy import QuantProxyFromInjector from .quant_proxy import QuantProxyProtocol @@ -135,28 +139,42 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: class DecoupledWeightQuantWithInputProxyFromInjector(DecoupledWeightQuantProxyFromInjector): + def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: + super().__init__(quant_layer, quant_injector) + # Necessary for export + self._cached_act = None + self.cache_inference_quant_act = False + self.cache_quant_io_metadata_only = True + @property def requires_quant_input(self): return True - def scale(self): - raise NotImplementedError - - def zero_point(self): - raise NotImplementedError - - def bit_width(self): - return self.tensor_quant.msb_clamp_bit_width_impl() - def pre_scale(self): raise NotImplementedError def pre_zero_point(self): raise NotImplementedError - def forward(self, x: torch.Tensor, input_bit_width: torch.Tensor, - input_is_signed: bool) -> Union[Tensor, QuantTensor]: + def forward( + self, + x: torch.Tensor, + quant_input: Optional[Union[Tensor, QuantTensor]] = None) -> Union[Tensor, QuantTensor]: + if isinstance(quant_input, + QuantTensor) and not self.training and self.cache_inference_quant_act: + cached_inp = _CachedIO(quant_input.detach(), self.cache_quant_io_metadata_only) + self._cached_act = cached_inp + if self.is_quant_enabled: + if quant_input is None: + assert self._cached_act is not None, "No cached quant input found. Enable caching and perform a forward pass" + quant_input = self._cached_act + else: + assert isinstance(quant_input, QuantTensor), "Input must be quantized" + + input_bit_width = quant_input.bit_width + input_is_signed = quant_input.signed + impl = self.export_handler if self.export_mode else self.tensor_quant out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed) return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) @@ -166,6 +184,11 @@ def forward(self, x: torch.Tensor, input_bit_width: torch.Tensor, class BiasQuantProxyFromInjector(ParameterQuantProxyFromInjector, BiasQuantProxyProtocol): + def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: + super().__init__(quant_layer, quant_injector) + self._cached_bias = None + self.cache_inference_quant_bias = False + @property def tracked_parameter_list(self): return [m.bias for m in self.tracked_module_list if m.bias is not None] @@ -177,9 +200,22 @@ def requires_input_scale(self) -> bool: else: return False + def get_cached(self, attr): + if self._cached_bias is None: + warn( + "No quant bias cache found, set cache_inference_quant_bias=True and run an " + "inference pass first") + return None + if self.training: + warn("Cached quant bias scale is being used in training mode.") + return getattr(self._cached_bias, attr) + def scale(self): - if self.requires_input_scale or not self.is_quant_enabled: + if not self.is_quant_enabled: return None + if self.requires_input_scale: + cache = self.get_cached('scale') + return cache zhs = self._zero_hw_sentinel() scale = self.__call__(self.tracked_parameter_list[0], zhs).scale return scale @@ -201,10 +237,13 @@ def bit_width(self): def forward(self, x: Tensor, input_scale: Optional[Tensor] = None) -> Union[Tensor, QuantTensor]: + out = x if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant if self.requires_input_scale and input_scale is None: - raise RuntimeError("Input scale required") + input_scale = self.scale() + if input_scale is None: + raise RuntimeError("Input scale required") if self.requires_input_scale: input_scale = input_scale.view(-1) @@ -212,6 +251,10 @@ def forward(self, else: out, out_scale, out_zp, out_bit_width = impl(x) - return QuantTensor(out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) + out = QuantTensor(out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) else: - return x + out = x + if isinstance(out, QuantTensor) and not self.training and self.cache_inference_quant_bias: + cached_bias = _CachedIO(out.detach(), metadata_only=False) + self._cached_bias = cached_bias + return out diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index fe7b29daf..4dd8417a9 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -11,6 +11,7 @@ import brevitas from brevitas.quant_tensor import QuantTensor +from brevitas.utils.quant_utils import _CachedIO from .quant_proxy import QuantProxyFromInjector from .quant_proxy import QuantProxyProtocol @@ -89,11 +90,20 @@ def __init__(self, quant_layer, quant_injector): QuantProxyFromInjector.__init__(self, quant_layer, quant_injector) ActQuantProxyProtocol.__init__(self) self.is_passthrough_act = _is_passthrough_act(quant_injector) + self._cached_act = None + self.cache_inference_quant_act = False + self.cache_quant_io_metadata_only = True @property def is_quant_enabled(self): return self._is_quant_enabled and not self.disable_quant + @property + def is_signed(self): + if self._cached_act is not None: + return self._cached_act.signed + return super().is_signed + @is_quant_enabled.setter def is_quant_enabled(self, is_quant_enabled): self._is_quant_enabled = is_quant_enabled @@ -118,36 +128,46 @@ def init_tensor_quant(self): self.fused_activation_quant_proxy = None def scale(self, force_eval=True): - if not self.is_quant_enabled: + if self.is_quant_enabled: + current_status = self.training + if force_eval: + self.eval() + out = self.__call__(self._zero_hw_sentinel()) + self.train(current_status) + return out.scale + elif self._cached_act is not None: + return self._cached_act.scale + elif self._cached_act is None: return None - current_status = self.training - if force_eval: - self.eval() - scale = self.__call__(self._zero_hw_sentinel()).scale - self.train(current_status) - return scale def zero_point(self, force_eval=True): - if not self.is_quant_enabled: + if self.is_quant_enabled: + current_status = self.training + if force_eval: + self.eval() + out = self.__call__(self._zero_hw_sentinel()) + self.train(current_status) + return out.zero_point + elif self._cached_act is not None: + return self._cached_act.zero_point + elif self._cached_act is None: return None - current_status = self.training - if force_eval: - self.eval() - zero_point = self.__call__(self._zero_hw_sentinel()).zero_point - self.train(current_status) - return zero_point def bit_width(self, force_eval=True): - if not self.is_quant_enabled: + if self.is_quant_enabled: + current_status = self.training + if force_eval: + self.eval() + out = self.__call__(self._zero_hw_sentinel()) + self.train(current_status) + return out.bit_width + elif self._cached_act is not None: + return self._cached_act.bit_width + elif self._cached_act is None: return None - current_status = self.training - if force_eval: - self.eval() - bit_width = self.__call__(self._zero_hw_sentinel()).bit_width - self.train(current_status) - return bit_width def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: + out = x if self.fused_activation_quant_proxy is not None: y = x if isinstance(y, QuantTensor): @@ -163,22 +183,26 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, # otherwise return a simple Tensor if isinstance(y, tuple) and not any(map(lambda f: f is None, y)): - return QuantTensor(*y, signed=self.is_signed, training=self.training) + out = QuantTensor(*y, signed=self.is_signed, training=self.training) elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant if isinstance(y, tuple): y = y[0] if isinstance(x, QuantTensor): - return QuantTensor( + out = QuantTensor( y, x.scale, x.zero_point, x.bit_width, x.signed, self.training) else: - return y + out = y else: if isinstance(y, tuple): y = y[0] - return y + out = y else: # If fused activation quant proxy is not enabled, return the input - return x + out = x + if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor): + cached_out = _CachedIO(out.detach(), self.cache_quant_io_metadata_only) + self._cached_act = cached_out + return out class DynamicActQuantProxyFromInjector(ActQuantProxyFromInjector): diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index 5afed36c0..5b8bf648f 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -5,6 +5,33 @@ from brevitas.core.function_wrapper import * from brevitas.core.quant import RescalingIntQuant from brevitas.inject.enum import FloatToIntImplType +from brevitas.quant_tensor import QuantTensor + + +class _CachedIO: + + def __init__(self, quant_tensor: QuantTensor, metadata_only: bool): + self.shape = quant_tensor.value.shape + if metadata_only: + self.quant_tensor = quant_tensor.set(value=None) + else: + self.quant_tensor = quant_tensor + + @property + def scale(self): + return self.quant_tensor.scale + + @property + def zero_point(self): + return self.quant_tensor.zero_point + + @property + def bit_width(self): + return self.quant_tensor.bit_width + + @property + def signed(self): + return self.quant_tensor.signed def has_learned_weight_bit_width(module): diff --git a/src/brevitas_examples/super_resolution/utils/evaluate.py b/src/brevitas_examples/super_resolution/utils/evaluate.py index 2eb9e3627..51c33a86c 100644 --- a/src/brevitas_examples/super_resolution/utils/evaluate.py +++ b/src/brevitas_examples/super_resolution/utils/evaluate.py @@ -55,8 +55,8 @@ def _calc_min_acc_bit_width(module: QuantWBIOL) -> Tensor: assert isinstance(module, qnn.QuantConv2d), "Error: function only support QuantConv2d." # bit-width and sign need to come from the quant tensor of the preceding layer if no io_quant - input_bit_width = module.quant_input_bit_width() - input_is_signed = float(module.is_quant_input_signed) + input_bit_width = module.input_quant.bit_width() + input_is_signed = float(module.input_quant.is_signed) # the tensor quantizer requires a QuantTensor with specified bit-width and sign quant_weight = module.quant_weight() diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 0b6303a8b..1775c68d6 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -59,7 +59,7 @@ def forward(self, x): model(inp) expected_scale = reference_implementation_scale_factors_po2(inp) - scale = model.act.quant_act_scale() + scale = model.act.act_quant.scale() assert torch.allclose(expected_scale, scale) diff --git a/tests/brevitas/nn/test_linear.py b/tests/brevitas/nn/test_linear.py index 8fd2d6e04..62799281b 100644 --- a/tests/brevitas/nn/test_linear.py +++ b/tests/brevitas/nn/test_linear.py @@ -36,7 +36,7 @@ def test_module_init_scale_impl_type_override(self): in_features=INPUT_FEATURES, bias=True, weight_scaling_impl_type='HE') - assert mod.quant_weight_scale() + assert mod.weight_quant.scale() class TestQuantLinearFwd: diff --git a/tests/brevitas/nn/test_wbiol.py b/tests/brevitas/nn/test_wbiol.py index 58b9a86ca..941fb03b0 100644 --- a/tests/brevitas/nn/test_wbiol.py +++ b/tests/brevitas/nn/test_wbiol.py @@ -76,23 +76,23 @@ def default_weight_tensor_quant(default_wbiol_layer): def test_default_wbiol_input_quant_enabled(default_wbiol_layer: QuantWBIOL): - assert not default_wbiol_layer.is_input_quant_enabled + assert not default_wbiol_layer.input_quant.is_quant_enabled def test_default_wbiol_output_quant_enabled(default_wbiol_layer: QuantWBIOL): - assert not default_wbiol_layer.is_output_quant_enabled + assert not default_wbiol_layer.output_quant.is_quant_enabled def test_default_wbiol_bias_quant_enabled(default_wbiol_layer: QuantWBIOL): - assert not default_wbiol_layer.is_bias_quant_enabled + assert not default_wbiol_layer.bias_quant.is_quant_enabled def test_default_wbiol_weight_quant_enabled(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.is_weight_quant_enabled + assert default_wbiol_layer.weight_quant.is_quant_enabled def test_default_wbiol_weight_bit_width_enabled(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.quant_weight_bit_width() == torch.tensor(8.) + assert default_wbiol_layer.weight_quant.bit_width() == torch.tensor(8.) def test_default_wbiol_return_quant(default_wbiol_layer: QuantWBIOL): @@ -100,63 +100,63 @@ def test_default_wbiol_return_quant(default_wbiol_layer: QuantWBIOL): def test_default_wbiol_quant_bias_signed(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.is_quant_bias_signed is None + assert default_wbiol_layer.bias_quant.is_signed is None def test_default_wbiol_quant_weight_signed(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.is_quant_weight_signed + assert default_wbiol_layer.weight_quant.is_signed def test_default_wbiol_quant_bias_narrow_range(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.is_quant_bias_narrow_range is None + assert default_wbiol_layer.bias_quant.is_narrow_range is None def test_default_wbiol_quant_weight_narrow_range(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.is_quant_weight_narrow_range + assert default_wbiol_layer.weight_quant.is_narrow_range def test_default_wbiol_quant_input_signed(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.is_quant_input_signed is None + assert default_wbiol_layer.input_quant.is_signed is None def test_default_wbiol_quant_output_signed(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.is_quant_output_signed is None + assert default_wbiol_layer.output_quant.is_signed is None def test_default_wbiol_quant_input_narrow_range(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.is_quant_input_narrow_range is None + assert default_wbiol_layer.input_quant.is_narrow_range is None def test_default_wbiol_quant_output_narrow_range(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.is_quant_output_narrow_range is None + assert default_wbiol_layer.output_quant.is_narrow_range is None def test_default_wbiol_quant_input_zero_point(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.quant_input_zero_point() is None + assert default_wbiol_layer.input_quant.zero_point() is None def test_default_wbiol_quant_output_zero_point(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.quant_output_zero_point() is None + assert default_wbiol_layer.output_quant.zero_point() is None def test_default_wbiol_quant_weight_zero_point(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.quant_weight_zero_point() == torch.tensor(0.) + assert default_wbiol_layer.weight_quant.zero_point() == torch.tensor(0.) def test_default_wbiol_quant_bias_zero_point(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.quant_bias_zero_point() is None + assert default_wbiol_layer.bias_quant.zero_point() is None def test_default_wbiol_quant_input_scale(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.quant_input_scale() is None + assert default_wbiol_layer.input_quant.scale() is None def test_default_wbiol_quant_output_scale(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.quant_output_scale() is None + assert default_wbiol_layer.output_quant.scale() is None def test_default_wbiol_quant_bias_scale(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.quant_bias_scale() is None + assert default_wbiol_layer.bias_quant.scale() is None def test_default_wbiol_weight_quant_proxy(default_wbiol_layer: QuantWBIOL): diff --git a/tests/brevitas/proxy/test_act_scaling.py b/tests/brevitas/proxy/test_act_scaling.py index 17cee628c..3b4537610 100644 --- a/tests/brevitas/proxy/test_act_scaling.py +++ b/tests/brevitas/proxy/test_act_scaling.py @@ -41,7 +41,7 @@ def test_scaling_stats_to_parameter(self): stats_act.eval() param_act.eval() - assert (torch.allclose(stats_act.quant_act_scale(), param_act.quant_act_scale())) + assert (torch.allclose(stats_act.act_quant.scale(), param_act.act_quant.scale())) def test_scaling_parameter_grad(self): stats_act = QuantReLU( diff --git a/tests/brevitas/proxy/test_proxy.py b/tests/brevitas/proxy/test_proxy.py index 08b525a71..28c3eed9e 100644 --- a/tests/brevitas/proxy/test_proxy.py +++ b/tests/brevitas/proxy/test_proxy.py @@ -45,9 +45,9 @@ def test_weight_decoupled_proxy(self): def test_weight_decoupled_with_input_proxy(self): model = QuantLinear(10, 5, weight_quant=Int8AccumulatorAwareWeightQuant) - with pytest.raises(NotImplementedError): + with pytest.raises(AssertionError): model.weight_quant.scale() - with pytest.raises(NotImplementedError): + with pytest.raises(AssertionError): model.weight_quant.zero_point() with pytest.raises(NotImplementedError): diff --git a/tests/brevitas/proxy/test_weight_scaling.py b/tests/brevitas/proxy/test_weight_scaling.py index 49a7f20fe..4649592e2 100644 --- a/tests/brevitas/proxy/test_weight_scaling.py +++ b/tests/brevitas/proxy/test_weight_scaling.py @@ -17,10 +17,10 @@ def test_parameter_from_stats_update(): weight_quant_type='binary', weight_scaling_impl_type='parameter_from_stats') l_max = linear.weight.abs().max() - old_scale = q_linear.quant_weight_scale() + old_scale = q_linear.weight_quant.scale() old_ql_max = q_linear.weight.abs().max() q_linear.load_state_dict(linear.state_dict()) - new_scale = q_linear.quant_weight_scale() + new_scale = q_linear.weight_quant.scale() new_ql_max = q_linear.weight.abs().max() assert old_scale == old_ql_max assert new_scale == l_max @@ -42,10 +42,10 @@ def test_parameter_from_stats_state_dict(): weight_quant_type='binary', weight_scaling_impl_type='parameter', weight_scaling_init=0.001) - q_linear1_old_scale = q_linear1.quant_weight_scale() + q_linear1_old_scale = q_linear1.weight_quant.scale() q_linear1.load_state_dict(q_linear2.state_dict()) - q_linear1_new_scale = q_linear1.quant_weight_scale() - q_linear2_scale = q_linear2.quant_weight_scale() + q_linear1_new_scale = q_linear1.weight_quant.scale() + q_linear2_scale = q_linear2.weight_quant.scale() assert q_linear1_old_scale != q_linear2_scale assert q_linear1_old_scale != q_linear1_new_scale assert q_linear1_new_scale == q_linear2_scale