A new vector quantization method with binary codes, in PyTorch.
pip install bitcodes-pytorch
from bitcodes_pytorch import Bitcodes
bitcodes = Bitcodes(
features=8, # Number of features per vector
num_bits=4, # Number of bits per vector
temperature=10, # Gumbel softmax training temperature
)
# Set to eval during inference to make deterministic
bitcodes.eval()
x = torch.randn(1, 6, 8)
# Computes y, the quantzed version of x, and the bitcodes
y, bits = bitcodes(x)
"""
y.shape = torch.Size([1, 6, 8])
bits = tensor([[
[0, 0, 0, 0],
[1, 0, 1, 1],
[1, 0, 0, 1],
[1, 0, 0, 0],
[0, 1, 1, 1],
[0, 0, 1, 0]
]])
"""
y_decoded = bitcodes.from_bits(bits)
assert torch.allclose(y, y_decoded) # Assert passes in eval mode!
from bitcodes_pytorch import to_decimal, to_bits
indices = to_decimal(bits)
# tensor([[ 0, 11, 9, 8, 7, 2]])
bits = to_bits(indices, num_bits=4)
"""
bits = tensor([[
[0, 0, 0, 0],
[1, 0, 1, 1],
[1, 0, 0, 1],
[1, 0, 0, 0],
[0, 1, 1, 1],
[0, 0, 1, 0]
]])
"""
Current vector quantization methods (e.g. VQ-VAE, RQ-VAE) either use a single large codebook or multiple smaller codebooks that are used as residuals. Residuals allow for an exponential increase in the number of possible combinations while keeping the number of total codebook items reasonably small by overlapping many codebook elements. If we let
Here we use num_bits
can be freely chosen. The residuals are overlapped to get the output, instead of quantizing the difference - this allows to remove the residual loop and quantize with large
Another nice property of bitcodes is that we can choose to quantize the bit matrix to integers in different ways after training (e.g. convert to decimal one or two rows at a time).