From f1655b2c68565842449393ba34cca6f182d0fdf6 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 20 Aug 2024 17:53:40 +0200 Subject: [PATCH] Feat: support for groupwise (MX) quantization (#971) --- notebooks/Brevitas_TVMCon2021.ipynb | 170 ++++----- notebooks/minifloat_mx_tutorial.ipynb | 285 +++++++++++++++ src/brevitas/core/function_wrapper/shape.py | 19 + src/brevitas/core/scaling/__init__.py | 2 + src/brevitas/core/scaling/runtime.py | 35 ++ .../export/onnx/standard/qcdq/handler.py | 3 +- src/brevitas/graph/quantize.py | 2 - src/brevitas/inject/enum.py | 9 + src/brevitas/nn/hadamard_classifier.py | 6 +- src/brevitas/nn/mixin/base.py | 5 +- src/brevitas/proxy/float_parameter_quant.py | 24 +- src/brevitas/proxy/float_runtime_quant.py | 61 +++- .../proxy/groupwise_float_parameter_quant.py | 45 +++ .../proxy/groupwise_float_runtime_quant.py | 85 +++++ .../proxy/groupwise_int_parameter_quant.py | 40 +++ .../proxy/groupwise_int_runtime_quant.py | 75 ++++ src/brevitas/quant/base.py | 38 +- src/brevitas/quant/experimental/float.py | 29 +- src/brevitas/quant/experimental/float_base.py | 21 ++ .../quant/experimental/float_quant_fnuz.py | 182 ++++++---- .../quant/experimental/float_quant_ocp.py | 203 +++++++---- .../quant/experimental/mx_quant_ocp.py | 137 ++++++++ src/brevitas/quant/solver/act.py | 7 +- src/brevitas/quant/solver/common.py | 35 +- src/brevitas/quant/solver/parameter.py | 36 +- src/brevitas/quant/solver/weight.py | 13 +- src/brevitas/quant_tensor/__init__.py | 2 + .../quant_tensor/base_quant_tensor.py | 27 ++ .../quant_tensor/float_quant_tensor.py | 21 +- .../quant_tensor/float_torch_handler.py | 6 +- .../groupwise_float_quant_tensor.py | 330 ++++++++++++++++++ .../groupwise_int_quant_tensor.py | 312 +++++++++++++++++ src/brevitas/utils/quant_utils.py | 90 +++++ .../common/generative/quant_blocks.py | 101 ------ .../common/generative/quantize.py | 196 +++++------ .../common/generative/quantizers.py | 104 ++---- src/brevitas_examples/common/parse_utils.py | 2 +- src/brevitas_examples/llm/README.md | 97 ++++- src/brevitas_examples/llm/llm_quant/eval.py | 1 + src/brevitas_examples/llm/main.py | 17 +- .../stable_diffusion/README.md | 33 +- .../stable_diffusion/main.py | 24 +- 42 files changed, 2266 insertions(+), 664 deletions(-) create mode 100644 notebooks/minifloat_mx_tutorial.ipynb create mode 100644 src/brevitas/proxy/groupwise_float_parameter_quant.py create mode 100644 src/brevitas/proxy/groupwise_float_runtime_quant.py create mode 100644 src/brevitas/proxy/groupwise_int_parameter_quant.py create mode 100644 src/brevitas/proxy/groupwise_int_runtime_quant.py create mode 100644 src/brevitas/quant/experimental/mx_quant_ocp.py create mode 100644 src/brevitas/quant_tensor/groupwise_float_quant_tensor.py create mode 100644 src/brevitas/quant_tensor/groupwise_int_quant_tensor.py diff --git a/notebooks/Brevitas_TVMCon2021.ipynb b/notebooks/Brevitas_TVMCon2021.ipynb index fc11303aa..20ce30701 100644 --- a/notebooks/Brevitas_TVMCon2021.ipynb +++ b/notebooks/Brevitas_TVMCon2021.ipynb @@ -122,7 +122,7 @@ " [-0.0140, 0.5607]], requires_grad=True) \n", "\n", "Quantized weight QuantTensor:\n", - " QuantTensor(value=tensor([[-0.0046, 0.3803],\n", + " IntQuantTensor(value=tensor([[-0.0046, 0.3803],\n", " [-0.5820, -0.5224],\n", " [-0.2704, 0.1879],\n", " [-0.0137, 0.5591]], grad_fn=), scale=0.004582525696605444, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True) \n", @@ -208,8 +208,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "/scratch/fabian/miniforge3/envs/torchgpu/lib/python3.11/site-packages/torch/_tensor.py:1362: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1699449183005/work/c10/core/TensorImpl.h:1900.)\n", - " return super().rename(names)\n" + "/home/giuseppe/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages/torch/_tensor.py:1419: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1712608853099/work/c10/core/TensorImpl.h:1921.)\n", + " return super().rename(names)\n", + "/home/giuseppe/Documents/git/dev_brevitas/src/brevitas/nn/quant_linear.py:69: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1712608853099/work/torch/csrc/utils/python_arg_parser.cpp:294.)\n", + " output_tensor = linear(x, quant_weight, quant_bias)\n" ] } ], @@ -250,7 +252,7 @@ "output_type": "stream", "text": [ "Weight QuantTensor:\n", - " QuantTensor(value=tensor([[-0.0078, 0.3828],\n", + " IntQuantTensor(value=tensor([[-0.0078, 0.3828],\n", " [-0.5781, -0.5234],\n", " [-0.2734, 0.1875],\n", " [-0.0156, 0.5625]], grad_fn=), scale=0.0078125, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True)\n", @@ -289,7 +291,7 @@ "output_type": "stream", "text": [ "Weight QuantTensor:\n", - " QuantTensor(value=tensor([[-0.1000, 0.1000],\n", + " IntQuantTensor(value=tensor([[-0.1000, 0.1000],\n", " [-0.1000, -0.1000],\n", " [-0.1000, 0.1000],\n", " [-0.1000, 0.1000]], grad_fn=), scale=0.10000000149011612, zero_point=0.0, bit_width=1.0, signed_t=True, training_t=True)\n" @@ -422,9 +424,9 @@ "output_type": "stream", "text": [ "Quant output:\n", - " QuantTensor(value=tensor([[-0.9109, -0.4609, 0.3135, -0.6523],\n", + " IntQuantTensor(value=tensor([[-0.9109, -0.4609, 0.3135, -0.6523],\n", " [ 1.2089, 0.6524, -0.3752, 0.8697],\n", - " [ 1.3893, 0.2816, -0.9011, 0.9521]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=0.0, bit_width=17.0, signed_t=True, training_t=True)\n" + " [ 1.3893, 0.2816, -0.9011, 0.9521]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=tensor([0.]), bit_width=17.0, signed_t=True, training_t=True)\n" ] } ], @@ -465,14 +467,14 @@ " [-1.0845, -1.3986]]) \n", "\n", "Quant input:\n", - " QuantTensor(value=tensor([[ 1.5490, -0.2894],\n", + " IntQuantTensor(value=tensor([[ 1.5490, -0.2894],\n", " [-2.1788, 0.5617],\n", " [-1.0894, -1.3958]], grad_fn=), scale=0.017021792009472847, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True) \n", "\n", "Quant output:\n", - " QuantTensor(value=tensor([[-0.9109, -0.4609, 0.3135, -0.6523],\n", + " IntQuantTensor(value=tensor([[-0.9109, -0.4609, 0.3135, -0.6523],\n", " [ 1.2089, 0.6524, -0.3752, 0.8697],\n", - " [ 1.3893, 0.2816, -0.9011, 0.9521]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=0.0, bit_width=17.0, signed_t=True, training_t=True)\n" + " [ 1.3893, 0.2816, -0.9011, 0.9521]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=tensor([0.]), bit_width=17.0, signed_t=True, training_t=True)\n" ] } ], @@ -524,7 +526,7 @@ " [-1.0845, -1.3986]]) \n", "\n", "Quant output:\n", - " QuantTensor(value=tensor([[1.5410, 0.0000],\n", + " IntQuantTensor(value=tensor([[1.5410, 0.0000],\n", " [0.0000, 0.5681],\n", " [0.0000, 0.0000]], grad_fn=), scale=0.006043121684342623, zero_point=0.0, bit_width=8.0, signed_t=False, training_t=True)\n" ] @@ -568,11 +570,11 @@ " [-1.0845, -1.3986]]) \n", "\n", "Quant output after QuantIdentity:\n", - " QuantTensor(value=tensor([[ 1.5490, -0.2894],\n", + " IntQuantTensor(value=tensor([[ 1.5490, -0.2894],\n", " [-2.1788, 0.5617],\n", " [-1.0894, -1.3958]], grad_fn=), scale=0.017021792009472847, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True)\n", "Quant output after QuantReLU:\n", - " QuantTensor(value=tensor([[1.5490, 0.0000],\n", + " IntQuantTensor(value=tensor([[1.5490, 0.0000],\n", " [0.0000, 0.5588],\n", " [0.0000, 0.0000]], grad_fn=), scale=0.006074443459510803, zero_point=0.0, bit_width=8.0, signed_t=False, training_t=True)\n" ] @@ -632,10 +634,10 @@ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[13], line 8\u001b[0m\n\u001b[1;32m 5\u001b[0m float_input \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m3\u001b[39m, \u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 6\u001b[0m quant_linear \u001b[38;5;241m=\u001b[39m QuantLinear(\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m4\u001b[39m, bias\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, bias_quant\u001b[38;5;241m=\u001b[39mInt16Bias, return_quant_tensor\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m----> 8\u001b[0m quant_output \u001b[38;5;241m=\u001b[39m \u001b[43mquant_linear\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfloat_input\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_linear.py:66\u001b[0m, in \u001b[0;36mQuantLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Union[Tensor, QuantTensor]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[Tensor, QuantTensor]:\n\u001b[0;32m---> 66\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:173\u001b[0m, in \u001b[0;36mQuantWeightBiasInputOutputLayer.forward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 169\u001b[0m compute_output_quant_tensor \u001b[38;5;241m=\u001b[39m \u001b[38;5;28misinstance\u001b[39m(quant_input, QuantTensor) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\n\u001b[1;32m 170\u001b[0m quant_weight, QuantTensor)\n\u001b[1;32m 171\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (compute_output_quant_tensor \u001b[38;5;129;01mor\u001b[39;00m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mis_output_quant_enabled) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreturn_quant_tensor:\n\u001b[0;32m--> 173\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mQuantLayer is not correctly configured\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 175\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(quant_input, QuantTensor) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(quant_weight, QuantTensor):\n\u001b[1;32m 176\u001b[0m output_bit_width \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmax_acc_bit_width(quant_input\u001b[38;5;241m.\u001b[39mbit_width, quant_weight\u001b[38;5;241m.\u001b[39mbit_width)\n", + "File \u001b[0;32m~/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/Documents/git/dev_brevitas/src/brevitas/nn/quant_linear.py:66\u001b[0m, in \u001b[0;36mQuantLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Union[Tensor, QuantTensor]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[Tensor, QuantTensor]:\n\u001b[0;32m---> 66\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Documents/git/dev_brevitas/src/brevitas/nn/quant_layer.py:152\u001b[0m, in \u001b[0;36mQuantWeightBiasInputOutputLayer.forward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 148\u001b[0m compute_output_quant_tensor \u001b[38;5;241m=\u001b[39m \u001b[38;5;28misinstance\u001b[39m(quant_input, QuantTensor) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\n\u001b[1;32m 149\u001b[0m quant_weight, QuantTensor)\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (compute_output_quant_tensor \u001b[38;5;129;01mor\u001b[39;00m\n\u001b[1;32m 151\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_quant\u001b[38;5;241m.\u001b[39mis_quant_enabled) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreturn_quant_tensor:\n\u001b[0;32m--> 152\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mQuantLayer is not correctly configured\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 154\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbias \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 155\u001b[0m quant_bias \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbias_quant(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbias, quant_input, quant_weight)\n", "\u001b[0;31mRuntimeError\u001b[0m: QuantLayer is not correctly configured" ] } @@ -666,9 +668,9 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[-0.6541, 0.1263, 0.1680, -0.1231],\n", + "IntQuantTensor(value=tensor([[-0.6541, 0.1263, 0.1680, -0.1231],\n", " [ 1.4658, 1.2395, -0.5207, 1.3989],\n", - " [ 1.6461, 0.8687, -1.0466, 1.4813]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(18.), signed_t=tensor(True), training_t=tensor(True))" + " [ 1.6461, 0.8687, -1.0466, 1.4813]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=tensor([0.]), bit_width=tensor(18.), signed_t=tensor(True), training_t=tensor(True))" ] }, "execution_count": 14, @@ -725,15 +727,15 @@ "output_type": "stream", "text": [ "Eval mode add quant inputs:\n", - " QuantTensor(value=tensor([[ 1.5335, -0.2875],\n", + " IntQuantTensor(value=tensor([[ 1.5335, -0.2875],\n", " [-2.0447, 0.5751],\n", " [-1.0863, -1.4057]]), scale=0.015974320471286774, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=False) \n", - " QuantTensor(value=tensor([[ 0.3994, 0.8307],\n", + " IntQuantTensor(value=tensor([[ 0.3994, 0.8307],\n", " [-0.7188, -0.3994],\n", " [-0.5910, 0.1757]]), scale=0.015974320471286774, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=False) \n", "\n", "Eval mode add quant output:\n", - " QuantTensor(value=tensor([[ 1.9329, 0.5431],\n", + " IntQuantTensor(value=tensor([[ 1.9329, 0.5431],\n", " [-2.7636, 0.1757],\n", " [-1.6773, -1.2300]]), scale=0.015974320471286774, zero_point=0.0, bit_width=9.0, signed_t=True, training_t=False)\n" ] @@ -791,7 +793,7 @@ "output_type": "stream", "text": [ "Quant input:\n", - " QuantTensor(value=tensor([[[-1.1218, -1.1580, -0.2533, -0.4343],\n", + " IntQuantTensor(value=tensor([[[-1.1218, -1.1580, -0.2533, -0.4343],\n", " [ 0.8504, 0.6876, -0.3076, -2.1170]],\n", "\n", " [[ 0.4704, -0.1628, 1.4475, 0.2714],\n", @@ -801,7 +803,7 @@ " [ 0.6152, -0.4162, -0.8323, -2.3160]]], grad_fn=), scale=0.018094077706336975, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True) \n", "\n", "Quant output:\n", - " QuantTensor(value=tensor([[[-1.1218, -0.2533],\n", + " IntQuantTensor(value=tensor([[[-1.1218, -0.2533],\n", " [ 0.8504, -0.3076]],\n", "\n", " [[ 0.4704, 1.4475],\n", @@ -844,7 +846,7 @@ "output_type": "stream", "text": [ "Quant input:\n", - " QuantTensor(value=tensor([[[-1.1218, -1.1580, -0.2533, -0.4343],\n", + " IntQuantTensor(value=tensor([[[-1.1218, -1.1580, -0.2533, -0.4343],\n", " [ 0.8504, 0.6876, -0.3076, -2.1170]],\n", "\n", " [[ 0.4704, -0.1628, 1.4475, 0.2714],\n", @@ -868,7 +870,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/tmp/ipykernel_88576/661358273.py:7: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1699449183005/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", + "/tmp/ipykernel_751241/661358273.py:7: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1712608853099/work/torch/csrc/utils/python_arg_parser.cpp:294.)\n", " quant_output = torch.tanh(quant_input)\n" ] } @@ -905,14 +907,14 @@ "output_type": "stream", "text": [ "Eval mode concat quant inputs:\n", - " QuantTensor(value=tensor([[ 1.5335, -0.2875],\n", + " IntQuantTensor(value=tensor([[ 1.5335, -0.2875],\n", " [-2.0447, 0.5751],\n", - " [-1.0863, -1.4057]]), scale=0.015974320471286774, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=False) QuantTensor(value=tensor([[ 0.3994, 0.8307],\n", + " [-1.0863, -1.4057]]), scale=0.015974320471286774, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=False) IntQuantTensor(value=tensor([[ 0.3994, 0.8307],\n", " [-0.7188, -0.3994],\n", " [-0.5910, 0.1757]]), scale=0.015974320471286774, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=False) \n", "\n", "Eval mode concat quant output:\n", - " QuantTensor(value=tensor([[ 1.5335, -0.2875, 0.3994, 0.8307],\n", + " IntQuantTensor(value=tensor([[ 1.5335, -0.2875, 0.3994, 0.8307],\n", " [-2.0447, 0.5751, -0.7188, -0.3994],\n", " [-1.0863, -1.4057, -0.5910, 0.1757]]), scale=0.015974320471286774, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=False)\n" ] @@ -921,9 +923,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/tmp/ipykernel_88576/3932472163.py:8: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1699449183005/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", + "/tmp/ipykernel_751241/3932472163.py:8: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1712608853099/work/torch/csrc/utils/python_arg_parser.cpp:294.)\n", " train_mode_cat = torch.cat([quant_identity(float_inp1), quant_identity(float_inp2)], dim=1)\n", - "/tmp/ipykernel_88576/3932472163.py:14: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1699449183005/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", + "/tmp/ipykernel_751241/3932472163.py:14: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1712608853099/work/torch/csrc/utils/python_arg_parser.cpp:294.)\n", " eval_mode_cat = torch.cat([eval_quant_inp1, eval_quant_inp2], dim=1)\n" ] } @@ -978,7 +980,7 @@ "output_type": "stream", "text": [ "Weight QuantTensor:\n", - " QuantTensor(value=tensor([[-0.0000, 0.3880],\n", + " IntQuantTensor(value=tensor([[-0.0000, 0.3880],\n", " [-0.5820, -0.5044],\n", " [-0.2716, 0.1940],\n", " [-0.0000, 0.5432]], grad_fn=), scale=0.03879871591925621, zero_point=0.0, bit_width=5.0, signed_t=True, training_t=True)\n" @@ -1012,13 +1014,10 @@ "output_type": "stream", "text": [ "Weight QuantTensor:\n", - " QuantTensor(value=tensor([[-0.0000, 0.3793],\n", + " IntQuantTensor(value=tensor([[-0.0000, 0.3880],\n", " [-0.5820, -0.5044],\n", - " [-0.2723, 0.1816],\n", - " [-0.0000, 0.5607]], grad_fn=), scale=tensor([[0.0253],\n", - " [0.0388],\n", - " [0.0182],\n", - " [0.0374]], grad_fn=), zero_point=0.0, bit_width=5.0, signed_t=True, training_t=True)\n" + " [-0.2716, 0.1940],\n", + " [-0.0000, 0.5432]], grad_fn=), scale=0.03879871591925621, zero_point=0.0, bit_width=5.0, signed_t=True, training_t=True)\n" ] } ], @@ -1049,7 +1048,7 @@ "output_type": "stream", "text": [ "QuantTensor:\n", - " QuantTensor(value=tensor([[ 1.6341, -0.5447],\n", + " IntQuantTensor(value=tensor([[ 1.6341, -0.5447],\n", " [-2.1788, 0.5447],\n", " [-1.0894, -1.6341]], grad_fn=), scale=0.5446973443031311, zero_point=0.0, bit_width=3.0, signed_t=True, training_t=True)\n" ] @@ -1080,7 +1079,7 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[1.5294, 0.0000],\n", + "IntQuantTensor(value=tensor([[1.5294, 0.0000],\n", " [0.0000, 0.5647],\n", " [0.0000, 0.0000]], grad_fn=), scale=tensor(0.0235, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))" ] @@ -1123,9 +1122,16 @@ " [-1.3986, 0.4033, 0.8380, -0.7193, -0.4033]]]) \n", "\n", "Per-channel quant output:\n", - " QuantTensor(value=tensor([[[ 0.8616, -0.7012, 0.4503],\n", - " [-1.1285, -0.4937, -0.1901]]], grad_fn=), scale=tensor([[[0.0021],\n", - " [0.0013]]], grad_fn=), zero_point=0.0, bit_width=17.0, signed_t=True, training_t=True)\n" + " IntQuantTensor(value=tensor([[[ 0.8616, -0.7012, 0.4503],\n", + " [-1.2235, -0.4359, -0.0473]]], grad_fn=), scale=tensor([[[0.0021]]], grad_fn=), zero_point=tensor([0.]), bit_width=17.0, signed_t=True, training_t=True)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages/torch/nn/modules/conv.py:306: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1712608853099/work/torch/csrc/utils/python_arg_parser.cpp:294.)\n", + " return F.conv1d(input, weight, bias, self.stride,\n" ] } ], @@ -1181,15 +1187,15 @@ " [-1.3986, 0.4033, 0.8380, -0.7193, -0.4033]]]) \n", "\n", "Per-channel quant output:\n", - " QuantTensor(value=tensor([[[ 0.8616, -0.7012, 0.4503],\n", - " [-1.1285, -0.4937, -0.1901]]], grad_fn=), scale=tensor([[[0.0021],\n", - " [0.0013]]], grad_fn=), zero_point=0.0, bit_width=17.0, signed_t=True, training_t=True)\n" + " IntQuantTensor(value=tensor([[[ 0.8616, -0.7012, 0.4503],\n", + " [-1.2235, -0.4359, -0.0473]]], grad_fn=), scale=tensor([[[0.0021]]], grad_fn=), zero_point=tensor([0.]), bit_width=17.0, signed_t=True, training_t=True)\n" ] } ], "source": [ "torch.manual_seed(0)\n", "\n", + "from brevitas.inject.enum import ScalingPerOutputType\n", "from brevitas.nn import QuantConv1d\n", "\n", "BATCHES = 1\n", @@ -1325,7 +1331,7 @@ "output_type": "stream", "text": [ "Weight QuantTensor:\n", - " QuantTensor(value=tensor([[-0.0060, 0.3793],\n", + " IntQuantTensor(value=tensor([[-0.0060, 0.3793],\n", " [-0.5820, -0.5224],\n", " [-0.2723, 0.1887],\n", " [-0.0132, 0.5607]], grad_fn=), scale=tensor([[0.0030],\n", @@ -1338,10 +1344,9 @@ "source": [ "torch.manual_seed(0)\n", "\n", - "from brevitas.quant import Int8WeightPerTensorFloat\n", + "from brevitas.quant import Int8WeightPerChannelFloat\n", "\n", - "class LearnedIntWeightPerChannelFloat(Int8WeightPerTensorFloat):\n", - " scaling_per_output_channel = True\n", + "class LearnedIntWeightPerChannelFloat(Int8WeightPerChannelFloat):\n", " scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS\n", " restrict_scaling_type = RestrictValueType.LOG_FP\n", " bit_width_impl_type = BitWidthImplType.PARAMETER \n", @@ -1367,10 +1372,10 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[-0.9109, -0.4588, 0.3119, -0.6530],\n", + "IntQuantTensor(value=tensor([[-0.9109, -0.4588, 0.3119, -0.6530],\n", " [ 1.2089, 0.6493, -0.3731, 0.8706],\n", " [ 1.3893, 0.2823, -0.8979, 0.9543]], grad_fn=), scale=tensor([[9.0542e-05, 3.9068e-05, 5.6866e-05, 6.4251e-05]],\n", - " grad_fn=), zero_point=tensor(0.), bit_width=tensor(17., grad_fn=), signed_t=tensor(True), training_t=tensor(True))" + " grad_fn=), zero_point=tensor([0.]), bit_width=tensor(17., grad_fn=), signed_t=tensor(True), training_t=tensor(True))" ] }, "execution_count": 28, @@ -1433,7 +1438,7 @@ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[29], line 12\u001b[0m\n\u001b[1;32m 5\u001b[0m float_linear \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mLinear(\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m4\u001b[39m, bias\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 6\u001b[0m quant_linear \u001b[38;5;241m=\u001b[39m QuantLinear(\n\u001b[1;32m 7\u001b[0m \u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m4\u001b[39m, \n\u001b[1;32m 8\u001b[0m input_quant\u001b[38;5;241m=\u001b[39mLearnedIntActPerTensorFloat,\n\u001b[1;32m 9\u001b[0m weight_quant\u001b[38;5;241m=\u001b[39mLearnedIntWeightPerChannelFloat, \n\u001b[1;32m 10\u001b[0m return_quant_tensor\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, bias\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m---> 12\u001b[0m \u001b[43mquant_linear\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_state_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfloat_linear\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstate_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:2153\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict, assign)\u001b[0m\n\u001b[1;32m 2148\u001b[0m error_msgs\u001b[38;5;241m.\u001b[39minsert(\n\u001b[1;32m 2149\u001b[0m \u001b[38;5;241m0\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mMissing key(s) in state_dict: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m. \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 2150\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m missing_keys)))\n\u001b[1;32m 2152\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(error_msgs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m-> 2153\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mError(s) in loading state_dict for \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 2154\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(error_msgs)))\n\u001b[1;32m 2155\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n", + "File \u001b[0;32m~/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages/torch/nn/modules/module.py:2189\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict, assign)\u001b[0m\n\u001b[1;32m 2184\u001b[0m error_msgs\u001b[38;5;241m.\u001b[39minsert(\n\u001b[1;32m 2185\u001b[0m \u001b[38;5;241m0\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mMissing key(s) in state_dict: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m. \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 2186\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m missing_keys)))\n\u001b[1;32m 2188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(error_msgs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m-> 2189\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mError(s) in loading state_dict for \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 2190\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(error_msgs)))\n\u001b[1;32m 2191\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n", "\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for QuantLinear:\n\tMissing key(s) in state_dict: \"input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value\", \"input_quant.fused_activation_quant_proxy.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\", \"weight_quant.tensor_quant.scaling_impl.value\", \"weight_quant.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\". " ] } @@ -1875,11 +1880,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "Requirement already satisfied: netron in /scratch/fabian/miniforge3/envs/torchgpu/lib/python3.11/site-packages (7.4.5)\n", - "Requirement already satisfied: onnx in /scratch/fabian/miniforge3/envs/torchgpu/lib/python3.11/site-packages (1.15.0)\n", - "Requirement already satisfied: onnxoptimizer in /scratch/fabian/miniforge3/envs/torchgpu/lib/python3.11/site-packages (0.3.13)\n", - "Requirement already satisfied: numpy in /scratch/fabian/miniforge3/envs/torchgpu/lib/python3.11/site-packages (from onnx) (1.26.0)\n", - "Requirement already satisfied: protobuf>=3.20.2 in /scratch/fabian/miniforge3/envs/torchgpu/lib/python3.11/site-packages (from onnx) (3.20.3)\n" + "\u001b[33mDEPRECATION: Loading egg at /home/giuseppe/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages/torchao-0.3.0-py3.11.egg is deprecated. pip 24.3 will enforce this behaviour change. A possible replacement is to use pip for package installation.. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: netron in /home/giuseppe/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages (7.6.3)\n", + "Requirement already satisfied: onnx in /home/giuseppe/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages (1.15.0)\n", + "Requirement already satisfied: onnxoptimizer in /home/giuseppe/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages (0.3.13)\n", + "Requirement already satisfied: numpy in /home/giuseppe/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages (from onnx) (1.26.4)\n", + "Requirement already satisfied: protobuf>=3.20.2 in /home/giuseppe/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages (from onnx) (3.20.3)\n" ] } ], @@ -1914,9 +1920,18 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 39, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "torch.manual_seed(0)\n", "\n", @@ -1937,7 +1952,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 40, "metadata": { "tags": [ "skip-execution" @@ -1966,10 +1981,10 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 42, + "execution_count": 40, "metadata": {}, "output_type": "execute_result" } @@ -1987,7 +2002,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 41, "metadata": {}, "outputs": [], "source": [ @@ -2001,7 +2016,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 42, "metadata": { "tags": [ "skip-execution" @@ -2030,10 +2045,10 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 44, + "execution_count": 42, "metadata": {}, "output_type": "execute_result" } @@ -2087,24 +2102,9 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 43, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/scratch/fabian/miniforge3/envs/torchgpu/lib/python3.11/site-packages/torch/overrides.py:110: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'\n", - " torch.has_cuda,\n", - "/scratch/fabian/miniforge3/envs/torchgpu/lib/python3.11/site-packages/torch/overrides.py:111: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'\n", - " torch.has_cudnn,\n", - "/scratch/fabian/miniforge3/envs/torchgpu/lib/python3.11/site-packages/torch/overrides.py:117: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'\n", - " torch.has_mps,\n", - "/scratch/fabian/miniforge3/envs/torchgpu/lib/python3.11/site-packages/torch/overrides.py:118: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'\n", - " torch.has_mkldnn,\n" - ] - } - ], + "outputs": [], "source": [ "from brevitas.graph.calibrate import bias_correction_mode\n", "from brevitas.graph.calibrate import calibration_mode\n", @@ -2143,7 +2143,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.9" }, "vscode": { "interpreter": { diff --git a/notebooks/minifloat_mx_tutorial.ipynb b/notebooks/minifloat_mx_tutorial.ipynb new file mode 100644 index 000000000..e764fd05c --- /dev/null +++ b/notebooks/minifloat_mx_tutorial.ipynb @@ -0,0 +1,285 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Minifloat and Groupwise quantization" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook shows some practical use cases for minifloat and groupwise quantization.\n", + "\n", + "Brevitas supports a wide combination of float quantization, including the OCP and FNUZ FP8 standard.\n", + "It is possible to define any combination of exponent/mantissa bitwidth, as well as exponent bias.\n", + "\n", + "Similarly, MX quantization is supported as general groupwise quantization on top of integer/minifloat datatype.\n", + "This allows to any general groupwise quantization, including MXInt and MXFloat standards.\n", + "\n", + "This tutorial shows how to instantiate and use some of the most interesting quantizers for minifloat and groupwise quantization" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Minifloat (FP8 and lower)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Brevitas offers some pre-defined quantizers for minifloat quantization, including OCP and FNUZ standards, which can be further customized according to the specific use case.\n", + "The general naming structure for the quantizers is the following:\n", + "\n", + "`Fp\\\\Weight\\Float`\n", + "\n", + "Where `Bitwidth` can be either empty or `e4m3`/`e5m2`, `Standard` can be empty or `OCP`/`FNUZ`, `Scaling` can be empty or `PerTensor`/`PerChannel`.\n", + "\n", + "If `Bitwidth` is empty, the user must set it with kwargs or by subclassing the quantizers. Once the bitwidth is defined, the correct values for inf/nan are automatically defined based on the `Standard`.\n", + "If a non-valid OCP bitwidth is set (e.g., e6m1), then no inf/nan values will be selected and the corresponding quantizer is not standard-compliant.\n", + "\n", + "`Standard` allows to pick among the two main FP8 standard options; moreover, if not specified, Brevitas offers the possibility of doing minifloat quantization without necessarily reserving values for inf/nan representation.\n", + "This allows to use the maximum available range, since often in quantization, values that exceed the quantization range saturate to maximum rather than going to inf/nan.\n", + "FNUZ quantizers need to have `saturating=True`.\n", + "\n", + "The `Scaling` options defines whether the quantization is _scaled_ or _unscaled_.\n", + "In the unscaled case, the scale factor for quantization is fixed to one, otherwise it can be set using any of the methods that Brevitas includes (e.g., statistics, learned, etc.)\n", + "\n", + "\n", + "Please keep in mind that not all combinations of the above options might be pre-defined and this serves mostly as indications of what Brevitas supports.\n", + "It is possible, following the same structure of the available quantizers, to define new ones that fit your needs.\n", + "\n", + "\n", + "Similar considerations can be extended for activation quantization." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages/torch/nn/modules/conv.py:456: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1712608853099/work/torch/csrc/utils/python_arg_parser.cpp:294.)\n", + " return F.conv2d(input, weight, bias, self.stride,\n" + ] + } + ], + "source": [ + "from brevitas.quant.experimental.float_base import Fp8e4m3Mixin\n", + "from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight\n", + "from brevitas.quant.experimental.float_quant_ocp import FpOCPWeightPerTensorFloat, FpOCPActPerTensorFloat\n", + "from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act\n", + "import brevitas.nn as qnn\n", + "import torch.nn as nn\n", + "import torch\n", + "from brevitas.quant_tensor import FloatQuantTensor\n", + "\n", + "class OCPFP8Weight(FpOCPWeightPerTensorFloat, Fp8e4m3Mixin):\n", + " pass\n", + "\n", + "\n", + "class OCPFP8Act(FpOCPActPerTensorFloat, Fp8e4m3Mixin):\n", + " pass\n", + "\n", + "\n", + "class FP8Model(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=OCPFP8Weight, input_quant=OCPFP8Act)\n", + " \n", + " def forward(self, x):\n", + " return self.conv(x)\n", + "\n", + "ocp_fp8_model = FP8Model()\n", + "x = torch.randn(1, 32, 8, 8)\n", + "ocp_fp8_model.eval()\n", + "o = ocp_fp8_model(x)\n", + "\n", + "intermediate_input = ocp_fp8_model.conv.input_quant(x)\n", + "assert isinstance(intermediate_input, FloatQuantTensor)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Groupwise quantization (MXInt/MXFloat)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Groupwise quantization is built on top of integer/minifloat quantization, with special considerations to accomodate for the groupwise scaling.\n", + "\n", + "Compared to Int/Float QuantTensor, the main difference of their groupwise equivalent is that value, scale, and zero_point are not direct attributes anymore but properties. The new attributes are value_, scale_, and zero_point_.\n", + "\n", + "The reason for this is shaping. When quantizing a tensor with shapes [O, I], where O is output channel and I is input channel, with groupsize k, groupwise quantization is normally represented as follow:\n", + "\n", + "- Tensor with shapes [O, k, I/k]\n", + "- Scales with shapes [O, k, 1]\n", + "- Zero point same as scale\n", + "\n", + "The alternative to this representation is to have all three tensors with shapes [O,I], with a massive increase in memory utilization, especially with QAT + gradients.\n", + "\n", + "The underscored attributes will have the compressed shapes, while the properties (non-underscored naming) will dynamically compute the expanded version of the property. This means:\n", + "```python\n", + "quant_tensor.scale_.shape\n", + "# This will print [O, k, 1]\n", + "quant_tensor.scale.shape\n", + "# This will print [O, I]\n", + "```\n", + "\n", + "With respect to pre-defined quantizers, Brevitas offers several Groupwise and MX options.\n", + "The main difference between the two is that MX is restricted to group_size=32 and the scale factor must be a power-of-2.\n", + "The user can override these settings but the corresponding output won't be MX compliant.\n", + "\n", + "Another difference is that MXFloat relies on the OCP format as underlying data type, while generic groupwise float relies on the non-standard minifloat representation explained above.\n", + "\n", + "Finally, the general groupwise scaling relies on float scales." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from brevitas.quant_tensor import GroupwiseFloatQuantTensor\n", + "\n", + "\n", + "class MXFloat8Weight(MXFloat8e4m3Weight, Fp8e4m3Mixin):\n", + " # The group dimension for the weights it is automatically identified based on the layer type\n", + " # If a new layer type is used, it can be manually specified\n", + " pass\n", + "\n", + "class MXFloat8Act(MXFloat8e4m3Act, Fp8e4m3Mixin):\n", + " # It is necessary to specify the group dimension for the activation quantization\n", + " group_dim = 1\n", + "\n", + "\n", + "class MXModel(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n", + " \n", + " def forward(self, x):\n", + " return self.conv(x)\n", + "\n", + "mx_model = MXModel()\n", + "x = torch.randn(1, 32, 8, 8)\n", + "mx_model.eval()\n", + "o = mx_model(x)\n", + "\n", + "intermediate_input = mx_model.conv.input_quant(x)\n", + "assert isinstance(intermediate_input, GroupwiseFloatQuantTensor)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3WeightMSE\n", + "from brevitas.quant_tensor import GroupwiseFloatQuantTensor\n", + "\n", + "\n", + "class MXFloat8Weight(MXFloat8e4m3WeightMSE, Fp8e4m3Mixin):\n", + " # The group dimension for the weights it is automatically identified based on the layer type\n", + " # If a new layer type is used, it can be manually specified\n", + " pass\n", + "\n", + "class MXFloat8Act(MXFloat8e4m3Act, Fp8e4m3Mixin):\n", + " # It is necessary to specify the group dimension for the activation quantization\n", + " group_dim = 1\n", + "\n", + "\n", + "class MXModel(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n", + " \n", + " def forward(self, x):\n", + " return self.conv(x)\n", + "\n", + "mx_model = MXModel()\n", + "x = torch.randn(1, 32, 8, 8)\n", + "mx_model.eval()\n", + "o = mx_model(x)\n", + "\n", + "intermediate_input = mx_model.conv.input_quant(x)\n", + "assert isinstance(intermediate_input, GroupwiseFloatQuantTensor)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from brevitas.quant_tensor import GroupwiseIntQuantTensor\n", + "from brevitas.quant.experimental.mx_quant_ocp import MXInt8Weight\n", + "from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act\n", + "import torch.nn as nn\n", + "import brevitas.nn as qnn\n", + "import torch\n", + "\n", + "class MXFloat8Weight(MXInt8Weight):\n", + " # The group dimension for the weights it is automatically identified based on the layer type\n", + " # If a new layer type is used, it can be manually specified\n", + " bit_width = 8\n", + "\n", + "class MXFloat8Act(MXInt8Act):\n", + " # It is necessary to specify the group dimension for the activation quantization\n", + " group_dim = 1\n", + " bit_width = 8\n", + "\n", + "class MXModel(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n", + " \n", + " def forward(self, x):\n", + " return self.conv(x)\n", + "\n", + "mx_model = MXModel()\n", + "x = torch.randn(1, 32, 8, 8)\n", + "mx_model.eval()\n", + "o = mx_model(x)\n", + "\n", + "intermediate_input = mx_model.conv.input_quant(x)\n", + "assert isinstance(intermediate_input, GroupwiseIntQuantTensor)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "brevitas_dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/brevitas/core/function_wrapper/shape.py b/src/brevitas/core/function_wrapper/shape.py index 692ae76da..cdef81b3e 100644 --- a/src/brevitas/core/function_wrapper/shape.py +++ b/src/brevitas/core/function_wrapper/shape.py @@ -153,6 +153,24 @@ def forward(self, x: torch.Tensor): return y.reshape(shape) +class OverSubChannelBlockView(brevitas.jit.ScriptModule): + __constants__ = ['expanded_scaling_shape'] + + def __init__(self, expanded_scaling_shape, permute_dims: Optional[Tuple[int, ...]]) -> None: + super(OverSubChannelBlockView, self).__init__() + self.expanded_scaling_shape = expanded_scaling_shape + if permute_dims is not None: + self.permute_impl = PermuteDims(permute_dims) + else: + self.permute_impl = torch.nn.Identity() + + @brevitas.jit.script_method + def forward(self, x: torch.Tensor): + y = self.permute_impl(x) + y = y.view(self.expanded_scaling_shape) + return y + + class StatsInputViewShapeImpl(object): """ Enum-like object to collect pointers to variants of ScriptModules that perform a view on a tensor. @@ -163,3 +181,4 @@ class StatsInputViewShapeImpl(object): OVER_BATCH_OVER_TENSOR = OverBatchOverTensorView OVER_BATCH_OVER_OUTPUT_CHANNELS = OverBatchOverOutputChannelView OVER_OUTPUT_FEATURES = OverOutputFeaturesView + OVER_SUBCHANNEL_BLOCK = OverSubChannelBlockView diff --git a/src/brevitas/core/scaling/__init__.py b/src/brevitas/core/scaling/__init__.py index 1be86be55..c21cb4b27 100644 --- a/src/brevitas/core/scaling/__init__.py +++ b/src/brevitas/core/scaling/__init__.py @@ -1,6 +1,8 @@ from brevitas.inject.enum import ScalingImplType +from brevitas.inject.enum import ScalingPerOutputType assert ScalingImplType +assert ScalingPerOutputType from brevitas.core.stats import SCALAR_SHAPE diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index f1d108068..23707344f 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -158,3 +158,38 @@ def _load_from_state_dict( missing_keys.remove(affine_weight_key) if config.IGNORE_MISSING_KEYS and affine_bias_key in missing_keys: missing_keys.remove(affine_bias_key) + + +class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule): + + def __init__( + self, + group_size: int, + group_dim: int, + scaling_stats_impl: torch.nn.Module, + scaling_min_val: Optional[float], + restrict_scaling_impl: Optional[torch.nn.Module]) -> None: + super(RuntimeDynamicGroupStatsScaling, self).__init__() + self.group_size = group_size + self.group_dim = group_dim + self.scaling_stats_impl = scaling_stats_impl + self.scaling_min_val = scaling_min_val + self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) + + @brevitas.jit.script_method + def group_scaling_reshape(self, stats_input): + tensor_shape = stats_input.shape + tensor_shape_list = list(tensor_shape) + tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size) + block_dim = self.group_dim + 1 if self.group_dim != -1 else -1 + tensor_shape_list.insert(block_dim, self.group_size) + stats_input = stats_input.view(tensor_shape_list) + return stats_input + + @brevitas.jit.script_method + def forward(self, stats_input) -> torch.Tensor: + stats_input_reshaped = self.group_scaling_reshape(stats_input) + out = self.scaling_stats_impl(stats_input_reshaped) + # Scaling min val + out = self.restrict_clamp_scaling(out) + return out diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index 97af39049..e13ac73d1 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -23,6 +23,7 @@ from brevitas.export.common.handler.qcdq import QMixin from brevitas.export.onnx.handler import ONNXBaseHandler from brevitas.export.onnx.handler import QuantLSTMLayerHandler +from brevitas.inject.enum import ScalingPerOutputType from ..function import CastFn from ..function import DequantizeLinearFn @@ -133,7 +134,7 @@ def validate(self, module): # Below 8b quantization is not supported. self.validate_8b_bit_width(module.bit_width(), le_then=False) # Only per tensor quantization is supported - assert not module.quant_injector.scaling_per_output_channel, "Only per tensor scaling supported" + assert module.quant_injector.scaling_per_output == ScalingPerOutputType.TENSOR, "Only per tensor scaling supported" def quantize_fn(self, x, dtype): return DynamicQuantizeLinearFn.apply(x, dtype) diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index c45e833d5..ee035b9bd 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -13,7 +13,6 @@ from brevitas.graph.fixed_point import CollapseConsecutiveConcats from brevitas.graph.fixed_point import MergeBatchNorm from brevitas.graph.fixed_point import MoveSplitBatchNormBeforeCat -from brevitas.graph.per_input import AdaptiveAvgPoolToAvgPool from brevitas.graph.quantize_impl import act_handler from brevitas.graph.quantize_impl import add_output_quant_handler from brevitas.graph.quantize_impl import inp_placeholder_handler @@ -25,7 +24,6 @@ from brevitas.graph.standardize import MeanMethodToAdaptiveAvgPool2d from brevitas.graph.standardize import RemoveStochasticModules from brevitas.graph.standardize import TorchFunctionalToModule -from brevitas.nn import quant_layer import brevitas.nn as qnn from brevitas.quant import Int8ActPerTensorFloat from brevitas.quant import Int8ActPerTensorFloatMinMaxInit diff --git a/src/brevitas/inject/enum.py b/src/brevitas/inject/enum.py index 95ea05355..129a55252 100644 --- a/src/brevitas/inject/enum.py +++ b/src/brevitas/inject/enum.py @@ -67,6 +67,15 @@ class ScalingImplType(AutoName): PARAMETER_FROM_STATS = auto() +class ScalingPerOutputType(AutoName): + """ + + """ + TENSOR = auto() + CHANNEL = auto() + GROUP = auto() + + class StatsOp(AutoName): """ diff --git a/src/brevitas/nn/hadamard_classifier.py b/src/brevitas/nn/hadamard_classifier.py index e78163321..38e82da48 100644 --- a/src/brevitas/nn/hadamard_classifier.py +++ b/src/brevitas/nn/hadamard_classifier.py @@ -14,7 +14,7 @@ from brevitas.function.ops import max_int from brevitas.function.ops_ste import ceil_ste -from brevitas.quant_tensor import QuantTensor +from brevitas.quant_tensor import IntQuantTensor from .mixin.base import QuantLayerMixin @@ -49,14 +49,14 @@ def forward(self, inp): out = inp.value / norm out = nn.functional.linear(out, self.proj[:self.out_channels, :self.in_channels]) out = -self.scale * out - if isinstance(inp, QuantTensor): + if isinstance(inp, IntQuantTensor): output_scale = inp.scale * self.scale / norm output_bit_width = self.max_output_bit_width(inp.bit_width) if (self.return_quant_tensor and inp.zero_point != 0.0).any(): raise RuntimeError("Computing zero point of output accumulator not supported yet.") else: output_zp = inp.zero_point - out = QuantTensor( + out = IntQuantTensor( value=out, scale=output_scale, zero_point=output_zp, diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index bbdd77ac7..167852508 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -21,6 +21,8 @@ from brevitas.quant_tensor import FloatQuantTensor from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor +from brevitas.quant_tensor.groupwise_float_quant_tensor import GroupwiseFloatQuantTensor +from brevitas.quant_tensor.groupwise_int_quant_tensor import GroupwiseIntQuantTensor from .utils import filter_kwargs @@ -71,7 +73,8 @@ def _set_global_is_quant_layer(self, value): config._IS_INSIDE_QUANT_LAYER = value def get_quant_tensor_class(self, inp: Union[Tensor, QuantTensor]): - quant_tensor_classes = [IntQuantTensor, FloatQuantTensor] + quant_tensor_classes = [ + IntQuantTensor, FloatQuantTensor, GroupwiseIntQuantTensor, GroupwiseFloatQuantTensor] for qt_class in quant_tensor_classes: if len(inp) == len(qt_class._fields): return qt_class diff --git a/src/brevitas/proxy/float_parameter_quant.py b/src/brevitas/proxy/float_parameter_quant.py index b59a37696..68038fa20 100644 --- a/src/brevitas/proxy/float_parameter_quant.py +++ b/src/brevitas/proxy/float_parameter_quant.py @@ -12,7 +12,7 @@ from brevitas.utils.quant_utils import _CachedIOFloat -class WeightFloatQuantProxyFromInjector(WeightQuantProxyFromInjectorBase): +class WeightFloatQuantProxyFromInjectorBase(WeightQuantProxyFromInjectorBase): def scale(self): if not self.is_quant_enabled: @@ -104,6 +104,28 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, FloatQuantTensor]: return x +class WeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase): + + def forward(self, x: torch.Tensor) -> Union[Tensor, FloatQuantTensor]: + if self.is_quant_enabled: + impl = self.export_handler if self.export_mode else self.tensor_quant + out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x) + return FloatQuantTensor( + out, + scale, + zero_point, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + saturating, + inf_values, + nan_values, + self.is_signed, + self.training) + else: # quantization disabled + return x + + class BiasFloatQuantProxyFromInjector(BiasQuantProxyFromInjectorBase): def scale(self): diff --git a/src/brevitas/proxy/float_runtime_quant.py b/src/brevitas/proxy/float_runtime_quant.py index 546fa8f8a..021aefd12 100644 --- a/src/brevitas/proxy/float_runtime_quant.py +++ b/src/brevitas/proxy/float_runtime_quant.py @@ -12,7 +12,7 @@ from brevitas.utils.quant_utils import _CachedIOFloat -class ActFloatQuantProxyFromInjector(ActQuantProxyFromInjectorBase): +class ActFloatQuantProxyFromInjectorBase(ActQuantProxyFromInjectorBase): def scale(self, force_eval=True): return self.retrieve_attribute('scale', force_eval) @@ -87,14 +87,67 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, FloatQuantTens y, x.scale, x.zero_point, - x.mantissa_bit_width, x.exponent_bit_width, + x.mantissa_bit_width, x.exponent_bias, + x.saturating, + x.inf_values, + x.nan_values, x.signed, - self.training, + self.training) + else: + out = y + else: + if isinstance(y, tuple): + y = y[0] + out = y + else: + # If fused activation quant proxy is not enabled, return the input + out = x + if not self.training and self.cache_inference_quant_act and isinstance(out, + FloatQuantTensor): + cached_out = _CachedIOFloat(out.detach(), self.cache_quant_io_metadata_only) + self._cached_act = cached_out + return out + + +class ActFloatQuantProxyFromInjector(ActFloatQuantProxyFromInjectorBase): + + def forward(self, x: Union[Tensor, FloatQuantTensor]) -> Union[Tensor, FloatQuantTensor]: + out = x + if self.fused_activation_quant_proxy is not None: + y = x + if isinstance(y, FloatQuantTensor): + y = y.value + + if self.export_mode: + y = self.fused_activation_quant_proxy.activation_impl(y) + y = self.export_handler(y) + elif not self.is_quant_enabled: + y = self.fused_activation_quant_proxy.activation_impl(y) + else: + y = self.fused_activation_quant_proxy(y) + # If y is an empty FloatQuantTensor, we need to check if this is a passthrough proxy, + # otherwise return a simple Tensor + # We exclude the last two values (inf_values and nan_values) + if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])): + out = FloatQuantTensor(*y, signed=self.is_signed, training=self.training) + elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant + if isinstance(y, tuple): + y = y[0] + if isinstance(x, FloatQuantTensor): + out = FloatQuantTensor( + y, + x.scale, + x.zero_point, + x.mantissa_bit_width, + x.exponent_bit_width, + x.exponent_bias, x.saturating, x.inf_values, - x.nan_values) + x.nan_values, + x.signed, + self.training) else: out = y else: diff --git a/src/brevitas/proxy/groupwise_float_parameter_quant.py b/src/brevitas/proxy/groupwise_float_parameter_quant.py new file mode 100644 index 000000000..cd38d9906 --- /dev/null +++ b/src/brevitas/proxy/groupwise_float_parameter_quant.py @@ -0,0 +1,45 @@ +from typing import Union + +import torch +from torch import Tensor + +from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjectorBase +from brevitas.quant_tensor import GroupwiseFloatQuantTensor + + +class GroupwiseWeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # TODO: Is this always generated? + self.view_impl = self.quant_injector.scaling_stats_input_view_shape_impl + + @property + def group_dim(self): + return self.quant_injector.group_dim + + @property + def group_size(self): + return self.quant_injector.group_size + + def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseFloatQuantTensor]: + if self.is_quant_enabled: + impl = self.export_handler if self.export_mode else self.tensor_quant + x = self.view_impl(x) + out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x) + return GroupwiseFloatQuantTensor( + out, + scale, + zero_point, + self.group_size, + self.group_dim, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + saturating, + inf_values, + nan_values, + self.is_signed, + self.training) + else: # quantization disabled + return x diff --git a/src/brevitas/proxy/groupwise_float_runtime_quant.py b/src/brevitas/proxy/groupwise_float_runtime_quant.py new file mode 100644 index 000000000..4ab182d20 --- /dev/null +++ b/src/brevitas/proxy/groupwise_float_runtime_quant.py @@ -0,0 +1,85 @@ +from typing import Union + +from torch import Tensor + +from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase +from brevitas.quant_tensor import GroupwiseFloatQuantTensor +from brevitas.quant_tensor import QuantTensor +from brevitas.utils.quant_utils import _CachedIOGroupwiseFloat + + +class GroupwiseActFloatQuantProxyFromInjector(ActFloatQuantProxyFromInjectorBase): + + @property + def group_dim(self): + return self.quant_injector.group_dim + + @property + def group_size(self): + return self.quant_injector.group_size + + def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, GroupwiseFloatQuantTensor]: + out = x + if self.fused_activation_quant_proxy is not None: + y = x + if isinstance(y, QuantTensor): + y = y.value + + if self.export_mode: + y = self.fused_activation_quant_proxy.activation_impl(y) + y = self.export_handler(y) + elif not self.is_quant_enabled: + y = self.fused_activation_quant_proxy.activation_impl(y) + else: + y = self.fused_activation_quant_proxy(y) + # If y is an empty GroupwiseFloatQuantTensor, we need to check if this is a passthrough proxy, + # otherwise return a simple Tensor + # We exclude the last two values (inf_values and nan_values) + if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])): + value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = y + out = GroupwiseFloatQuantTensor( + value, + scale, + zero_point, + self.group_size, + self.group_dim, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + saturating, + inf_values, + nan_values, + signed=self.is_signed, + training=self.training) + elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant + if isinstance(y, tuple): + y = y[0] + if isinstance(x, GroupwiseFloatQuantTensor): + out = GroupwiseFloatQuantTensor( + y, + x.scale, + x.zero_point, + self.group_size, + self.group_dim, + x.exponent_bit_width, + x.mantissa_bit_width, + x.exponent_bias, + x.saturating, + x.inf_values, + x.nan_values, + x.signed, + self.training) + else: + out = y + else: + if isinstance(y, tuple): + y = y[0] + out = y + else: + # If fused activation quant proxy is not enabled, return the input + out = x + if not self.training and self.cache_inference_quant_act and isinstance( + out, GroupwiseFloatQuantTensor): + cached_out = _CachedIOGroupwiseFloat(out.detach(), self.cache_quant_io_metadata_only) + self._cached_act = cached_out + return out diff --git a/src/brevitas/proxy/groupwise_int_parameter_quant.py b/src/brevitas/proxy/groupwise_int_parameter_quant.py new file mode 100644 index 000000000..035ee9729 --- /dev/null +++ b/src/brevitas/proxy/groupwise_int_parameter_quant.py @@ -0,0 +1,40 @@ +from typing import Optional, Union + +import torch +from torch import Tensor + +from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector +from brevitas.quant_tensor import GroupwiseIntQuantTensor + + +class GroupwiseWeightQuantProxyFromInjector(WeightQuantProxyFromInjector): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # TODO: Is this always generated? + self.view_impl = self.quant_injector.scaling_stats_input_view_shape_impl + + @property + def group_dim(self): + return self.quant_injector.group_dim + + @property + def group_size(self): + return self.quant_injector.group_size + + def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseIntQuantTensor]: + if self.is_quant_enabled: + impl = self.export_handler if self.export_mode else self.tensor_quant + x = self.view_impl(x) + out, scale, zero_point, bit_width = impl(x) + return GroupwiseIntQuantTensor( + out, + scale, + zero_point, + self.group_size, + self.group_dim, + bit_width, + self.is_signed, + self.training) + else: # quantization disabled + return x diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py new file mode 100644 index 000000000..e9788e89b --- /dev/null +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -0,0 +1,75 @@ +from typing import Union + +from torch import Tensor + +from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector +from brevitas.quant_tensor import GroupwiseIntQuantTensor +from brevitas.quant_tensor import QuantTensor +from brevitas.utils.quant_utils import _CachedIOGroupwiseInt + + +class GroupwiseActQuantProxyFromInjector(ActQuantProxyFromInjector): + + @property + def group_dim(self): + return self.quant_injector.group_dim + + @property + def group_size(self): + return self.quant_injector.group_size + + def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, GroupwiseIntQuantTensor]: + out = x + if self.fused_activation_quant_proxy is not None: + y = x + if isinstance(y, QuantTensor): + y = y.value + + if self.export_mode: + y = self.fused_activation_quant_proxy.activation_impl(y) + y = self.export_handler(y) + elif not self.is_quant_enabled: + y = self.fused_activation_quant_proxy.activation_impl(y) + else: + y = self.fused_activation_quant_proxy(y) + # If y is an empty GroupwiseIntQuantTensor, we need to check if this is a passthrough proxy, + # otherwise return a simple Tensor + # We exclude the last two values (inf_values and nan_values) + if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])): + value, scale, zero_point, bit_width, = y + out = GroupwiseIntQuantTensor( + value, + scale, + zero_point, + self.group_size, + self.group_dim, + bit_width, + signed=self.is_signed, + training=self.training) + elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant + if isinstance(y, tuple): + y = y[0] + if isinstance(x, GroupwiseIntQuantTensor): + out = GroupwiseIntQuantTensor( + y, + x.scale, + x.zero_point, + self.group_size, + self.group_dim, + x.bit_width, + x.signed, + self.training) + else: + out = y + else: + if isinstance(y, tuple): + y = y[0] + out = y + else: + # If fused activation quant proxy is not enabled, return the input + out = x + if not self.training and self.cache_inference_quant_act and isinstance( + out, GroupwiseIntQuantTensor): + cached_out = _CachedIOGroupwiseInt(out.detach(), self.cache_quant_io_metadata_only) + self._cached_act = cached_out + return out diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index 3a5a2ec53..18351a05b 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -48,9 +48,11 @@ from brevitas.inject.enum import QuantType from brevitas.inject.enum import RestrictValueType from brevitas.inject.enum import ScalingImplType +from brevitas.inject.enum import ScalingPerOutputType from brevitas.inject.enum import StatsOp from brevitas.proxy import DecoupledWeightQuantProxyFromInjector from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector +from brevitas.quant.solver.common import SolveStatsReduceDimFromEnum from brevitas.quant.solver.parameter import SolveParameterScalingShape from brevitas.quant.solver.weight import SolveWeightScalingPerOutputChannelShapeFromModule from brevitas.quant.solver.weight import SolveWeightScalingStatsInputDimsFromModule @@ -227,7 +229,7 @@ class ShiftedParamFromPercentileUintQuant(ExtendedInjector): class PerChannelFloatScaling8bit(ExtendedInjector): """ """ - scaling_per_output_channel = True + scaling_per_output_type = ScalingPerOutputType.CHANNEL restrict_scaling_type = RestrictValueType.FP bit_width = 8 @@ -235,7 +237,7 @@ class PerChannelFloatScaling8bit(ExtendedInjector): class PerTensorFloatScaling8bit(ExtendedInjector): """ """ - scaling_per_output_channel = False + scaling_per_output_type = ScalingPerOutputType.TENSOR restrict_scaling_type = RestrictValueType.FP bit_width = 8 @@ -243,7 +245,7 @@ class PerTensorFloatScaling8bit(ExtendedInjector): class PerChannelPoTScaling8bit(ExtendedInjector): """ """ - scaling_per_output_channel = True + scaling_per_output_type = ScalingPerOutputType.CHANNEL restrict_scaling_type = RestrictValueType.FP bit_width = 8 @@ -251,7 +253,7 @@ class PerChannelPoTScaling8bit(ExtendedInjector): class PerTensorPoTScaling8bit(ExtendedInjector): """ """ - scaling_per_output_channel = False + scaling_per_output_type = ScalingPerOutputType.TENSOR restrict_scaling_type = RestrictValueType.POWER_OF_TWO bit_width = 8 restrict_value_float_to_int_impl = CeilSte @@ -262,7 +264,7 @@ class SignedBinaryClampedConst(ExtendedInjector): scaling_impl_type = ScalingImplType.CONST restrict_scaling_type = RestrictValueType.FP float_to_int_impl_type = FloatToIntImplType.ROUND - scaling_per_output_channel = False + scaling_per_output_type = ScalingPerOutputType.TENSOR narrow_range = True signed = True @@ -270,7 +272,7 @@ class SignedBinaryClampedConst(ExtendedInjector): class PerTensorConstScaling2bit(ExtendedInjector): scaling_impl_type = ScalingImplType.CONST restrict_scaling_type = RestrictValueType.FP - scaling_per_output_channel = False + scaling_per_output_type = ScalingPerOutputType.TENSOR bit_width = 2 @@ -300,7 +302,8 @@ class WeightPerTensorFloatDecoupledL2Param(SolveWeightScalingStatsInputDimsFromM signed = True -class WeightPerChannelFloatDecoupled(SolveWeightScalingStatsInputDimsFromModule, +class WeightPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum, + SolveWeightScalingStatsInputDimsFromModule, SolveWeightScalingPerOutputChannelShapeFromModule, SolveParameterScalingShape): """ @@ -324,10 +327,11 @@ class WeightPerChannelFloatDecoupled(SolveWeightScalingStatsInputDimsFromModule, signed = True scaling_stats_input_view_shape_impl = OverOutputChannelView stats_reduce_dim = SCALING_STATS_REDUCE_DIM - scaling_per_output_channel = True + scaling_per_output_type = ScalingPerOutputType.CHANNEL -class WeightNormPerChannelFloatDecoupled(SolveWeightScalingStatsInputDimsFromModule, +class WeightNormPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum, + SolveWeightScalingStatsInputDimsFromModule, SolveWeightScalingPerOutputChannelShapeFromModule, SolveParameterScalingShape): """Experimental narrow per-channel weight normalization-based signed integer quantizer @@ -360,6 +364,7 @@ def scaling_init(scaling_init_impl, bit_width): pre_scaling_impl = ParameterPreScalingWeightNorm restrict_pre_scaling_impl = LogFloatRestrictValue normalize_stats_impl = L2Norm + scaling_per_output_type = ScalingPerOutputType.CHANNEL pre_scaling_shape = this.scaling_shape # TODO: decouple pre_scaling_shape from scaling_shape int_scaling_impl = SingleArgStatelessBuffer(1.) zero_point_impl = ZeroZeroPoint @@ -369,7 +374,6 @@ def scaling_init(scaling_init_impl, bit_width): signed = True scaling_stats_input_view_shape_impl = OverOutputChannelView stats_reduce_dim = SCALING_STATS_REDUCE_DIM - scaling_per_output_channel = True scaling_min_val = 1e-10 pre_scaling_min_val = 1e-10 @@ -420,17 +424,19 @@ class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant): class MSESubInjectorBase(ExtendedInjector): @value - def inner_stats_input_view_shape_impl(per_channel): - if per_channel: + def inner_stats_input_view_shape_impl(scaling_per_output): + if scaling_per_output == ScalingPerOutputType.CHANNEL: return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS - else: + elif scaling_per_output == ScalingPerOutputType.TENSOR: return StatsInputViewShapeImpl.OVER_TENSOR + elif scaling_per_output == ScalingPerOutputType.GROUP: + raise RuntimeError("Not implemented yet") permute_dims = (this << 1).permute_dims class MSESymmetricScaleSubInjector(MSESubInjectorBase): - per_channel = (this << 1).scaling_per_output_channel + scaling_per_output = (this << 1).scaling_per_output proxy_module = (this << 1).proxy_module mse_init_op = AbsMax stats_impl = MSE @@ -440,7 +446,7 @@ class MSESymmetricScaleSubInjector(MSESubInjectorBase): class MSEAsymmetricScaleSubInjector(MSESubInjectorBase): - per_channel = (this << 1).scaling_per_output_channel + scaling_per_output = (this << 1).scaling_per_output proxy_module = (this << 1).proxy_module mse_init_op = AbsMinMax stats_impl = MSE @@ -451,7 +457,7 @@ class MSEAsymmetricScaleSubInjector(MSESubInjectorBase): class MSEZeroPointSubInjector(MSESubInjectorBase): # zp is per channel when scaling is per channel - per_channel = (this << 1).scaling_per_output_channel + scaling_per_output = (this << 1).scaling_per_output proxy_module = (this << 1).proxy_module mse_init_op = NegativeMinOrZero mse_search_method = 'grid' diff --git a/src/brevitas/quant/experimental/float.py b/src/brevitas/quant/experimental/float.py index 8e86936c2..a083d9b0c 100644 --- a/src/brevitas/quant/experimental/float.py +++ b/src/brevitas/quant/experimental/float.py @@ -1,6 +1,7 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from brevitas.inject.enum import ScalingPerOutputType from brevitas.quant.base import MSESymmetricScale from brevitas.quant.experimental.float_base import FloatActBase from brevitas.quant.experimental.float_base import FloatWeightBase @@ -42,49 +43,49 @@ class Fp8e4m3WeightPerTensorFloat(Fp8e4m3Mixin, ScaledFloatWeightBase): """ FP8 signed E3M4 weight quantizer with per-tensor absmax-based scaling. """ - scaling_per_output_channel = False + scaling_per_output_type = ScalingPerOutputType.TENSOR class Fp8e5m2WeightPerTensorFloat(Fp8e5m2Mixin, ScaledFloatWeightBase): """ FP8 signed E5M2 weight quantizer with per-tensor absmax-based scaling. """ - scaling_per_output_channel = False + scaling_per_output_type = ScalingPerOutputType.TENSOR class Fp8e4m3ActPerTensorFloat(Fp8e4m3Mixin, ScaledFloatActBase): """ FP8 signed E4M3 activation quantizer with per-tensor static percentile-based scaling. """ - scaling_per_output_channel = False + scaling_per_output_type = ScalingPerOutputType.TENSOR class Fp8e5m2ActPerTensorFloat(Fp8e5m2Mixin, ScaledFloatActBase): """ FP8 signed E5M2 activation quantizer with per-tensor static percentile-based scaling. """ - scaling_per_output_channel = False + scaling_per_output_type = ScalingPerOutputType.TENSOR class Fp8e4m3WeightPerChannelFloat(Fp8e4m3Mixin, ScaledFloatWeightBase): """ FP8 signed E3M4 weight quantizer with per-channel absmax-based scaling. """ - scaling_per_output_channel = True + scaling_per_output_type = ScalingPerOutputType.CHANNEL class Fp8e5m2WeightPerChannelFloat(Fp8e5m2Mixin, ScaledFloatWeightBase): """ FP8 signed E5M2 weight quantizer with per-channel absmax-based scaling. """ - scaling_per_output_channel = True + scaling_per_output_type = ScalingPerOutputType.CHANNEL class Fp8e4m3ActPerChannelFloat2d(Fp8e4m3Mixin, ScaledFloatActBase): """ FP8 signed E4M3 activation quantizer with per-channel static percentile-based scaling. """ - scaling_per_output_channel = True + scaling_per_output_type = ScalingPerOutputType.CHANNEL scaling_stats_permute_dims = (1, 0, 2, 3) @@ -92,7 +93,7 @@ class Fp8e5m2ActPerChannelFloat2d(Fp8e5m2Mixin, ScaledFloatActBase): """ FP8 signed E5M2 activation quantizer with per-channel static percentile-based scaling. """ - scaling_per_output_channel = True + scaling_per_output_type = ScalingPerOutputType.CHANNEL scaling_stats_permute_dims = (1, 0, 2, 3) @@ -100,21 +101,21 @@ class Fp8e4m3ActPerTensorFloatMSE(Fp8e4m3Mixin, MSESymmetricScale, ScaledFloatAc """ FP8 signed E4M3 activation quantizer with per-tensor static MSE-based scaling. """ - scaling_per_output_channel = False + scaling_per_output_type = ScalingPerOutputType.TENSOR class Fp8e5m2ActPerTensorFloatMSE(Fp8e5m2Mixin, MSESymmetricScale, ScaledFloatActBase): """ FP8 signed E5M2 activation quantizer with per-tensor static MSE-based scaling. """ - scaling_per_output_channel = False + scaling_per_output_type = ScalingPerOutputType.TENSOR class Fp8e4m3ActPerChannelFloat2dMSE(Fp8e4m3Mixin, MSESymmetricScale, ScaledFloatActBase): """ FP8 signed E4M3 activation quantizer with per-channel static MSE-based scaling. """ - scaling_per_output_channel = True + scaling_per_output_type = ScalingPerOutputType.CHANNEL scaling_stats_permute_dims = (1, 0, 2, 3) @@ -122,7 +123,7 @@ class Fp8e5m2ActPerChannelFloat2dMSE(Fp8e5m2Mixin, MSESymmetricScale, ScaledFloa """ FP8 signed E5M2 activation quantizer with per-channel static MSE-based scaling. """ - scaling_per_output_channel = True + scaling_per_output_type = ScalingPerOutputType.CHANNEL scaling_stats_permute_dims = (1, 0, 2, 3) @@ -130,11 +131,11 @@ class Fp8e4m3WeightPerChannelFloatMSE(Fp8e4m3Mixin, MSESymmetricScale, ScaledFlo """ FP8 signed E3M4 weight quantizer with per-channel MSE-based scaling. """ - scaling_per_output_channel = True + scaling_per_output_type = ScalingPerOutputType.CHANNEL class Fp8e4m3WeightPerTensorFloatMSE(Fp8e4m3Mixin, MSESymmetricScale, ScaledFloatWeightBase): """ FP8 signed E3M4 weight quantizer with per-tensor MSE-based scaling. """ - scaling_per_output_channel = False + scaling_per_output_type = ScalingPerOutputType.TENSOR diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index 1b7191037..e17791841 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -62,3 +62,24 @@ class Fp8e5m2Mixin(ExtendedInjector): exponent_bit_width = 5 mantissa_bit_width = 2 saturating = True + + +class Fp6e3m2Mixin(ExtendedInjector): + bit_width = 6 + exponent_bit_width = 3 + mantissa_bit_width = 2 + saturating = True + + +class Fp6e2m3Mixin(ExtendedInjector): + bit_width = 6 + exponent_bit_width = 2 + mantissa_bit_width = 3 + saturating = True + + +class Fp4e2m1Mixin(ExtendedInjector): + bit_width = 4 + exponent_bit_width = 2 + mantissa_bit_width = 1 + saturating = True diff --git a/src/brevitas/quant/experimental/float_quant_fnuz.py b/src/brevitas/quant/experimental/float_quant_fnuz.py index 7d7035cb6..354a9ff21 100644 --- a/src/brevitas/quant/experimental/float_quant_fnuz.py +++ b/src/brevitas/quant/experimental/float_quant_fnuz.py @@ -3,6 +3,8 @@ from dependencies import value +from brevitas.inject import ExtendedInjector +from brevitas.inject.enum import ScalingPerOutputType from brevitas.quant.base import MSESymmetricScale from brevitas.quant.experimental.float_base import FloatActBase from brevitas.quant.experimental.float_base import FloatWeightBase @@ -12,152 +14,196 @@ from brevitas.quant.experimental.float_base import ScaledFloatWeightBase -class Fp8e4m3FNUZMixin(Fp8e4m3Mixin): - nan_values = None - inf_values = None +class FpFNUZMixin(ExtendedInjector): + saturating = True @value def exponent_bias(exponent_bit_width): return 2 ** (exponent_bit_width - 1) -class Fp8e5m2FNUZMixin(Fp8e5m2Mixin): - nan_values = None - inf_values = None - - @value - def exponent_bias(exponent_bit_width): - return 2 ** (exponent_bit_width - 1) +class FpFNUZWeight(FpFNUZMixin, FloatWeightBase): + """ + FNUZ FP8 signed weight quantizer. + """ + pass -class Fp8e4m3FNUZWeight(Fp8e4m3FNUZMixin, FloatWeightBase): +class FpFNUZAct(FpFNUZMixin, FloatActBase): """ - FP8 signed E3M4 weight quantizer. + FP8 signed activation quantizer. """ pass -class Fp8e5m2FNUZWeight(Fp8e5m2FNUZMixin, FloatWeightBase): +class FpFNUZWeightPerTensorFloat(FpFNUZMixin, ScaledFloatWeightBase): """ - FP8 signed E5M2 weight quantizer. + FP8 signed weight quantizer with per-tensor absmax-based scaling. """ - pass + scaling_per_output_type = ScalingPerOutputType.TENSOR + + +class FpFNUZActPerTensorFloat(FpFNUZMixin, ScaledFloatActBase): + """ + FP8 signed activation quantizer with per-tensor static percentile-based scaling. + """ + scaling_per_output_type = ScalingPerOutputType.TENSOR + + +class FpFNUZWeightPerChannelFloat(FpFNUZMixin, ScaledFloatWeightBase): + """ + FP8 signed weight quantizer with per-channel absmax-based scaling. + """ + scaling_per_output_type = ScalingPerOutputType.CHANNEL + + +class FpFNUZActPerChannelFloat2d(FpFNUZMixin, ScaledFloatActBase): + """ + FP8 signed activation quantizer with per-channel static percentile-based scaling. + """ + scaling_per_output_type = ScalingPerOutputType.CHANNEL + scaling_stats_permute_dims = (1, 0, 2, 3) + + +class FpFNUZActPerTensorFloatMSE(FpFNUZMixin, MSESymmetricScale, ScaledFloatActBase): + """ + FP8 signed activation quantizer with per-tensor static MSE-based scaling. + """ + scaling_per_output_type = ScalingPerOutputType.TENSOR + + +class FpFNUZActPerChannelFloat2dMSE(FpFNUZMixin, MSESymmetricScale, ScaledFloatActBase): + """ + FP8 signed activation quantizer with per-channel static MSE-based scaling. + """ + scaling_per_output_type = ScalingPerOutputType.CHANNEL + scaling_stats_permute_dims = (1, 0, 2, 3) + + +class FpFNUZWeightPerChannelFloatMSE(FpFNUZMixin, MSESymmetricScale, ScaledFloatWeightBase): + """ + FP8 signed weight quantizer with per-channel MSE-based scaling. + """ + scaling_per_output_type = ScalingPerOutputType.CHANNEL -class Fp8e4m3FNUZAct(Fp8e4m3FNUZMixin, FloatActBase): +class FpFNUZWeightPerTensorFloatMSE(FpFNUZMixin, MSESymmetricScale, ScaledFloatWeightBase): """ - FP8 signed E4M3 activation quantizer. + FP8 signed weight quantizer with per-tensor MSE-based scaling. + """ + scaling_per_output_type = ScalingPerOutputType.TENSOR + + +## Predefined FP8 Quantizers + + +class Fp8e4m3FNUZWeight(FpFNUZWeight, Fp8e4m3Mixin): + """ + FNUZ FP8 E4M3 signed weight quantizer. """ pass -class Fp8e5m2FNUZAct(Fp8e5m2FNUZMixin, FloatActBase): +class Fp8e4m3FNUZAct(FpFNUZAct, Fp8e4m3Mixin): """ - FP8 signed E5M2 activation quantizer. + FNUZ FP8 E4M3 signed act quantizer. """ pass -class Fp8e4m3FNUZWeightPerTensorFloat(Fp8e4m3FNUZMixin, ScaledFloatWeightBase): +class Fp8e4m3FNUZWeightPerTensorFloat(FpFNUZWeightPerTensorFloat, Fp8e4m3Mixin): """ - FP8 signed E3M4 weight quantizer with per-tensor absmax-based scaling. + FNUZ FP8 E4M3 per-tensor scaled signed weight quantizer. """ - scaling_per_output_channel = False + pass -class Fp8e5m2FNUZWeightPerTensorFloat(Fp8e5m2FNUZMixin, ScaledFloatWeightBase): +class Fp8e4m3FNUZWeightPerChannelFloat(FpFNUZWeightPerChannelFloat, Fp8e4m3Mixin): """ - FP8 signed E5M2 weight quantizer with per-tensor absmax-based scaling. + FNUZ FP8 E4M3 per-channel scaled signed weight quantizer. """ - scaling_per_output_channel = False + pass -class Fp8e4m3FNUZActPerTensorFloat(Fp8e4m3FNUZMixin, ScaledFloatActBase): +class Fp8e4m3FNUZActPerTensorFloat(FpFNUZActPerTensorFloat, Fp8e4m3Mixin): """ - FP8 signed E4M3 activation quantizer with per-tensor static percentile-based scaling. + FNUZ FP8 E4M3 scaled signed act quantizer. """ - scaling_per_output_channel = False + pass -class Fp8e5m2FNUZActPerTensorFloat(Fp8e5m2FNUZMixin, ScaledFloatActBase): +class Fp8e4m3FNUZActPerTensorFloatMSE(FpFNUZActPerTensorFloatMSE, Fp8e4m3Mixin): """ - FP8 signed E5M2 activation quantizer with per-tensor static percentile-based scaling. + FNUZ FP8 E4M3 MSE-based scaled signed act quantizer. """ - scaling_per_output_channel = False + pass -class Fp8e4m3FNUZWeightPerChannelFloat(Fp8e4m3FNUZMixin, ScaledFloatWeightBase): +class Fp8e4m3FNUZWeightPerTensorFloatMSE(FpFNUZWeightPerTensorFloatMSE, Fp8e4m3Mixin): """ - FP8 signed E3M4 weight quantizer with per-channel absmax-based scaling. + FNUZ FP8 E4M3 MSE-based per-tensor scaled signed weight quantizer. """ - scaling_per_output_channel = True + pass -class Fp8e5m2FNUZWeightPerChannelFloat(Fp8e5m2FNUZMixin, ScaledFloatWeightBase): +class Fp8e4m3FNUZWeightPerChannelFloatMSE(FpFNUZWeightPerChannelFloatMSE, Fp8e4m3Mixin): """ - FP8 signed E5M2 weight quantizer with per-channel absmax-based scaling. + FNUZ FP8 E4M3 MSE-based per-channel scaled signed weight quantizer. """ - scaling_per_output_channel = True + pass -class Fp8e4m3FNUZActPerChannelFloat2d(Fp8e4m3FNUZMixin, ScaledFloatActBase): +class Fp8e5m2FNUZWeight(FpFNUZWeight, Fp8e5m2Mixin): """ - FP8 signed E4M3 activation quantizer with per-channel static percentile-based scaling. + FNUZ FP8 e5m2 signed weight quantizer. """ - scaling_per_output_channel = True - scaling_stats_permute_dims = (1, 0, 2, 3) + pass -class Fp8e5m2FNUZActPerChannelFloat2d(Fp8e5m2FNUZMixin, ScaledFloatActBase): +class Fp8e5m2FNUZAct(FpFNUZAct, Fp8e5m2Mixin): """ - FP8 signed E5M2 activation quantizer with per-channel static percentile-based scaling. + FNUZ FP8 e5m2 signed act quantizer. """ - scaling_per_output_channel = True - scaling_stats_permute_dims = (1, 0, 2, 3) + pass -class Fp8e4m3FNUZActPerTensorFloatMSE(Fp8e4m3FNUZMixin, MSESymmetricScale, ScaledFloatActBase): +class Fp8e5m2FNUZWeightPerTensorFloat(FpFNUZWeightPerTensorFloat, Fp8e5m2Mixin): """ - FP8 signed E4M3 activation quantizer with per-tensor static MSE-based scaling. + FNUZ FP8 e5m2 per-tensor scaled signed weight quantizer. """ - scaling_per_output_channel = False + pass -class Fp8e5m2FNUZActPerTensorFloatMSE(Fp8e5m2FNUZMixin, MSESymmetricScale, ScaledFloatActBase): +class Fp8e5m2FNUZWeightPerChannelFloat(FpFNUZWeightPerChannelFloat, Fp8e5m2Mixin): """ - FP8 signed E5M2 activation quantizer with per-tensor static MSE-based scaling. + FNUZ FP8 e5m2 per-channel scaled signed weight quantizer. """ - scaling_per_output_channel = False + pass -class Fp8e4m3FNUZActPerChannelFloat2dMSE(Fp8e4m3FNUZMixin, MSESymmetricScale, ScaledFloatActBase): +class Fp8e5m2FNUZActPerTensorFloat(FpFNUZActPerTensorFloat, Fp8e5m2Mixin): """ - FP8 signed E4M3 activation quantizer with per-channel static MSE-based scaling. + FNUZ FP8 e5m2 scaled signed act quantizer. """ - scaling_per_output_channel = True - scaling_stats_permute_dims = (1, 0, 2, 3) + pass -class Fp8e5m2FNUZActPerChannelFloat2dMSE(Fp8e5m2FNUZMixin, MSESymmetricScale, ScaledFloatActBase): +class Fp8e5m2FNUZActPerTensorFloatMSE(FpFNUZActPerTensorFloatMSE, Fp8e5m2Mixin): """ - FP8 signed E5M2 activation quantizer with per-channel static MSE-based scaling. + FNUZ FP8 e5m2 MSE-based scaled signed act quantizer. """ - scaling_per_output_channel = True - scaling_stats_permute_dims = (1, 0, 2, 3) + pass -class Fp8e4m3FNUZWeightPerChannelFloatMSE(Fp8e4m3FNUZMixin, - MSESymmetricScale, - ScaledFloatWeightBase): +class Fp8e5m2FNUZWeightPerTensorFloatMSE(FpFNUZWeightPerTensorFloatMSE, Fp8e5m2Mixin): """ - FP8 signed E3M4 weight quantizer with per-channel MSE-based scaling. + FNUZ FP8 e5m2 MSE-based per-tensor scaled signed weight quantizer. """ - scaling_per_output_channel = True + pass -class Fp8e4m3FNUZWeightPerTensorFloatMSE(Fp8e4m3FNUZMixin, MSESymmetricScale, - ScaledFloatWeightBase): +class Fp8e5m2FNUZWeightPerChannelFloatMSE(FpFNUZWeightPerChannelFloatMSE, Fp8e5m2Mixin): """ - FP8 signed E3M4 weight quantizer with per-tensor MSE-based scaling. + FNUZ FP8 e5m2 MSE-based per-channel scaled signed weight quantizer. """ - scaling_per_output_channel = False + pass diff --git a/src/brevitas/quant/experimental/float_quant_ocp.py b/src/brevitas/quant/experimental/float_quant_ocp.py index f2b148482..9b336cc91 100644 --- a/src/brevitas/quant/experimental/float_quant_ocp.py +++ b/src/brevitas/quant/experimental/float_quant_ocp.py @@ -3,6 +3,8 @@ from dependencies import value +from brevitas.inject import ExtendedInjector +from brevitas.inject.enum import ScalingPerOutputType from brevitas.quant.base import MSESymmetricScale from brevitas.quant.experimental.float_base import FloatActBase from brevitas.quant.experimental.float_base import FloatWeightBase @@ -13,26 +15,28 @@ from brevitas.utils.float_quant_utils import get_max_available_float -class Fp8e4m3OCPMixin(Fp8e4m3Mixin): - nan_values = (('111',)) - inf_values = None +class FpOCPMixin(ExtendedInjector): + saturating = True @value - def max_available_float( - exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values, - saturating): - return get_max_available_float( - exponent_bit_width, - mantissa_bit_width, - exponent_bias, - nan_values, - inf_values, - saturating) - + def inf_values(bit_width, mantissa_bit_width, exponent_bit_width): + if bit_width == 8: + if mantissa_bit_width == 3 and exponent_bit_width == 4: + return None + if mantissa_bit_width == 2 and exponent_bit_width == 5: + return (('00',)) + else: + return None -class Fp8e5m2OCPMixin(Fp8e5m2Mixin): - nan_values = ('01', '11', '10') - inf_values = (('00',)) + @value + def nan_values(bit_width, mantissa_bit_width, exponent_bit_width): + if bit_width == 8: + if mantissa_bit_width == 3 and exponent_bit_width == 4: + return (('111',)) + if mantissa_bit_width == 2 and exponent_bit_width == 5: + return ('01', '11', '10') + else: + return None @value def max_available_float( @@ -47,131 +51,188 @@ def max_available_float( saturating) -class Fp8e4m3OCPWeight(Fp8e4m3OCPMixin, FloatWeightBase): +class FpOCPWeight(FpOCPMixin, FloatWeightBase): """ - FP8 signed E3M4 weight quantizer. + OCP FP signed weight quantizer. """ pass -class Fp8e5m2OCPWeight(Fp8e5m2OCPMixin, FloatWeightBase): +class FpOCPAct(FpOCPMixin, FloatActBase): """ - FP8 signed E5M2 weight quantizer. + OCP FP signed activation quantizer. """ pass -class Fp8e4m3OCPAct(Fp8e4m3OCPMixin, FloatActBase): +class FpOCPWeightPerTensorFloat(FpOCPMixin, ScaledFloatWeightBase): + """ + OCP FP signed E3M4 weight quantizer with per-tensor absmax-based scaling. + """ + scaling_per_output_type = ScalingPerOutputType.TENSOR + + +class FpOCPActPerTensorFloat(FpOCPMixin, ScaledFloatActBase): + """ + OCP FP signed activation quantizer with per-tensor static percentile-based scaling. + """ + scaling_per_output_type = ScalingPerOutputType.TENSOR + + +class FpOCPWeightPerChannelFloat(FpOCPMixin, ScaledFloatWeightBase): + """ + OCP FP signed E3M4 weight quantizer with per-channel absmax-based scaling. + """ + scaling_per_output_type = ScalingPerOutputType.CHANNEL + + +class FpOCPActPerChannelFloat2d(FpOCPMixin, ScaledFloatActBase): + """ + OCP FP signed activation quantizer with per-channel static percentile-based scaling. + """ + scaling_per_output_type = ScalingPerOutputType.CHANNEL + scaling_stats_permute_dims = (1, 0, 2, 3) + + +class FpOCPActPerTensorFloatMSE(FpOCPMixin, MSESymmetricScale, ScaledFloatActBase): + """ + OCP FP signed activation quantizer with per-tensor static MSE-based scaling. + """ + scaling_per_output_type = ScalingPerOutputType.TENSOR + + +class FpOCPActPerChannelFloat2dMSE(FpOCPMixin, MSESymmetricScale, ScaledFloatActBase): """ - FP8 signed E4M3 activation quantizer. + OCP FP signed activation quantizer with per-channel static MSE-based scaling. + """ + scaling_per_output_type = ScalingPerOutputType.CHANNEL + scaling_stats_permute_dims = (1, 0, 2, 3) + + +class FpOCPWeightPerChannelFloatMSE(FpOCPMixin, MSESymmetricScale, ScaledFloatWeightBase): + """ + OCP FP signed E3M4 weight quantizer with per-channel MSE-based scaling. + """ + scaling_per_output_type = ScalingPerOutputType.CHANNEL + + +class FpOCPWeightPerTensorFloatMSE(FpOCPMixin, MSESymmetricScale, ScaledFloatWeightBase): + """ + OCP FP signed E3M4 weight quantizer with per-tensor MSE-based scaling. + """ + scaling_per_output_type = ScalingPerOutputType.TENSOR + + +## Predefined FP8 Quantizers + + +class Fp8e4m3OCPWeight(FpOCPWeight, Fp8e4m3Mixin): + """ + OCP FP8 E4M3 signed weight quantizer. """ pass -class Fp8e5m2OCPAct(Fp8e5m2OCPMixin, FloatActBase): +class Fp8e4m3OCPAct(FpOCPAct, Fp8e4m3Mixin): """ - FP8 signed E5M2 activation quantizer. + OCP FP8 E4M3 signed act quantizer. """ pass -class Fp8e4m3OCPWeightPerTensorFloat(Fp8e4m3OCPMixin, ScaledFloatWeightBase): +class Fp8e4m3OCPWeightPerTensorFloat(FpOCPWeightPerTensorFloat, Fp8e4m3Mixin): """ - FP8 signed E3M4 weight quantizer with per-tensor absmax-based scaling. + OCP FP8 E4M3 per-tensor scaled signed weight quantizer. """ - scaling_per_output_channel = False + pass -class Fp8e5m2OCPWeightPerTensorFloat(Fp8e5m2OCPMixin, ScaledFloatWeightBase): +class Fp8e4m3OCPWeightPerChannelFloat(FpOCPWeightPerChannelFloat, Fp8e4m3Mixin): """ - FP8 signed E5M2 weight quantizer with per-tensor absmax-based scaling. + OCP FP8 E4M3 per-channel scaled signed weight quantizer. """ - scaling_per_output_channel = False + pass -class Fp8e4m3OCPActPerTensorFloat(Fp8e4m3OCPMixin, ScaledFloatActBase): +class Fp8e4m3OCPActPerTensorFloat(FpOCPActPerTensorFloat, Fp8e4m3Mixin): """ - FP8 signed E4M3 activation quantizer with per-tensor static percentile-based scaling. + OCP FP8 E4M3 scaled signed act quantizer. """ - scaling_per_output_channel = False + pass -class Fp8e5m2OCPActPerTensorFloat(Fp8e5m2OCPMixin, ScaledFloatActBase): +class Fp8e4m3OCPActPerTensorFloatMSE(FpOCPActPerTensorFloatMSE, Fp8e4m3Mixin): """ - FP8 signed E5M2 activation quantizer with per-tensor static percentile-based scaling. + OCP FP8 E4M3 MSE-based scaled signed act quantizer. """ - scaling_per_output_channel = False + pass -class Fp8e4m3OCPWeightPerChannelFloat(Fp8e4m3OCPMixin, ScaledFloatWeightBase): +class Fp8e4m3OCPWeightPerTensorFloatMSE(FpOCPWeightPerTensorFloatMSE, Fp8e4m3Mixin): """ - FP8 signed E3M4 weight quantizer with per-channel absmax-based scaling. + OCP FP8 E4M3 MSE-based per-tensor scaled signed weight quantizer. """ - scaling_per_output_channel = True + pass -class Fp8e5m2OCPWeightPerChannelFloat(Fp8e5m2OCPMixin, ScaledFloatWeightBase): +class Fp8e4m3OCPWeightPerChannelFloatMSE(FpOCPWeightPerChannelFloatMSE, Fp8e4m3Mixin): """ - FP8 signed E5M2 weight quantizer with per-channel absmax-based scaling. + OCP FP8 E4M3 MSE-based per-channel scaled signed weight quantizer. """ - scaling_per_output_channel = True + pass -class Fp8e4m3OCPActPerChannelFloat2d(Fp8e4m3OCPMixin, ScaledFloatActBase): +class Fp8e5m2OCPWeight(FpOCPWeight, Fp8e5m2Mixin): """ - FP8 signed E4M3 activation quantizer with per-channel static percentile-based scaling. + OCP FP8 e5m2 signed weight quantizer. """ - scaling_per_output_channel = True - scaling_stats_permute_dims = (1, 0, 2, 3) + pass -class Fp8e5m2OCPActPerChannelFloat2d(Fp8e5m2OCPMixin, ScaledFloatActBase): +class Fp8e5m2OCPAct(FpOCPAct, Fp8e5m2Mixin): """ - FP8 signed E5M2 activation quantizer with per-channel static percentile-based scaling. + OCP FP8 e5m2 signed act quantizer. """ - scaling_per_output_channel = True - scaling_stats_permute_dims = (1, 0, 2, 3) + pass -class Fp8e4m3OCPActPerTensorFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatActBase): +class Fp8e5m2OCPWeightPerTensorFloat(FpOCPWeightPerTensorFloat, Fp8e5m2Mixin): """ - FP8 signed E4M3 activation quantizer with per-tensor static MSE-based scaling. + OCP FP8 e5m2 per-tensor scaled signed weight quantizer. """ - scaling_per_output_channel = False + pass -class Fp8e5m2OCPActPerTensorFloatMSE(Fp8e5m2OCPMixin, MSESymmetricScale, ScaledFloatActBase): +class Fp8e5m2OCPWeightPerChannelFloat(FpOCPWeightPerChannelFloat, Fp8e5m2Mixin): """ - FP8 signed E5M2 activation quantizer with per-tensor static MSE-based scaling. + OCP FP8 e5m2 per-channel scaled signed weight quantizer. """ - scaling_per_output_channel = False + pass -class Fp8e4m3OCPActPerChannelFloat2dMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatActBase): +class Fp8e5m2OCPActPerTensorFloat(FpOCPActPerTensorFloat, Fp8e5m2Mixin): """ - FP8 signed E4M3 activation quantizer with per-channel static MSE-based scaling. + OCP FP8 e5m2 scaled signed act quantizer. """ - scaling_per_output_channel = True - scaling_stats_permute_dims = (1, 0, 2, 3) + pass -class Fp8e5m2OCPActPerChannelFloat2dMSE(Fp8e5m2OCPMixin, MSESymmetricScale, ScaledFloatActBase): +class Fp8e5m2OCPActPerTensorFloatMSE(FpOCPActPerTensorFloatMSE, Fp8e5m2Mixin): """ - FP8 signed E5M2 activation quantizer with per-channel static MSE-based scaling. + OCP FP8 e5m2 MSE-based scaled signed act quantizer. """ - scaling_per_output_channel = True - scaling_stats_permute_dims = (1, 0, 2, 3) + pass -class Fp8e4m3OCPWeightPerChannelFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatWeightBase): +class Fp8e5m2OCPWeightPerTensorFloatMSE(FpOCPWeightPerTensorFloatMSE, Fp8e5m2Mixin): """ - FP8 signed E3M4 weight quantizer with per-channel MSE-based scaling. + OCP FP8 e5m2 MSE-based per-tensor scaled signed weight quantizer. """ - scaling_per_output_channel = True + pass -class Fp8e4m3OCPWeightPerTensorFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatWeightBase): +class Fp8e5m2OCPWeightPerChannelFloatMSE(FpOCPWeightPerChannelFloatMSE, Fp8e5m2Mixin): """ - FP8 signed E3M4 weight quantizer with per-tensor MSE-based scaling. + OCP FP8 e5m2 MSE-based per-channel scaled signed weight quantizer. """ - scaling_per_output_channel = False + pass diff --git a/src/brevitas/quant/experimental/mx_quant_ocp.py b/src/brevitas/quant/experimental/mx_quant_ocp.py new file mode 100644 index 000000000..6e61d078c --- /dev/null +++ b/src/brevitas/quant/experimental/mx_quant_ocp.py @@ -0,0 +1,137 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from dependencies import value + +from brevitas.core.function_wrapper.ops_ste import CeilSte +from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling +from brevitas.inject import ExtendedInjector +from brevitas.inject.enum import RestrictValueType +from brevitas.inject.enum import ScalingPerOutputType +from brevitas.proxy.groupwise_float_parameter_quant import \ + GroupwiseWeightFloatQuantProxyFromInjector +from brevitas.proxy.groupwise_float_runtime_quant import GroupwiseActFloatQuantProxyFromInjector +from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector +from brevitas.proxy.groupwise_int_runtime_quant import GroupwiseActQuantProxyFromInjector +from brevitas.quant.base import IntQuant +from brevitas.quant.base import MaxStatsScaling +from brevitas.quant.base import MinMaxStatsScaling +from brevitas.quant.base import MSEAsymmetricScale +from brevitas.quant.base import MSESymmetricScale +from brevitas.quant.base import ShiftedMinUintQuant +from brevitas.quant.experimental.float_base import ScaledFloatActBase +from brevitas.quant.experimental.float_base import ScaledFloatWeightBase +from brevitas.quant.experimental.float_quant_ocp import FpOCPAct +from brevitas.quant.experimental.float_quant_ocp import FpOCPWeight +from brevitas.quant.solver.act import ActQuantSolver +from brevitas.quant.solver.weight import WeightQuantSolver + + +class GroupwiseWeightFloatProxyMixin(ExtendedInjector): + proxy_class = GroupwiseWeightFloatQuantProxyFromInjector + + +class GroupwiseActFloatProxyMixin(ExtendedInjector): + proxy_class = GroupwiseActFloatQuantProxyFromInjector + + +class GroupwiseWeightProxyMixin(ExtendedInjector): + proxy_class = GroupwiseWeightQuantProxyFromInjector + + +class GroupwiseActProxyMixin(ExtendedInjector): + proxy_class = GroupwiseActQuantProxyFromInjector + + +class MXWeightMixin(ExtendedInjector): + group_size = 32 + restrict_scaling_type = RestrictValueType.POWER_OF_TWO + restrict_value_float_to_int_impl = CeilSte + scaling_per_output_type = ScalingPerOutputType.GROUP + + +class MXActMixin(ExtendedInjector): + group_size = 32 + restrict_scaling_type = RestrictValueType.POWER_OF_TWO + restrict_value_float_to_int_impl = CeilSte + scaling_impl = RuntimeDynamicGroupStatsScaling + scaling_per_output_type = ScalingPerOutputType.GROUP + + @value + def stats_reduce_dim(group_dim): + # If group_dim = -1, we need a workaround to avoid selecting wrong dim + if group_dim == -1: + return -1 + else: + return group_dim + 1 + + +class MXFloat8e4m3Weight(MXWeightMixin, + GroupwiseWeightFloatProxyMixin, + FpOCPWeight, + ScaledFloatWeightBase): + """ + MX Float signed weight quantizer. + """ + bit_width = 8 + exponent_bit_width = 4 + mantissa_bit_width = 3 + + +class MXFloat8e4m3Act(MXActMixin, GroupwiseActFloatProxyMixin, FpOCPAct, ScaledFloatActBase): + """ + MX Float signed activation quantizer. + """ + bit_width = 8 + exponent_bit_width = 4 + mantissa_bit_width = 3 + + +class MXFloat8e4m3WeightMSE(MXFloat8e4m3Weight, MSESymmetricScale): + """ + MX Float signed weight quantizer with per-channel MSE-based scaling. + """ + pass + + +class MXInt8Weight(MXWeightMixin, + GroupwiseWeightProxyMixin, + IntQuant, + MaxStatsScaling, + WeightQuantSolver): + """ + MX Int signed weight quantizer. + """ + bit_width = 8 + + +class ShiftedMXUInt8Weight(MXWeightMixin, + GroupwiseWeightProxyMixin, + ShiftedMinUintQuant, + MinMaxStatsScaling, + WeightQuantSolver): + """ + MX Int signed weight quantizer. + """ + bit_width = 8 + + +class MXInt8Act(MXActMixin, GroupwiseActProxyMixin, IntQuant, MaxStatsScaling, ActQuantSolver): + """ + MX Int signed activation quantizer. + """ + bit_width = 8 + + +class MXInt8WeightMSE(MXInt8Weight, MSESymmetricScale): + """ + MX Int signed weight quantizer with per-channel MSE-based scaling. + """ + pass + + +class ShiftedMXUInt8WeightMSE(ShiftedMXUInt8Weight, MSEAsymmetricScale): + """ + MX Int signed weight quantizer with per-channel MSE-based scaling. + """ + pass diff --git a/src/brevitas/quant/solver/act.py b/src/brevitas/quant/solver/act.py index 2f6c219c5..3149e75b9 100644 --- a/src/brevitas/quant/solver/act.py +++ b/src/brevitas/quant/solver/act.py @@ -18,6 +18,7 @@ from brevitas.inject import value from brevitas.inject.enum import QuantType from brevitas.inject.enum import ScalingImplType +from brevitas.inject.enum import ScalingPerOutputType from brevitas.proxy import ActQuantProxyFromInjector from brevitas.proxy.utils import ConvertRuntimeStatsToParameter from brevitas.quant.solver.common import * @@ -100,12 +101,12 @@ def min_val(signed): class SolveActScalingShape(ExtendedInjector): @value - def scaling_shape(scaling_per_output_channel): + def scaling_shape(scaling_per_output): # this pattern of returning this.something allows to resolve scaling_output_channel_shape # only when scaling_per_output_channel is True - if scaling_per_output_channel: + if scaling_per_output == ScalingPerOutputType.CHANNEL: return this.per_channel_broadcastable_shape - else: + elif scaling_per_output == ScalingPerOutputType.TENSOR: return SCALAR_SHAPE diff --git a/src/brevitas/quant/solver/common.py b/src/brevitas/quant/solver/common.py index 2e5e1e982..61eccc90b 100644 --- a/src/brevitas/quant/solver/common.py +++ b/src/brevitas/quant/solver/common.py @@ -12,6 +12,7 @@ from brevitas.core.restrict_val import * from brevitas.core.scaling import * from brevitas.core.scaling import ScalingImplType +from brevitas.core.scaling import ScalingPerOutputType from brevitas.core.stats import * from brevitas.inject import ExtendedInjector from brevitas.inject import value @@ -171,21 +172,43 @@ def int_scaling_impl(restrict_scaling_type): class SolveStatsReduceDimFromEnum(ExtendedInjector): @value - def stats_reduce_dim(scaling_stats_op, scaling_per_output_channel): - if scaling_stats_op == StatsOp.MAX_AVE or scaling_per_output_channel: + def stats_reduce_dim(scaling_stats_op, scaling_per_output): + if scaling_per_output == ScalingPerOutputType.CHANNEL or scaling_stats_op == StatsOp.MAX_AVE: return SCALING_STATS_REDUCE_DIM - else: + elif scaling_per_output == ScalingPerOutputType.TENSOR: return None + elif scaling_per_output == ScalingPerOutputType.GROUP: + return SCALING_STATS_REDUCE_DIM + 1 + + @value + def keepdim(scaling_per_output): + if scaling_per_output == ScalingPerOutputType.GROUP: + return True + else: + return False + + # Retrocompatibility. Priority is given to scaling_per_output_channel binary flag. + # We might want to check for discrepancies between the two and raise an error. + @value + def scaling_per_output(scaling_per_output_type=None, scaling_per_output_channel=None): + if scaling_per_output_channel is not None: + return ScalingPerOutputType.CHANNEL if scaling_per_output_channel else ScalingPerOutputType.TENSOR + elif scaling_per_output_type is not None: + return scaling_per_output_type + else: + raise RuntimeError("Specify scaling_per_output_type or scaling_per_output_channel") class SolveScalingStatsInputViewShapeImplFromEnum(ExtendedInjector): @value - def scaling_stats_input_view_shape_impl(scaling_per_output_channel, scaling_stats_op): - if scaling_per_output_channel or scaling_stats_op == StatsOp.MAX_AVE: + def scaling_stats_input_view_shape_impl(scaling_stats_op, scaling_per_output): + if scaling_per_output == ScalingPerOutputType.CHANNEL or scaling_stats_op == StatsOp.MAX_AVE: return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS - else: + elif scaling_per_output == ScalingPerOutputType.TENSOR: return StatsInputViewShapeImpl.OVER_TENSOR + elif scaling_per_output == ScalingPerOutputType.GROUP: + return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK @value def permute_dims(scaling_stats_permute_dims): diff --git a/src/brevitas/quant/solver/parameter.py b/src/brevitas/quant/solver/parameter.py index d924704c7..d8c655efa 100644 --- a/src/brevitas/quant/solver/parameter.py +++ b/src/brevitas/quant/solver/parameter.py @@ -14,6 +14,7 @@ from brevitas.core.function_wrapper import TensorClampSte from brevitas.core.scaling import * from brevitas.core.scaling import ScalingImplType +from brevitas.core.scaling import ScalingPerOutputType from brevitas.inject import ExtendedInjector from brevitas.quant.solver.common import * @@ -108,10 +109,33 @@ def scaling_impl(scaling_impl_type): class SolveParameterScalingShape(ExtendedInjector): @value - def scaling_shape(scaling_per_output_channel): - # this pattern of returning this.something allows to resolve scaling_output_channel_shape - # only when scaling_per_output_channel is True - if scaling_per_output_channel: - return this.scaling_per_output_channel_shape - else: + def scaling_shape(module, group_size=None, scaling_per_output=None): + if scaling_per_output == ScalingPerOutputType.TENSOR: return SCALAR_SHAPE + elif scaling_per_output == ScalingPerOutputType.CHANNEL: + return this.scaling_per_output_channel_shape + elif scaling_per_output == ScalingPerOutputType.GROUP: + assert group_size is not None, "Per Group scaling requires group size" + size = list(module.weight.shape) + assert size[1] % group_size == 0, 'Input channel is not divisible by group size' + size[1] = size[1] // group_size + size.insert(2, 1) + return size + + @value + def reshaped_scaling_shape(module): + return module.weight.shape + + @value + def expanded_scaling_shape(module, group_size=None): + assert group_size is not None, "Per Group scaling requires group size" + size = list(module.weight.shape) + assert size[1] % group_size == 0, 'Input channel is not divisible by group size' + size[1] = size[1] // group_size + size.insert(2, group_size) + return size + + @value + def group_dim(module, group_size=None): + if group_size is not None: + return 1 diff --git a/src/brevitas/quant/solver/weight.py b/src/brevitas/quant/solver/weight.py index 57f7dd8b4..097f65443 100644 --- a/src/brevitas/quant/solver/weight.py +++ b/src/brevitas/quant/solver/weight.py @@ -3,6 +3,7 @@ from brevitas.core.quant import * from brevitas.core.quant import QuantType +from brevitas.core.scaling import ScalingPerOutputType from brevitas.inject import ExtendedInjector from brevitas.inject import this from brevitas.inject import value @@ -62,11 +63,11 @@ class SolveWeightScalingStatsInputDimsFromModule(ExtendedInjector): # such that output channels are dim 0 and the remaining features are dim 1, # along which we concatenate @value - def scaling_stats_input_concat_dim(scaling_per_output_channel): - if scaling_per_output_channel: - return 1 - else: + def scaling_stats_input_concat_dim(scaling_per_output): + if scaling_per_output == ScalingPerOutputType.TENSOR: return 0 + elif scaling_per_output == ScalingPerOutputType.CHANNEL: + return 1 @value def permute_dims(module, output_channel_dim): @@ -88,9 +89,9 @@ def output_channel_dim(module): return module.output_channel_dim -class WeightQuantSolver(SolveWeightScalingStatsInputDimsFromModule, +class WeightQuantSolver(SolveStatsReduceDimFromEnum, + SolveWeightScalingStatsInputDimsFromModule, SolveScalingStatsInputViewShapeImplFromEnum, - SolveStatsReduceDimFromEnum, SolveScalingStatsOpFromEnum, SolveBitWidthImplFromEnum, SolveTensorQuantFloatToIntImplFromEnum, diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index be7e10bc0..d13facba8 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -4,4 +4,6 @@ from .base_quant_tensor import * from .base_quant_tensor import _unpack_quant_tensor from .float_quant_tensor import * +from .groupwise_float_quant_tensor import * +from .groupwise_int_quant_tensor import * from .int_quant_tensor import * diff --git a/src/brevitas/quant_tensor/base_quant_tensor.py b/src/brevitas/quant_tensor/base_quant_tensor.py index 6239b8324..7b5dcd597 100644 --- a/src/brevitas/quant_tensor/base_quant_tensor.py +++ b/src/brevitas/quant_tensor/base_quant_tensor.py @@ -115,6 +115,33 @@ class FloatQuantTensorBase(NamedTuple): training_t: Tensor +class GroupwiseFloatQuantTensorBase(NamedTuple): + value_: Tensor + scale_: Tensor + zero_point_: Tensor + group_size: Tensor + group_dim: Tensor + exponent_bit_width: Tensor + mantissa_bit_width: Tensor + exponent_bias: Tensor + saturating_t: Tensor + inf_values: List[str] + nan_values: List[str] + signed_t: Tensor + training_t: Tensor + + +class GroupwisIntQuantTensorBase(NamedTuple): + value_: Tensor + scale_: Tensor + zero_point_: Tensor + group_size: Tensor + group_dim: Tensor + bit_width: Tensor + signed_t: Tensor + training_t: Tensor + + def _unpack_quant_tensor(input_data): if isinstance(input_data, QuantTensor): return input_data.value diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index 74f42dc94..cf4ba1420 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -256,11 +256,11 @@ def cat(tensors, dim, out=None): exponent_bit_width=output_exponent_bit_width, mantissa_bit_width=output_mantissa_bit_width, exponent_bias=output_exponent_bias, - signed=output_signed, - training=output_training, saturating=output_saturating, inf_values=output_inf_values, - nan_values=output_nan_values) + nan_values=output_nan_values, + signed=output_signed, + training=output_training) else: tensors = [_unpack_quant_tensor(qt) for qt in tensors] output_value = torch.cat(tensors, dim=dim) @@ -280,11 +280,11 @@ def __neg__(self): exponent_bit_width=self.exponent_bit_width, mantissa_bit_width=self.mantissa_bit_width, exponent_bias=self.exponent_bias, - signed=self.signed, - training=self.training, saturating=self.saturating, inf_values=self.inf_values, - nan_values=self.nan_values) + nan_values=self.nan_values, + signed=self.signed, + training=self.training) else: # TODO: implement raise NotImplementedError @@ -304,7 +304,7 @@ def __mul__(self, other): return output def __str__(self): - return f"FloatQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})" + return f"FloatQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, exponent_bit_width={self.exponent_bit_width}, mantissa_bit_width={self.mantissa_bit_width}, exponent_bias={self.exponent_bias}, inf_values={self.inf_values}, nan_values={self.nan_values}, signed_t={self.signed_t}, training_t={self.training_t})" def __truediv__(self, other): if isinstance(other, QuantTensor): @@ -325,10 +325,11 @@ def __abs__(self): zero_point=self.zero_point, exponent_bit_width=self.exponent_bit_width, mantissa_bit_width=self.mantissa_bit_width, - signed=False, - training=self.training, + exponent_bias=self.exponent_bias, saturating=self.saturating, inf_values=self.inf_values, - nan_values=self.nan_values) + nan_values=self.nan_values, + signed=False, + training=self.training) else: return self diff --git a/src/brevitas/quant_tensor/float_torch_handler.py b/src/brevitas/quant_tensor/float_torch_handler.py index 7fb4507c1..60401fc4a 100644 --- a/src/brevitas/quant_tensor/float_torch_handler.py +++ b/src/brevitas/quant_tensor/float_torch_handler.py @@ -92,11 +92,11 @@ def embedding_handler(input, quant_weight, *args, **kwargs): exponent_bit_width, mantissa_bit_width, exponent_bias, - signed, - training, saturating, inf_values, - nan_values) + nan_values, + signed, + training) return out diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py new file mode 100644 index 000000000..7d73bf7de --- /dev/null +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -0,0 +1,330 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import torch + +from brevitas.quant_tensor import _unpack_quant_tensor +from brevitas.quant_tensor.base_quant_tensor import GroupwiseFloatQuantTensorBase +from brevitas.quant_tensor.base_quant_tensor import QuantTensor +from brevitas.utils.torch_utils import float_internal_scale + +from .float_torch_handler import FLOAT_QUANT_TENSOR_FN_HANDLER +from .torch_handler import QUANT_TENSOR_FN_HANDLER + +IS_VALID_ATOL = 2e-1 +BFLOAT16_IS_VALID_ATOL = 0.5 + + +class GroupwiseFloatQuantTensor(GroupwiseFloatQuantTensorBase, QuantTensor): + + def __new__( + cls, + value, + scale, + zero_point, + group_size, + group_dim, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + saturating, + inf_values, + nan_values, + signed, + training): + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, dtype=torch.float) + if not isinstance(zero_point, torch.Tensor): + zero_point = torch.tensor(zero_point, dtype=torch.float) + if not isinstance(exponent_bit_width, torch.Tensor): + exponent_bit_width = torch.tensor(exponent_bit_width, dtype=torch.float) + if not isinstance(mantissa_bit_width, torch.Tensor): + mantissa_bit_width = torch.tensor(mantissa_bit_width, dtype=torch.float) + if not isinstance(exponent_bias, torch.Tensor): + exponent_bias = torch.tensor(exponent_bias, dtype=torch.float) + if not isinstance(saturating, torch.Tensor): + saturating = torch.tensor(saturating, dtype=torch.bool) + if not isinstance(signed, torch.Tensor): + signed = torch.tensor(signed, dtype=torch.bool) + if not isinstance(training, torch.Tensor): + training = torch.tensor(training, dtype=torch.bool) + quant_tensor = super().__new__( + cls, + value, + scale, + zero_point, + group_size, + group_dim, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + saturating, + inf_values, + nan_values, + signed, + training) + return quant_tensor + + @property + def signed(self): + return self.signed_t.item() + + @property + def training(self): + return self.training_t.item() + + @property + def saturating(self): + return self.saturating_t.item() + + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func in QUANT_TENSOR_FN_HANDLER: + return QUANT_TENSOR_FN_HANDLER[func](*args, **kwargs) + else: + args = _unpack_quant_tensor(args) + kwargs = _unpack_quant_tensor(kwargs) + return func(*args, **kwargs) + + def expand(self): + curr_shape = self.value_.shape + new_value = self.value_.flatten(self.group_dim, self.group_dim + 1) + if self.scale_.shape != (): + new_scale = self.scale_.expand(curr_shape).flatten(self.group_dim, self.group_dim + 1) + else: + new_scale = self.scale_ + if self.zero_point_.shape != (): + new_zp = self.zero_point_.expand(curr_shape).flatten(self.group_dim, self.group_dim + 1) + else: + new_zp = self.zero_point_ + + return new_value, new_scale, new_zp + + @staticmethod + def from_expanded(value, group_size, group_dim, compress=False): + size = list(value.shape) + assert size[group_dim] % group_size == 0, 'Input channel is not divisible by group size' + if compress: + size[group_dim] = 1 + else: + size[group_dim] = size[group_dim] // group_size + size.insert(group_dim + 1, group_size) + new_value = value.view(size) + return new_value + + @property + def tensor(self): + return self.value + + @property + def value(self): + new_value, new_scale, new_zp = self.expand() + return new_value + + @property + def scale(self): + new_value, new_scale, new_zp = self.expand() + return new_scale + + @property + def zero_point(self): + new_value, new_scale, new_zp = self.expand() + return new_zp + + @property + def _pre_round_float_value(self): + value, scale, zp = self.expand() + if self.scale.dtype == torch.bfloat16: + value = value.type(torch.float32) + scale = scale.type(torch.float32) + minifloat_value = value / scale + fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width + int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale) + minifloat_value = minifloat_value / int_scale + return minifloat_value + + @property + def is_valid(self): + with torch.no_grad(): + pre_round_minifloat_value = self._pre_round_float_value + rounded_minifloat_value = torch.round(pre_round_minifloat_value) + max_abs_diff = torch.max(torch.abs(pre_round_minifloat_value - rounded_minifloat_value)) + atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL + is_minifloat = max_abs_diff < atol + # We are missing the checks about self being contained between max and min value + # given by mantissa, exponent, inf, nan, and saturating + return is_minifloat + + @property + def device(self): + value_device = self.value_.device + is_same_device = True + for t in [self.scale, + self.zero_point, + self.exponent_bit_width, + self.mantissa_bit_width, + self.exponent_bias]: + is_same_device &= value_device == t.device + if not is_same_device: + raise RuntimeError("Value and metadata are on different devices") + return value_device + + def minifloat(self, float_datatype=True): + # TODO: Check if OCP and cast to proper data-type if matching + assert float_datatype, "Minifloat quant returns only higher precision dtype" + + if self.is_valid: + fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width + int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale) + float_value = torch.round(self._pre_round_float_value) * int_scale + return float_value.type(self.scale.dtype) + else: + raise RuntimeError(f"FloatQuantTensor not valid.") + + @staticmethod + def check_input_type(tensor): + if not isinstance(tensor, GroupwiseFloatQuantTensor): + raise RuntimeError("Tensor is not a GroupwiseFloatQuantTensor") + + @staticmethod + def is_zero_zero_point(tensor): + GroupwiseFloatQuantTensor.check_input_type(tensor) + return (tensor.zero_point == 0.).all() + + def check_scaling_factors_same(self, other): + if self.training: + return True + if not torch.allclose(self.scale, other.scale): + raise RuntimeError("Scaling factors are different") + + def check_zero_points_same(self, other): + if self.training: + return True + if not torch.allclose(self.zero_point, other.zero_point): + raise RuntimeError("Zero points are different") + + def check_bit_width_same(self, other): + if not torch.allclose(self.exponent_bit_width, + other.exponent_bit_width) and not torch.allclose( + self.mantissa_bit_width, other.mantissa_bit_width): + raise RuntimeError("Bit widths are different") + + def check_exponent_bias(self, other): + if not torch.allclose(self.exponent_bias, other.exponent_bias): + raise RuntimeError("Bit widths are different") + + def check_inf_nan_same(self, other): + if not (set(self.inf_values) == set(other.inf_values)) and not (set(self.nan_values) == set( + other.nan_values)): + raise RuntimeError("Floating point representations are different") + + def check_sign_same(self, other): + if not self.signed == other.signed: + raise RuntimeError("Signs are different") + + def view(self, *args, **kwargs): + return self.value.view(*args, **kwargs) #self.set(value=self.value.view(*args, **kwargs)) + + def reshape(self, *args, **kwargs): + return self.value.reshape( + *args, **kwargs) # self.set(value=self.value.reshape(*args, **kwargs)) + + def flatten(self, *args, **kwargs): + return self.value.flatten( + *args, **kwargs) #self.set(value=self.value.flatten(*args, **kwargs)) + + def transpose(self, *args, **kwargs): + value = self.value.transpose(*args, **kwargs) + return value + + def permute(self, *args, **kwargs): + value = self.value.permute(*args, **kwargs) + return value + + # Reference: https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types + + def __neg__(self): + neg_deq = -self.minifloat(float_datatype=True) + _, scale, zp = self.expand() + + neg_value = (-neg_deq - zp) * scale + # In case the dtype of self.minifloat is different from the one of the scale + neg_value = neg_value.type(scale.dtype) + neg_value = GroupwiseFloatQuantTensor.from_expanded( + neg_value, self.group_size, self.group_dim, compress=False) + scale = GroupwiseFloatQuantTensor.from_expanded( + scale, self.group_size, self.group_dim, compress=True) + if self.signed: + return GroupwiseFloatQuantTensor( + value=neg_value, + scale=scale, + zero_point=self.zero_point, + group_size=self.group_size, + group_dim=self.group_dim, + exponent_bit_width=self.exponent_bit_width, + mantissa_bit_width=self.mantissa_bit_width, + exponent_bias=self.exponent_bias, + saturating=self.saturating, + inf_values=self.inf_values, + nan_values=self.nan_values, + signed=self.signed, + training=self.training) + else: + # TODO: implement + raise NotImplementedError + + def __add__(self, other): + if isinstance(other, QuantTensor): + return self.value + other.value + else: + output = self.value + other + return output + + def __mul__(self, other): + if isinstance(other, QuantTensor): + return self.value * other.value + else: + output = self.value * other + return output + + def __str__(self): + return f"GroupwiseFloatQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, group_size={self.group_size}, group_dim={self.group_dim}, exponent_bit_width={self.exponent_bit_width}, mantissa_bit_width={self.mantissa_bit_width}, exponent_bias={self.exponent_bias}, inf_values={self.inf_values}, nan_values={self.nan_values}, signed_t={self.signed_t}, training_t={self.training_t})" + + def __truediv__(self, other): + if isinstance(other, QuantTensor): + return self.value / other.value + else: + output = self.value / other + return output + + def __abs__(self): + if self.signed: + neg_deq = self.minifloat(float_datatype=True) + _, scale, zp = self.expand() + + # In case the dtype of self.minifloat is different from the one of the scale + abs_value = (neg_deq - zp) * scale + # In case the dtype of self.minifloat is different from the one of the scale + abs_value = abs_value.type(scale.dtype) + abs_value = GroupwiseFloatQuantTensor.from_expanded( + abs_value, self.group_size, self.group_dim, compress=False) + scale = GroupwiseFloatQuantTensor.from_expanded( + scale, self.group_size, self.group_dim, compress=True) + return GroupwiseFloatQuantTensor( + value=abs_value, + scale=self.scale, + zero_point=self.zero_point, + group_size=self.group_size, + group_dim=self.group_dim, + exponent_bit_width=self.exponent_bit_width, + mantissa_bit_width=self.mantissa_bit_width, + exponent_bias=self.exponent_bias, + saturating=self.saturating, + inf_values=self.inf_values, + nan_values=self.nan_values, + signed=False, + training=self.training) + else: + return self diff --git a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py new file mode 100644 index 000000000..976e86130 --- /dev/null +++ b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py @@ -0,0 +1,312 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import torch + +from brevitas.function.ops_ste import round_ste +from brevitas.quant_tensor import _unpack_quant_tensor +from brevitas.quant_tensor.base_quant_tensor import GroupwisIntQuantTensorBase +from brevitas.quant_tensor.base_quant_tensor import QuantTensor +from brevitas.utils.torch_utils import float_internal_scale + +from .int_torch_handler import INT_QUANT_TENSOR_FN_HANDLER +from .torch_handler import QUANT_TENSOR_FN_HANDLER + +IS_VALID_ATOL = 2e-1 +BFLOAT16_IS_VALID_ATOL = 0.5 + + +class GroupwiseIntQuantTensor(GroupwisIntQuantTensorBase, QuantTensor): + + def __new__(cls, value, scale, zero_point, group_size, group_dim, bit_width, signed, training): + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, dtype=torch.float) + if not isinstance(zero_point, torch.Tensor): + zero_point = torch.tensor(zero_point, dtype=torch.float) + if not isinstance(bit_width, torch.Tensor): + bit_width = torch.tensor(bit_width, dtype=torch.float) + if not isinstance(signed, torch.Tensor): + signed = torch.tensor(signed, dtype=torch.bool) + if not isinstance(training, torch.Tensor): + training = torch.tensor(training, dtype=torch.bool) + quant_tensor = super().__new__( + cls, value, scale, zero_point, group_size, group_dim, bit_width, signed, training) + return quant_tensor + + @property + def signed(self): + return self.signed_t.item() + + @property + def training(self): + return self.training_t.item() + + @property + def saturating(self): + return self.saturating_t.item() + + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if func in QUANT_TENSOR_FN_HANDLER: + return QUANT_TENSOR_FN_HANDLER[func](*args, **kwargs) + else: + args = _unpack_quant_tensor(args) + kwargs = _unpack_quant_tensor(kwargs) + return func(*args, **kwargs) + + def expand(self): + curr_shape = self.value_.shape + new_value = self.value_.flatten(self.group_dim, self.group_dim + 1) + if self.scale_.shape != (): + new_scale = self.scale_.expand(curr_shape).flatten(self.group_dim, self.group_dim + 1) + else: + new_scale = self.scale_ + if self.zero_point_.shape != (): + new_zp = self.zero_point_.expand(curr_shape).flatten(self.group_dim, self.group_dim + 1) + else: + new_zp = self.zero_point_ + + return new_value, new_scale, new_zp + + @staticmethod + def from_expanded(value, group_size, group_dim, compress=False): + size = list(value.shape) + assert size[group_dim] % group_size == 0, 'Input channel is not divisible by group size' + if compress: + size[group_dim] = 1 + else: + size[group_dim] = size[group_dim] // group_size + size.insert(group_dim + 1, group_size) + new_value = value.view(size) + return new_value + + @property + def tensor(self): + return self.value + + @property + def value(self): + new_value, new_scale, new_zp = self.expand() + return new_value + + @property + def scale(self): + new_value, new_scale, new_zp = self.expand() + return new_scale + + @property + def zero_point(self): + new_value, new_scale, new_zp = self.expand() + return new_zp + + @property + def _pre_round_int_value(self): + value = self.value + scale = self.scale + zero_point = self.zero_point + if self.scale.dtype == torch.bfloat16: + value = self.value.type(torch.float32) + scale = self.scale.type(torch.float32) + zero_point = self.zero_point.type(torch.float32) + int_value = value / scale + int_value = int_value + zero_point + return int_value + + @property + def is_valid(self): + with torch.no_grad(): + pre_round_int_value = self._pre_round_int_value + rounded_int_value = torch.round(pre_round_int_value) + max_abs_diff = torch.max(torch.abs(pre_round_int_value - rounded_int_value)) + atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL + is_int = max_abs_diff < atol + if self.bit_width >= 2: + if self.signed: + is_upper_b = (2.0 ** (self.bit_width - 1) - 1 >= rounded_int_value).all() + is_lower_b = (-2.0 ** (self.bit_width - 1) <= rounded_int_value).all() + else: + is_upper_b = (2.0 ** self.bit_width - 1 >= rounded_int_value).all() + is_lower_b = (0. <= rounded_int_value).all() + return (is_int & is_upper_b & is_lower_b).item() + else: # binary case + unique_vals = rounded_int_value.unique( + sorted=False, return_counts=False, return_inverse=False) + is_binary = unique_vals.view(-1).size()[0] == 2 + is_signed = (unique_vals < 0.).any().item() + sign_match = is_signed == self.signed + return is_int.item() and is_binary and sign_match + + @property + def device(self): + value_device = self.value_.device + is_same_device = True + for t in [self.scale, + self.zero_point, + self.exponent_bit_width, + self.mantissa_bit_width, + self.exponent_bias]: + is_same_device &= value_device == t.device + if not is_same_device: + raise RuntimeError("Value and metadata are on different devices") + return value_device + + def int(self, float_datatype=False): + if self.is_valid: + int_value = round_ste(self._pre_round_int_value) + if float_datatype: + # Values at 8bit and lower can be represented exactly with float16 and bfloat16 + # otherwise (e.g. Int16 bias), we upscale to float32 + if self.bit_width <= 8.: + return int_value.type(self.scale.dtype) + else: + return int_value.type(torch.float32) + else: + if self.bit_width <= 8. and self.signed_t.item(): + return int_value.to(torch.int8) + elif self.bit_width <= 8. and not self.signed_t.item(): + return int_value.to(torch.uint8) + else: + return int_value.to(torch.int32) + else: + raise RuntimeError(f"GroupwiseIntQuantTensor not valid.") + + @staticmethod + def check_input_type(tensor): + if not isinstance(tensor, GroupwiseIntQuantTensor): + raise RuntimeError("Tensor is not a GroupwiseIntQuantTensor") + + @staticmethod + def is_zero_zero_point(tensor): + GroupwiseIntQuantTensor.check_input_type(tensor) + return (tensor.zero_point == 0.).all() + + def check_scaling_factors_same(self, other): + if self.training: + return True + if not torch.allclose(self.scale, other.scale): + raise RuntimeError("Scaling factors are different") + + def check_zero_points_same(self, other): + if self.training: + return True + if not torch.allclose(self.zero_point, other.zero_point): + raise RuntimeError("Zero points are different") + + def check_bit_width_same(self, other): + if not torch.allclose(self.exponent_bit_width, + other.exponent_bit_width) and not torch.allclose( + self.mantissa_bit_width, other.mantissa_bit_width): + raise RuntimeError("Bit widths are different") + + def check_exponent_bias(self, other): + if not torch.allclose(self.exponent_bias, other.exponent_bias): + raise RuntimeError("Bit widths are different") + + def check_inf_nan_same(self, other): + if not (set(self.inf_values) == set(other.inf_values)) and not (set(self.nan_values) == set( + other.nan_values)): + raise RuntimeError("Floating point representations are different") + + def check_sign_same(self, other): + if not self.signed == other.signed: + raise RuntimeError("Signs are different") + + def view(self, *args, **kwargs): + return self.value.view(*args, **kwargs) #self.set(value=self.value.view(*args, **kwargs)) + + def reshape(self, *args, **kwargs): + return self.value.reshape( + *args, **kwargs) # self.set(value=self.value.reshape(*args, **kwargs)) + + def flatten(self, *args, **kwargs): + return self.value.flatten( + *args, **kwargs) #self.set(value=self.value.flatten(*args, **kwargs)) + + def transpose(self, *args, **kwargs): + value = self.value.transpose(*args, **kwargs) + return value + + def permute(self, *args, **kwargs): + value = self.value.permute(*args, **kwargs) + return value + + # Reference: https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types + + def __neg__(self): + neg_deq = -self.minifloat(float_datatype=True) + _, scale, zp = self.expand() + + neg_value = (-neg_deq - zp) * scale + # In case the dtype of self.minifloat is different from the one of the scale + neg_value = neg_value.type(scale.dtype) + neg_value = GroupwiseIntQuantTensor.from_expanded( + neg_value, self.group_size, self.group_dim, compress=False) + scale = GroupwiseIntQuantTensor.from_expanded( + scale, self.group_size, self.group_dim, compress=True) + if self.signed: + return GroupwiseIntQuantTensor( + value=neg_value, + scale=scale, + zero_point=self.zero_point, + group_size=self.group_size, + group_dim=self.group_dim, + bit_width=self.bit_width, + signed=self.signed, + training=self.training, + saturating=self.saturating) + else: + # TODO: implement + raise NotImplementedError + + def __add__(self, other): + if isinstance(other, QuantTensor): + return self.value + other.value + else: + output = self.value + other + return output + + def __mul__(self, other): + if isinstance(other, QuantTensor): + return self.value * other.value + else: + output = self.value * other + return output + + def __str__(self): + return f"GroupwiseIntQuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, group_size={self.group_size}, group_dim={self.group_dim}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})" + + def __truediv__(self, other): + if isinstance(other, QuantTensor): + return self.value / other.value + else: + output = self.value / other + return output + + def __abs__(self): + if self.signed: + neg_deq = self.minifloat(float_datatype=True) + _, scale, zp = self.expand() + + # In case the dtype of self.minifloat is different from the one of the scale + abs_value = (neg_deq - zp) * scale + # In case the dtype of self.minifloat is different from the one of the scale + abs_value = abs_value.type(scale.dtype) + abs_value = GroupwiseIntQuantTensor.from_expanded( + abs_value, self.group_size, self.group_dim, compress=False) + scale = GroupwiseIntQuantTensor.from_expanded( + scale, self.group_size, self.group_dim, compress=True) + return GroupwiseIntQuantTensor( + value=abs_value, + scale=self.scale, + zero_point=self.zero_point, + group_size=self.group_size, + group_dim=self.group_dim, + bit_width=self.bit_width, + signed=False, + training=self.training, + saturating=self.saturating) + else: + return self diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index 8df77a99e..f160877a0 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -6,6 +6,8 @@ from brevitas.core.quant import RescalingIntQuant from brevitas.inject.enum import FloatToIntImplType from brevitas.quant_tensor import FloatQuantTensor +from brevitas.quant_tensor import GroupwiseFloatQuantTensor +from brevitas.quant_tensor import GroupwiseIntQuantTensor from brevitas.quant_tensor import IntQuantTensor @@ -81,6 +83,94 @@ def signed(self): return self.quant_tensor.signed +class _CachedIOGroupwiseFloat: + + def __init__(self, quant_tensor: GroupwiseFloatQuantTensor, metadata_only: bool): + self.shape = quant_tensor.value.shape + if metadata_only: + self.quant_tensor = quant_tensor.set(value=None) + else: + self.quant_tensor = quant_tensor + + @property + def scale(self): + return self.quant_tensor.scale + + @property + def zero_point(self): + return self.quant_tensor.zero_point + + @property + def exponent_bit_width(self): + return self.quant_tensor.exponent_bit_width + + @property + def mantissa_bit_width(self): + return self.quant_tensor.mantissa_bit_width + + @property + def exponent_bias(self): + return self.quant_tensor.exponent_bias + + @property + def saturating(self): + return self.quant_tensor.saturating + + @property + def inf_values(self): + return self.quant_tensor.inf_values + + @property + def nan_values(self): + return self.quant_tensor.nan_values + + @property + def signed(self): + return self.quant_tensor.signed + + @property + def group_size(self): + return self.quant_tensor.group_size + + @property + def group_dim(self): + return self.quant_tensor.group_dim + + +class _CachedIOGroupwiseInt: + + def __init__(self, quant_tensor: GroupwiseIntQuantTensor, metadata_only: bool): + self.shape = quant_tensor.value.shape + if metadata_only: + self.quant_tensor = quant_tensor.set(value=None) + else: + self.quant_tensor = quant_tensor + + @property + def scale(self): + return self.quant_tensor.scale + + @property + def zero_point(self): + return self.quant_tensor.zero_point + + @property + def bit_width(self): + return self.quant_tensor.bit_width + + @property + def signed(self): + return self.quant_tensor.signed + + @property + def group_size(self): + return self.quant_tensor.group_size + + @property + def group_dim(self): + return self.quant_tensor.group_dim + + def has_learned_weight_bit_width(module): from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector diff --git a/src/brevitas_examples/common/generative/quant_blocks.py b/src/brevitas_examples/common/generative/quant_blocks.py index 40516111f..18149578d 100644 --- a/src/brevitas_examples/common/generative/quant_blocks.py +++ b/src/brevitas_examples/common/generative/quant_blocks.py @@ -16,79 +16,6 @@ from brevitas.function.ops_ste import abs_binary_sign_grad -class OverSubChannelBlockView(brevitas.jit.ScriptModule): - __constants__ = ['scaling_input_shape'] - - def __init__(self, scaling_input_shape, permute_dims: Optional[Tuple[int, ...]]) -> None: - super(OverSubChannelBlockView, self).__init__() - self.scaling_input_shape = scaling_input_shape - if permute_dims is not None: - self.permute_impl = PermuteDims(permute_dims) - else: - self.permute_impl = nn.Identity() - - @brevitas.jit.script_method - def forward(self, x: torch.Tensor): - y = self.permute_impl(x) - y = y.view(self.scaling_input_shape) - return y - - -class ExpandReshapeScalingWrapper(brevitas.jit.ScriptModule): - __constants__ = ['expanded_scaling_shape', 'reshaped_scaling_shape'] - - def __init__(self, wrapped_scaling_impl, expanded_scaling_shape, reshaped_scaling_shape): - super(ExpandReshapeScalingWrapper, self).__init__() - self.wrapped_scaling_impl = wrapped_scaling_impl - self.expanded_scaling_shape = expanded_scaling_shape - self.reshaped_scaling_shape = reshaped_scaling_shape - self.slice_tensor = SliceTensor() - - @brevitas.jit.script_method - def forward(self, x): - scale = self.wrapped_scaling_impl(x) - scale = scale.expand(self.expanded_scaling_shape) - scale = scale.reshape(self.reshaped_scaling_shape) - # slice tensor when required by partial quantization - scale = self.slice_tensor(scale) - return scale - - -class ExpandReshapeZeroPointWrapper(brevitas.jit.ScriptModule): - __constants__ = ['expanded_zero_point_shape', 'reshaped_zero_point_shape'] - - def __init__( - self, wrapped_zero_point_impl, expanded_zero_point_shape, reshaped_zero_point_shape): - super(ExpandReshapeZeroPointWrapper, self).__init__() - self.wrapped_zero_point_impl = wrapped_zero_point_impl - self.expanded_zero_point_shape = expanded_zero_point_shape - self.reshaped_zero_point_shape = reshaped_zero_point_shape - self.slice_tensor = SliceTensor() - - def unexpanded_zero_point(self, unexpanded_scale, bit_width): - """ - This is used at export time. - """ - zero_point_stats = self.wrapped_zero_point_impl.parameter_list_stats() - zero_point = self.wrapped_zero_point_impl.scale_shift_zero_point( - -zero_point_stats, unexpanded_scale, bit_width) - return zero_point - - @brevitas.jit.script_method - def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor): - # We have to break into wrapped_zero_point_impl since we need to expand and reshape - # Before we call into scale_shift_zero_point - zero_point_stats = self.wrapped_zero_point_impl.parameter_list_stats() - zero_point_stats = zero_point_stats.expand(self.expanded_zero_point_shape).contiguous() - # contiguous() above is to avoid an unsafe_view below - zero_point_stats = zero_point_stats.reshape(self.reshaped_zero_point_shape) - # slice tensor when required by partial quantization - zero_point_stats = self.slice_tensor(zero_point_stats) - zero_point = self.wrapped_zero_point_impl.scale_shift_zero_point( - -zero_point_stats, scale, bit_width) - return zero_point - - # TODO: restore JIT compatibility class RuntimeDynamicStatsScaling(nn.Module): @@ -135,31 +62,3 @@ def forward(self, x, scale, bit_width) -> Tensor: x = abs_binary_sign_grad(x) x = self.scale_shift_zero_point(x, scale, bit_width) return x - - -class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule): - - def __init__(self, group_size: int, group_dim: int, scaling_stats_impl: nn.Module) -> None: - super(RuntimeDynamicGroupStatsScaling, self).__init__() - self.group_size = group_size - self.group_dim = group_dim - self.scaling_stats_impl = scaling_stats_impl - - @brevitas.jit.script_method - def group_scaling_reshape(self, stats_input): - tensor_shape = stats_input.shape - tensor_shape_list = list(tensor_shape) - tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size) - block_dim = self.group_dim + 1 if self.group_dim != -1 else -1 - tensor_shape_list.insert(block_dim, self.group_size) - stats_input = stats_input.view(tensor_shape_list) - return stats_input - - @brevitas.jit.script_method - def forward(self, stats_input) -> Tensor: - stats_input_reshaped = self.group_scaling_reshape(stats_input) - out = self.scaling_stats_impl(stats_input_reshaped) - out = torch.clamp_min(out, min=torch.tensor(1e-6, device=out.device, dtype=out.dtype)) - out = out.expand(stats_input_reshaped.shape) - out = out.reshape(stats_input.shape) - return out diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index a86de3b76..57670f6f6 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -8,10 +8,8 @@ from torch import nn from brevitas import nn as qnn -from brevitas.core.stats.stats_op import NegativeMinOrZero from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint from brevitas.graph.quantize import layerwise_quantize -from brevitas.inject.enum import StatsOp from brevitas.quant.experimental.float import Fp8e4m3Act from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat @@ -19,13 +17,17 @@ from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZActPerTensorFloat from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZWeightPerChannelFloat from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZWeightPerTensorFloat -from brevitas.quant.experimental.float_quant_fnuz import Fp8e5m2FNUZActPerTensorFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat -from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPActPerTensorFloat -from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeightPerChannelFloat -from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeightPerTensorFloat +from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act +from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight +from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3WeightMSE +from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act +from brevitas.quant.experimental.mx_quant_ocp import MXInt8Weight +from brevitas.quant.experimental.mx_quant_ocp import MXInt8WeightMSE +from brevitas.quant.experimental.mx_quant_ocp import ShiftedMXUInt8Weight +from brevitas.quant.experimental.mx_quant_ocp import ShiftedMXUInt8WeightMSE from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint from brevitas.quant.fixed_point import Int8ActPerTensorFixedPointMSE from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint @@ -46,7 +48,7 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear -from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicOCPActPerTensorFloat +from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat from brevitas_examples.common.generative.quantizers import Fp8e4m3WeightSymmetricGroupQuant from brevitas_examples.common.generative.quantizers import Int8DynamicActPerGroupFloat from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFloat @@ -66,7 +68,7 @@ 'sym': Int8WeightPerChannelFloat, 'asym': ShiftedUint8WeightPerChannelFloat}, 'per_group': { 'sym': IntWeightSymmetricGroupQuant, - 'asym': ShiftedUintWeightAsymmetricGroupQuant},}, + 'asym': ShiftedUintWeightAsymmetricGroupQuant}}, 'mse': { 'per_tensor': { 'sym': Int8WeightPerTensorFloatMSE, @@ -79,12 +81,16 @@ 'per_tensor': { 'sym': Int8WeightPerTensorFixedPoint}, 'per_channel': { - 'sym': Int8WeightPerChannelFixedPoint},}, + 'sym': Int8WeightPerChannelFixedPoint}, + 'per_group': { + 'sym': MXInt8Weight, 'asym': ShiftedMXUInt8Weight}}, 'mse': { 'per_tensor': { 'sym': Int8WeightPerTensorFixedPointMSE}, 'per_channel': { - 'sym': Int8WeightPerChannelFixedPointMSE},},}}, + 'sym': Int8WeightPerChannelFixedPointMSE}}, + 'per_group': { + 'sym': MXInt8WeightMSE, 'asym': ShiftedMXUInt8WeightMSE}}}, 'float': { 'float_scale': { 'stats': { @@ -95,28 +101,26 @@ 'per_group': { 'sym': Fp8e4m3WeightSymmetricGroupQuant}}}}, 'float_ocp': { - 'e4m3': { - 'float_scale': { - 'stats': { - 'per_tensor': { - 'sym': Fp8e4m3OCPWeightPerTensorFloat}, - 'per_channel': { - 'sym': Fp8e4m3OCPWeightPerChannelFloat}}}}, - 'e5m2': { - 'float_scale': { - 'stats': { - 'per_tensor': { - 'sym': Fp8e5m2OCPWeightPerTensorFloat}, - 'per_channel': { - 'sym': Fp8e5m2OCPWeightPerChannelFloat}}}}}, + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3OCPWeightPerTensorFloat}, + 'per_channel': { + 'sym': Fp8e4m3OCPWeightPerChannelFloat}}}, + 'po2_scale': { + 'stats': { + 'per_group': { + 'sym': MXFloat8e4m3Weight}}, + 'mse': { + 'per_group': { + 'sym': MXFloat8e4m3WeightMSE}}}}, 'float_fnuz': { - 'e4m3': { - 'float_scale': { - 'stats': { - 'per_tensor': { - 'sym': Fp8e4m3FNUZWeightPerTensorFloat}, - 'per_channel': { - 'sym': Fp8e4m3FNUZWeightPerChannelFloat}}}}}} + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3FNUZWeightPerTensorFloat}, + 'per_channel': { + 'sym': Fp8e4m3FNUZWeightPerChannelFloat}}}}} INPUT_QUANT_MAP = { 'int': { @@ -124,18 +128,18 @@ 'float_scale': { 'stats': { 'per_tensor': { - 'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat},}, + 'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat}}, 'mse': { 'per_tensor': { - 'sym': Int8ActPerTensorFloatMSE, 'asym': ShiftedUint8ActPerTensorFloatMSE}}, - }, + 'sym': Int8ActPerTensorFloatMSE, + 'asym': ShiftedUint8ActPerTensorFloatMSE}}}, 'po2_scale': { 'stats': { 'per_tensor': { - 'sym': Int8ActPerTensorFixedPoint},}, + 'sym': Int8ActPerTensorFixedPoint}}, 'mse': { 'per_tensor': { - 'sym': Int8ActPerTensorFixedPointMSE},},}}, + 'sym': Int8ActPerTensorFixedPointMSE}}}}, 'dynamic': { 'float_scale': { 'stats': { @@ -146,45 +150,38 @@ 'sym': Int8DynamicActPerRowFloat, 'asym': ShiftedUint8DynamicActPerRowFloat}, 'per_group': { - 'sym': Int8DynamicActPerGroupFloat},}}}}, + 'sym': Int8DynamicActPerGroupFloat}}}, + 'po2_scale': { + 'stats': { + 'per_group': MXInt8Act}}}}, 'float': { 'static': { 'float_scale': { 'stats': { 'per_tensor': { - 'sym': Fp8e4m3ActPerTensorFloat},}}}, + 'sym': Fp8e4m3ActPerTensorFloat}}}}, + 'dynamic': { + 'float_scale': { + 'stats': { + 'per_group': Fp8e4m3DynamicActPerGroupFloat}}}, 'no_scale': { 'sym': Fp8e4m3Act,}}, 'float_ocp': { 'static': { - 'e4m3': { - 'float_scale': { - 'stats': { - 'per_tensor': { - 'sym': Fp8e4m3OCPActPerTensorFloat}}}}, - 'e5m2': { - 'float_scale': { - 'stats': { - 'per_tensor': { - 'sym': Fp8e5m2OCPActPerTensorFloat}}}}}}, + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3OCPActPerTensorFloat}}}}, + 'dynamic': { + 'po2_scale': { + 'stats': { + 'per_group': MXFloat8e4m3Act}}}}, 'float_fnuz': { 'static': { - 'e4m3': { - 'float_scale': { - 'stats': { - 'per_tensor': { - 'sym': Fp8e4m3FNUZActPerTensorFloat}}}}, - 'e5m2': { - 'float_scale': { - 'stats': { - 'per_tensor': { - 'sym': Fp8e5m2FNUZActPerTensorFloat}}}}}, - 'dynamic': { - 'e4m3': { - 'float_scale': { - 'stats': { - 'per_tensor': { - 'sym': Fp8e4m3DynamicOCPActPerTensorFloat}}}}}}} + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3FNUZActPerTensorFloat}}}}}} def generate_quantizers( @@ -206,8 +203,6 @@ def generate_quantizers( input_quant_granularity=None, input_group_size=None, quantize_input_zero_point=False, - use_ocp=False, - use_fnuz=False, device=None, weight_kwargs=None, input_kwargs=None): @@ -215,67 +210,36 @@ def generate_quantizers( Replace float layers with quant layers in the target model """ # Retrive base input and weight quantizers - std_float_weight_quant_format = None - std_float_input_format = None # match against custom float format - if re.compile(r'e[1-8]m[1-8]').match(weight_quant_format): + if re.compile(r'e[1-8]m[1-8]').findall(weight_quant_format): + format = re.compile(r'e[1-8]m[1-8]').findall(weight_quant_format)[0] + weight_quant_format = weight_quant_format.replace('_' + format, '') weight_float_format = { - 'exponent_bit_width': int(weight_quant_format[1]), - 'mantissa_bit_width': int(weight_quant_format[3])} - std_float_weight_quant_format = weight_quant_format - weight_quant_format = 'float' - if use_ocp: - weight_quant_format += '_ocp' - elif use_fnuz: - weight_quant_format += '_fnuz' + 'exponent_bit_width': int(format[1]), 'mantissa_bit_width': int(format[3])} else: weight_float_format = {} - if re.compile(r'e[1-8]m[1-8]').match(input_quant_format): + if re.compile(r'e[1-8]m[1-8]').findall(input_quant_format): + format = re.compile(r'e[1-8]m[1-8]').findall(input_quant_format)[0] + input_quant_format = input_quant_format.replace('_' + format, '') input_float_format = { - 'exponent_bit_width': int(input_quant_format[1]), - 'mantissa_bit_width': int(input_quant_format[3])} - std_float_input_format = input_quant_format - input_quant_format = 'float' - if use_ocp: - input_quant_format += '_ocp' - elif use_fnuz: - input_quant_format += '_fnuz' + 'exponent_bit_width': int(format[1]), 'mantissa_bit_width': int(format[3])} else: input_float_format = {} - if 'ocp' in weight_quant_format or 'fnuz' in weight_quant_format: - weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][std_float_weight_quant_format][ - weight_scale_precision][weight_param_method][weight_quant_granularity][ - weight_quant_type] - else: - weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_precision][ - weight_param_method][weight_quant_granularity][weight_quant_type] + weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_precision][ + weight_param_method][weight_quant_granularity][weight_quant_type] if input_bit_width is not None and input_scale_type == 'no_scale': input_quant = sym_input_quant = linear_input_quant = INPUT_QUANT_MAP[input_quant_format][ input_scale_type][input_quant_type] elif input_bit_width is not None: - if 'ocp' in input_quant_format or 'fnuz' in input_quant_format: - input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ - std_float_input_format][input_scale_precision][input_param_method][ - input_quant_granularity][input_quant_type] - # Some activations in MHA should always be symmetric - sym_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ - std_float_input_format][input_scale_precision][input_param_method][ - input_quant_granularity]['sym'] - linear_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ - std_float_input_format][input_scale_precision][input_param_method][ - input_quant_granularity][input_quant_type] - else: - input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ - input_scale_precision][input_param_method][input_quant_granularity][ - input_quant_type] - # Some activations in MHA should always be symmetric - sym_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ - input_scale_precision][input_param_method][input_quant_granularity]['sym'] - linear_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ - input_scale_precision][input_param_method][input_quant_granularity][ - input_quant_type] + input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][input_scale_precision][ + input_param_method][input_quant_granularity][input_quant_type] + # Some activations in MHA should always be symmetric + sym_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ + input_scale_precision][input_param_method][input_quant_granularity]['sym'] + linear_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ + input_scale_precision][input_param_method][input_quant_granularity][input_quant_type] if input_kwargs is None: input_kwargs = dict() @@ -472,8 +436,6 @@ def quantize_model( input_group_size=None, quantize_input_zero_point=False, quantize_embedding=False, - use_ocp=False, - use_fnuz=False, device=None, weight_kwargs=None, input_kwargs=None): @@ -497,8 +459,6 @@ def quantize_model( input_quant_granularity, input_group_size, quantize_input_zero_point, - use_ocp, - use_fnuz, device, weight_kwargs, input_kwargs) diff --git a/src/brevitas_examples/common/generative/quantizers.py b/src/brevitas_examples/common/generative/quantizers.py index 445ea45b6..1f41e136a 100644 --- a/src/brevitas_examples/common/generative/quantizers.py +++ b/src/brevitas_examples/common/generative/quantizers.py @@ -7,105 +7,59 @@ from brevitas.core.function_wrapper.shape import OverOutputFeaturesView from brevitas.core.function_wrapper.shape import OverTensorView -from brevitas.core.scaling import ParameterFromStatsFromParameterScaling +from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling from brevitas.core.stats import AbsMinMax from brevitas.core.stats import NegativeMinOrZero from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE -from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint from brevitas.inject import ExtendedInjector from brevitas.inject import this from brevitas.inject import value +from brevitas.inject.enum import ScalingPerOutputType +from brevitas.proxy.groupwise_float_parameter_quant import \ + GroupwiseWeightFloatQuantProxyFromInjector +from brevitas.proxy.groupwise_float_runtime_quant import GroupwiseActFloatQuantProxyFromInjector +from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector +from brevitas.proxy.groupwise_int_runtime_quant import GroupwiseActQuantProxyFromInjector from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector +from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloat -from brevitas.quant.scaled_int import Int8ActPerTensorFloatMSE from brevitas.quant.scaled_int import Int8WeightPerChannelFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat -from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatMSE +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat from .quant_blocks import * -class WeightSymmetricGroupQuantMixin(ExtendedInjector): - - @value - def expanded_scaling_shape(module, group_size): - if isinstance(module, nn.Conv2d): - return module.weight.size(0), module.weight.size(1) // group_size, group_size, module.weight.size(2), module.weight.size(3) - elif isinstance(module, nn.Linear): - return module.weight.size(0), module.weight.size(1) // group_size, group_size - elif isinstance(module, nn.Embedding): - return module.weight.size(0), module.weight.size(1) // group_size, group_size - else: - raise RuntimeError("Module not supported.") - - @value - def scaling_shape(module, group_size): - if isinstance(module, nn.Conv2d): - return module.weight.size(0), module.weight.size(1) // group_size, 1, module.weight.size(2), module.weight.size(3) - elif isinstance(module, nn.Linear): - return module.weight.size(0), module.weight.size(1) // group_size, 1 - elif isinstance(module, nn.Embedding): - return module.weight.size(0), module.weight.size(1) // group_size, 1 - else: - raise RuntimeError("Module not supported.") - - @value - def reshaped_scaling_shape(module): - return module.weight.shape - - scaling_input_shape = this.expanded_scaling_shape - scaling_stats_input_view_shape_impl = OverSubChannelBlockView - scaling_impl = ExpandReshapeScalingWrapper - # scale is converted to a parameter right away - wrapped_scaling_impl = ParameterFromStatsFromParameterScaling - keepdim = True - stats_reduce_dim = 2 - # Set bit_width and block size externally - bit_width = None - group_size = None - - class DynamicActProxyMixin(ExtendedInjector): proxy_class = DynamicActQuantProxyFromInjector -class IntWeightSymmetricGroupQuant(WeightSymmetricGroupQuantMixin, Int8WeightPerChannelFloat): +class IntWeightSymmetricGroupQuant(Int8WeightPerChannelFloat): """ Block / group / vector signed symmetric int weight quantizer with float scales. We inherit from a per-channel quantizer to re-use some underlying machinery. """ - pass + proxy_class = GroupwiseWeightQuantProxyFromInjector + scaling_per_output_type = ScalingPerOutputType.GROUP -class Fp8e4m3WeightSymmetricGroupQuant(WeightSymmetricGroupQuantMixin, - Fp8e4m3WeightPerChannelFloat): +class Fp8e4m3WeightSymmetricGroupQuant(Fp8e4m3WeightPerChannelFloat): """ Block / group / vector signed symmetric e4m3 weight quantizer with float scales. We inherit from a per-channel quantizer to re-use some underlying machinery. """ - pass + proxy_class = GroupwiseWeightFloatQuantProxyFromInjector + scaling_per_output_type = ScalingPerOutputType.GROUP -class ShiftedUintWeightAsymmetricGroupQuant(IntWeightSymmetricGroupQuant): +class ShiftedUintWeightAsymmetricGroupQuant(ShiftedUint8WeightPerChannelFloat): """ Block / group / vector signed asymmetric weight quantizer with float scales and zero-points. """ - zero_point_input_shape = this.scaling_input_shape - reshaped_zero_point_shape = this.reshaped_scaling_shape - zero_point_shape = this.scaling_shape - expanded_zero_point_shape = this.expanded_scaling_shape - zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl - zero_point_stats_input_concat_dim = 0 - zero_point_impl = ExpandReshapeZeroPointWrapper - zero_point_stats_impl = NegativeMinOrZero - scaling_stats_impl = AbsMinMax - keepdim = True - # zero-point is converted to a parameter right away - wrapped_zero_point_impl = ParameterFromStatsFromParameterZeroPoint - quantize_zero_point = False - signed = False + proxy_class = GroupwiseWeightQuantProxyFromInjector + scaling_per_output_type = ScalingPerOutputType.GROUP class Int8DynamicActPerTensorFloat(DynamicActProxyMixin, Int8ActPerTensorFloat): @@ -132,18 +86,10 @@ class Int8DynamicActPerGroupFloat(DynamicActProxyMixin, Int8ActPerTensorFloat): """ Symmetric quantizer with per group scale. """ + proxy_class = GroupwiseActQuantProxyFromInjector scaling_impl = RuntimeDynamicGroupStatsScaling - keepdim = True scaling_stats_op = 'min_max' - scaling_per_output_channel = True - - @value - def stats_reduce_dim(group_dim): - # If group_dim = -1, we need a workaround to avoid selecting wrong dim - if group_dim == -1: - return -1 - else: - return group_dim + 1 + scaling_per_output_type = ScalingPerOutputType.GROUP class ShiftedUint8DynamicActPerTensorFloat(DynamicActProxyMixin, ShiftedUint8ActPerTensorFloat): @@ -170,11 +116,11 @@ class ShiftedUint8DynamicActPerRowFloat(DynamicActProxyMixin, ShiftedUint8ActPer zero_point_stats_impl = NegativeMinOrZero -class Fp8e4m3DynamicOCPActPerTensorFloat(DynamicActProxyMixin, Fp8e4m3OCPActPerTensorFloat): +class Fp8e4m3DynamicActPerGroupFloat(DynamicActProxyMixin, Fp8e4m3ActPerTensorFloat): """ - Symmetric quantizer with per tensor dynamic scale. + Symmetric quantizer with per group scale. """ - scaling_impl = RuntimeDynamicStatsScaling - scaling_stats_input_view_shape_impl = OverTensorView + proxy_class = GroupwiseActFloatQuantProxyFromInjector + scaling_impl = RuntimeDynamicGroupStatsScaling + scaling_per_output_type = ScalingPerOutputType.GROUP scaling_stats_op = 'min_max' - dynamic_scaling_broadcastable_fn = lambda x, shape: x.view(SCALAR_SHAPE) diff --git a/src/brevitas_examples/common/parse_utils.py b/src/brevitas_examples/common/parse_utils.py index 0b13b69b8..186fdca60 100644 --- a/src/brevitas_examples/common/parse_utils.py +++ b/src/brevitas_examples/common/parse_utils.py @@ -13,7 +13,7 @@ def __init__(self, pattern): self._pattern = re.compile(pattern) def __call__(self, value): - if not self._pattern.match(value): + if not self._pattern.findall(value): raise argparse.ArgumentTypeError( "Argument has to match '{}'".format(self._pattern.pattern)) return value diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index 315e3b674..cdf708d17 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -14,50 +14,113 @@ Set the env variable BREVITAS_JIT=1 to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points. ```bash -usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--seqlen SEQLEN] [--eval] [--weight-bit-width WEIGHT_BIT_WIDTH] [--weight-param-method {stats,mse}] - [--weight-scale-type {float32,po2}] [--weight-quant-type {sym,asym}] [--weight-quant-granularity {per_channel,per_tensor,per_group}] - [--weight-group-size WEIGHT_GROUP_SIZE] [--quantize-weight-zero-point] [--input-bit-width INPUT_BIT_WIDTH] [--input-param-method {stats,mse}] - [--input-scale-type {float32,po2}] [--input-quant-type {sym,asym}] [--input-quant-granularity {per_tensor}] [--quantize-input-zero-point] [--gptq] - [--act-calibration] [--bias-corr] [--act-equalization] +usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] + [--seqlen SEQLEN] [--eval] [--dataset {wikitext2,c4}] + [--weight-bit-width WEIGHT_BIT_WIDTH] + [--weight-param-method {stats,mse}] + [--weight-scale-precision {float_scale,po2_scale}] + [--weight-quant-type {sym,asym}] + [--weight-quant-format WEIGHT_QUANT_FORMAT] + [--weight-quant-granularity {per_channel,per_tensor,per_group}] + [--weight-group-size WEIGHT_GROUP_SIZE] + [--quantize-weight-zero-point] + [--input-bit-width INPUT_BIT_WIDTH] + [--input-quant-format INPUT_QUANT_FORMAT] + [--input-param-method {stats,mse}] + [--input-scale-precision {float_scale,po2_scale}] + [--input-scale-type {static,dynamic,no_scale}] + [--input-quant-type {sym,asym}] + [--input-quant-granularity {per_tensor,per_row,per_group}] + [--input-group-size INPUT_GROUP_SIZE] + [--quantize-input-zero-point] [--quantize-last-layer] [--gptq] + [--act-calibration] [--bias-corr] [--ln-affine-merge] + [--no-quantize] [--no-float16] [--replace-mha] + [--weight-equalization] + [--act-equalization {None,layerwise,fx}] [--load-awq LOAD_AWQ] [--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}] + [--checkpoint-name CHECKPOINT_NAME] -optional arguments: +options: -h, --help show this help message and exit --model MODEL HF model name. Default: facebook/opt-125m. --seed SEED Seed for sampling the calibration data. Default: 0. --nsamples NSAMPLES Number of calibration data samples. Default: 128. --seqlen SEQLEN Sequence length. Default: 2048. - --eval Eval model PPL on C4. + --eval Eval model PPL on the chosen Dataset. + --dataset {wikitext2,c4} + Dataset to use for quantization (default: wikitext2) --weight-bit-width WEIGHT_BIT_WIDTH Weight bit width. Default: 8. --weight-param-method {stats,mse} How scales/zero-point are determined. Default: stats. - --weight-scale-type {float32,po2} + --weight-scale-precision {float_scale,po2_scale} Whether scale is a float value or a po2. Default: po2. --weight-quant-type {sym,asym} Weight quantization type. Default: asym. + --weight-quant-format WEIGHT_QUANT_FORMAT + Weight quantization type. Either int or eXmY, with + X+Y==weight_bit_width-1. It's possible to add + float_ocp_ or float_fnuz_ before the exponent/mantissa + bitwidth. Default: int. --weight-quant-granularity {per_channel,per_tensor,per_group} - Granularity for scales/zero-point of weights. Default: per_group. + Granularity for scales/zero-point of weights. Default: + per_group. --weight-group-size WEIGHT_GROUP_SIZE - Group size for per_group weight quantization. Default: 128. + Group size for per_group weight quantization. Default: + 128. --quantize-weight-zero-point Quantize weight zero-point. --input-bit-width INPUT_BIT_WIDTH - Input bit width. Default: None (disables input quantization). + Input bit width. Default: None (disables input + quantization). + --input-quant-format INPUT_QUANT_FORMAT + Input quantization type. Either int or eXmY, with + X+Y==weight_bit_width-1. It's possible to add + float_ocp_ or float_fnuz_ before the exponent/mantissa + bitwidth. Default: int. --input-param-method {stats,mse} - How scales/zero-point are determined. Default: stats. - --input-scale-type {float32,po2} - Whether input scale is a float value or a po2. Default: float32. + How scales/zero-point are determined. Default: stats + (percentile for static, absmax or minmax for dynamic). + --input-scale-precision {float_scale,po2_scale} + Whether input scale is a float value or a po2. + Default: float. + --input-scale-type {static,dynamic,no_scale} + Whether input scale is a static value or a dynamic + value. --input-quant-type {sym,asym} Input quantization type. Default: asym. - --input-quant-granularity {per_tensor} - Granularity for scales/zero-point of inputs. Default: per_tensor. + --input-quant-granularity {per_tensor,per_row,per_group} + Granularity for scales/zero-point of inputs. Default: + per_tensor. + --input-group-size INPUT_GROUP_SIZE + Group size for per_group input quantization. Default: + 64. --quantize-input-zero-point Quantize input zero-point. + --quantize-last-layer + Quantize last nn.Linear layer. --gptq Apply GPTQ. --act-calibration Apply activation calibration. --bias-corr Apply bias correction. - --act-equalization Apply activation equalization (SmoothQuant). + --ln-affine-merge Merge LN affine params. + --no-quantize Disable quantization. + --no-float16 Disable float16 as base datatype and switch to + float32. + --replace-mha Replace HuggingFace Attention with a quantizable + version + --weight-equalization + Apply weight equalization. Relevant to ReLU based + models (e.g. OPT). + --act-equalization {None,layerwise,fx} + Apply activation equalization (SmoothQuant). Layerwise + introduces standalone mul nodes,while fx merges them + whenever possible into previous tensors, which is + possible on ReLU based models (e.g. OPT). + --load-awq LOAD_AWQ Load the awq search results. --export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight} Model export. + --checkpoint-name CHECKPOINT_NAME + Filename to save checkpoint. If `None`, no checkpoint + is saved (default: None) + ``` diff --git a/src/brevitas_examples/llm/llm_quant/eval.py b/src/brevitas_examples/llm/llm_quant/eval.py index 0691e5cfa..27ef16c97 100644 --- a/src/brevitas_examples/llm/llm_quant/eval.py +++ b/src/brevitas_examples/llm/llm_quant/eval.py @@ -41,6 +41,7 @@ def model_eval(model, valenc, seqlen): shift_logits = lm_logits[:, :-1, :].contiguous() dev = shift_logits.device shift_labels = inps['input_ids'][:, 1:].to(dev) + shift_logits = shift_logits.to(dev) loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) neg_log_likelihood = loss.float() * seqlen diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index b8c19a9d7..05d84f647 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -81,7 +81,8 @@ '--weight-quant-format', type=quant_format_validator, default='int', - help='Weight quantization type. Either int or eXmY, with X+Y==weight_bit_width-1. Default: int.' + help= + 'Weight quantization type. Either int or eXmY, with X+Y==weight_bit_width-1. It\'s possible to add float_ocp_ or float_fnuz_ before the exponent/mantissa bitwidth. Default: int.' ) parser.add_argument( '--weight-quant-granularity', @@ -105,7 +106,9 @@ '--input-quant-format', type=quant_format_validator, default='int', - help='Input quantization type. Either int or eXmY, with X+Y==weight_bit_width-1. Default: int.') + help= + 'Input quantization type. Either int or eXmY, with X+Y==weight_bit_width-1. It\'s possible to add float_ocp_ or float_fnuz_ before the exponent/mantissa bitwidth. Default: int.' +) parser.add_argument( '--input-param-method', type=str, @@ -187,10 +190,6 @@ type=str, default=None, help="Filename to save checkpoint. If `None`, no checkpoint is saved (default: %(default)s)") -add_bool_arg( - parser, 'use-ocp', default=False, help='Use OCP format for float quantization. Default: False') -add_bool_arg( - parser, 'use-fnuz', default=True, help='Use FNUZ format for float quantization. Default: True') def set_seed(seed): @@ -253,9 +252,7 @@ def validate(args): assert args.quantize_weight_zero_point, "Quantized weight zero point required." if args.input_bit_width is not None and args.input_quant_type == 'asym': assert args.quantize_input_zero_point, "Quantized input zero point required." - if (args.input_bit_width and - (args.input_scale_type == 'static' or - (args.input_scale_type == 'dynamic' and args.input_quant_type == 'asym'))): + if args.input_bit_width and args.input_scale_type == 'static': assert args.act_calibration, "Static input quantization is being applied without activation calibration. Set --act-calibration." if (args.weight_equalization or args.act_equalization == 'fx'): if args.replace_mha: @@ -387,8 +384,6 @@ def main(): input_quant_granularity=args.input_quant_granularity, input_group_size=args.input_group_size, quantize_input_zero_point=args.quantize_input_zero_point, - use_ocp=args.use_ocp, - use_fnuz=args.use_fnuz, device=device) layer_map = generate_quant_maps( linear_input_quant=linear_input_quant, diff --git a/src/brevitas_examples/stable_diffusion/README.md b/src/brevitas_examples/stable_diffusion/README.md index a51a06df5..1cc5374e8 100644 --- a/src/brevitas_examples/stable_diffusion/README.md +++ b/src/brevitas_examples/stable_diffusion/README.md @@ -97,17 +97,15 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--quantize-input-zero-point | --no-quantize-input-zero-point] [--export-cpu-float32 | --no-export-cpu-float32] [--use-mlperf-inference | --no-use-mlperf-inference] - [--use-ocp | --no-use-ocp] [--use-fnuz | --no-use-fnuz] [--use-negative-prompts | --no-use-negative-prompts] - [--dry-run | --no-dry-run] - [--quantize-sdp-1 | --no-quantize-sdp-1] - [--quantize-sdp-2 | --no-quantize-sdp-2] + [--dry-run | --no-dry-run] [--quantize-sdp | --no-quantize-sdp] [--override-conv-quant-config | --no-override-conv-quant-config] [--vae-fp16-fix | --no-vae-fp16-fix] + [--share-qkv-quant | --no-share-qkv-quant] Stable Diffusion quantization -optional arguments: +options: -h, --help show this help message and exit -m MODEL, --model MODEL Path or name of the model. @@ -203,10 +201,14 @@ optional arguments: Input quantization type. Default: asym. --weight-quant-format WEIGHT_QUANT_FORMAT Weight quantization type. Either int or eXmY, with - X+Y==weight_bit_width-1. Default: int. + X+Y==weight_bit_width-1. It's possible to add + float_ocp_ or float_fnuz_ before the exponent/mantissa + bitwidth. Default: int. --input-quant-format INPUT_QUANT_FORMAT Input quantization type. Either int or eXmY, with - X+Y==input_bit_width-1. Default: int. + X+Y==input_bit_width-1. It's possible to add + float_ocp_ or float_fnuz_ before the exponent/mantissa + bitwidth. Default: int. --weight-quant-granularity {per_channel,per_tensor,per_group} Granularity for scales/zero-point of weights. Default: per_channel. @@ -242,14 +244,6 @@ optional arguments: --no-use-mlperf-inference Disable Evaluate FID score with MLPerf pipeline. Default: False - --use-ocp Enable Use OCP format for float quantization. Default: - True - --no-use-ocp Disable Use OCP format for float quantization. - Default: True - --use-fnuz Enable Use FNUZ format for float quantization. - Default: True - --no-use-fnuz Disable Use FNUZ format for float quantization. - Default: True --use-negative-prompts Enable Use negative prompts during generation/calibration. Default: Enabled @@ -260,10 +254,8 @@ optional arguments: calibration. Default: Disabled --no-dry-run Disable Generate a quantized model without any calibration. Default: Disabled - --quantize-sdp-1 Enable Quantize SDP. Default: Disabled - --no-quantize-sdp-1 Disable Quantize SDP. Default: Disabled - --quantize-sdp-2 Enable Quantize SDP. Default: Disabled - --no-quantize-sdp-2 Disable Quantize SDP. Default: Disabled + --quantize-sdp Enable Quantize SDP. Default: Disabled + --no-quantize-sdp Disable Quantize SDP. Default: Disabled --override-conv-quant-config Enable Quantize Convolutions in the same way as SDP (i.e., FP8). Default: Disabled @@ -274,4 +266,7 @@ optional arguments: Default: Disabled --no-vae-fp16-fix Disable Rescale the VAE to not go NaN with FP16. Default: Disabled + --share-qkv-quant Enable Share QKV/KV quantization. Default: Disabled + --no-share-qkv-quant Disable Share QKV/KV quantization. Default: Disabled + ``` diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index d370a67df..42d874b4f 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -336,8 +336,6 @@ def input_zp_stats_type(): input_param_method=args.input_param_method, input_quant_type=args.input_quant_type, input_quant_granularity=args.input_quant_granularity, - use_ocp=args.use_ocp, - use_fnuz=args.use_fnuz, input_kwargs=input_kwargs) layer_map = generate_quant_maps( @@ -366,7 +364,7 @@ def input_zp_stats_type(): dtype=dtype, device=args.device, weight_bit_width=weight_bit_width, - weight_quant_format='e4m3', + weight_quant_format='float_ocp_e4m3', weight_quant_type='sym', weight_param_method=args.weight_param_method, weight_scale_precision=args.weight_scale_precision, @@ -375,14 +373,12 @@ def input_zp_stats_type(): quantize_weight_zero_point=args.quantize_weight_zero_point, quantize_input_zero_point=args.quantize_input_zero_point, input_bit_width=args.linear_output_bit_width, - input_quant_format='e4m3', + input_quant_format='float_ocp_e4m3', input_scale_type=args.input_scale_type, input_scale_precision=args.input_scale_precision, input_param_method=args.input_param_method, input_quant_type='sym', input_quant_granularity=args.input_quant_granularity, - use_ocp=args.use_ocp, - use_fnuz=args.use_fnuz, input_kwargs=input_kwargs) # We generate all quantizers, but we are only interested in activation quantization for # the output of softmax and the output of QKV @@ -763,13 +759,15 @@ def input_zp_stats_type(): type=quant_format_validator, default='int', help= - 'Weight quantization type. Either int or eXmY, with X+Y==weight_bit_width-1. Default: int.') + 'Weight quantization type. Either int or eXmY, with X+Y==weight_bit_width-1. It\'s possible to add float_ocp_ or float_fnuz_ before the exponent/mantissa bitwidth. Default: int.' + ) parser.add_argument( '--input-quant-format', type=quant_format_validator, default='int', help= - 'Input quantization type. Either int or eXmY, with X+Y==input_bit_width-1. Default: int.') + 'Input quantization type. Either int or eXmY, with X+Y==input_bit_width-1. It\'s possible to add float_ocp_ or float_fnuz_ before the exponent/mantissa bitwidth. Default: int.' + ) parser.add_argument( '--weight-quant-granularity', type=str, @@ -815,16 +813,6 @@ def input_zp_stats_type(): 'use-mlperf-inference', default=False, help='Evaluate FID score with MLPerf pipeline. Default: False') - add_bool_arg( - parser, - 'use-ocp', - default=False, - help='Use OCP format for float quantization. Default: True') - add_bool_arg( - parser, - 'use-fnuz', - default=True, - help='Use FNUZ format for float quantization. Default: True') add_bool_arg( parser, 'use-negative-prompts',