Skip to content

Commit

Permalink
bugfix: fix smem_size in FusedAddRMSNorm which is missed in PR #636 (#…
Browse files Browse the repository at this point in the history
…646)

Fix smem_size in FusedAddRMSNorm which is missed in #636 
Fix issue #645
  • Loading branch information
Atream authored Dec 4, 2024
1 parent 6819a0f commit b577710
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion include/flashinfer/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
const uint32_t num_warps = ceil_div(block_size, 32);
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = (num_warps + d) * sizeof(float);
const uint32_t smem_size = (ceil_div(num_warps, 4) * 4 + d) * sizeof(float);
float weight_bias = 0.f;
void* args[] = {&input, &residual, &weight, &d, &weight_bias, &eps};

Expand Down

0 comments on commit b577710

Please sign in to comment.