Skip to content

Commit

Permalink
Feat (llm): export to MatMulNBits
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 20, 2023
1 parent b051309 commit 7cd9d17
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 4 deletions.
30 changes: 30 additions & 0 deletions src/brevitas/export/onnx/standard/function.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,40 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import torch
from torch.autograd import Function
from torch.onnx.symbolic_helper import _get_tensor_sizes

from brevitas.export.onnx import onnx_export_opset


class MatMulNBitsFn(Function):

@staticmethod
def symbolic(g, x, int_weights, scales, zero_points, K, N, bits, block_size):
ret = g.op(
'com.microsoft::MatMulNBits',
x,
int_weights,
scales,
zero_points,
K_i=K,
N_i=N,
bits_i=bits,
block_size_i=block_size)
output_size = _get_tensor_sizes(x)
output_size[-1] = N
ret.setType(x.type().with_sizes(output_size))
return ret

@staticmethod
def forward(g, x, int_weights, scales, zero_points, K, N, bits, block_size):
shape = x.shape
out_shape = list(shape)
out_shape[-1] = N
return torch.empty(out_shape)


AXIS_OPSET = 13


Expand Down
132 changes: 129 additions & 3 deletions src/brevitas_examples/llm/llm_quant/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@
from abc import ABC
from abc import abstractmethod
from contextlib import contextmanager
import math
import warnings

import numpy as np
import torch
from torch.nn import Module
from torch.onnx import register_custom_op_symbolic

from brevitas.export.common.handler.base import BaseHandler
from brevitas.export.manager import _set_layer_export_handler
from brevitas.export.manager import _set_layer_export_mode
from brevitas.export.manager import _set_proxy_export_handler
from brevitas.export.manager import _set_proxy_export_mode
from brevitas.export.manager import BaseManager
from brevitas.export.onnx.handler import ONNXBaseHandler
from brevitas.export.onnx.standard.function import MatMulNBitsFn
from brevitas.nn import QuantLinear
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector

Expand Down Expand Up @@ -47,8 +49,11 @@ def export_scale(self, proxy_module, bit_width):
scaling_impl = self.scaling_impl(proxy_module)
int_scaling_impl = proxy_module.tensor_quant.int_scaling_impl
int_threshold = int_scaling_impl(bit_width)
threshold = scaling_impl.wrapped_scaling_impl.stats_scaling_impl(
scaling_impl.wrapped_scaling_impl.parameter_list_stats())
if hasattr(scaling_impl, 'wrapped_scaling_impl'):
threshold = scaling_impl.wrapped_scaling_impl.stats_scaling_impl(
scaling_impl.wrapped_scaling_impl.parameter_list_stats())
else:
threshold = scaling_impl.stats_scaling_impl(scaling_impl.parameter_list_stats())
return threshold / int_threshold

def export_zero_point(self, proxy_module, scale, bit_width):
Expand Down Expand Up @@ -243,3 +248,124 @@ def replace_call_fn_target(graph_model, src, target):
node.target = target
graph_model.graph.lint()
graph_model.recompile()


class ONNXLinearWeightBlockQuantHandlerFwd(ONNXBaseHandler, WeightBlockQuantHandlerBase):
handled_layer = QuantLinear

def __init__(self):
super(ONNXLinearWeightBlockQuantHandlerFwd, self).__init__()
self.group_size = None
register_custom_op_symbolic('::MatMulNBitsFn', MatMulNBitsFn.symbolic, 1)

