diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py index a260852f5..2a1d7aac3 100644 --- a/bitsandbytes/_ops.py +++ b/bitsandbytes/_ops.py @@ -225,7 +225,8 @@ def _( n = A.numel() blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + dtype = torch.float32 if torch.cuda.is_available() else A.dtype + absmax = torch.empty((blocks,), device=A.device, dtype=dtype) out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) return out, absmax @@ -268,7 +269,8 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor torch._check_is_size(blocksize) n = A.numel() blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + dtype = torch.float32 if torch.cuda.is_available() else A.dtype + absmax = torch.empty((blocks,), device=A.device, dtype=dtype) out = torch.empty_like(A, dtype=torch.uint8) return out, absmax diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 5f009ea40..1727bcb46 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -49,7 +49,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor rem = n % blocksize has_rem = rem > 0 blocks = n // blocksize + has_rem - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) A_reshaped = A.reshape(n) A_com = A_reshaped[: n - rem] A_com_reshaped = A_com.reshape(n // blocksize, blocksize) diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py index ce5926979..48d30ced4 100644 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -154,7 +154,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor rem = n % blocksize has_rem = rem > 0 blocks = n // blocksize + has_rem - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) A_reshaped = A.reshape(n) A_com = A_reshaped[: n - rem] A_com_reshaped = A_com.reshape(n // blocksize, blocksize) @@ -204,7 +204,7 @@ def _( full_blocks = n // blocksize rem = n % blocksize blocks = full_blocks + 1 if rem else full_blocks - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) A_flattened = A.reshape(n) # Scale full blocks of the tensor to [-1, 1] @@ -229,7 +229,7 @@ def _( if quant_storage != torch.uint8: packed = packed.squeeze().view(quant_storage).unsqueeze(1) - return packed, absmax.float() + return packed, absmax @register_kernel("bitsandbytes::dequantize_4bit", "default") diff --git a/bitsandbytes/backends/xpu/ops.py b/bitsandbytes/backends/xpu/ops.py index 999116c97..57a2d39e9 100755 --- a/bitsandbytes/backends/xpu/ops.py +++ b/bitsandbytes/backends/xpu/ops.py @@ -44,9 +44,9 @@ def _( # void cdequantize_blockwise_fp32( # float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream) if dtype == torch.float16: - ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel()) + ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax.float(), out, blocksize, A.numel()) elif dtype == torch.bfloat16: - ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel()) + ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax.float(), out, blocksize, A.numel()) elif dtype == torch.float32: ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel()) else: diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 6893752c9..b8b6bbb4f 100755 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -759,8 +759,6 @@ def dequantize_blockwise( if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() if out is not None: torch.ops.bitsandbytes.dequantize_blockwise.out( @@ -1034,8 +1032,6 @@ def dequantize_4bit( if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() # IPEX format is different, we need extra process. if getattr(quant_state, "ipex", False) and quant_state.quant_type == "nf4": @@ -1079,8 +1075,6 @@ def quantize( code = code.to(A.device) absmax = torch.abs(A).max() - if absmax.dtype != torch.float32: - absmax = absmax.float() inp = A / absmax out = quantize_no_absmax(inp, code, out) return out, (absmax, code) @@ -2326,11 +2320,8 @@ def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor): quant_state = linear.weight.quant_state if quant_state.nested: - absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) + absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2).to(x.dtype) absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() - quant_state.absmax = absmax quant_state.nested = False delattr(quant_state, "state2") diff --git a/tests/test_functional.py b/tests/test_functional.py index 2e2e898cc..2301c69da 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -104,8 +104,6 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, pytest.skip("Not a typical use case.") if blocksize != 256: pytest.skip("Only blocksize 256 is used in CPU/XPU") - if dtype != torch.float32: - pytest.skip("Only float32 is used in CPU/XPU") diffs = [] reldiffs = [] @@ -137,11 +135,10 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, abserr = sum(diffs) / len(diffs) relerr = sum(reldiffs) / len(reldiffs) if signed: - threshold_abserr = 0.0036 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0035 assert abserr < 0.0036 assert relerr < 0.015 else: - assert abserr < 0.00175 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0023 + assert abserr < 0.0023 assert relerr < 0.012 assert A2.dtype == dtype diff --git a/tests/test_ops.py b/tests/test_ops.py index 60c47a250..f9793d1e0 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -105,9 +105,6 @@ class TestInt8BlockwiseQuantOps: @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) def test_quantize_blockwise(self, device, dtype, blocksize): if device == "cpu": - if dtype != torch.float32: - pytest.skip("CPU implementation is only available for float32") - if blocksize != 256: pytest.skip("CPU implementation is slow; only test blocksize=256") @@ -120,7 +117,6 @@ def test_quantize_blockwise(self, device, dtype, blocksize): assert out.device == A.device assert absmax.device == A.device - assert absmax.dtype == torch.float32 opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize)) @@ -128,9 +124,6 @@ def test_quantize_blockwise(self, device, dtype, blocksize): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) def test_dequantize_blockwise(self, device, dtype, blocksize): - if device == "cpu" and dtype != torch.float32: - pytest.skip("CPU implementation is only available for float32") - A = torch.randint(0, 255, (1024, 1024), dtype=torch.uint8, device=device) code = bitsandbytes.functional.create_dynamic_map().to(device, dtype=torch.float32) @@ -166,7 +159,6 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize assert out.dtype == storage_dtype assert absmax.device == A.device - assert absmax.dtype == torch.float32 if storage_dtype != torch.uint8: pytest.xfail("opcheck fails for storage_dtype != torch.uint8")