diff --git a/src/brevitas/export/onnx/standard/function.py b/src/brevitas/export/onnx/standard/function.py index 1a1de5b15..30e7cf3f7 100644 --- a/src/brevitas/export/onnx/standard/function.py +++ b/src/brevitas/export/onnx/standard/function.py @@ -1,10 +1,40 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import torch from torch.autograd import Function +from torch.onnx.symbolic_helper import _get_tensor_sizes from brevitas.export.onnx import onnx_export_opset + +class MatMulNBitsFn(Function): + + @staticmethod + def symbolic(g, x, int_weights, scales, zero_points, K, N, bits, block_size): + ret = g.op( + 'com.microsoft::MatMulNBits', + x, + int_weights, + scales, + zero_points, + K_i=K, + N_i=N, + bits_i=bits, + block_size_i=block_size) + output_size = _get_tensor_sizes(x) + output_size[-1] = N + ret.setType(x.type().with_sizes(output_size)) + return ret + + @staticmethod + def forward(g, x, int_weights, scales, zero_points, K, N, bits, block_size): + shape = x.shape + out_shape = list(shape) + out_shape[-1] = N + return torch.empty(out_shape) + + AXIS_OPSET = 13 diff --git a/src/brevitas_examples/llm/llm_quant/export.py b/src/brevitas_examples/llm/llm_quant/export.py index 2d99f17e9..068e75266 100644 --- a/src/brevitas_examples/llm/llm_quant/export.py +++ b/src/brevitas_examples/llm/llm_quant/export.py @@ -6,12 +6,12 @@ from abc import ABC from abc import abstractmethod from contextlib import contextmanager -import math import warnings import numpy as np import torch from torch.nn import Module +from torch.onnx import register_custom_op_symbolic from brevitas.export.common.handler.base import BaseHandler from brevitas.export.manager import _set_layer_export_handler @@ -19,6 +19,8 @@ from brevitas.export.manager import _set_proxy_export_handler from brevitas.export.manager import _set_proxy_export_mode from brevitas.export.manager import BaseManager +from brevitas.export.onnx.handler import ONNXBaseHandler +from brevitas.export.onnx.standard.function import MatMulNBitsFn from brevitas.nn import QuantLinear from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector @@ -47,8 +49,11 @@ def export_scale(self, proxy_module, bit_width): scaling_impl = self.scaling_impl(proxy_module) int_scaling_impl = proxy_module.tensor_quant.int_scaling_impl int_threshold = int_scaling_impl(bit_width) - threshold = scaling_impl.wrapped_scaling_impl.stats_scaling_impl( - scaling_impl.wrapped_scaling_impl.parameter_list_stats()) + if hasattr(scaling_impl, 'wrapped_scaling_impl'): + threshold = scaling_impl.wrapped_scaling_impl.stats_scaling_impl( + scaling_impl.wrapped_scaling_impl.parameter_list_stats()) + else: + threshold = scaling_impl.stats_scaling_impl(scaling_impl.parameter_list_stats()) return threshold / int_threshold def export_zero_point(self, proxy_module, scale, bit_width): @@ -243,3 +248,124 @@ def replace_call_fn_target(graph_model, src, target): node.target = target graph_model.graph.lint() graph_model.recompile() + + +class ONNXLinearWeightBlockQuantHandlerFwd(ONNXBaseHandler, WeightBlockQuantHandlerBase): + handled_layer = QuantLinear + + def __init__(self): + super(ONNXLinearWeightBlockQuantHandlerFwd, self).__init__() + self.group_size = None + register_custom_op_symbolic('::MatMulNBitsFn', MatMulNBitsFn.symbolic, 1) + + def pack_int_weights(self, bit_width, int_weights, zero_point): + assert int_weights.dtype in [torch.uint8], "Packing requires (u)int8 input." + zero_point = zero_point.to(torch.uint8).flatten() + rows, cols = int_weights.shape + block_size = self.group_size + blob_size = block_size // 2 + k_blocks = (rows + block_size - 1) // block_size + padded_rows = k_blocks * block_size + pad_len = padded_rows - rows + if pad_len > 0: + int_weights = torch.nn.functional(int_weights, (0, 0, 0, pad_len)) + if bit_width == 8: + return int_weights + elif bit_width == 4 or bit_width == 2: + packed_int_weights = torch.zeros((k_blocks * blob_size, cols), + device=int_weights.device, + dtype=torch.uint8) + packed_zp = torch.zeros((zero_point.shape[0] + 1) // 2, + device=int_weights.device, + dtype=torch.uint8) + i = 0 + for column in range(packed_int_weights.shape[0]): + # Compared to the reference below we don't transpose the matrix and we pack into 8b data rather than 32b + # https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/05781593c818d4dc8adc2d32c975e83d17d2b9a8/quant/quant_linear.py#L346 + for j in range(i, i + (8 // bit_width)): + shift_factor = (bit_width * (j - i)) + packed_int_weights[column, :] |= int_weights[j, :] << shift_factor + i += 8 // bit_width + packed_int_weights = packed_int_weights.t() + packed_int_weights = packed_int_weights.reshape(-1, k_blocks, blob_size) + i = 0 + for column in range(packed_zp.shape[0]): + # Compared to the reference below we don't transpose the matrix and we pack into 8b data rather than 32b + # https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/05781593c818d4dc8adc2d32c975e83d17d2b9a8/quant/quant_linear.py#L346 + for j in range(i, i + (8 // bit_width)): + shift_factor = (bit_width * (j - i)) + packed_zp[column] |= zero_point[j] << shift_factor + i += 8 // bit_width + return packed_int_weights, packed_zp + else: + raise RuntimeError("Only 4 and 8 bit quantization export is supported at the moment") + + # # pack 3b values into 3 bytes, 5b values into 5 bytes, 6b values into 4 bytes + # elif bit_width == 3 or bit_width == 5 or bit_width == 6: + # padding = (int_weights.shape[1] * bit_width) % 8 + # if padding > 0: + # warnings.warn( + # f"Weight tensor does not divide by {bit_width}, zero-padding columns by {padding}." + # ) + # packed_int_weights = torch.zeros( + # (int_weights.shape[0], (int_weights.shape[1] * bit_width + padding) // 8), + # device=int_weights.device, + # dtype=int_weights.dtype) + + # def lcm(x, y): + # from fractions import gcd + # return x * y // gcd(x, y) + + # num_packed_bits = lcm(bit_width, 8) + # num_packed_bytes = num_packed_bits // 8 + # num_packed_elems = num_packed_bits // bit_width + + # i = 0 + # for column in range(0, packed_int_weights.shape[1], num_packed_bytes): + # # cast to uint8 since it's the only dtype supported by unpackbits + # # the bit-wise representation of int8 values isn't affected + # bits_to_unpack = int_weights[:, i:i + num_packed_elems].numpy().astype(np.uint8) + # unpacked_bits = np.unpackbits(bits_to_unpack, axis=1) + # unpacked_bits = unpacked_bits.reshape(unpacked_bits.shape[0], -1, 8) + # unpacked_bits = unpacked_bits[:, :, -bit_width:] + # unpacked_bits = unpacked_bits.reshape(unpacked_bits.shape[0], -1) + # packed_bits = np.packbits(unpacked_bits, axis=1) + # packed_int_weights[:, column:column + + # num_packed_bytes] |= torch.from_numpy(packed_bits) + # i += num_packed_elems + # return packed_int_weights + # else: + # raise ValueError(f"Bit width {bit_width} not supported.") + + def prepare_for_export(self, module): + self.bit_width = self.bit_width_impl(module.weight_quant)() + assert self.bit_width <= 8., "Only 8b or lower is supported." + quant_weight = module.quant_weight() + self.bias = module.bias + self.scale = self.export_scale(module.weight_quant, self.bit_width) + if (quant_weight.zero_point != 0.).any(): + self.zero_point = self.export_zero_point( + module.weight_quant, self.scale, self.bit_width) + else: + # if there is no zero-point, export zeroes in the shape of scale + self.zero_point = torch.zeros_like(self.scale) + self.group_size = module.weight_quant.quant_injector.block_size + self.bit_width = int(self.bit_width.cpu().item()) + self.int_weight, self.zero_point = self.pack_int_weights(self.bit_width, quant_weight.int().t().detach(), self.zero_point) + self.weight_shape = module.weight.shape + + def symbolic_execution(self, x): + int_weights = self.int_weight + scale = self.scale + bit_width = self.bit_width + N, K = self.weight_shape + out = MatMulNBitsFn.apply( + x, int_weights, scale.flatten(), self.zero_point, K, N, bit_width, self.group_size) + return out + + +def export_packed_onnx(model, input, export_path): + export_class = block_quant_layer_level_manager( + export_handlers=[ONNXLinearWeightBlockQuantHandlerFwd]) + with torch.inference_mode(), brevitas_layer_export_mode(model, export_class): + torch.onnx.export(model, input, export_path) diff --git a/src/brevitas_examples/llm/llm_quant/quantize.py b/src/brevitas_examples/llm/llm_quant/quantize.py index 647dd35a2..1a61711c2 100644 --- a/src/brevitas_examples/llm/llm_quant/quantize.py +++ b/src/brevitas_examples/llm/llm_quant/quantize.py @@ -156,7 +156,7 @@ def quantize_model( weight_quant_format = 'float' else: weight_float_format = {} - if re.compile(r'e[1-8]m[1-8]').match(input_quant_format): + if input_quant_format is not None and re.compile(r'e[1-8]m[1-8]').match(input_quant_format): input_float_format = { 'exponent_bit_width': int(input_quant_format[1]), 'mantissa_bit_width': int(input_quant_format[3])} diff --git a/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py b/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py index a234c86d0..471359e22 100644 --- a/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py +++ b/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py @@ -40,6 +40,7 @@ from pathlib import Path import re from typing import List +import warnings import torch from torch._decomp import get_decompositions @@ -49,6 +50,9 @@ from brevitas.backport.fx._symbolic_trace import wrap from brevitas.backport.fx.experimental.proxy_tensor import make_fx +from brevitas.export.onnx.handler import ONNXBaseHandler +from brevitas.export.onnx.standard.function import MatMulNBitsFn +from brevitas.nn.quant_linear import QuantLinear from brevitas_examples.llm.llm_quant.export import block_quant_layer_level_manager from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager from brevitas_examples.llm.llm_quant.export import brevitas_layer_export_mode diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 5f640dcda..f46facf55 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -18,6 +18,7 @@ from brevitas_examples.llm.llm_quant.equalize import apply_act_equalization from brevitas_examples.llm.llm_quant.equalize import apply_weight_equalization from brevitas_examples.llm.llm_quant.eval import model_eval +from brevitas_examples.llm.llm_quant.export import export_packed_onnx from brevitas_examples.llm.llm_quant.gptq import apply_gptq from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers @@ -169,6 +170,7 @@ def __call__(self, value): choices=[ None, 'onnx_qcdq', + 'packed_onnx', 'torch_qcdq', 'sharded_torchmlir_group_weight', 'sharded_packed_torchmlir_group_weight'], @@ -189,6 +191,8 @@ def model_export(model, ref_input, args): from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import \ sharded_weight_group_export sharded_weight_group_export(model, no_custom_packed_export=False) + elif args.export_target == 'packed_onnx': + export_packed_onnx(model, ref_input, export_path=f"{args.model.replace('/', '-')}.onnx") elif args.export_target == 'onnx_qcdq': export_onnx_qcdq(model, ref_input, export_path=f"{args.model.replace('/', '-')}.onnx") elif args.export_target == 'torch_qcdq':