diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp index bd5241c5703fc..6e79c94e52518 100644 --- a/csrc/core/math.hpp +++ b/csrc/core/math.hpp @@ -5,14 +5,18 @@ static inline constexpr auto div_ceil(A a, B b) { return (a + b - 1) / b; } -// Compute the next multiple of a that is greater than or equal to b -template -static inline constexpr auto next_multiple_of(A a, B b) { - return div_ceil(b, a) * a; +// Round a down to the next multiple of b. The caller is responsible for making +// sure that b is non-zero +template +inline constexpr T round_to_previous_multiple_of(T a, T b) +{ + return a % b == 0 ? a : (a / b) * b; } -// Compute the largest multiple of a that is less than or equal to b -template -static inline constexpr auto prev_multiple_of(A a, B b) { - return (b / a) * a; +// Round a up to the next multiple of b. The caller is responsible for making +// sure that b is non-zero +template +inline constexpr T round_to_next_multiple_of(T a, T b) +{ + return a % b == 0 ? a : ((a / b) + 1) * b; } \ No newline at end of file diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 91aa966ce7739..024331fa4e64e 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -36,7 +36,7 @@ __global__ void act_and_mul_quant_kernel( const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token); const int32_t elems_per_block = - next_multiple_of(elems_per_128bit_load, tgt_elems_per_block); + round_to_next_multiple_of(tgt_elems_per_block, elems_per_128bit_load); const int32_t block_start = blockIdx.y * elems_per_block; int32_t block_end = block_start + elems_per_block; block_end = block_end > d ? d : block_end; @@ -47,7 +47,7 @@ __global__ void act_and_mul_quant_kernel( // 128-bit vectorized code const int32_t vec_loop_end = - prev_multiple_of(elems_per_128bit_load, block_end); + round_to_previous_multiple_of(elems_per_128bit_load, block_end); const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load; const int32_t vec_start_idx = block_start / elems_per_128bit_load;