-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Vectorized Bias Add with AtomicAdd may lead to unknown bugs #271
Comments
More clean script to help reproduce: # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from bitblas import tvm as tvm
import bitblas.testing
from tvm import tl
from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import (
MatmulBlockScheduler,
MatmulFineGrainScheduler,
MatmulWeightPropagationScheduler,
)
from bitblas.ops.general_matmul.tilelang.dequantize import (
MatmulDequantizeScheduler,
MatmulDequantizeFineGrainedScheduler,
MatmulDequantizeWeightPropagationScheduler,
MatmulINT4DequantizeFineGrainedScheduler,
MatmulINT4DequantizeWeightPropagationScheduler,
)
from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import (
MatmulINT4FineGrainScheduler,
MatmulINT4WeightPropagationScheduler,
)
import torch
import torch.backends
torch.manual_seed(0)
verbose = False
def assert_matmul_fine_grained_dequant_with_default_correctness(
M,
N,
K,
trans_A=False,
trans_B=True,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
bit=4,
storage_dtype="int8",
source_format="uint",
with_scaling=False,
with_zeros=False,
group_size=-1,
fast_decoding=False,
zeros_mode="original",
with_bias=False,
split_k_factor=1,
):
import numpy as np
from bitblas.quantization import general_compress, interleave_weight
matmul = MatmulDequantizeFineGrainedScheduler(
M=M,
N=N,
K=K,
trans_A=trans_A,
trans_B=trans_B,
in_dtype=in_dtype,
out_dtype=out_dtype,
accum_dtype=accum_dtype,
num_bits=bit,
storage_dtype=storage_dtype,
source_format=source_format,
with_scaling=with_scaling,
with_zeros=with_zeros,
group_size=group_size,
fast_decoding=fast_decoding,
zeros_mode=zeros_mode,
with_bias=with_bias,
).apply_config(
block_row_warps=2,
block_col_warps=2,
warp_row_tiles=32,
warp_col_tiles=32,
chunk=32,
num_stages=0,
enable_rasterization=False,
split_k_factor=split_k_factor,
)
mod, params = tl.lower(matmul)
src_code = mod.imported_modules[0].get_source()
# src_code is the generated cuda source
assert src_code is not None
input_shape = (M, K)
weight_shape = (N, K)
output_shape = (M, N)
bias_shape = (N, )
inputs = []
inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5)
maxq = 2**(bit - 1)
zeros = maxq
if source_format == "uint":
inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda())
elif source_format == "int":
inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda())
else:
raise NotImplementedError
bias = torch.ones(bias_shape, dtype=torch.float16).cuda()
inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda())
intweight = inputs[1]
intweight = intweight.cpu().to(torch.int8)
if source_format == "int":
intweight = intweight + maxq
if with_zeros:
inputs[1] = inputs[1] - zeros
permuted_inputs = []
permuted_inputs.append(inputs[0])
qw = general_compress(intweight.cpu().numpy(), source_bits=bit, storage_dtype=np.int8)
# lop3 transformation
if fast_decoding:
qw = interleave_weight(qw, bit, target_dtype=in_dtype)
permuted_inputs.append(torch.from_numpy(qw).cuda())
if with_scaling:
if group_size == -1:
group_size = K
permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda())
if with_zeros:
if zeros_mode == "original":
permuted_inputs.append(torch.randn((N, K // group_size), dtype=torch.float16).cuda())
elif zeros_mode == "rescale":
original_zeros = (torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros)
scaled_zeros = original_zeros * permuted_inputs[-1]
permuted_inputs.append(scaled_zeros)
elif zeros_mode == "quantized":
original_zeros = (torch.ones([K // group_size, N], dtype=torch.int8).cuda() * zeros)
qzeros = general_compress(
original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8)
permuted_inputs.append(torch.from_numpy(qzeros).cuda())
else:
raise NotImplementedError
if with_bias:
permuted_inputs.append(bias)
permuted_inputs.append(inputs[2])
mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer)
mod(*permuted_inputs)
print(permuted_inputs[-1])
args = [inputs[0]]
b = inputs[1]
if with_scaling:
scale = permuted_inputs[2]
rescale_b = torch.empty_like(b, dtype=torch.float16)
for i in range(N):
for j in range(K):
if with_zeros:
if zeros_mode == "original":
rescale_b[i, j] = (b[i, j] - zeros) * scale[i, j // group_size]
elif zeros_mode == "rescale":
rescale_b[i, j] = (b[i, j] * scale[i, j // group_size] + zeros)
else:
raise NotImplementedError
else:
rescale_b[i, j] = b[i, j] * scale[i, j // group_size]
args.append(rescale_b.t().cuda())
else:
args.append(b.t().cuda().to(torch.float16))
ref_result = torch.matmul(*args)
if with_bias:
ref_result = ref_result + bias
print(ref_result)
bitblas.testing.torch_assert_close(permuted_inputs[-1], ref_result, rtol=1e-1, atol=1e-1)
def test_matmul_fine_grained_dequant_with_default():
assert_matmul_fine_grained_dequant_with_default_correctness(
1024, 1024, 1024, source_format="uint", bit=4)
assert_matmul_fine_grained_dequant_with_default_correctness(
1024, 1024, 1024, source_format="uint", bit=2)
assert_matmul_fine_grained_dequant_with_default_correctness(
1024, 1024, 1024, source_format="uint", bit=4, with_scaling=True)
assert_matmul_fine_grained_dequant_with_default_correctness(
1024,
1024,
1024,
source_format="uint",
bit=4,
with_scaling=True,
with_zeros=True,
)
assert_matmul_fine_grained_dequant_with_default_correctness(
1024, 1024, 1024, source_format="uint", bit=4, fast_decoding=True)
assert_matmul_fine_grained_dequant_with_default_correctness(
1024,
1024,
1024,
source_format="uint",
bit=4,
with_scaling=True,
fast_decoding=True,
)
assert_matmul_fine_grained_dequant_with_default_correctness(
1024,
1024,
1024,
source_format="uint",
bit=4,
with_scaling=True,
with_zeros=True,
fast_decoding=True,
)
if __name__ == "__main__":
# non-splitk + non-bias
assert_matmul_fine_grained_dequant_with_default_correctness(
128, 128, 128, source_format="int", bit=4, with_bias=False, split_k_factor=1)
# non-splitk + bias
assert_matmul_fine_grained_dequant_with_default_correctness(
128, 128, 128, source_format="int", bit=4, with_bias=True, split_k_factor=1)
# atomicAdd + non-bias
assert_matmul_fine_grained_dequant_with_default_correctness(
128, 128, 128, source_format="int", bit=4, with_bias=False, split_k_factor=2)
# atomicAdd + bias
assert_matmul_fine_grained_dequant_with_default_correctness(
1024, 1024, 1024, source_format="uint", bit=4, with_bias=True, split_k_factor=2) |
Interesting bug and resolved, as blockIdx.z represents k-dimension during splitk implementation, so bias add must be done in only one blockZDim, otherwise bias will be added multiple times. |
Closed as be resolved in pr #270 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
have correctness issues while without atomicAdd it's correct.
currently we disable atomicAdd when we have bias to skip this situation.
Reproduce:
The text was updated successfully, but these errors were encountered: