Skip to content

Commit 60d0142

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 2a542e2 commit 60d0142

File tree

9 files changed

+1151
-1545
lines changed

9 files changed

+1151
-1545
lines changed

transformer_engine/common/fused_router/fused_aux_loss.cu

Lines changed: 123 additions & 197 deletions
Original file line numberDiff line numberDiff line change
@@ -8,229 +8,155 @@
88

99
namespace transformer_engine {
1010

11-
template<typename DType, typename IndexType>
12-
__global__ void fused_aux_loss_forward_kernel(
13-
DType* probs,
14-
IndexType* tokens_per_expert,
15-
int num_tokens,
16-
int num_experts,
17-
int topk,
18-
float coeff,
19-
DType* aux_loss,
20-
float* Const_buf
21-
){
22-
int warp_num = blockDim.x / kThreadsPerWarp;
23-
int warp_id = threadIdx.x / kThreadsPerWarp;
24-
int lane_id = threadIdx.x % kThreadsPerWarp;
25-
extern __shared__ DType aggregated_probs_per_expert[];
26-
// Clear the shmem
27-
for(int i = threadIdx.x; i < num_experts; i += blockDim.x) {
28-
aggregated_probs_per_expert[i] = 0;
29-
}
30-
__syncthreads();
31-
32-
/**
11+
template <typename DType, typename IndexType>
12+
__global__ void fused_aux_loss_forward_kernel(DType* probs, IndexType* tokens_per_expert,
13+
int num_tokens, int num_experts, int topk,
14+
float coeff, DType* aux_loss, float* Const_buf) {
15+
int warp_num = blockDim.x / kThreadsPerWarp;
16+
int warp_id = threadIdx.x / kThreadsPerWarp;
17+
int lane_id = threadIdx.x % kThreadsPerWarp;
18+
extern __shared__ DType aggregated_probs_per_expert[];
19+
// Clear the shmem
20+
for (int i = threadIdx.x; i < num_experts; i += blockDim.x) {
21+
aggregated_probs_per_expert[i] = 0;
22+
}
23+
__syncthreads();
24+
25+
/**
3326
* Section: Reduce the probs to the aggregated_probs_per_expert
3427
*/
35-
// Loop: for all positions in each row
36-
for(int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
37-
DType tmp = 0;
38-
// Loop: for all rows that this warp is responsible for
39-
for(int j = warp_id; j < num_tokens; j += warp_num) {
40-
tmp += probs[j * num_experts + i];
41-
}
42-
atomicAdd(&aggregated_probs_per_expert[i], tmp);
28+
// Loop: for all positions in each row
29+
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
30+
DType tmp = 0;
31+
// Loop: for all rows that this warp is responsible for
32+
for (int j = warp_id; j < num_tokens; j += warp_num) {
33+
tmp += probs[j * num_experts + i];
4334
}
44-
__syncthreads();
35+
atomicAdd(&aggregated_probs_per_expert[i], tmp);
36+
}
37+
__syncthreads();
4538

46-
/**
39+
/**
4740
* Section: aggregated_probs_per_expert * tokens_per_expert
4841
* In-place update on shmem
4942
*/
50-
for(int i = threadIdx.x; i < num_experts; i += blockDim.x) {
51-
aggregated_probs_per_expert[i] *= tokens_per_expert[i];
52-
}
53-
__syncthreads();
43+
for (int i = threadIdx.x; i < num_experts; i += blockDim.x) {
44+
aggregated_probs_per_expert[i] *= tokens_per_expert[i];
45+
}
46+
__syncthreads();
5447

55-
if(warp_id == 0) {
56-
/**
48+
if (warp_id == 0) {
49+
/**
5750
* Section: Reduce to get the sum of aggregated_probs_per_expert
5851
*/
59-
DType intermediate_result = warp_reduce_on_shmem(aggregated_probs_per_expert, num_experts, sum, lane_id);
60-
__syncwarp();
52+
DType intermediate_result =
53+
warp_reduce_on_shmem(aggregated_probs_per_expert, num_experts, sum, lane_id);
54+
__syncwarp();
6155

62-
if(lane_id == 0) {
63-
/**
56+
if (lane_id == 0) {
57+
/**
6458
* Section: Compute the aux_loss
6559
*/
66-
float C_coeff = (num_experts * coeff) / topk / num_tokens / num_tokens;
67-
aux_loss[0] = intermediate_result * C_coeff;
68-
Const_buf[0] = C_coeff;
69-
}
60+
float C_coeff = (num_experts * coeff) / topk / num_tokens / num_tokens;
61+
aux_loss[0] = intermediate_result * C_coeff;
62+
Const_buf[0] = C_coeff;
7063
}
64+
}
7165
}
7266

