Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use faster operations on packed-quantized, add tests #211

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,20 @@ def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
"""
Packs a tensor of quantized weights stored in int8 into int32s with padding

Pseudocode:
1. Shift wrt num_bits to convert to unsigned. num_bits=8
[1,2] -> [129, 130]
2. Pad to fill in 32 bits
[129, 130] -> [129, 130, 0, 0]
3. convert to binary align in order
[129, 130, 0, 0] -> 00000000 00000000 10000010 10000001
4. convert aligned binary to number
00000000000000001000001010000001 -> 33409
5. covert back to uint32
33409 -> 33409

:param value: tensor to pack
:param num_bits: number of bits used to store underlying data
:param num_bits: number of bits used to store underlying data, must be at least 1
horheynm marked this conversation as resolved.
Show resolved Hide resolved
:returns: packed int32 tensor
"""
if value.dtype is not torch.int8:
Expand All @@ -146,19 +158,22 @@ def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
if num_bits > 8:
raise ValueError("Packing is only supported for less than 8 bits")

if num_bits < 1:
raise ValueError(f"num_bits must be at least 1, got {num_bits}")

# convert to unsigned for packing
offset = pow(2, num_bits) // 2
offset = 1 << (num_bits - 1)
value = (value + offset).to(torch.uint8)
value = value.cpu().numpy().astype(np.uint32)
pack_factor = 32 // num_bits

# pad input tensor and initialize packed output
packed_size = math.ceil(value.shape[1] / pack_factor)
packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32)
padding = packed.shape[1] * pack_factor - value.shape[1]
padding = packed_size * pack_factor - value.shape[1]
horheynm marked this conversation as resolved.
Show resolved Hide resolved
value = np.pad(value, pad_width=[(0, 0), (0, padding)], constant_values=0)

# pack values
packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32)
for i in range(pack_factor):
packed |= value[:, i::pack_factor] << num_bits * i

Expand All @@ -172,7 +187,9 @@ def unpack_from_int32(
) -> torch.Tensor:
"""
Unpacks a tensor of packed int32 weights into individual int8s, maintaining the
original their bit range
original bit range.

Return tensors in int8

:param value: tensor to upack
:param num_bits: number of bits to unpack each data point into
Expand All @@ -190,7 +207,7 @@ def unpack_from_int32(
pack_factor = 32 // num_bits

# unpack
mask = pow(2, num_bits) - 1
mask = (1 << num_bits) - 1
unpacked = torch.zeros(
(value.shape[0], value.shape[1] * pack_factor),
device=value.device,
Expand Down
162 changes: 162 additions & 0 deletions tests/test_compressors/quantized_compressors/test_pack_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,165 @@ def test_actorder_reload_match(actorder, tmp_path, mock_per_group_calibration):
assert torch.equal(fake_quant_dummy, reconstructed_dense["dummy.weight"])

shutil.rmtree(tmp_path)


@pytest.mark.parametrize(
"num_bits,values,expected_values",
[
(
4,
torch.tensor([[1]]),
torch.tensor([[9]], dtype=torch.int32),
),
(
8,
torch.tensor([[1]]),
torch.tensor([[129]], dtype=torch.int32),
),
# 0000 0000 0000 0000 1100 1011 1010 1001
(4, torch.tensor([[1, 2, 3, 4]]), torch.tensor([[52137]], dtype=torch.int32)),
# 0111 0110 0101 0100 0011 0010 0001 0000
(
4,
torch.tensor([[-8, -7, -6, -5, -4, -3, -2, -1]]),
torch.tensor([[1985229328]], dtype=torch.int32),
),
# 10000100 10000011 10000010 10000001
(
8,
torch.tensor([[1, 2, 3, 4]]),
torch.tensor([[-2071756159]], dtype=torch.int32),
),
# 00000011 00000010 00000001 00000000
(
8,
torch.tensor([[-128, -127, -126, -125]]),
torch.tensor([[50462976]], dtype=torch.int32),
),
(
4,
torch.tensor([[-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4]]),
torch.tensor([[1985229328, 52137]], dtype=torch.int32),
),
(
4,
torch.tensor(
[
[-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4, -8, -8, -8, -8],
[1, 2, 3, 4, -8, -8, -8, -8, -8, -7, -6, -5, -4, -3, -2, -1],
]
),
torch.tensor([[1985229328, 52137], [52137, 1985229328]], dtype=torch.int32),
),
(
8,
torch.tensor(
[
[1, 2, 3, 4],
[-128, -127, -126, -125],
]
),
torch.tensor([[-2071756159], [50462976]], dtype=torch.int32),
),
(
8,
torch.tensor(
[
[1, 2, 3, 4, -128, -127, -126, -125],
[-128, -127, -126, -125, 1, 2, 3, 4],
]
),
torch.tensor(
[[-2071756159, 50462976], [50462976, -2071756159]], dtype=torch.int32
),
),
],
)
def test_pack_to_int32(num_bits, values, expected_values):
values = values.to(torch.int8)
packed_values = pack_to_int32(values, num_bits)
assert torch.equal(packed_values, expected_values)
assert packed_values.dtype == expected_values.dtype


@pytest.mark.parametrize(
"num_bits,values,expected_tensor",
[
(
4,
torch.tensor([[9]], dtype=torch.int32),
torch.tensor([[1]], dtype=torch.int8),
),
(
8,
torch.tensor([[129]], dtype=torch.int32),
torch.tensor([[1]], dtype=torch.int8),
),
(
4,
torch.tensor([[52137]], dtype=torch.int32),
torch.tensor([[1, 2, 3, 4]], dtype=torch.int8),
),
(
4,
torch.tensor([[1985229328]], dtype=torch.int32),
torch.tensor([[-8, -7, -6, -5, -4, -3, -2, -1]], dtype=torch.int8),
),
(
8,
torch.tensor([[-2071756159]], dtype=torch.int32),
torch.tensor([[1, 2, 3, 4]], dtype=torch.int8),
),
(
8,
torch.tensor([[50462976]], dtype=torch.int32),
torch.tensor([[-128, -127, -126, -125]], dtype=torch.int8),
),
(
4,
torch.tensor([[1985229328, 52137]], dtype=torch.int32),
torch.tensor(
[[-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4]], dtype=torch.int8
),
),
(
4,
torch.tensor([[1985229328, 52137], [52137, 1985229328]], dtype=torch.int32),
torch.tensor(
[
[-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4, -8, -8, -8, -8],
[1, 2, 3, 4, -8, -8, -8, -8, -8, -7, -6, -5, -4, -3, -2, -1],
],
dtype=torch.int8,
),
),
(
8,
torch.tensor([[-2071756159], [50462976]], dtype=torch.int32),
torch.tensor(
[
[1, 2, 3, 4],
[-128, -127, -126, -125],
],
dtype=torch.int8,
),
),
(
8,
torch.tensor(
[[-2071756159, 50462976], [50462976, -2071756159]], dtype=torch.int32
),
torch.tensor(
[
[1, 2, 3, 4, -128, -127, -126, -125],
[-128, -127, -126, -125, 1, 2, 3, 4],
],
dtype=torch.int8,
),
),
],
)
def test_unpack_from_int32(num_bits, values, expected_tensor):
unpacked_tensor = unpack_from_int32(values, num_bits, expected_tensor.shape)
assert torch.equal(unpacked_tensor, unpacked_tensor)
assert unpacked_tensor.dtype == unpacked_tensor.dtype