Skip to content

Add SYCL Kernels for XPU backend #1679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
dd7b173
Add SYCL Kernels for XPU backend
xiaolil1 Jun 15, 2025
df93cdd
Merge pull request #1 from xiaolil1/jiqing
xiaolil1 Jun 16, 2025
872aa02
fix transpose
jiqing-feng Jun 16, 2025
04437a3
fix log and format
jiqing-feng Jun 16, 2025
d585bea
revert cpu changes
jiqing-feng Jun 16, 2025
1781611
clean ipex_xpu
jiqing-feng Jun 16, 2025
c982781
clean ipex import
jiqing-feng Jun 16, 2025
a4c5f8c
fix ipex cpu import
jiqing-feng Jun 16, 2025
4f076bb
fix typo
jiqing-feng Jun 16, 2025
76d7178
fix comments
jiqing-feng Jun 16, 2025
b31ea62
Merge pull request #2 from xiaolil1/jiqing
xiaolil1 Jun 16, 2025
452aa84
refine gemv_4bit kernel
xiaolil1 Jun 17, 2025
e8ac8b5
Merge branch 'main' into main
jiqing-feng Jun 17, 2025
8620a95
enable FP4 for dequant_4bit and gemv_4bit
xiaolil1 Jun 17, 2025
00f064b
refine FP4 dequantization performance
xiaolil1 Jun 17, 2025
d60750f
remove check for better performance
jiqing-feng Jun 17, 2025
59f2aa8
Merge pull request #3 from xiaolil1/jiqing
xiaolil1 Jun 17, 2025
aad358f
fix doc
jiqing-feng Jun 17, 2025
45e4451
Merge pull request #4 from xiaolil1/jiqing
xiaolil1 Jun 17, 2025
1e21ee9
clean code
xiaolil1 Jun 18, 2025
4e7f5c1
Merge branch 'main' into main
xiaolil1 Jun 18, 2025
1601652
fix tests
jiqing-feng Jun 18, 2025
1cc25ff
rm comments
jiqing-feng Jun 18, 2025
c44f38e
Merge pull request #5 from xiaolil1/jiqing
xiaolil1 Jun 18, 2025
9f283bd
fix memory issue
xiaolil1 Jun 20, 2025
9897eae
fix ut failure
xiaolil1 Jun 20, 2025
411a276
adjust threshold
jiqing-feng Jun 20, 2025
b6a3524
fix xpu check
jiqing-feng Jun 20, 2025
1c4f478
change test_functional check
jiqing-feng Jun 20, 2025
e5cf821
fix test_module
jiqing-feng Jun 20, 2025
502fe83
Merge pull request #6 from xiaolil1/jiqing
xiaolil1 Jun 20, 2025
8b54381
fix device check
jiqing-feng Jun 23, 2025
1e0f661
Merge pull request #7 from xiaolil1/jiqing_test
jiqing-feng Jun 23, 2025
99698d2
fix tests
jiqing-feng Jun 23, 2025
b88236a
Merge pull request #8 from xiaolil1/jiqing
jiqing-feng Jun 23, 2025
56c48bc
Merge branch 'main' into main
jiqing-feng Jun 24, 2025
302413e
Merge branch 'main' into main
jiqing-feng Jun 25, 2025
685962c
Enable Windows build and refine code
xiaolil1 Jun 27, 2025
7842f9d
Merge branch 'main' into main
jiqing-feng Jun 30, 2025
041b442
Merge branch 'main' into main
jiqing-feng Jul 1, 2025
aa0cf92
Merge branch 'main' into main
jiqing-feng Jul 2, 2025
b3db4bf
fix xpu log
jiqing-feng Jul 2, 2025
d66f93d
Merge pull request #9 from xiaolil1/jiqing
xiaolil1 Jul 2, 2025
5bf3159
remove ipex entirely
jiqing-feng Jul 3, 2025
005a63c
fix cpu int8 CB
jiqing-feng Jul 3, 2025
683f37c
Merge pull request #10 from xiaolil1/jiqing
xiaolil1 Jul 3, 2025
223d7d7
fix lint
jiqing-feng Jul 3, 2025
dc75ad8
Merge pull request #11 from xiaolil1/jiqing
xiaolil1 Jul 3, 2025
883d693
fix logs (#12)
jiqing-feng Jul 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 1 addition & 18 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ jobs:
- name: Run tests
run: pytest --durations=100

test-cpu-ipex:
test-cpu-intel:
if: github.repository == 'bitsandbytes-foundation/bitsandbytes'
needs: build-cpu
runs-on: banb-aws-general-8-plus-use1-public-80
Expand All @@ -186,7 +186,6 @@ jobs:
- name: Install dependencies
run: |
pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu
pip install intel_extension_for_pytorch==2.7.0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/
pip install -e ".[test]"
pip install pytest-cov

Expand All @@ -196,9 +195,6 @@ jobs:
- name: Show environment information
run: python -m torch.utils.collect_env

- name: IPEX smoke test
run: python -c "import torch; import intel_extension_for_pytorch as ipex; print(torch.__version__); print(ipex.__version__);"

- name: Run tests
run: pytest --durations=100

Expand Down Expand Up @@ -286,15 +282,6 @@ jobs:
fail-fast: false
matrix:
torch_version: ["2.7.1"] #["2.6.0", "2.7.1"]
ipex: [false]
# ipex: [true, false]
# include:
# - torch_version: "2.6.0"
# ipex: true
# ipex_version: "2.6.10+xpu"
# - torch_version: "2.7.1"
# ipex: true
# ipex_version: "2.7.10+xpu"
runs-on:
group: bandb-itac-bmsprpvc1550-8-1gpu
env:
Expand Down Expand Up @@ -330,10 +317,6 @@ jobs:
- name: Install PyTorch
run: pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/xpu

- name: Install IPEX
if: matrix.ipex == true
run: pip install intel_extension_for_pytorch==${{ matrix.ipex_version }} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/

- name: Install dependencies
run: |
pip install -e ".[test]"
Expand Down
31 changes: 29 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
set(MPS_FILES csrc/mps_ops.mm)
set(METAL_FILES csrc/mps_kernels.metal)
set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp)
# C++ sources are always included
list(APPEND SRC_FILES ${CPP_FILES})