73-
template<typename DType, typename IndexType>
74-
void fused_aux_loss_forward_kernel_launcher(
75-
DType* probs,
76-
IndexType* tokens_per_expert,
77-
int num_tokens,
78-
int num_experts,
79-
int topk,
80-
float coeff,
81-
DType* aux_loss,
82-
float* Const_buf,
83-
cudaStream_t stream
84-
){
85-
// Meta data for the kernel
86-
size_t shared_memory_size = sizeof(DType) * num_experts * 2;
87-
// Use Only 1 block/1024 threads to avoid the grid sync
88-
int grid_size = 1;
89-
int block_size = 1024;
90-
fused_aux_loss_forward_kernel<DType, IndexType><<<grid_size, block_size, shared_memory_size, stream>>>(
91-
probs,
92-
tokens_per_expert,
93-
num_tokens,
94-
num_experts,
95-
topk,
96-
coeff,
97-
aux_loss,
98-
Const_buf
99-
);
67+
template <typename DType, typename IndexType>
68+
void fused_aux_loss_forward_kernel_launcher(DType* probs, IndexType* tokens_per_expert,
69+
int num_tokens, int num_experts, int topk, float coeff,
70+
DType* aux_loss, float* Const_buf,
71+
cudaStream_t stream) {
72+
// Meta data for the kernel
73+
size_t shared_memory_size = sizeof(DType) * num_experts * 2;
74+
// Use Only 1 block/1024 threads to avoid the grid sync
75+
int grid_size = 1;
76+
int block_size = 1024;
77+
fused_aux_loss_forward_kernel<DType, IndexType>
78+
<<<grid_size, block_size, shared_memory_size, stream>>>(
79+
probs, tokens_per_expert, num_tokens, num_experts, topk, coeff, aux_loss, Const_buf);
10080
}
10181

102-
void fused_aux_loss_forward(
103-
Tensor * probs,
104-
Tensor * tokens_per_expert,
105-
int num_tokens,
106-
int num_experts,
107-
int topk,
108-
float coeff,
109-
Tensor * aux_loss,
110-
Tensor * Const_buf,
111-
cudaStream_t stream
112-
){
113-
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
114-
probs->data.dtype, DType,
115-
TRANSFORMER_ENGINE_TYPE_SWITCH_INDEX(
116-
tokens_per_expert->data.dtype, IndexType,
117-
fused_aux_loss_forward_kernel_launcher<DType, IndexType>(
118-
reinterpret_cast<DType *>(probs->data.dptr),
119-
reinterpret_cast<IndexType *>(tokens_per_expert->data.dptr),
120-
num_tokens,
121-
num_experts,
122-
topk,
123-
coeff,
124-
reinterpret_cast<DType *>(aux_loss->data.dptr),
125-
reinterpret_cast<float *>(Const_buf->data.dptr),
126-
stream
127-
);
128-
);
129-
);
82+
void fused_aux_loss_forward(Tensor* probs, Tensor* tokens_per_expert, int num_tokens,
83+
int num_experts, int topk, float coeff, Tensor* aux_loss,
84+
Tensor* Const_buf, cudaStream_t stream) {
85+
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
86+
probs->data.dtype, DType,
87+
TRANSFORMER_ENGINE_TYPE_SWITCH_INDEX(
88+
tokens_per_expert->data.dtype, IndexType,
89+
fused_aux_loss_forward_kernel_launcher<DType, IndexType>(
90+
reinterpret_cast<DType*>(probs->data.dptr),
91+
reinterpret_cast<IndexType*>(tokens_per_expert->data.dptr), num_tokens, num_experts,
92+
topk, coeff, reinterpret_cast<DType*>(aux_loss->data.dptr),
93+
reinterpret_cast<float*>(Const_buf->data.dptr), stream);););
13094
}
13195

