Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix correctness issue for float16xuint1 with fast dequantize #277

Merged
merged 2 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,4 @@ def remove_tvm_path(path):
from .ops.general_flashatten import FlashAttenConfig, FlashAtten # noqa: F401
from .module import Linear # noqa: F401

__version__ = "0.1.0"
from .version import __version__ # noqa: F401
65 changes: 41 additions & 24 deletions bitblas/gpu/intrin/lop3.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,54 +685,71 @@
"""

decode_i1_to_f16 = """
template <typename T1, typename T2>
__device__ void decode_i1u_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8)
/*
Kind 0: original
Kind 1: rescale
Kind 2: quantized
# documents for zeros_mode:
# original: target = (dequantize_weight - zero_point) * scale
# rescale: target = dequantize_weight * scale - zero_point
# quantized: target = (dequantize_weight - dequantize_zeros) * scale
# Notice: only support "original" and "rescale" now
zeros_mode: Literal["original", "rescale", "quantized"] = "original"
*/
template <typename T1, typename T2, bool isSigned = false, bool withScaling = false, bool withZeros = false, int ZerosKind = 1>
__device__ void decode_i1b_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8, half *scale = nullptr, half *zeros = nullptr)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);

static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400;
static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1
// interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0}
// only decode e7,e5,e3,e1,e8,e6,e4,e2,e0
int8_t const i1s_i16 = *reinterpret_cast<int8_t *>(_i1s);
int i1s = (i1s_i16 & 0x0f);
i1s |= ((i1s_i16 & 0xf0) << 12);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{

asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
if constexpr (isSigned)
{
asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i]));
asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT));
}
if constexpr (withZeros && ZerosKind == 0)
{
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros)));
}
if constexpr (withScaling)
{
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0));
}
if constexpr (withZeros && ZerosKind == 1)
{
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros)));
}
}
}

template <typename T1, typename T2>
__device__ void decode_i1s_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8)
{
uint *h = reinterpret_cast<uint *>(B_local_decode);

static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x00010001;
static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;
static constexpr uint MEDIAN_NUM = 0x64006400;
static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1
decode_i1b_to_f16<T1, T2, true>(_i1s, B_local_decode, N);
}

int8_t const i1s_i16 = *reinterpret_cast<int8_t *>(_i1s);
int i1s = (i1s_i16 & 0x0f);
i1s |= ((i1s_i16 & 0xf0) << 12);
#pragma unroll
// decode 2 elems at one time.
for (int i = 0; i < (N / 2); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n"
: "=r"(h[i])
: "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut));
asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM));
asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i]));
asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT));
}
template <typename T1, typename T2>
__device__ void decode_i1u_to_f16(T1 *_i1u, T2 *B_local_decode, const int N = 8)
{
decode_i1b_to_f16<T1, T2, false>(_i1u, B_local_decode, N);
}
"""

Expand Down
2 changes: 1 addition & 1 deletion bitblas/ops/lop3_permutate/lop3_permutate_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def interleave_weight_f16_1b(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N
B_tmp_4[v0, v1] = ((B[v0, v1] & T.uint32(0x0000F000)) >> 12) << 24
B_tmp_5[v0, v1] = ((B[v0, v1] & T.uint32(0x000F0000)) >> 16) << 8
B_tmp_6[v0, v1] = ((B[v0, v1] & T.uint32(0x00F00000)) >> 20) << 12
B_tmp_7[v0, v1] = ((B[v0, v1] & T.uint32(0x00F00000)) >> 24) << 20
B_tmp_7[v0, v1] = ((B[v0, v1] & T.uint32(0x0F000000)) >> 24) << 20
B[v0, v1] = (
B_tmp_1[v0, v1]
| B_tmp_2[v0, v1]
Expand Down
21 changes: 21 additions & 0 deletions bitblas/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os

# Get the absolute path of the current Python script's directory
current_dir = os.path.dirname(os.path.abspath(__file__))

# Get the absolute path of the project root directory (one level above the current directory)
project_root_dir = os.path.abspath(os.path.join(current_dir, ".."))

# Define the path to the VERSION file located in the project root directory
version_file_path = os.path.join(project_root_dir, "VERSION")

# Read and store the version information from the VERSION file
# Use 'strip()' to remove any leading/trailing whitespace or newline characters
with open(version_file_path, "r") as version_file:
__version__ = version_file.read().strip()

# Define the public API for the module
__all__ = ["__version__"]
45 changes: 38 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,18 @@ def get_requirements() -> List[str]:
return requirements


def find_version(filepath: str) -> str:
def find_version(version_file_path: str) -> str:
"""Extract version information from the given filepath.

Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
"""
with open(filepath) as fp:
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M)
if version_match:
return version_match.group(1)
raise RuntimeError("Unable to find version string.")
# Read and store the version information from the VERSION file
# Use 'strip()' to remove any leading/trailing whitespace or newline characters
if not os.path.exists(version_file_path):
raise FileNotFoundError(f"Version file not found at {version_file_path}")
with open(version_file_path, "r") as version_file:
version = version_file.read().strip()
return version


def get_nvcc_cuda_version():
Expand All @@ -65,7 +67,7 @@ def get_nvcc_cuda_version():


def get_bitblas_version(with_cuda=True, with_system_info=True) -> str:
version = find_version(get_path("bitblas", "__init__.py"))
version = find_version(get_path(".", "VERSION"))
local_version_parts = []
if with_system_info:
local_version_parts.append(get_system_info().replace("-", "."))
Expand Down Expand Up @@ -267,6 +269,35 @@ def run(self):
if not os.path.exists(target_dir):
os.makedirs(target_dir)
shutil.copy2(source_dir, target_dir)
# copy compoable kernel to the package directory
CK_PREBUILD_ITEMS = [
"3rdparty/composable_kernel",
]
for item in CK_PREBUILD_ITEMS:
source_dir = os.path.join(ROOT_DIR, item)
target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item)
if os.path.isdir(source_dir):
self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir)
else:
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
shutil.copy2(source_dir, target_dir)

# copy compoable kernel to the package directory
CONFIG_ITEMS = ["VERSION", "README.md", "LICENSE"]
for item in CONFIG_ITEMS:
source_dir = os.path.join(ROOT_DIR, item)
target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item)
if os.path.isdir(source_dir):
self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir)
else:
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
shutil.copy2(source_dir, target_dir)


class BitBLASSdistCommand(sdist):
Expand Down
Loading