Skip to content

Commit

Permalink
bugfix: fix the misaligned address bug of norm kernels for certain sh…
Browse files Browse the repository at this point in the history
…apes (#636)

This PR fixes the issue #634, which is brought by #592 .
If we want to use 16-bytes vectorized read/write, we need to confirm the
address is aligned to 16 bytes.
When `num_warps` is not a multiple of 4 (4*sizeof(float) = 16), the
address of `smem + num_warps` might not align to 16 bytes.

We can fix this by shifting the start offset of vectorized read/write to
`smem + ceil_div(num_warps, 4) * 4` to force the alignment.

cc @ovowei @Abatom
  • Loading branch information
yzh119 authored Nov 25, 2024
1 parent ae501ed commit db9c48d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
8 changes: 5 additions & 3 deletions include/flashinfer/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
const uint32_t num_threads = num_warps * warp_size;
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
extern __shared__ float smem[];
float* smem_x = smem + ceil_div(num_warps, 4) * 4;

float sum_sq = 0.f;

Expand All @@ -151,7 +152,7 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
x_vec.store(smem + num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
x_vec.store(smem_x + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}

Expand Down Expand Up @@ -185,7 +186,7 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
x_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
x_vec.load(smem + num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
x_vec.load(smem_x + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
Expand Down Expand Up @@ -247,7 +248,8 @@ cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batc
const uint32_t num_warps = ceil_div(block_size, 32);
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = (num_warps + d) * sizeof(float);
// NOTE(Zihao): use ceil_div(num_warps, 4) * 4 for address alignment to 16 bytes
const uint32_t smem_size = (ceil_div(num_warps, 4) * 4 + d) * sizeof(float);
float weight_bias = 1.f;
void* args[] = {&input, &residual, &weight, &d, &weight_bias, &eps};

Expand Down
8 changes: 4 additions & 4 deletions tests/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def fused_add_rms_norm(x, residual, weight, eps):


@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("specify_out", [True, False])
def test_norm(batch_size, hidden_size, dtype, specify_out):
Expand All @@ -83,7 +83,7 @@ def test_norm(batch_size, hidden_size, dtype, specify_out):


@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
eps = 1e-6
Expand All @@ -105,7 +105,7 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):


@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("specify_out", [True, False])
def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):
Expand All @@ -123,7 +123,7 @@ def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):


@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
eps = 1e-6
Expand Down

0 comments on commit db9c48d

Please sign in to comment.