132-
template<typename DType, typename IndexType>
133-
__global__ void fused_aux_loss_backward_kernel(
134-
float* Const_buf,
135-
IndexType* tokens_per_expert,
136-
int num_tokens,
137-
int num_experts,
138-
DType* grad_aux_loss,
139-
DType* grad_probs
140-
){
141-
int global_warp_num = gridDim.x * blockDim.x / kThreadsPerWarp;
142-
int global_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kThreadsPerWarp;
143-
int lane_id = threadIdx.x % kThreadsPerWarp;
144-
145-
// Loop: for all positions in each row
146-
for(int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
147-
DType C_coeff = Const_buf[0];
148-
IndexType tokens_per_expert_i = tokens_per_expert[i];
149-
DType grad_aux_loss_value = grad_aux_loss[0];
150-
// Loop: for all rows
151-
for(int j = global_warp_id; j < num_tokens; j += global_warp_num) {
152-
grad_probs[j * num_experts + i] = C_coeff * tokens_per_expert_i * grad_aux_loss_value;
153-
}
96+
template <typename DType, typename IndexType>
97+
__global__ void fused_aux_loss_backward_kernel(float* Const_buf, IndexType* tokens_per_expert,
98+
int num_tokens, int num_experts,
99+
DType* grad_aux_loss, DType* grad_probs) {
100+
int global_warp_num = gridDim.x * blockDim.x / kThreadsPerWarp;
101+
int global_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kThreadsPerWarp;
102+
int lane_id = threadIdx.x % kThreadsPerWarp;
103+
104+
// Loop: for all positions in each row
105+
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
106+
DType C_coeff = Const_buf[0];
107+
IndexType tokens_per_expert_i = tokens_per_expert[i];
108+
DType grad_aux_loss_value = grad_aux_loss[0];
109+
// Loop: for all rows
110+
for (int j = global_warp_id; j < num_tokens; j += global_warp_num) {
111+
grad_probs[j * num_experts + i] = C_coeff * tokens_per_expert_i * grad_aux_loss_value;
154112
}
113+
}
155114
}
156115

157-
template<typename DType, typename IndexType>
158-
void fused_aux_loss_backward_kernel_launcher(
159-
float* Const_buf,
160-
IndexType* tokens_per_expert,
161-
int num_tokens,
162-
int num_experts,
163-
DType* grad_aux_loss,
164-
DType* grad_probs,
165-
cudaStream_t stream
166-
){
167-
// Meta data for the kernel
168-
int block_size = 256;
169-
int grid_size = (num_tokens + block_size - 1) / block_size;
170-
fused_aux_loss_backward_kernel<DType, IndexType><<<grid_size, block_size, 0, stream>>>(
171-
Const_buf,
172-
tokens_per_expert,
173-
num_tokens,
174-
num_experts,
175-
grad_aux_loss,
176-
grad_probs
177-
);
116+
template <typename DType, typename IndexType>
117+
void fused_aux_loss_backward_kernel_launcher(float* Const_buf, IndexType* tokens_per_expert,
118+
int num_tokens, int num_experts, DType* grad_aux_loss,
119+
DType* grad_probs, cudaStream_t stream) {
120+
// Meta data for the kernel
121+
int block_size = 256;
122+
int grid_size = (num_tokens + block_size - 1) / block_size;
123+
fused_aux_loss_backward_kernel<DType, IndexType><<<grid_size, block_size, 0, stream>>>(
124+
Const_buf, tokens_per_expert, num_tokens, num_experts, grad_aux_loss, grad_probs);
178125
}
179126

180-
void fused_aux_loss_backward(
181-
Tensor * Const_buf,
182-
Tensor * tokens_per_expert,
183-
int num_tokens,
184-
int num_experts,
185-
Tensor * grad_aux_loss,
186-
Tensor * grad_probs,
187-
cudaStream_t stream
188-
){
189-
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
190-
grad_aux_loss->data.dtype, DType,
191-
TRANSFORMER_ENGINE_TYPE_SWITCH_INDEX(
192-
tokens_per_expert->data.dtype, IndexType,
193-
fused_aux_loss_backward_kernel_launcher<DType, IndexType>(
194-
reinterpret_cast<float *>(Const_buf->data.dptr),
195-
reinterpret_cast<IndexType *>(tokens_per_expert->data.dptr),
196-
num_tokens,
197-
num_experts,
198-
reinterpret_cast<DType *>(grad_aux_loss->data.dptr),
199-
reinterpret_cast<DType *>(grad_probs->data.dptr),
200-
stream
201-
);
202-
);
203-
);
127+
void fused_aux_loss_backward(Tensor* Const_buf, Tensor* tokens_per_expert, int num_tokens,
128+
int num_experts, Tensor* grad_aux_loss, Tensor* grad_probs,
129+
cudaStream_t stream) {
130+
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
131+
grad_aux_loss->data.dtype, DType,
132+
TRANSFORMER_ENGINE_TYPE_SWITCH_INDEX(
133+
tokens_per_expert->data.dtype, IndexType,
134+
fused_aux_loss_backward_kernel_launcher<DType, IndexType>(
135+
reinterpret_cast<float*>(Const_buf->data.dptr),
136+
reinterpret_cast<IndexType*>(tokens_per_expert->data.dptr), num_tokens, num_experts,
137+
reinterpret_cast<DType*>(grad_aux_loss->data.dptr),
138+
reinterpret_cast<DType*>(grad_probs->data.dptr), stream);););
204139
}
205140

206-
} // namespace transformer_engine
207-
208-
void nvte_fused_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert, int num_tokens, int num_experts, int topk, float coeff, NVTETensor aux_loss, NVTETensor Const_buf, cudaStream_t stream){
209-
NVTE_API_CALL(nvte_fused_aux_loss_forward);
210-
using namespace transformer_engine;
211-
fused_aux_loss_forward(
212-
convertNVTETensorCheck(probs),
213-
convertNVTETensorCheck(tokens_per_expert),
214-
num_tokens,
215-
num_experts,
216-
topk,
217-
coeff,
218-
convertNVTETensorCheck(aux_loss),
219-
convertNVTETensorCheck(Const_buf),
220-
stream
221-
);
141+
} // namespace transformer_engine
142+
143+
void nvte_fused_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert,
144+
int num_tokens, int num_experts, int topk, float coeff,
145+
NVTETensor aux_loss, NVTETensor Const_buf, cudaStream_t stream) {
146+
NVTE_API_CALL(nvte_fused_aux_loss_forward);
147+
using namespace transformer_engine;
148+
fused_aux_loss_forward(convertNVTETensorCheck(probs), convertNVTETensorCheck(tokens_per_expert),
149+
num_tokens, num_experts, topk, coeff, convertNVTETensorCheck(aux_loss),
150+
convertNVTETensorCheck(Const_buf), stream);
222151
}
223152

