diff --git a/include/flashinfer/norm.cuh b/include/flashinfer/norm.cuh index 04c5b686..6872261c 100644 --- a/include/flashinfer/norm.cuh +++ b/include/flashinfer/norm.cuh @@ -127,6 +127,7 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res const uint32_t num_threads = num_warps * warp_size; const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads); extern __shared__ float smem[]; + float* smem_x = smem + ceil_div(num_warps, 4) * 4; float sum_sq = 0.f; @@ -151,7 +152,7 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res } if ((i * num_threads + thread_id) * VEC_SIZE < d) { residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - x_vec.store(smem + num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + x_vec.store(smem_x + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); } } @@ -185,7 +186,7 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res x_vec.fill(0.f); if ((i * num_threads + thread_id) * VEC_SIZE < d) { weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - x_vec.load(smem + num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + x_vec.load(smem_x + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; j++) { @@ -247,7 +248,8 @@ cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batc const uint32_t num_warps = ceil_div(block_size, 32); dim3 nblks(batch_size); dim3 nthrs(32, num_warps); - const uint32_t smem_size = (num_warps + d) * sizeof(float); + // NOTE(Zihao): use ceil_div(num_warps, 4) * 4 for address alignment to 16 bytes + const uint32_t smem_size = (ceil_div(num_warps, 4) * 4 + d) * sizeof(float); float weight_bias = 1.f; void* args[] = {&input, &residual, &weight, &d, &weight_bias, &eps}; diff --git a/tests/test_norm.py b/tests/test_norm.py index e60127a0..8827f5c8 100644 --- a/tests/test_norm.py +++ b/tests/test_norm.py @@ -65,7 +65,7 @@ def fused_add_rms_norm(x, residual, weight, eps): @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192]) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("specify_out", [True, False]) def test_norm(batch_size, hidden_size, dtype, specify_out): @@ -83,7 +83,7 @@ def test_norm(batch_size, hidden_size, dtype, specify_out): @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192]) @pytest.mark.parametrize("dtype", [torch.float16]) def test_fused_add_rmsnorm(batch_size, hidden_size, dtype): eps = 1e-6 @@ -105,7 +105,7 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype): @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192]) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("specify_out", [True, False]) def test_gemma_norm(batch_size, hidden_size, dtype, specify_out): @@ -123,7 +123,7 @@ def test_gemma_norm(batch_size, hidden_size, dtype, specify_out): @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192]) @pytest.mark.parametrize("dtype", [torch.float16]) def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype): eps = 1e-6