1414 * See the License for the specific language governing permissions and
1515 * limitations under the License.
1616 */
17+ #include " tensorrt_llm/kernels/moe_utils.cuh"
1718#include " tensorrt_llm/kernels/preQuantScaleKernel.h"
1819
1920namespace tensorrt_llm
@@ -41,7 +42,7 @@ struct Vec2Type<__nv_bfloat16>
4142
4243template <typename T_in, typename T_out, int kProcessRows , typename AccessType>
4344__global__ void apply_per_channel_scale (T_out* smoothed_act, T_in const * act, T_in const * per_channel_scale, int rows,
44- int cols, int64_t const * num_valid_tokens_ptr)
45+ int cols, int64_t const * num_valid_tokens_ptr, int64_t * expert_first_token_offset, int const num_experts_per_node )
4546{
4647 static constexpr int kElems = sizeof (AccessType) / sizeof (T_in);
4748 T_in scale[kElems ], act_vec[kElems ];
@@ -53,11 +54,19 @@ __global__ void apply_per_channel_scale(T_out* smoothed_act, T_in const* act, T_
5354 return ;
5455 act += row_offset * kProcessRows * cols;
5556 smoothed_act += row_offset * kProcessRows * cols;
56- *reinterpret_cast <AccessType*>(scale) = reinterpret_cast <AccessType const *>(per_channel_scale)[col_offset];
5757#pragma unroll
5858 for (int i = 0 ; i < kProcessRows ; ++i)
5959 {
6060 *reinterpret_cast <AccessType*>(act_vec) = reinterpret_cast <AccessType const *>(act + i * cols)[col_offset];
61+ int expert = 0 ;
62+ if (expert_first_token_offset != nullptr )
63+ {
64+ expert = findTotalEltsLessThanTarget (
65+ expert_first_token_offset, num_experts_per_node, (int64_t ) row_offset * kProcessRows + i + 1 )
66+ - 1 ;
67+ }
68+ *reinterpret_cast <AccessType*>(scale)
69+ = reinterpret_cast <AccessType const *>(per_channel_scale)[expert * cols / kElems + col_offset];
6170 if constexpr ((std::is_same_v<T_in, half>
6271#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
6372 || std::is_same_v<T_in, __nv_bfloat16>
@@ -98,13 +107,14 @@ __global__ void apply_per_channel_scale(T_out* smoothed_act, T_in const* act, T_
98107
99108template <typename T_in, typename T_out, int kProcessRows , typename AccessType = float4 >
100109void apply_per_channel_scale_kernel_launcher_ (T_out* smoothed_act, T_in const * act, T_in const * per_channel_scale,
101- int rows, int cols, int64_t const * num_valid_tokens_ptr = nullptr , cudaStream_t stream = 0 )
110+ int rows, int cols, int64_t const * num_valid_tokens_ptr = nullptr , cudaStream_t stream = 0 ,
111+ int64_t * expert_first_token_offset = nullptr , int const num_experts_per_node = 0 )
102112{
103113 static constexpr int kElems = sizeof (AccessType) / sizeof (T_in);
104114 dim3 block (128 );
105115 dim3 grid ((rows + kProcessRows - 1 ) / kProcessRows , (cols / kElems + block.x - 1 ) / block.x );
106- apply_per_channel_scale<T_in, T_out, kProcessRows , AccessType>
107- <<<grid, block, 0 , stream>>> (smoothed_act, act, per_channel_scale, rows, cols, num_valid_tokens_ptr);
116+ apply_per_channel_scale<T_in, T_out, kProcessRows , AccessType><<<grid, block, 0 , stream>>> (smoothed_act, act,
117+ per_channel_scale, rows, cols, num_valid_tokens_ptr, expert_first_token_offset, num_experts_per_node );
108118}
109119
110120template <typename T_in, typename T_out>
@@ -134,6 +144,34 @@ void apply_per_channel_scale_kernel_launcher(T_out* smoothed_act, T_in const* ac
134144 }
135145}
136146
147+ template <typename T_in, typename T_out>
148+ void apply_per_channel_scale_per_expert_kernel_launcher (T_out* smoothed_act, T_in const * act,
149+ T_in const * per_channel_scale, int rows, int cols, int64_t * expert_first_token_offset,
150+ int const num_experts_per_node, int64_t const * num_valid_tokens_ptr, cudaStream_t stream)
151+ {
152+ uint64_t elems = static_cast <uint64_t >(rows) * static_cast <uint64_t >(cols);
153+ if (elems < 2048 * 2048 )
154+ {
155+ apply_per_channel_scale_kernel_launcher_<T_in, T_out, 1 , float4 >(smoothed_act, act, per_channel_scale, rows,
156+ cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
157+ }
158+ else if (elems < 4096 * 4096 )
159+ {
160+ apply_per_channel_scale_kernel_launcher_<T_in, T_out, 4 , float4 >(smoothed_act, act, per_channel_scale, rows,
161+ cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
162+ }
163+ else if (elems < 8192 * 8192 )
164+ {
165+ apply_per_channel_scale_kernel_launcher_<T_in, T_out, 8 , float4 >(smoothed_act, act, per_channel_scale, rows,
166+ cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
167+ }
168+ else
169+ {
170+ apply_per_channel_scale_kernel_launcher_<T_in, T_out, 16 , float4 >(smoothed_act, act, per_channel_scale, rows,
171+ cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
172+ }
173+ }
174+
137175#define INSTANTIATE_PREQUANT_SCALE (T_in, T_out ) \
138176 template void apply_per_channel_scale_kernel_launcher<T_in, T_out>(T_out * smoothed_act, const T_in* act, \
139177 const T_in* per_channel_scale, int rows, int cols, int64_t const * num_valid_tokens_ptr, cudaStream_t stream)
@@ -150,5 +188,22 @@ INSTANTIATE_PREQUANT_SCALE(__nv_bfloat16, __nv_fp8_e4m3);
150188#endif
151189#endif
152190
191+ #define INSTANTIATE_PREQUANT_SCALE_PER_EXPERT (T_in, T_out ) \
192+ template void apply_per_channel_scale_per_expert_kernel_launcher<T_in, T_out>(T_out * smoothed_act, \
193+ const T_in* act, const T_in* per_channel_scale, int rows, int cols, int64_t * expert_first_token_offset, \
194+ int const num_experts_per_node, int64_t const * num_valid_tokens_ptr, cudaStream_t stream)
195+
196+ INSTANTIATE_PREQUANT_SCALE_PER_EXPERT (half, half);
197+ #if defined(ENABLE_FP8)
198+ INSTANTIATE_PREQUANT_SCALE_PER_EXPERT (half, __nv_fp8_e4m3);
199+ #endif
200+
201+ #if defined(ENABLE_BF16)
202+ INSTANTIATE_PREQUANT_SCALE_PER_EXPERT (__nv_bfloat16, __nv_bfloat16);
203+ #if defined(ENABLE_FP8)
204+ INSTANTIATE_PREQUANT_SCALE_PER_EXPERT (__nv_bfloat16, __nv_fp8_e4m3);
205+ #endif
206+ #endif
207+
153208} // namespace kernels
154209} // namespace tensorrt_llm
0 commit comments