set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)")
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps)
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, xpu)")
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps xpu)
option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF)

if(APPLE)
Expand Down Expand Up @@ -64,10 +65,18 @@ elseif(${COMPUTE_BACKEND} STREQUAL "mps")
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS ON)
elseif(${COMPUTE_BACKEND} STREQUAL "xpu")
if(APPLE)
message(FATAL_ERROR "XPU is not supported on macOS" )
endif()
set(BUILD_CUDA OFF)
set(BUILD_MPS OFF)
set(BUILD_XPU ON)
else()
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS OFF)
set(BUILD_XPU OFF)
endif()


Expand Down Expand Up @@ -217,6 +226,15 @@ elseif(BUILD_MPS)
COMMENT "Compiling Metal kernels"
VERBATIM)
add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib")
elseif(BUILD_XPU)
list(APPEND SRC_FILES ${XPU_FILES})
string(APPEND BNB_OUTPUT_NAME "_xpu")
add_compile_definitions(BUILD_XPU)
set(CMAKE_C_COMPILER icx)
set(CMAKE_CXX_COMPILER icpx)
if(WIN32)
set(CMAKE_CXX_COMPILER icx)
endif()
else()
string(APPEND BNB_OUTPUT_NAME "_cpu")
set(GPU_SOURCES)
Expand Down Expand Up @@ -285,6 +303,15 @@ if(BUILD_MPS)
add_dependencies(bitsandbytes metallib)
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
endif()
if(BUILD_XPU)
set(SYCL_LINK_FLAGS "-fsycl;--offload-compress;-fsycl-targets=spir64_gen,spir64;-Xs;-device pvc,xe-lpg,ats-m150 -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required'")
set(SYCL_COMPILE_FLAGS "-fsycl;-fhonor-nans;-fhonor-infinities;-fno-associative-math;-fno-approx-func;-fno-sycl-instrument-device-code;--offload-compress;-fsycl-targets=spir64_gen,spir64;")

set_property(TARGET bitsandbytes PROPERTY CXX_STANDARD 20)
target_compile_options(bitsandbytes PRIVATE ${SYCL_COMPILE_FLAGS})
target_link_options(bitsandbytes PRIVATE ${SYCL_LINK_FLAGS})

endif()

if(WIN32)
set_target_properties(bitsandbytes PROPERTIES PREFIX "lib")
Expand Down
21 changes: 0 additions & 21 deletions bitsandbytes/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

import torch

from .cextension import ipex_cpu, ipex_xpu

_IS_TORCH_GTE_24 = False

if hasattr(torch.library, "register_fake"):
Expand Down Expand Up @@ -329,22 +327,3 @@ def _(
)
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")


if ipex_cpu or ipex_xpu:
# Register the dequantize_nf4_ipex implementation
torch.library.define(
"bitsandbytes::dequantize_nf4_ipex",
"(Tensor A, Tensor absmax, int blocksize, int[] shape, ScalarType dtype) -> Tensor",
)

@register_fake("bitsandbytes::dequantize_nf4_ipex")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
return torch.empty(shape, dtype=dtype, device=A.device)
20 changes: 3 additions & 17 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing_extensions import deprecated

