8
8
9
9
namespace transformer_engine {
10
10
11
- template <typename DType, typename IndexType>
12
- __global__ void fused_aux_loss_forward_kernel (
13
- DType* probs,
14
- IndexType* tokens_per_expert,
15
- int num_tokens,
16
- int num_experts,
17
- int topk,
18
- float coeff,
19
- DType* aux_loss,
20
- float * Const_buf
21
- ){
22
- int warp_num = blockDim .x / kThreadsPerWarp ;
23
- int warp_id = threadIdx .x / kThreadsPerWarp ;
24
- int lane_id = threadIdx .x % kThreadsPerWarp ;
25
- extern __shared__ DType aggregated_probs_per_expert[];
26
- // Clear the shmem
27
- for (int i = threadIdx .x ; i < num_experts; i += blockDim .x ) {
28
- aggregated_probs_per_expert[i] = 0 ;
29
- }
30
- __syncthreads ();
31
-
32
- /* *
11
+ template <typename DType, typename IndexType>
12
+ __global__ void fused_aux_loss_forward_kernel (DType* probs, IndexType* tokens_per_expert,
13
+ int num_tokens, int num_experts, int topk,
14
+ float coeff, DType* aux_loss, float * Const_buf) {
15
+ int warp_num = blockDim .x / kThreadsPerWarp ;
16
+ int warp_id = threadIdx .x / kThreadsPerWarp ;
17
+ int lane_id = threadIdx .x % kThreadsPerWarp ;
18
+ extern __shared__ DType aggregated_probs_per_expert[];
19
+ // Clear the shmem
20
+ for (int i = threadIdx .x ; i < num_experts; i += blockDim .x ) {
21
+ aggregated_probs_per_expert[i] = 0 ;
22
+ }
23
+ __syncthreads ();
24
+
25
+ /* *
33
26
* Section: Reduce the probs to the aggregated_probs_per_expert
34
27
*/
35
- // Loop: for all positions in each row
36
- for (int i = lane_id; i < num_experts; i += kThreadsPerWarp ) {
37
- DType tmp = 0 ;
38
- // Loop: for all rows that this warp is responsible for
39
- for (int j = warp_id; j < num_tokens; j += warp_num) {
40
- tmp += probs[j * num_experts + i];
41
- }
42
- atomicAdd (&aggregated_probs_per_expert[i], tmp);
28
+ // Loop: for all positions in each row
29
+ for (int i = lane_id; i < num_experts; i += kThreadsPerWarp ) {
30
+ DType tmp = 0 ;
31
+ // Loop: for all rows that this warp is responsible for
32
+ for (int j = warp_id; j < num_tokens; j += warp_num) {
33
+ tmp += probs[j * num_experts + i];
43
34
}
44
- __syncthreads ();
35
+ atomicAdd (&aggregated_probs_per_expert[i], tmp);
36
+ }
37
+ __syncthreads ();
45
38
46
- /* *
39
+ /* *
47
40
* Section: aggregated_probs_per_expert * tokens_per_expert
48
41
* In-place update on shmem
49
42
*/
50
- for (int i = threadIdx .x ; i < num_experts; i += blockDim .x ) {
51
- aggregated_probs_per_expert[i] *= tokens_per_expert[i];
52
- }
53
- __syncthreads ();
43
+ for (int i = threadIdx .x ; i < num_experts; i += blockDim .x ) {
44
+ aggregated_probs_per_expert[i] *= tokens_per_expert[i];
45
+ }
46
+ __syncthreads ();
54
47
55
- if (warp_id == 0 ) {
56
- /* *
48
+ if (warp_id == 0 ) {
49
+ /* *
57
50
* Section: Reduce to get the sum of aggregated_probs_per_expert
58
51
*/
59
- DType intermediate_result = warp_reduce_on_shmem (aggregated_probs_per_expert, num_experts, sum, lane_id);
60
- __syncwarp ();
52
+ DType intermediate_result =
53
+ warp_reduce_on_shmem (aggregated_probs_per_expert, num_experts, sum, lane_id);
54
+ __syncwarp ();
61
55
62
- if (lane_id == 0 ) {
63
- /* *
56
+ if (lane_id == 0 ) {
57
+ /* *
64
58
* Section: Compute the aux_loss
65
59
*/
66
- float C_coeff = (num_experts * coeff) / topk / num_tokens / num_tokens;
67
- aux_loss[0 ] = intermediate_result * C_coeff;
68
- Const_buf[0 ] = C_coeff;
69
- }
60
+ float C_coeff = (num_experts * coeff) / topk / num_tokens / num_tokens;
61
+ aux_loss[0 ] = intermediate_result * C_coeff;
62
+ Const_buf[0 ] = C_coeff;
70
63
}
64
+ }
71
65
}
72
66
73
- template <typename DType, typename IndexType>
74
- void fused_aux_loss_forward_kernel_launcher (
75
- DType* probs,
76
- IndexType* tokens_per_expert,
77
- int num_tokens,
78
- int num_experts,
79
- int topk,
80
- float coeff,
81
- DType* aux_loss,
82
- float * Const_buf,
83
- cudaStream_t stream
84
- ){
85
- // Meta data for the kernel
86
- size_t shared_memory_size = sizeof (DType) * num_experts * 2 ;
87
- // Use Only 1 block/1024 threads to avoid the grid sync
88
- int grid_size = 1 ;
89
- int block_size = 1024 ;
90
- fused_aux_loss_forward_kernel<DType, IndexType><<<grid_size, block_size, shared_memory_size, stream>>> (
91
- probs,
92
- tokens_per_expert,
93
- num_tokens,
94
- num_experts,
95
- topk,
96
- coeff,
97
- aux_loss,
98
- Const_buf
99
- );
67
+ template <typename DType, typename IndexType>
68
+ void fused_aux_loss_forward_kernel_launcher (DType* probs, IndexType* tokens_per_expert,
69
+ int num_tokens, int num_experts, int topk, float coeff,
70
+ DType* aux_loss, float * Const_buf,
71
+ cudaStream_t stream) {
72
+ // Meta data for the kernel
73
+ size_t shared_memory_size = sizeof (DType) * num_experts * 2 ;
74
+ // Use Only 1 block/1024 threads to avoid the grid sync
75
+ int grid_size = 1 ;
76
+ int block_size = 1024 ;
77
+ fused_aux_loss_forward_kernel<DType, IndexType>
78
+ <<<grid_size, block_size, shared_memory_size, stream>>> (
79
+ probs, tokens_per_expert, num_tokens, num_experts, topk, coeff, aux_loss, Const_buf);
100
80
}
101
81
102
- void fused_aux_loss_forward (
103
- Tensor * probs,
104
- Tensor * tokens_per_expert,
105
- int num_tokens,
106
- int num_experts,
107
- int topk,
108
- float coeff,
109
- Tensor * aux_loss,
110
- Tensor * Const_buf,
111
- cudaStream_t stream
112
- ){
113
- TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT (
114
- probs->data .dtype , DType,
115
- TRANSFORMER_ENGINE_TYPE_SWITCH_INDEX (
116
- tokens_per_expert->data .dtype , IndexType,
117
- fused_aux_loss_forward_kernel_launcher<DType, IndexType>(
118
- reinterpret_cast <DType *>(probs->data .dptr ),
119
- reinterpret_cast <IndexType *>(tokens_per_expert->data .dptr ),
120
- num_tokens,
121
- num_experts,
122
- topk,
123
- coeff,
124
- reinterpret_cast <DType *>(aux_loss->data .dptr ),
125
- reinterpret_cast <float *>(Const_buf->data .dptr ),
126
- stream
127
- );
128
- );
129
- );
82
+ void fused_aux_loss_forward (Tensor* probs, Tensor* tokens_per_expert, int num_tokens,
83
+ int num_experts, int topk, float coeff, Tensor* aux_loss,
84
+ Tensor* Const_buf, cudaStream_t stream) {
85
+ TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT (
86
+ probs->data .dtype , DType,
87
+ TRANSFORMER_ENGINE_TYPE_SWITCH_INDEX (
88
+ tokens_per_expert->data .dtype , IndexType,
89
+ fused_aux_loss_forward_kernel_launcher<DType, IndexType>(
90
+ reinterpret_cast <DType*>(probs->data .dptr ),
91
+ reinterpret_cast <IndexType*>(tokens_per_expert->data .dptr ), num_tokens, num_experts,
92
+ topk, coeff, reinterpret_cast <DType*>(aux_loss->data .dptr ),
93
+ reinterpret_cast <float *>(Const_buf->data .dptr ), stream);););
130
94
}
131
95
132
- template <typename DType, typename IndexType>
133
- __global__ void fused_aux_loss_backward_kernel (
134
- float * Const_buf,
135
- IndexType* tokens_per_expert,
136
- int num_tokens,
137
- int num_experts,
138
- DType* grad_aux_loss,
139
- DType* grad_probs
140
- ){
141
- int global_warp_num = gridDim .x * blockDim .x / kThreadsPerWarp ;
142
- int global_warp_id = (blockIdx .x * blockDim .x + threadIdx .x ) / kThreadsPerWarp ;
143
- int lane_id = threadIdx .x % kThreadsPerWarp ;
144
-
145
- // Loop: for all positions in each row
146
- for (int i = lane_id; i < num_experts; i += kThreadsPerWarp ) {
147
- DType C_coeff = Const_buf[0 ];
148
- IndexType tokens_per_expert_i = tokens_per_expert[i];
149
- DType grad_aux_loss_value = grad_aux_loss[0 ];
150
- // Loop: for all rows
151
- for (int j = global_warp_id; j < num_tokens; j += global_warp_num) {
152
- grad_probs[j * num_experts + i] = C_coeff * tokens_per_expert_i * grad_aux_loss_value;
153
- }
96
+ template <typename DType, typename IndexType>
97
+ __global__ void fused_aux_loss_backward_kernel (float * Const_buf, IndexType* tokens_per_expert,
98
+ int num_tokens, int num_experts,
99
+ DType* grad_aux_loss, DType* grad_probs) {
100
+ int global_warp_num = gridDim .x * blockDim .x / kThreadsPerWarp ;
101
+ int global_warp_id = (blockIdx .x * blockDim .x + threadIdx .x ) / kThreadsPerWarp ;
102
+ int lane_id = threadIdx .x % kThreadsPerWarp ;
103
+
104
+ // Loop: for all positions in each row
105
+ for (int i = lane_id; i < num_experts; i += kThreadsPerWarp ) {
106
+ DType C_coeff = Const_buf[0 ];
107
+ IndexType tokens_per_expert_i = tokens_per_expert[i];
108
+ DType grad_aux_loss_value = grad_aux_loss[0 ];
109
+ // Loop: for all rows
110
+ for (int j = global_warp_id; j < num_tokens; j += global_warp_num) {
111
+ grad_probs[j * num_experts + i] = C_coeff * tokens_per_expert_i * grad_aux_loss_value;
154
112
}
113
+ }
155
114
}
156
115
157
- template <typename DType, typename IndexType>
158
- void fused_aux_loss_backward_kernel_launcher (
159
- float * Const_buf,
160
- IndexType* tokens_per_expert,
161
- int num_tokens,
162
- int num_experts,
163
- DType* grad_aux_loss,
164
- DType* grad_probs,
165
- cudaStream_t stream
166
- ){
167
- // Meta data for the kernel
168
- int block_size = 256 ;
169
- int grid_size = (num_tokens + block_size - 1 ) / block_size;
170
- fused_aux_loss_backward_kernel<DType, IndexType><<<grid_size, block_size, 0 , stream>>> (
171
- Const_buf,
172
- tokens_per_expert,
173
- num_tokens,
174
- num_experts,
175
- grad_aux_loss,
176
- grad_probs
177
- );
116
+ template <typename DType, typename IndexType>
117
+ void fused_aux_loss_backward_kernel_launcher (float * Const_buf, IndexType* tokens_per_expert,
118
+ int num_tokens, int num_experts, DType* grad_aux_loss,
119
+ DType* grad_probs, cudaStream_t stream) {
120
+ // Meta data for the kernel
121
+ int block_size = 256 ;
122
+ int grid_size = (num_tokens + block_size - 1 ) / block_size;
123
+ fused_aux_loss_backward_kernel<DType, IndexType><<<grid_size, block_size, 0 , stream>>> (
124
+ Const_buf, tokens_per_expert, num_tokens, num_experts, grad_aux_loss, grad_probs);
178
125
}
179
126
180
- void fused_aux_loss_backward (
181
- Tensor * Const_buf,
182
- Tensor * tokens_per_expert,
183
- int num_tokens,
184
- int num_experts,
185
- Tensor * grad_aux_loss,
186
- Tensor * grad_probs,
187
- cudaStream_t stream
188
- ){
189
- TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT (
190
- grad_aux_loss->data .dtype , DType,
191
- TRANSFORMER_ENGINE_TYPE_SWITCH_INDEX (
192
- tokens_per_expert->data .dtype , IndexType,
193
- fused_aux_loss_backward_kernel_launcher<DType, IndexType>(
194
- reinterpret_cast <float *>(Const_buf->data .dptr ),
195
- reinterpret_cast <IndexType *>(tokens_per_expert->data .dptr ),
196
- num_tokens,
197
- num_experts,
198
- reinterpret_cast <DType *>(grad_aux_loss->data .dptr ),
199
- reinterpret_cast <DType *>(grad_probs->data .dptr ),
200
- stream
201
- );
202
- );
203
- );
127
+ void fused_aux_loss_backward (Tensor* Const_buf, Tensor* tokens_per_expert, int num_tokens,
128
+ int num_experts, Tensor* grad_aux_loss, Tensor* grad_probs,
129
+ cudaStream_t stream) {
130
+ TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT (
131
+ grad_aux_loss->data .dtype , DType,
132
+ TRANSFORMER_ENGINE_TYPE_SWITCH_INDEX (
133
+ tokens_per_expert->data .dtype , IndexType,
134
+ fused_aux_loss_backward_kernel_launcher<DType, IndexType>(
135
+ reinterpret_cast <float *>(Const_buf->data .dptr ),
136
+ reinterpret_cast <IndexType*>(tokens_per_expert->data .dptr ), num_tokens, num_experts,
137
+ reinterpret_cast <DType*>(grad_aux_loss->data .dptr ),
138
+ reinterpret_cast <DType*>(grad_probs->data .dptr ), stream);););
204
139
}
205
140
206
- } // namespace transformer_engine
207
-
208
- void nvte_fused_aux_loss_forward (const NVTETensor probs, const NVTETensor tokens_per_expert, int num_tokens, int num_experts, int topk, float coeff, NVTETensor aux_loss, NVTETensor Const_buf, cudaStream_t stream){
209
- NVTE_API_CALL (nvte_fused_aux_loss_forward);
210
- using namespace transformer_engine ;
211
- fused_aux_loss_forward (
212
- convertNVTETensorCheck (probs),
213
- convertNVTETensorCheck (tokens_per_expert),
214
- num_tokens,
215
- num_experts,
216
- topk,
217
- coeff,
218
- convertNVTETensorCheck (aux_loss),
219
- convertNVTETensorCheck (Const_buf),
220
- stream
221
- );
141
+ } // namespace transformer_engine
142
+
143
+ void nvte_fused_aux_loss_forward (const NVTETensor probs, const NVTETensor tokens_per_expert,
144
+ int num_tokens, int num_experts, int topk, float coeff,
145
+ NVTETensor aux_loss, NVTETensor Const_buf, cudaStream_t stream) {
146
+ NVTE_API_CALL (nvte_fused_aux_loss_forward);
147
+ using namespace transformer_engine ;
148
+ fused_aux_loss_forward (convertNVTETensorCheck (probs), convertNVTETensorCheck (tokens_per_expert),
149
+ num_tokens, num_experts, topk, coeff, convertNVTETensorCheck (aux_loss),
150
+ convertNVTETensorCheck (Const_buf), stream);
222
151
}
223
152
224
- void nvte_fused_aux_loss_backward (const NVTETensor Const_buf, const NVTETensor tokens_per_expert, int num_tokens, int num_experts, NVTETensor grad_aux_loss, NVTETensor grad_probs, cudaStream_t stream){
225
- NVTE_API_CALL (nvte_fused_aux_loss_backward);
226
- using namespace transformer_engine ;
227
- fused_aux_loss_backward (
228
- convertNVTETensorCheck (Const_buf),
229
- convertNVTETensorCheck (tokens_per_expert),
230
- num_tokens,
231
- num_experts,
232
- convertNVTETensorCheck (grad_aux_loss),
233
- convertNVTETensorCheck (grad_probs),
234
- stream
235
- );
236
- }
153
+ void nvte_fused_aux_loss_backward (const NVTETensor Const_buf, const NVTETensor tokens_per_expert,
154
+ int num_tokens, int num_experts, NVTETensor grad_aux_loss,
155
+ NVTETensor grad_probs, cudaStream_t stream) {
156
+ NVTE_API_CALL (nvte_fused_aux_loss_backward);
157
+ using namespace transformer_engine ;
158
+ fused_aux_loss_backward (convertNVTETensorCheck (Const_buf),
159
+ convertNVTETensorCheck (tokens_per_expert), num_tokens, num_experts,
160
+ convertNVTETensorCheck (grad_aux_loss), convertNVTETensorCheck (grad_probs),
161
+ stream);
162
+ }
0 commit comments