From 067917cfb05a5c6be17bb44cedbea7fcc8ac8599 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 24 Dec 2024 07:58:13 +0000 Subject: [PATCH 1/2] float16xuint1 fix --- bitblas/gpu/intrin/lop3.py | 65 ++++++++++++------- .../ops/lop3_permutate/lop3_permutate_impl.py | 2 +- 2 files changed, 42 insertions(+), 25 deletions(-) diff --git a/bitblas/gpu/intrin/lop3.py b/bitblas/gpu/intrin/lop3.py index 48accfba2..a906ebee6 100644 --- a/bitblas/gpu/intrin/lop3.py +++ b/bitblas/gpu/intrin/lop3.py @@ -685,15 +685,29 @@ """ decode_i1_to_f16 = """ -template -__device__ void decode_i1u_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8) +/* +Kind 0: original +Kind 1: rescale +Kind 2: quantized +# documents for zeros_mode: +# original: target = (dequantize_weight - zero_point) * scale +# rescale: target = dequantize_weight * scale - zero_point +# quantized: target = (dequantize_weight - dequantize_zeros) * scale +# Notice: only support "original" and "rescale" now +zeros_mode: Literal["original", "rescale", "quantized"] = "original" +*/ +template +__device__ void decode_i1b_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8, half *scale = nullptr, half *zeros = nullptr) { uint *h = reinterpret_cast(B_local_decode); static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; static constexpr uint BOTTOM_MASK = 0x00010001; static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400; + static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1 + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 int8_t const i1s_i16 = *reinterpret_cast(_i1s); int i1s = (i1s_i16 & 0x0f); i1s |= ((i1s_i16 & 0xf0) << 12); @@ -701,38 +715,41 @@ // decode 2 elems at one time. for (int i = 0; i < (N / 2); i++) { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" : "=r"(h[i]) : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + if constexpr (isSigned) + { + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i])); + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT)); + } + if constexpr (withZeros && ZerosKind == 0) + { + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros))); + } + if constexpr (withScaling) + { + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); + } + if constexpr (withZeros && ZerosKind == 1) + { + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros))); + } } } template __device__ void decode_i1s_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8) { - uint *h = reinterpret_cast(B_local_decode); - - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint BOTTOM_MASK = 0x00010001; - static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = 0x64006400; - static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1 + decode_i1b_to_f16(_i1s, B_local_decode, N); +} - int8_t const i1s_i16 = *reinterpret_cast(_i1s); - int i1s = (i1s_i16 & 0x0f); - i1s |= ((i1s_i16 & 0xf0) << 12); -#pragma unroll - // decode 2 elems at one time. - for (int i = 0; i < (N / 2); i++) - { - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(h[i]) - : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); - asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i])); - asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT)); - } +template +__device__ void decode_i1u_to_f16(T1 *_i1u, T2 *B_local_decode, const int N = 8) +{ + decode_i1b_to_f16(_i1u, B_local_decode, N); } """ diff --git a/bitblas/ops/lop3_permutate/lop3_permutate_impl.py b/bitblas/ops/lop3_permutate/lop3_permutate_impl.py index 94ddd13c6..c5d240e69 100644 --- a/bitblas/ops/lop3_permutate/lop3_permutate_impl.py +++ b/bitblas/ops/lop3_permutate/lop3_permutate_impl.py @@ -80,7 +80,7 @@ def interleave_weight_f16_1b(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N B_tmp_4[v0, v1] = ((B[v0, v1] & T.uint32(0x0000F000)) >> 12) << 24 B_tmp_5[v0, v1] = ((B[v0, v1] & T.uint32(0x000F0000)) >> 16) << 8 B_tmp_6[v0, v1] = ((B[v0, v1] & T.uint32(0x00F00000)) >> 20) << 12 - B_tmp_7[v0, v1] = ((B[v0, v1] & T.uint32(0x00F00000)) >> 24) << 20 + B_tmp_7[v0, v1] = ((B[v0, v1] & T.uint32(0x0F000000)) >> 24) << 20 B[v0, v1] = ( B_tmp_1[v0, v1] | B_tmp_2[v0, v1] From 46753d27efbf12f61ae160e9b212a4e6f9cd17ea Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 24 Dec 2024 08:00:48 +0000 Subject: [PATCH 2/2] enhance setup --- bitblas/__init__.py | 2 +- bitblas/version.py | 21 +++++++++++++++++++++ setup.py | 45 ++++++++++++++++++++++++++++++++++++++------- 3 files changed, 60 insertions(+), 8 deletions(-) create mode 100644 bitblas/version.py diff --git a/bitblas/__init__.py b/bitblas/__init__.py index 9ab4554e2..32ff07132 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -153,4 +153,4 @@ def remove_tvm_path(path): from .ops.general_flashatten import FlashAttenConfig, FlashAtten # noqa: F401 from .module import Linear # noqa: F401 -__version__ = "0.1.0" +from .version import __version__ # noqa: F401 diff --git a/bitblas/version.py b/bitblas/version.py new file mode 100644 index 000000000..145025aa9 --- /dev/null +++ b/bitblas/version.py @@ -0,0 +1,21 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os + +# Get the absolute path of the current Python script's directory +current_dir = os.path.dirname(os.path.abspath(__file__)) + +# Get the absolute path of the project root directory (one level above the current directory) +project_root_dir = os.path.abspath(os.path.join(current_dir, "..")) + +# Define the path to the VERSION file located in the project root directory +version_file_path = os.path.join(project_root_dir, "VERSION") + +# Read and store the version information from the VERSION file +# Use 'strip()' to remove any leading/trailing whitespace or newline characters +with open(version_file_path, "r") as version_file: + __version__ = version_file.read().strip() + +# Define the public API for the module +__all__ = ["__version__"] diff --git a/setup.py b/setup.py index bfc6b3830..eefd4316d 100644 --- a/setup.py +++ b/setup.py @@ -40,16 +40,18 @@ def get_requirements() -> List[str]: return requirements -def find_version(filepath: str) -> str: +def find_version(version_file_path: str) -> str: """Extract version information from the given filepath. Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py """ - with open(filepath) as fp: - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M) - if version_match: - return version_match.group(1) - raise RuntimeError("Unable to find version string.") + # Read and store the version information from the VERSION file + # Use 'strip()' to remove any leading/trailing whitespace or newline characters + if not os.path.exists(version_file_path): + raise FileNotFoundError(f"Version file not found at {version_file_path}") + with open(version_file_path, "r") as version_file: + version = version_file.read().strip() + return version def get_nvcc_cuda_version(): @@ -65,7 +67,7 @@ def get_nvcc_cuda_version(): def get_bitblas_version(with_cuda=True, with_system_info=True) -> str: - version = find_version(get_path("bitblas", "__init__.py")) + version = find_version(get_path(".", "VERSION")) local_version_parts = [] if with_system_info: local_version_parts.append(get_system_info().replace("-", ".")) @@ -267,6 +269,35 @@ def run(self): if not os.path.exists(target_dir): os.makedirs(target_dir) shutil.copy2(source_dir, target_dir) + # copy compoable kernel to the package directory + CK_PREBUILD_ITEMS = [ + "3rdparty/composable_kernel", + ] + for item in CK_PREBUILD_ITEMS: + source_dir = os.path.join(ROOT_DIR, item) + target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item) + if os.path.isdir(source_dir): + self.mkpath(target_dir) + distutils.dir_util.copy_tree(source_dir, target_dir) + else: + target_dir = os.path.dirname(target_dir) + if not os.path.exists(target_dir): + os.makedirs(target_dir) + shutil.copy2(source_dir, target_dir) + + # copy compoable kernel to the package directory + CONFIG_ITEMS = ["VERSION", "README.md", "LICENSE"] + for item in CONFIG_ITEMS: + source_dir = os.path.join(ROOT_DIR, item) + target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item) + if os.path.isdir(source_dir): + self.mkpath(target_dir) + distutils.dir_util.copy_tree(source_dir, target_dir) + else: + target_dir = os.path.dirname(target_dir) + if not os.path.exists(target_dir): + os.makedirs(target_dir) + shutil.copy2(source_dir, target_dir) class BitBLASSdistCommand(sdist):