import bitsandbytes.functional as F
from bitsandbytes.functional import ipex_cpu, ipex_xpu

# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
Expand Down Expand Up @@ -320,8 +319,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):

CB = state.CB.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
output = torch.nn.functional.linear(A, CB, bias)
# to pass the test: tests/test_modules.py::test_linear8bitlt_no_fp16_weights[2.0-xpu]
state.idx = False
ctx.state = state
ctx.dtype_A = A.dtype
ctx.grad_shape = A.shape
Expand Down Expand Up @@ -425,9 +422,9 @@ def matmul(
if threshold > 0.0:
state.threshold = threshold
# MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU
if state.is_training:
if (A.device.type == "cpu" and ipex_cpu) or (A.device.type == "xpu" and ipex_xpu):
return MatMul8bitFp.apply(A, B, out, bias, state)
if state.is_training and A.device.type in ("cpu", "xpu"):
return MatMul8bitFp.apply(A, B, out, bias, state)

return MatMul8bitLt.apply(A, B, out, bias, state)


Expand All @@ -440,17 +437,6 @@ def matmul_4bit(
):
assert quant_state is not None

if A.device.type in ("cpu", "xpu") and A.requires_grad == False:
if getattr(quant_state, "ipex", False):
# IPEX CPU will change weight to 4D so don't need transpose
B = B.t() if B.dim() == 2 else B
out = F.gemv_4bit(A, B, out, state=quant_state)
if bias is not None:
out += bias
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)

if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
if A.shape[-1] % quant_state.blocksize != 0:
warn(
Expand Down
171 changes: 76 additions & 95 deletions bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from collections.abc import Sequence
import ctypes as ct
import logging

import torch

from bitsandbytes.functional import get_ptr

from ..._ops import register_kernel
from ...cextension import lib
from ..utils import ipex_cpu
from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib

logger = logging.getLogger(__name__)

# torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+.
# However, we can overflow if we use this without AVX512_VNNI support.
Expand All @@ -24,97 +25,77 @@ def _(A: torch.Tensor, B: torch.Tensor):
).reshape(*A.shape[:-1], B.shape[0])


@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)

n = A.numel()

# Only FP32 has c++ kernrl
if A.dtype == torch.float32:
blocks = -(n // -blocksize)

absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty_like(A, dtype=torch.uint8)

lib.cquantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(n),
)
else:
rem = n % blocksize
has_rem = rem > 0
blocks = n // blocksize + has_rem
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
A_reshaped = A.reshape(n)
A_com = A_reshaped[: n - rem]
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
scaled_A = scaled_A.reshape(-1)
if has_rem:
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)

diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)

return out, absmax


@register_kernel("bitsandbytes::dequantize_blockwise", "cpu")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")

# Only FP32 has c++ kernrl
if dtype == torch.float32:
out = torch.empty_like(A, dtype=dtype)

lib.cdequantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(A.numel()),
)
else:
out = code[A.reshape(-1).int()]
blocks = out.shape[-1] // blocksize
res = out.shape[-1] % blocksize
if res != 0:
out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0)
out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1)
out = out[: blocks * blocksize + res]
out = out.reshape(A.shape)

return out


if ipex_cpu:
from bitsandbytes.utils import _reverse_4bit_compress_format

@register_kernel("bitsandbytes::dequantize_nf4_ipex", "cpu")
if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary):

@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)

n = A.numel()

# Only FP32 has c++ kernrl
if A.dtype == torch.float32:
blocks = -(n // -blocksize)

absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty_like(A, dtype=torch.uint8)

lib.cquantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(n),
)
else:
rem = n % blocksize
has_rem = rem > 0
blocks = n // blocksize + has_rem
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
A_reshaped = A.reshape(n)
A_com = A_reshaped[: n - rem]
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
scaled_A = scaled_A.reshape(-1)
if has_rem:
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)

diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)

return out, absmax

@register_kernel("bitsandbytes::dequantize_blockwise", "cpu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
shape: Sequence[int],
dtype: torch.dtype,
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
) -> torch.Tensor:
ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", shape, 2)
A = _reverse_4bit_compress_format(ipex_weight.reshape(-1)).reshape(1, -1)
return torch.ops.bitsandbytes.dequantize_4bit.default(
A,
absmax,
blocksize,
"nf4",
shape,
dtype,
)
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")

# Only FP32 has c++ kernrl
if dtype == torch.float32:
out = torch.empty_like(A, dtype=dtype)

lib.cdequantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(A.numel()),
)
else:
out = code[A.reshape(-1).int()]
blocks = out.shape[-1] // blocksize
res = out.shape[-1] % blocksize
if res != 0:
out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0)
out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1)
out = out[: blocks * blocksize + res]
out = out.reshape(A.shape)

return out
Loading