Skip to content

Commit

Permalink
Small cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 24, 2023
1 parent 89795e6 commit 58c22f3
Showing 1 changed file with 32 additions and 69 deletions.
101 changes: 32 additions & 69 deletions src/brevitas_examples/llm/llm_quant/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,83 +259,46 @@ def __init__(self):
register_custom_op_symbolic('::MatMulNBitsFn', MatMulNBitsFn.symbolic, 1)

def pack_int_weights(self, bit_width, int_weights, zero_point):
assert int_weights.dtype in [torch.uint8], "Packing requires (u)int8 input."
zero_point = zero_point.to(torch.uint8).flatten()
assert int_weights.dtype in [torch.uint8, torch.int8], "Packing requires (u)int8 input."
assert bit_width == 4, "Only 4 bit quantization export is supported at the moment"

is_symmetric = torch.sum(zero_point) == 0
zero_point = zero_point.to(torch.uint8)
rows, cols = int_weights.shape
block_size = self.group_size
blob_size = block_size // 2
k_blocks = (rows + block_size - 1) // block_size
padded_rows = k_blocks * block_size
pad_len = padded_rows - rows

# ONNX operator assumes implicit zp of 8 (largest negative number in Po2)
# If we are in a "symmetric" quantized scenario, we need to add this implicit zero point
# Otherwise it has already been added during the convesion to integer
zp = 0 if not int_weights.dtype == torch.int8 else 8
int_weights += zp
if pad_len > 0:
int_weights = torch.nn.functional(int_weights, (0, 0, 0, pad_len))
if bit_width == 8:
return int_weights
elif bit_width == 4 or bit_width == 2:
packed_int_weights = torch.zeros((k_blocks * blob_size, cols),
device=int_weights.device,
dtype=torch.uint8)
packed_zp = torch.zeros((zero_point.shape[0] + 1) // 2,
device=int_weights.device,
dtype=torch.uint8)
i = 0
for column in range(packed_int_weights.shape[0]):
# Compared to the reference below we don't transpose the matrix and we pack into 8b data rather than 32b
# https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/05781593c818d4dc8adc2d32c975e83d17d2b9a8/quant/quant_linear.py#L346
for j in range(i, i + (8 // bit_width)):
shift_factor = (bit_width * (j - i))
packed_int_weights[column, :] |= int_weights[j, :] << shift_factor
i += 8 // bit_width
packed_int_weights = packed_int_weights.t()
packed_int_weights = packed_int_weights.reshape(-1, k_blocks, blob_size)
i = 0
for column in range(packed_zp.shape[0]):
# Compared to the reference below we don't transpose the matrix and we pack into 8b data rather than 32b
# https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/05781593c818d4dc8adc2d32c975e83d17d2b9a8/quant/quant_linear.py#L346
for j in range(i, i + (8 // bit_width)):
shift_factor = (bit_width * (j - i))
packed_zp[column] |= zero_point[j] << shift_factor
i += 8 // bit_width
return packed_int_weights, packed_zp
else:
raise RuntimeError("Only 4 and 8 bit quantization export is supported at the moment")

# # pack 3b values into 3 bytes, 5b values into 5 bytes, 6b values into 4 bytes
# elif bit_width == 3 or bit_width == 5 or bit_width == 6:
# padding = (int_weights.shape[1] * bit_width) % 8
# if padding > 0:
# warnings.warn(
# f"Weight tensor does not divide by {bit_width}, zero-padding columns by {padding}."
# )
# packed_int_weights = torch.zeros(
# (int_weights.shape[0], (int_weights.shape[1] * bit_width + padding) // 8),
# device=int_weights.device,
# dtype=int_weights.dtype)

# def lcm(x, y):
# from fractions import gcd
# return x * y // gcd(x, y)

# num_packed_bits = lcm(bit_width, 8)
# num_packed_bytes = num_packed_bits // 8
# num_packed_elems = num_packed_bits // bit_width

# i = 0
# for column in range(0, packed_int_weights.shape[1], num_packed_bytes):
# # cast to uint8 since it's the only dtype supported by unpackbits
# # the bit-wise representation of int8 values isn't affected
# bits_to_unpack = int_weights[:, i:i + num_packed_elems].numpy().astype(np.uint8)
# unpacked_bits = np.unpackbits(bits_to_unpack, axis=1)
# unpacked_bits = unpacked_bits.reshape(unpacked_bits.shape[0], -1, 8)
# unpacked_bits = unpacked_bits[:, :, -bit_width:]
# unpacked_bits = unpacked_bits.reshape(unpacked_bits.shape[0], -1)
# packed_bits = np.packbits(unpacked_bits, axis=1)
# packed_int_weights[:, column:column +
# num_packed_bytes] |= torch.from_numpy(packed_bits)
# i += num_packed_elems
# return packed_int_weights
# else:
# raise ValueError(f"Bit width {bit_width} not supported.")
packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8")
rows, cols = int_weights.shape
int_weights = int_weights.t()
for n in range(cols):
for k_id in range(0, rows, block_size):
blk_int0 = (int_weights[n, k_id:k_id + block_size:2].numpy()).astype("uint8")
blk_int1 = (int_weights[n, k_id + 1:k_id + block_size:2].numpy()).astype("uint8")
packed[n, k_id // block_size] = np.bitwise_or(blk_int0, np.left_shift(blk_int1, 4))

zero_point = zero_point.to(torch.uint8).flatten()
base_zp = 136 if is_symmetric else 0
packed_zp = base_zp * torch.ones(
(zero_point.shape[0] + 1) // 2, device=int_weights.device, dtype=torch.uint8)

i = 0
for column in range(packed_zp.shape[0]):
for j in range(i, i + (8 // bit_width)):
shift_factor = (bit_width * (j - i))
packed_zp[column] |= zero_point[j] << shift_factor
i += 8 // bit_width
return torch.tensor(packed), packed_zp

def prepare_for_export(self, module):
self.bit_width = self.bit_width_impl(module.weight_quant)()
Expand Down

0 comments on commit 58c22f3

Please sign in to comment.