def pack_int_weights(self, bit_width, int_weights, zero_point):
assert int_weights.dtype in [torch.uint8], "Packing requires (u)int8 input."
zero_point = zero_point.to(torch.uint8).flatten()
rows, cols = int_weights.shape
block_size = self.group_size
blob_size = block_size // 2
k_blocks = (rows + block_size - 1) // block_size
padded_rows = k_blocks * block_size
pad_len = padded_rows - rows
if pad_len > 0:
int_weights = torch.nn.functional(int_weights, (0, 0, 0, pad_len))
if bit_width == 8:
return int_weights
elif bit_width == 4 or bit_width == 2:
packed_int_weights = torch.zeros((k_blocks * blob_size, cols),
device=int_weights.device,
dtype=torch.uint8)
packed_zp = torch.zeros((zero_point.shape[0] + 1) // 2,
device=int_weights.device,
dtype=torch.uint8)
i = 0
for column in range(packed_int_weights.shape[0]):
# Compared to the reference below we don't transpose the matrix and we pack into 8b data rather than 32b
# https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/05781593c818d4dc8adc2d32c975e83d17d2b9a8/quant/quant_linear.py#L346
for j in range(i, i + (8 // bit_width)):
shift_factor = (bit_width * (j - i))
packed_int_weights[column, :] |= int_weights[j, :] << shift_factor
i += 8 // bit_width
packed_int_weights = packed_int_weights.t()
packed_int_weights = packed_int_weights.reshape(-1, k_blocks, blob_size)
i = 0
for column in range(packed_zp.shape[0]):
# Compared to the reference below we don't transpose the matrix and we pack into 8b data rather than 32b
# https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/05781593c818d4dc8adc2d32c975e83d17d2b9a8/quant/quant_linear.py#L346
for j in range(i, i + (8 // bit_width)):
shift_factor = (bit_width * (j - i))
packed_zp[column] |= zero_point[j] << shift_factor
i += 8 // bit_width
return packed_int_weights, packed_zp
else:
raise RuntimeError("Only 4 and 8 bit quantization export is supported at the moment")

# # pack 3b values into 3 bytes, 5b values into 5 bytes, 6b values into 4 bytes
# elif bit_width == 3 or bit_width == 5 or bit_width == 6:
# padding = (int_weights.shape[1] * bit_width) % 8
# if padding > 0:
# warnings.warn(
# f"Weight tensor does not divide by {bit_width}, zero-padding columns by {padding}."
# )
# packed_int_weights = torch.zeros(
# (int_weights.shape[0], (int_weights.shape[1] * bit_width + padding) // 8),
# device=int_weights.device,
# dtype=int_weights.dtype)

# def lcm(x, y):
# from fractions import gcd
# return x * y // gcd(x, y)

# num_packed_bits = lcm(bit_width, 8)
# num_packed_bytes = num_packed_bits // 8
# num_packed_elems = num_packed_bits // bit_width

# i = 0
# for column in range(0, packed_int_weights.shape[1], num_packed_bytes):
# # cast to uint8 since it's the only dtype supported by unpackbits
# # the bit-wise representation of int8 values isn't affected
# bits_to_unpack = int_weights[:, i:i + num_packed_elems].numpy().astype(np.uint8)
# unpacked_bits = np.unpackbits(bits_to_unpack, axis=1)
# unpacked_bits = unpacked_bits.reshape(unpacked_bits.shape[0], -1, 8)
# unpacked_bits = unpacked_bits[:, :, -bit_width:]
# unpacked_bits = unpacked_bits.reshape(unpacked_bits.shape[0], -1)
# packed_bits = np.packbits(unpacked_bits, axis=1)
# packed_int_weights[:, column:column +
# num_packed_bytes] |= torch.from_numpy(packed_bits)
# i += num_packed_elems
# return packed_int_weights
# else:
# raise ValueError(f"Bit width {bit_width} not supported.")

def prepare_for_export(self, module):
self.bit_width = self.bit_width_impl(module.weight_quant)()
assert self.bit_width <= 8., "Only 8b or lower is supported."
quant_weight = module.quant_weight()
self.bias = module.bias
self.scale = self.export_scale(module.weight_quant, self.bit_width)
if (quant_weight.zero_point != 0.).any():
self.zero_point = self.export_zero_point(
module.weight_quant, self.scale, self.bit_width)
else:
# if there is no zero-point, export zeroes in the shape of scale
self.zero_point = torch.zeros_like(self.scale)
self.group_size = module.weight_quant.quant_injector.block_size
self.bit_width = int(self.bit_width.cpu().item())
self.int_weight, self.zero_point = self.pack_int_weights(self.bit_width, quant_weight.int().t().detach(), self.zero_point)
self.weight_shape = module.weight.shape

def symbolic_execution(self, x):
int_weights = self.int_weight
scale = self.scale
bit_width = self.bit_width
N, K = self.weight_shape
out = MatMulNBitsFn.apply(
x, int_weights, scale.flatten(), self.zero_point, K, N, bit_width, self.group_size)
return out


def export_packed_onnx(model, input, export_path):
export_class = block_quant_layer_level_manager(
export_handlers=[ONNXLinearWeightBlockQuantHandlerFwd])
with torch.inference_mode(), brevitas_layer_export_mode(model, export_class):
torch.onnx.export(model, input, export_path)
2 changes: 1 addition & 1 deletion src/brevitas_examples/llm/llm_quant/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def quantize_model(
weight_quant_format = 'float'
else:
weight_float_format = {}
if re.compile(r'e[1-8]m[1-8]').match(input_quant_format):
if input_quant_format is not None and re.compile(r'e[1-8]m[1-8]').match(input_quant_format):
input_float_format = {
'exponent_bit_width': int(input_quant_format[1]),
'mantissa_bit_width': int(input_quant_format[3])}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from pathlib import Path
import re
from typing import List
import warnings

import torch
from torch._decomp import get_decompositions
Expand All @@ -49,6 +50,9 @@

from brevitas.backport.fx._symbolic_trace import wrap
from brevitas.backport.fx.experimental.proxy_tensor import make_fx
from brevitas.export.onnx.handler import ONNXBaseHandler
from brevitas.export.onnx.standard.function import MatMulNBitsFn
from brevitas.nn.quant_linear import QuantLinear
from brevitas_examples.llm.llm_quant.export import block_quant_layer_level_manager
from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager
from brevitas_examples.llm.llm_quant.export import brevitas_layer_export_mode
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from brevitas_examples.llm.llm_quant.equalize import apply_act_equalization
from brevitas_examples.llm.llm_quant.equalize import apply_weight_equalization
from brevitas_examples.llm.llm_quant.eval import model_eval
from brevitas_examples.llm.llm_quant.export import export_packed_onnx
from brevitas_examples.llm.llm_quant.gptq import apply_gptq
from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge
from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers
Expand Down Expand Up @@ -169,6 +170,7 @@ def __call__(self, value):
choices=[
None,
'onnx_qcdq',
'packed_onnx',
'torch_qcdq',
'sharded_torchmlir_group_weight',
'sharded_packed_torchmlir_group_weight'],
Expand All @@ -189,6 +191,8 @@ def model_export(model, ref_input, args):
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import \
sharded_weight_group_export
sharded_weight_group_export(model, no_custom_packed_export=False)
elif args.export_target == 'packed_onnx':
export_packed_onnx(model, ref_input, export_path=f"{args.model.replace('/', '-')}.onnx")
elif args.export_target == 'onnx_qcdq':
export_onnx_qcdq(model, ref_input, export_path=f"{args.model.replace('/', '-')}.onnx")
elif args.export_target == 'torch_qcdq':
Expand Down

0 comments on commit 7cd9d17

Please sign in to comment.