Skip to content

Enable fp16/bf16 absmax #1672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions bitsandbytes/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ def _(

n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
dtype = torch.float32 if torch.cuda.is_available() else A.dtype
absmax = torch.empty((blocks,), device=A.device, dtype=dtype)
out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage)
return out, absmax

Expand Down Expand Up @@ -268,7 +269,8 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
torch._check_is_size(blocksize)
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
dtype = torch.float32 if torch.cuda.is_available() else A.dtype
absmax = torch.empty((blocks,), device=A.device, dtype=dtype)
out = torch.empty_like(A, dtype=torch.uint8)
return out, absmax

Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
rem = n % blocksize
has_rem = rem > 0
blocks = n // blocksize + has_rem
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype)
A_reshaped = A.reshape(n)
A_com = A_reshaped[: n - rem]
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
Expand Down
6 changes: 3 additions & 3 deletions bitsandbytes/backends/default/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
rem = n % blocksize
has_rem = rem > 0
blocks = n // blocksize + has_rem
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype)
A_reshaped = A.reshape(n)
A_com = A_reshaped[: n - rem]
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
Expand Down Expand Up @@ -204,7 +204,7 @@ def _(
full_blocks = n // blocksize
rem = n % blocksize
blocks = full_blocks + 1 if rem else full_blocks
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype)
A_flattened = A.reshape(n)

# Scale full blocks of the tensor to [-1, 1]
Expand All @@ -229,7 +229,7 @@ def _(
if quant_storage != torch.uint8:
packed = packed.squeeze().view(quant_storage).unsqueeze(1)

return packed, absmax.float()
return packed, absmax


@register_kernel("bitsandbytes::dequantize_4bit", "default")
Expand Down
4 changes: 2 additions & 2 deletions bitsandbytes/backends/xpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def _(
# void cdequantize_blockwise_fp32(
# float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream)
if dtype == torch.float16:
ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel())
ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax.float(), out, blocksize, A.numel())
elif dtype == torch.bfloat16:
ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel())
ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax.float(), out, blocksize, A.numel())
elif dtype == torch.float32:
ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel())
else:
Expand Down
11 changes: 1 addition & 10 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,8 +759,6 @@ def dequantize_blockwise(
if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32:
absmax = absmax.float()

if out is not None:
torch.ops.bitsandbytes.dequantize_blockwise.out(
Expand Down Expand Up @@ -1034,8 +1032,6 @@ def dequantize_4bit(
if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32:
absmax = absmax.float()

# IPEX format is different, we need extra process.
if getattr(quant_state, "ipex", False) and quant_state.quant_type == "nf4":
Expand Down Expand Up @@ -1079,8 +1075,6 @@ def quantize(
code = code.to(A.device)

absmax = torch.abs(A).max()
if absmax.dtype != torch.float32:
absmax = absmax.float()
inp = A / absmax
out = quantize_no_absmax(inp, code, out)
return out, (absmax, code)
Expand Down Expand Up @@ -2326,11 +2320,8 @@ def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor):
quant_state = linear.weight.quant_state

if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2).to(x.dtype)
absmax += quant_state.offset
if absmax.dtype != torch.float32:
absmax = absmax.float()

quant_state.absmax = absmax
quant_state.nested = False
delattr(quant_state, "state2")
Expand Down
5 changes: 1 addition & 4 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,6 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
pytest.skip("Not a typical use case.")
if blocksize != 256:
pytest.skip("Only blocksize 256 is used in CPU/XPU")
if dtype != torch.float32:
pytest.skip("Only float32 is used in CPU/XPU")

diffs = []
reldiffs = []
Expand Down Expand Up @@ -137,11 +135,10 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
abserr = sum(diffs) / len(diffs)
relerr = sum(reldiffs) / len(reldiffs)
if signed:
threshold_abserr = 0.0036 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0035
assert abserr < 0.0036
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because threshold_abserr is not used.

assert relerr < 0.015
else:
assert abserr < 0.00175 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0023
assert abserr < 0.0023
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have no reason to have a tighter threshold for ipex, otherwise the half-precision check cannot pass.

assert relerr < 0.012
assert A2.dtype == dtype

Expand Down
8 changes: 0 additions & 8 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,6 @@ class TestInt8BlockwiseQuantOps:
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_quantize_blockwise(self, device, dtype, blocksize):
if device == "cpu":
if dtype != torch.float32:
pytest.skip("CPU implementation is only available for float32")

if blocksize != 256:
pytest.skip("CPU implementation is slow; only test blocksize=256")

Expand All @@ -120,17 +117,13 @@ def test_quantize_blockwise(self, device, dtype, blocksize):
assert out.device == A.device

assert absmax.device == A.device
assert absmax.dtype == torch.float32

opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize))

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_dequantize_blockwise(self, device, dtype, blocksize):
if device == "cpu" and dtype != torch.float32:
pytest.skip("CPU implementation is only available for float32")

A = torch.randint(0, 255, (1024, 1024), dtype=torch.uint8, device=device)
code = bitsandbytes.functional.create_dynamic_map().to(device, dtype=torch.float32)

Expand Down Expand Up @@ -166,7 +159,6 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
assert out.dtype == storage_dtype

assert absmax.device == A.device
assert absmax.dtype == torch.float32

if storage_dtype != torch.uint8:
pytest.xfail("opcheck fails for storage_dtype != torch.uint8")
Expand Down