Skip to content

Commit

Permalink
refactor math.hpp
Browse files Browse the repository at this point in the history
Signed-off-by: Sage Moore <[email protected]>
  • Loading branch information
SageMoore committed Dec 16, 2024
1 parent 58111a9 commit 9a18085
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
20 changes: 12 additions & 8 deletions csrc/core/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@ static inline constexpr auto div_ceil(A a, B b) {
return (a + b - 1) / b;
}

// Compute the next multiple of a that is greater than or equal to b
template <typename A, typename B>
static inline constexpr auto next_multiple_of(A a, B b) {
return div_ceil(b, a) * a;
// Round a down to the next multiple of b. The caller is responsible for making
// sure that b is non-zero
template <typename T>
inline constexpr T round_to_previous_multiple_of(T a, T b)
{
return a % b == 0 ? a : (a / b) * b;
}

// Compute the largest multiple of a that is less than or equal to b
template <typename A, typename B>
static inline constexpr auto prev_multiple_of(A a, B b) {
return (b / a) * a;
// Round a up to the next multiple of b. The caller is responsible for making
// sure that b is non-zero
template <typename T>
inline constexpr T round_to_next_multiple_of(T a, T b)
{
return a % b == 0 ? a : ((a / b) + 1) * b;
}
4 changes: 2 additions & 2 deletions csrc/quantization/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ __global__ void act_and_mul_quant_kernel(

const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token);
const int32_t elems_per_block =
next_multiple_of(elems_per_128bit_load, tgt_elems_per_block);
round_to_next_multiple_of(tgt_elems_per_block, elems_per_128bit_load);
const int32_t block_start = blockIdx.y * elems_per_block;
int32_t block_end = block_start + elems_per_block;
block_end = block_end > d ? d : block_end;
Expand All @@ -47,7 +47,7 @@ __global__ void act_and_mul_quant_kernel(

// 128-bit vectorized code
const int32_t vec_loop_end =
prev_multiple_of(elems_per_128bit_load, block_end);
round_to_previous_multiple_of(elems_per_128bit_load, block_end);
const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load;
const int32_t vec_start_idx = block_start / elems_per_128bit_load;

Expand Down

0 comments on commit 9a18085

Please sign in to comment.