Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Vectorized Bias Add with AtomicAdd may lead to unknown bugs #271

Closed
LeiWang1999 opened this issue Dec 16, 2024 · 3 comments
Closed

Comments

@LeiWang1999
Copy link
Contributor

  #pragma unroll
  for (int i_10 = 0; i_10 < 4; ++i_10) {
    __syncthreads();
    uint4 __1;
      uint4 v_ = *(uint4*)(((half_t*)buf_dyn_shmem) + (((i_10 * 1024) + (((int)threadIdx.x) * 8)) + 3072));
      uint4 v__1 = *(uint4*)(Bias + (((((int)blockIdx.x) * 64) + ((((int)threadIdx.x) >> 5) * 16)) + ((((int)threadIdx.x) & 1) * 8)));
      ((half2*)(&(__1.x)))->x = (((half2*)(&(v_.x)))->x+((half2*)(&(v__1.x)))->x);
      ((half2*)(&(__1.x)))->y = (((half2*)(&(v_.x)))->y+((half2*)(&(v__1.x)))->y);
      ((half2*)(&(__1.y)))->x = (((half2*)(&(v_.y)))->x+((half2*)(&(v__1.y)))->x);
      ((half2*)(&(__1.y)))->y = (((half2*)(&(v_.y)))->y+((half2*)(&(v__1.y)))->y);
      ((half2*)(&(__1.z)))->x = (((half2*)(&(v_.z)))->x+((half2*)(&(v__1.z)))->x);
      ((half2*)(&(__1.z)))->y = (((half2*)(&(v_.z)))->y+((half2*)(&(v__1.z)))->y);
      ((half2*)(&(__1.w)))->x = (((half2*)(&(v_.w)))->x+((half2*)(&(v__1.w)))->x);
      ((half2*)(&(__1.w)))->y = (((half2*)(&(v_.w)))->y+((half2*)(&(v__1.w)))->y);
    *(uint4*)(((half_t*)buf_dyn_shmem) + (((i_10 * 1024) + (((int)threadIdx.x) * 8)) + 3072)) = __1;
  }
  __syncthreads();
  #pragma unroll
  for (int i_11 = 0; i_11 < 16; ++i_11) {
    atomicAddx2((&(C[(((((((int)blockIdx.y) * 65536) + (i_11 * 4096)) + ((((int)threadIdx.x) >> 5) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 31) * 2))])), (&(((half_t*)buf_dyn_shmem)[(((((((i_11 >> 2) * 1024) + (((((int)threadIdx.x) & 31) >> 3) * 256)) + ((i_11 & 3) * 64)) + ((((int)threadIdx.x) >> 5) * 16)) + ((((int)threadIdx.x) & 7) * 2)) + 3072)])));
  }

have correctness issues while without atomicAdd it's correct.

  for (int i_14 = 0; i_14 < 4; ++i_14) {
    __syncthreads();
    uint4 __1;
      uint4 v_ = *(uint4*)(((half_t*)buf_dyn_shmem) + (((((i_14 * 1024) + (((((int)threadIdx.x) & 7) >> 1) * 256)) + ((((int)threadIdx.x) >> 3) * 16)) + ((((int)threadIdx.x) & 1) * 8)) + 3072));
      uint4 v__1 = *(uint4*)(Bias + ((((int)blockIdx.x) * 64) + ((((int)threadIdx.x) & 7) * 8)));
      ((half2*)(&(__1.x)))->x = (((half2*)(&(v_.x)))->x+((half2*)(&(v__1.x)))->x);
      ((half2*)(&(__1.x)))->y = (((half2*)(&(v_.x)))->y+((half2*)(&(v__1.x)))->y);
      ((half2*)(&(__1.y)))->x = (((half2*)(&(v_.y)))->x+((half2*)(&(v__1.y)))->x);
      ((half2*)(&(__1.y)))->y = (((half2*)(&(v_.y)))->y+((half2*)(&(v__1.y)))->y);
      ((half2*)(&(__1.z)))->x = (((half2*)(&(v_.z)))->x+((half2*)(&(v__1.z)))->x);
      ((half2*)(&(__1.z)))->y = (((half2*)(&(v_.z)))->y+((half2*)(&(v__1.z)))->y);
      ((half2*)(&(__1.w)))->x = (((half2*)(&(v_.w)))->x+((half2*)(&(v__1.w)))->x);
      ((half2*)(&(__1.w)))->y = (((half2*)(&(v_.w)))->y+((half2*)(&(v__1.w)))->y);
    *(uint4*)(((half_t*)buf_dyn_shmem) + (((((i_14 * 1024) + (((((int)threadIdx.x) & 7) >> 1) * 256)) + ((((int)threadIdx.x) >> 3) * 16)) + ((((int)threadIdx.x) & 1) * 8)) + 3072)) = __1;
  }
  __syncthreads();
  #pragma unroll
  for (int i_15 = 0; i_15 < 4; ++i_15) {
    *(uint4*)(C + (((((((int)blockIdx.y) * 65536) + (i_15 * 16384)) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8))) = *(uint4*)(((half_t*)buf_dyn_shmem) + (((((i_15 * 1024) + (((((int)threadIdx.x) & 7) >> 1) * 256)) + ((((int)threadIdx.x) >> 3) * 16)) + ((((int)threadIdx.x) & 1) * 8)) + 3072));
  }

currently we disable atomicAdd when we have bias to skip this situation.

Reproduce:

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import bitblas
import bitblas.testing
from bitblas import Linear as BitBLASLinear
import torch
import time
import numpy as np
import torch.nn as nn

torch.manual_seed(0)
bitblas.set_log_level("DEBUG")


def correctness_consistent(m, in_features, out_features, bias):
    linear_torch = (nn.Linear(in_features, out_features, bias=bias).to(torch.float16).cuda())
    linear_bitblas = BitBLASLinear(
        in_features,
        out_features,
        bias=bias,
        A_dtype="float16",
        W_dtype="float16",
        accum_dtype="float16",
        out_dtype="float16",
        opt_M=m,
    ).cuda()

    with torch.no_grad():
        linear_bitblas.load_and_transform_weight(linear_torch.weight.clone())
        if bias:
            linear_bitblas.bias = nn.Parameter(linear_torch.bias.clone())

    with torch.no_grad():
        if not isinstance(m, int):
            # When m is a list, average m
            m = sum(m) // len(m)
        input_data = torch.randn(m, in_features, dtype=torch.float16).cuda()
        output_torch = linear_torch(input_data)
        output_bitblas = linear_bitblas(input_data)
    print(output_torch)
    print(output_bitblas)
    bitblas.testing.torch_assert_close(output_torch, output_bitblas, rtol=1e-1, atol=1e-2)


def test_correctness_consistent():
    correctness_consistent(1, 1024, 1024, False)
    correctness_consistent(1, 1024, 1024, True)
    correctness_consistent(1024, 1024, 1024, True)
    correctness_consistent([1, 1024], 1024, 1024, True)


def correctness_weight_only_dequantize(
    m,
    in_features,
    out_features,
    bias,
    W_dtype,
    group_size,
    with_scaling,
    with_zeros,
    zeros_mode,
):
    import numpy as np
    from bitblas.quantization.utils import general_compress
    from bitblas.cache import global_operator_cache

    global_operator_cache.clear()
    linear_bitblas = BitBLASLinear(
        in_features,
        out_features,
        bias=bias,
        A_dtype="float16",
        W_dtype=W_dtype,
        accum_dtype="float16",
        out_dtype="float16",
        group_size=group_size,
        with_scaling=with_scaling,
        with_zeros=with_zeros,
        opt_M=m,
    ).cuda()
    if not isinstance(m, int):
        # average m
        m = sum(m) // len(m)
    input_shape = (m, in_features)
    weight_shape = (out_features, in_features)
    output_shape = (m, out_features)
    inputs = []
    inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5)
    source_format, bit = (
        linear_bitblas.bitblas_matmul.source_format,
        linear_bitblas.bitblas_matmul.bit,
    )

    maxq = 2**(bit - 1)
    zeros = maxq
    if source_format == "uint":
        inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda())
    elif source_format == "int":
        inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda())
    else:
        raise NotImplementedError

    inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda())

    intweight = inputs[1]
    intweight = intweight.cpu().to(torch.int8)
    if source_format == "int":
        intweight = intweight + maxq
    if with_zeros:
        inputs[1] = inputs[1] - zeros
    bias_tensor = torch.rand((output_shape[-1],), dtype=torch.float16).cuda()
    ref_result = torch.matmul(inputs[0], (inputs[1].t()).to(torch.float16))
    if bias:
        ref_result = ref_result + bias_tensor

    with torch.no_grad():
        permuted_inputs = []
        permuted_inputs.append(inputs[0])
        if linear_bitblas.bitblas_matmul.weight_transform is not None:
            permuted_inputs.append(
                linear_bitblas.bitblas_matmul.weight_transform(intweight.cpu()).cuda())
        else:
            permuted_inputs.append(inputs[1])
        linear_bitblas.qweight.data = permuted_inputs[-1].clone()
        if with_scaling:
            if group_size == -1:
                group_size = in_features
            permuted_inputs.append(
                torch.ones([out_features, in_features // group_size], dtype=torch.float16).cuda())
            linear_bitblas.scales.data = permuted_inputs[-1].clone()
        if with_zeros:
            if zeros_mode == "original":
                permuted_inputs.append(
                    torch.ones([out_features, in_features // group_size],
                               dtype=torch.float16).cuda() * zeros)
            elif zeros_mode == "rescale":
                original_zeros = (
                    torch.ones([out_features, in_features // group_size],
                               dtype=torch.float16).cuda() * zeros)
                scaled_zeros = original_zeros * permuted_inputs[-1]
                permuted_inputs.append(scaled_zeros)
            elif zeros_mode == "quantized":
                original_zeros = (
                    torch.ones([in_features // group_size, out_features], dtype=torch.int8).cuda() *
                    zeros)
                qzeros = general_compress(
                    original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8)
                permuted_inputs.append(torch.from_numpy(qzeros).cuda())
            else:
                raise NotImplementedError
            linear_bitblas.zeros.data = permuted_inputs[-1].clone()
        if bias:
            permuted_inputs.append(bias_tensor)
            linear_bitblas.bias.data = bias_tensor.clone()

    with torch.no_grad():
        output_bitblas = linear_bitblas(inputs[0])

    rtol = 1e0
    atol = 1e0
    if zeros_mode == "original":
        rtol = 1e2
        atol = 1e2
    print(output_bitblas)
    print(ref_result)
    torch.testing.assert_close(output_bitblas, ref_result, rtol=rtol, atol=atol)


def test_correctness_weight_only_dequantize():
    correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None)
    correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None)
    correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint4", -1, False, False, None)
    correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", -1, True, False, None)
    correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "original")
    correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint2", 128, True, True, "original")
    correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "rescale")


def profile(model, input_data):
    model = model.cuda()
    model.eval()

    def get_runtime(num_repeats=1):
        tic = time.time()
        for _ in range(num_repeats):
            _ = model(input_data)
        torch.cuda.synchronize()
        return (time.time() - tic) * 1000 / num_repeats

    with torch.no_grad():
        # print("Warming up ...")
        st = time.time()
        while time.time() - st < 1.0:
            get_runtime()  # warmup
        warmup_runtime = get_runtime()
        num_repeats = max(1, int(1000 / warmup_runtime))
        times = get_runtime(num_repeats)
    return np.mean(times)


if __name__ == "__main__":
    # bitblas.testing.main()
    correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint4", -1, False, False, None)
@LeiWang1999
Copy link
Contributor Author

More clean script to help reproduce:

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from bitblas import tvm as tvm
import bitblas.testing
from tvm import tl
from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import (
    MatmulBlockScheduler,
    MatmulFineGrainScheduler,
    MatmulWeightPropagationScheduler,
)

from bitblas.ops.general_matmul.tilelang.dequantize import (
    MatmulDequantizeScheduler,
    MatmulDequantizeFineGrainedScheduler,
    MatmulDequantizeWeightPropagationScheduler,
    MatmulINT4DequantizeFineGrainedScheduler,
    MatmulINT4DequantizeWeightPropagationScheduler,
)

from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import (
    MatmulINT4FineGrainScheduler,
    MatmulINT4WeightPropagationScheduler,
)

import torch
import torch.backends

torch.manual_seed(0)

verbose = False


def assert_matmul_fine_grained_dequant_with_default_correctness(
    M,
    N,
    K,
    trans_A=False,
    trans_B=True,
    in_dtype="float16",
    out_dtype="float16",
    accum_dtype="float16",
    bit=4,
    storage_dtype="int8",
    source_format="uint",
    with_scaling=False,
    with_zeros=False,
    group_size=-1,
    fast_decoding=False,
    zeros_mode="original",
    with_bias=False,
    split_k_factor=1,
):
    import numpy as np
    from bitblas.quantization import general_compress, interleave_weight

    matmul = MatmulDequantizeFineGrainedScheduler(
        M=M,
        N=N,
        K=K,
        trans_A=trans_A,
        trans_B=trans_B,
        in_dtype=in_dtype,
        out_dtype=out_dtype,
        accum_dtype=accum_dtype,
        num_bits=bit,
        storage_dtype=storage_dtype,
        source_format=source_format,
        with_scaling=with_scaling,
        with_zeros=with_zeros,
        group_size=group_size,
        fast_decoding=fast_decoding,
        zeros_mode=zeros_mode,
        with_bias=with_bias,
    ).apply_config(
        block_row_warps=2,
        block_col_warps=2,
        warp_row_tiles=32,
        warp_col_tiles=32,
        chunk=32,
        num_stages=0,
        enable_rasterization=False,
        split_k_factor=split_k_factor,
    )

    mod, params = tl.lower(matmul)
    src_code = mod.imported_modules[0].get_source()
    # src_code is the generated cuda source
    assert src_code is not None
    input_shape = (M, K)
    weight_shape = (N, K)
    output_shape = (M, N)
    bias_shape = (N, )
    inputs = []
    inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5)
    maxq = 2**(bit - 1)
    zeros = maxq
    if source_format == "uint":
        inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda())
    elif source_format == "int":
        inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda())
    else:
        raise NotImplementedError
    bias = torch.ones(bias_shape, dtype=torch.float16).cuda()

    inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda())

    intweight = inputs[1]
    intweight = intweight.cpu().to(torch.int8)
    if source_format == "int":
        intweight = intweight + maxq
    if with_zeros:
        inputs[1] = inputs[1] - zeros

    permuted_inputs = []
    permuted_inputs.append(inputs[0])
    qw = general_compress(intweight.cpu().numpy(), source_bits=bit, storage_dtype=np.int8)
    # lop3 transformation
    if fast_decoding:
        qw = interleave_weight(qw, bit, target_dtype=in_dtype)
    permuted_inputs.append(torch.from_numpy(qw).cuda())
    if with_scaling:
        if group_size == -1:
            group_size = K
        permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda())
    if with_zeros:
        if zeros_mode == "original":
            permuted_inputs.append(torch.randn((N, K // group_size), dtype=torch.float16).cuda())
        elif zeros_mode == "rescale":
            original_zeros = (torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros)
            scaled_zeros = original_zeros * permuted_inputs[-1]
            permuted_inputs.append(scaled_zeros)
        elif zeros_mode == "quantized":
            original_zeros = (torch.ones([K // group_size, N], dtype=torch.int8).cuda() * zeros)
            qzeros = general_compress(
                original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8)
            permuted_inputs.append(torch.from_numpy(qzeros).cuda())
        else:
            raise NotImplementedError
    
    if with_bias:
        permuted_inputs.append(bias)

    permuted_inputs.append(inputs[2])

    mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer)

    mod(*permuted_inputs)

    print(permuted_inputs[-1])

    args = [inputs[0]]
    b = inputs[1]
    if with_scaling:
        scale = permuted_inputs[2]
        rescale_b = torch.empty_like(b, dtype=torch.float16)
        for i in range(N):
            for j in range(K):
                if with_zeros:
                    if zeros_mode == "original":
                        rescale_b[i, j] = (b[i, j] - zeros) * scale[i, j // group_size]
                    elif zeros_mode == "rescale":
                        rescale_b[i, j] = (b[i, j] * scale[i, j // group_size] + zeros)
                    else:
                        raise NotImplementedError
                else:
                    rescale_b[i, j] = b[i, j] * scale[i, j // group_size]
        args.append(rescale_b.t().cuda())
    else:
        args.append(b.t().cuda().to(torch.float16))

    ref_result = torch.matmul(*args)
    if with_bias:
        ref_result = ref_result + bias
    
    print(ref_result)
    bitblas.testing.torch_assert_close(permuted_inputs[-1], ref_result, rtol=1e-1, atol=1e-1)


def test_matmul_fine_grained_dequant_with_default():
    assert_matmul_fine_grained_dequant_with_default_correctness(
        1024, 1024, 1024, source_format="uint", bit=4)
    assert_matmul_fine_grained_dequant_with_default_correctness(
        1024, 1024, 1024, source_format="uint", bit=2)
    assert_matmul_fine_grained_dequant_with_default_correctness(
        1024, 1024, 1024, source_format="uint", bit=4, with_scaling=True)
    assert_matmul_fine_grained_dequant_with_default_correctness(
        1024,
        1024,
        1024,
        source_format="uint",
        bit=4,
        with_scaling=True,
        with_zeros=True,
    )
    assert_matmul_fine_grained_dequant_with_default_correctness(
        1024, 1024, 1024, source_format="uint", bit=4, fast_decoding=True)
    assert_matmul_fine_grained_dequant_with_default_correctness(
        1024,
        1024,
        1024,
        source_format="uint",
        bit=4,
        with_scaling=True,
        fast_decoding=True,
    )
    assert_matmul_fine_grained_dequant_with_default_correctness(
        1024,
        1024,
        1024,
        source_format="uint",
        bit=4,
        with_scaling=True,
        with_zeros=True,
        fast_decoding=True,
    )



if __name__ == "__main__":
    # non-splitk + non-bias
    assert_matmul_fine_grained_dequant_with_default_correctness(
        128, 128, 128, source_format="int", bit=4, with_bias=False, split_k_factor=1)
    # non-splitk + bias
    assert_matmul_fine_grained_dequant_with_default_correctness(
        128, 128, 128, source_format="int", bit=4, with_bias=True, split_k_factor=1)
    # atomicAdd + non-bias
    assert_matmul_fine_grained_dequant_with_default_correctness(
        128, 128, 128, source_format="int", bit=4, with_bias=False, split_k_factor=2)
    # atomicAdd + bias
    assert_matmul_fine_grained_dequant_with_default_correctness(
        1024, 1024, 1024, source_format="uint", bit=4, with_bias=True, split_k_factor=2)

@LeiWang1999
Copy link
Contributor Author

Interesting bug and resolved, as blockIdx.z represents k-dimension during splitk implementation, so bias add must be done in only one blockZDim, otherwise bias will be added multiple times.

@LeiWang1999
Copy link
Contributor Author

Closed as be resolved in pr #270

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant