From b577710246156696631ce43b8991695ec3f44bef Mon Sep 17 00:00:00 2001 From: Atream <80757050+Atream@users.noreply.github.com> Date: Wed, 4 Dec 2024 17:36:12 +0800 Subject: [PATCH] bugfix: fix smem_size in FusedAddRMSNorm which is missed in PR #636 (#646) Fix smem_size in FusedAddRMSNorm which is missed in #636 Fix issue #645 --- include/flashinfer/norm.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/flashinfer/norm.cuh b/include/flashinfer/norm.cuh index 6872261c..ee807ab0 100644 --- a/include/flashinfer/norm.cuh +++ b/include/flashinfer/norm.cuh @@ -207,7 +207,7 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz const uint32_t num_warps = ceil_div(block_size, 32); dim3 nblks(batch_size); dim3 nthrs(32, num_warps); - const uint32_t smem_size = (num_warps + d) * sizeof(float); + const uint32_t smem_size = (ceil_div(num_warps, 4) * 4 + d) * sizeof(float); float weight_bias = 0.f; void* args[] = {&input, &residual, &weight, &d, &weight_bias, &eps};