Skip to content

Commit

Permalink
[MoE][Common/PyTorch] Add permutation (#936)
Browse files Browse the repository at this point in the history
* Add permutation functions

* Add permutation ops

* Remove the dependency on cutlass

* Move permutation.py out of module dir

* Rewrite the unit test and enable skipping if FP8 is unavailable

* Rename exposed C++ API and reorder its parameters + take NVTETensor as inputs

* Use Float8Tensor for FP8 input

* Move dtype to ctx

---------

Signed-off-by: Jiang Shao <[email protected]>
Co-authored-by: Qi Zhang <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>
  • Loading branch information
3 people authored Aug 22, 2024
1 parent 47caafb commit a335374
Show file tree
Hide file tree
Showing 11 changed files with 1,394 additions and 1 deletion.
515 changes: 515 additions & 0 deletions tests/pytorch/test_permutation.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ list(APPEND transformer_engine_SOURCES
layer_norm/ln_api.cpp
layer_norm/ln_bwd_semi_cuda_kernel.cu
layer_norm/ln_fwd_cuda_kernel.cu
permutation/permutation.cu
rmsnorm/rmsnorm_api.cpp
rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
rmsnorm/rmsnorm_fwd_cuda_kernel.cu
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
"Unable to find suitable cuBLAS GEMM algorithm");
NVTE_CHECK_CUBLAS(status);

if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms");
if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms");

// D = alpha * (A * B) + beta * C
NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc,
Expand Down
21 changes: 21 additions & 0 deletions transformer_engine/common/include/transformer_engine/permutation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#ifndef TRANSFORMER_ENGINE_PERMUTATION_H_
#define TRANSFORMER_ENGINE_PERMUTATION_H_

#include "transformer_engine.h"

void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor sorted_row_id,
NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad,
const NVTETensor input_fwd, const int num_rows, const int topK,
const int num_cols, const int num_out_tokens, cudaStream_t stream = nullptr);

void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map,
const NVTETensor prob, const int num_rows, const int topK, const int num_cols,
cudaStream_t stream = nullptr);

#endif // TRANSFORMER_ENGINE_PERMUTATION_H_
Loading

0 comments on commit a335374

Please sign in to comment.