224-
void nvte_fused_aux_loss_backward(const NVTETensor Const_buf, const NVTETensor tokens_per_expert, int num_tokens, int num_experts, NVTETensor grad_aux_loss, NVTETensor grad_probs, cudaStream_t stream){
225-
NVTE_API_CALL(nvte_fused_aux_loss_backward);
226-
using namespace transformer_engine;
227-
fused_aux_loss_backward(
228-
convertNVTETensorCheck(Const_buf),
229-
convertNVTETensorCheck(tokens_per_expert),
230-
num_tokens,
231-
num_experts,
232-
convertNVTETensorCheck(grad_aux_loss),
233-
convertNVTETensorCheck(grad_probs),
234-
stream
235-
);
236-
}
153+
void nvte_fused_aux_loss_backward(const NVTETensor Const_buf, const NVTETensor tokens_per_expert,
154+
int num_tokens, int num_experts, NVTETensor grad_aux_loss,
155+
NVTETensor grad_probs, cudaStream_t stream) {
156+
NVTE_API_CALL(nvte_fused_aux_loss_backward);
157+
using namespace transformer_engine;
158+
fused_aux_loss_backward(convertNVTETensorCheck(Const_buf),
159+
convertNVTETensorCheck(tokens_per_expert), num_tokens, num_experts,
160+
convertNVTETensorCheck(grad_aux_loss), convertNVTETensorCheck(grad_probs),
161+
stream);
162+
}

0 commit comments

Comments
 (0)