Skip to content

Commit baed11c

Browse files
author
tongliu
committed
fix bug on seq aux loss
Signed-off-by: tongliu <[email protected]>
1 parent 752c351 commit baed11c

File tree

7 files changed

+28
-24
lines changed

7 files changed

+28
-24
lines changed

tests/pytorch/test_fused_router.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import math
33
from typing import Optional, Dict
4-
from transformer_engine.pytorch.router_func import (
4+
from transformer_engine.pytorch.router import (
55
fused_topk_softmax_sigmoid,
66
fused_compute_scores_for_aux_loss,
77
fused_aux_loss,
@@ -199,7 +199,7 @@ def run_comparison(
199199
expert_bias=expert_bias_clone,
200200
)
201201

202-
assert torch.allclose(probs, probs_fused, atol=atol, rtol=rtol)
202+
assert torch.allclose(probs, probs_fused, atol=atol, rtol=rtol), f"probs are not close: {probs} != {probs_fused}"
203203
assert torch.allclose(routing_map, routing_map_fused, atol=atol, rtol=rtol)
204204

205205
# Fake the loss
@@ -342,7 +342,7 @@ def test_fused_aux_loss(dtype, num_tokens, num_experts, topk):
342342
aux_loss_fused = fused_aux_loss(
343343
probs=probs_clone,
344344
tokens_per_expert=tokens_per_expert,
345-
num_tokens=num_tokens,
345+
total_num_tokens=num_tokens,
346346
num_experts=num_experts,
347347
topk=topk,
348348
coeff=coeff,

transformer_engine/common/fused_router/fused_aux_loss.cu

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace transformer_engine {
1111

1212
template <typename DataType, typename IndexType>
1313
__global__ void fused_aux_loss_forward_kernel(const DataType* probs,
14-
const IndexType* tokens_per_expert, int num_tokens,
14+
const IndexType* tokens_per_expert, int total_num_tokens, int num_tokens,
1515
int num_experts, int topk, float coeff,
1616
DataType* aux_loss, float* Const_buf) {
1717
int warp_num = blockDim.x / kThreadsPerWarp;
@@ -61,7 +61,7 @@ __global__ void fused_aux_loss_forward_kernel(const DataType* probs,
6161
/**
6262
* Section: Compute the aux_loss
6363
*/
64-
float C_coeff = (num_experts * coeff) / topk / num_tokens / num_tokens;
64+
float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens;
6565
aux_loss[0] = DataType(double(intermediate_result) * C_coeff);
6666
Const_buf[0] = C_coeff;
6767
}
@@ -70,7 +70,7 @@ __global__ void fused_aux_loss_forward_kernel(const DataType* probs,
7070

7171
template <typename DataType, typename IndexType>
7272
void fused_aux_loss_forward_kernel_launcher(const DataType* probs,
73-
const IndexType* tokens_per_expert, int num_tokens,
73+
const IndexType* tokens_per_expert, int total_num_tokens, int num_tokens,
7474
int num_experts, int topk, float coeff,
7575
DataType* aux_loss, float* Const_buf,
7676
cudaStream_t stream) {
@@ -81,10 +81,10 @@ void fused_aux_loss_forward_kernel_launcher(const DataType* probs,
8181
int block_size = 1024;
8282
fused_aux_loss_forward_kernel<DataType, IndexType>
8383
<<<grid_size, block_size, shared_memory_size, stream>>>(
84-
probs, tokens_per_expert, num_tokens, num_experts, topk, coeff, aux_loss, Const_buf);
84+
probs, tokens_per_expert, total_num_tokens, num_tokens, num_experts, topk, coeff, aux_loss, Const_buf);
8585
}
8686

87-
void fused_aux_loss_forward(const Tensor& probs, const Tensor& tokens_per_expert, int num_tokens,
87+
void fused_aux_loss_forward(const Tensor& probs, const Tensor& tokens_per_expert, int total_num_tokens, int num_tokens,
8888
int num_experts, int topk, float coeff, Tensor& aux_loss,
8989
Tensor& Const_buf, cudaStream_t stream) {
9090
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
@@ -93,7 +93,7 @@ void fused_aux_loss_forward(const Tensor& probs, const Tensor& tokens_per_expert
9393
tokens_per_expert.data.dtype, IndexType,
9494
fused_aux_loss_forward_kernel_launcher<DataType, IndexType>(
9595
reinterpret_cast<DataType*>(probs.data.dptr),
96-
reinterpret_cast<IndexType*>(tokens_per_expert.data.dptr), num_tokens, num_experts,
96+
reinterpret_cast<IndexType*>(tokens_per_expert.data.dptr), total_num_tokens, num_tokens, num_experts,
9797
topk, coeff, reinterpret_cast<DataType*>(aux_loss.data.dptr),
9898
reinterpret_cast<float*>(Const_buf.data.dptr), stream);););
9999
}
@@ -148,12 +148,13 @@ void fused_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_per_e
148148
} // namespace transformer_engine
149149

150150
void nvte_fused_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert,
151+
int total_num_tokens,
151152
int num_tokens, int num_experts, int topk, float coeff,
152153
NVTETensor aux_loss, NVTETensor Const_buf, cudaStream_t stream) {
153154
NVTE_API_CALL(nvte_fused_aux_loss_forward);
154155
using namespace transformer_engine;
155156
fused_aux_loss_forward(*convertNVTETensorCheck(probs), *convertNVTETensorCheck(tokens_per_expert),
156-
num_tokens, num_experts, topk, coeff, *convertNVTETensorCheck(aux_loss),
157+
total_num_tokens, num_tokens, num_experts, topk, coeff, *convertNVTETensorCheck(aux_loss),
157158
*convertNVTETensorCheck(Const_buf), stream);
158159
}
159160

transformer_engine/common/include/transformer_engine/fused_router.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ void nvte_fused_scores_for_aux_loss_backward(const NVTETensor intermediate_outpu
9696
*
9797
* \param[in] probs Probabilities from the forward pass.
9898
* \param[in] tokens_per_expert Number of tokens per expert.
99-
* \param[in] num_tokens Number of total tokens.
99+
* \param[in] total_num_tokens Number of total tokens. Will be used in seq/global aux loss.
100+
* \param[in] num_tokens Number of tokens.
100101
* \param[in] num_experts Number of experts.
101102
* \param[in] topk Topk value.
102103
* \param[in] coeff Coefficient.
@@ -105,7 +106,7 @@ void nvte_fused_scores_for_aux_loss_backward(const NVTETensor intermediate_outpu
105106
* \param[in] stream CUDA stream used for the operation.
106107
*/
107108
void nvte_fused_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert,
108-
int num_tokens, int num_experts, int topk, float coeff,
109+
int total_num_tokens, int num_tokens, int num_experts, int topk, float coeff,
109110
NVTETensor aux_loss, NVTETensor Const_buf, cudaStream_t stream);
110111

111112
/*! \brief Backward pass for auxiliary loss.

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ at::Tensor fused_scores_for_aux_loss_bwd(int num_tokens, int num_experts,
3636
int topk, std::string score_function);
3737

3838
std::tuple<at::Tensor, at::Tensor> fused_aux_loss_fwd(at::Tensor probs,
39-
at::Tensor tokens_per_expert, int num_tokens,
40-
int num_experts, int topk, float coeff);
39+
at::Tensor tokens_per_expert, int total_num_tokens,
40+
int num_tokens, int num_experts, int topk, float coeff);
4141

4242
at::Tensor fused_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, int num_tokens,
4343
int num_experts, at::Tensor grad_aux_loss);

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
272272
py::arg("num_experts"), py::arg("intermediate_output"), py::arg("grad_scores"),
273273
py::arg("topk"), py::arg("score_function"), "Fused topk softmax bwd");
274274
m.def("fused_aux_loss_fwd", &transformer_engine::pytorch::fused_aux_loss_fwd, py::arg("probs"),
275-
py::arg("tokens_per_expert"), py::arg("num_tokens"), py::arg("num_experts"),
275+
py::arg("tokens_per_expert"), py::arg("total_num_tokens"), py::arg("num_tokens"), py::arg("num_experts"),
276276
py::arg("topk"), py::arg("coeff"), "Fused aux loss fwd");
277277
m.def("fused_aux_loss_bwd", &transformer_engine::pytorch::fused_aux_loss_bwd,
278278
py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_tokens"),

transformer_engine/pytorch/csrc/extensions/router.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,10 @@ at::Tensor fused_scores_for_aux_loss_bwd(int num_tokens, int num_experts,
145145
}
146146

147147
std::tuple<at::Tensor, at::Tensor> fused_aux_loss_fwd(at::Tensor probs,
148-
at::Tensor tokens_per_expert, int num_tokens,
149-
int num_experts, int topk, float coeff) {
148+
at::Tensor tokens_per_expert, int total_num_tokens,
149+
int num_tokens, int num_experts, int topk, float coeff) {
150150
TORCH_CHECK(topk > 0, "topk must be greater than 0");
151-
TORCH_CHECK(num_tokens > 0, "num_tokens must be greater than 0");
151+
TORCH_CHECK(total_num_tokens > 0, "total_num_tokens must be greater than 0");
152152
TORCH_CHECK(num_experts > 0, "num_experts must be greater than 0");
153153

154154
// Create the output tensor
@@ -160,15 +160,15 @@ std::tuple<at::Tensor, at::Tensor> fused_aux_loss_fwd(at::Tensor probs,
160160
auto aux_loss_cu = makeTransformerEngineTensor(aux_loss);
161161
auto Const_buf_cu = makeTransformerEngineTensor(Const_buf);
162162

163-
nvte_fused_aux_loss_forward(probs_cu.data(), tokens_per_expert_cu.data(), num_tokens, num_experts,
163+
nvte_fused_aux_loss_forward(probs_cu.data(), tokens_per_expert_cu.data(), total_num_tokens, num_tokens, num_experts,
164164
topk, coeff, aux_loss_cu.data(), Const_buf_cu.data(),
165165
at::cuda::getCurrentCUDAStream());
166166

167167
return std::make_tuple(aux_loss, Const_buf);
168168
}
169169

170-
at::Tensor fused_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, int num_tokens,
171-
int num_experts, at::Tensor grad_aux_loss) {
170+
at::Tensor fused_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert,
171+
int num_tokens, int num_experts, at::Tensor grad_aux_loss) {
172172
// Create the output tensor
173173
at::Tensor grad_probs = at::empty({num_tokens, num_experts},
174174
at::dtype(grad_aux_loss.scalar_type()).device(at::kCUDA));

transformer_engine/pytorch/router_func.py renamed to transformer_engine/pytorch/router.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,16 @@ def forward(
125125
ctx,
126126
probs: torch.Tensor,
127127
tokens_per_expert: torch.Tensor,
128-
num_tokens: int,
128+
total_num_tokens: int,
129129
num_experts: int,
130130
topk: int,
131131
coeff: float,
132132
):
133+
num_tokens = probs.size(0)
133134
aux_loss, Const_buf = tex.fused_aux_loss_fwd(
134135
probs=probs,
135136
tokens_per_expert=tokens_per_expert,
137+
total_num_tokens=total_num_tokens,
136138
num_tokens=num_tokens,
137139
num_experts=num_experts,
138140
topk=topk,
@@ -159,9 +161,9 @@ def backward(ctx, grad_aux_loss):
159161
def fused_aux_loss(
160162
probs: torch.Tensor,
161163
tokens_per_expert: torch.Tensor,
162-
num_tokens: int,
164+
total_num_tokens: int,
163165
num_experts: int,
164166
topk: int,
165167
coeff: float,
166168
):
167-
return FusedAuxLoss.apply(probs, tokens_per_expert, num_tokens, num_experts, topk, coeff)
169+
return FusedAuxLoss.apply(probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff)

0 commit comments

Comments
 (0)