Skip to content

Commit

Permalink
Atomic gemm and FP8 Reduce Scatter (#449)
Browse files Browse the repository at this point in the history
* Initial commit

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Repro for RS output mismatch with Single GEMM + Split pipelined RS

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* minor changes for AG->GEMM pipelined overlap

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Add Atomic Gemm cublasApi attributes and initial implementation of AG->Atomic GEMM

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* AtomicGemm+RS functional with workaround

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* add amax update to layernorm_linear for FP8 unit test accuracy

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Enable reducescatter2_userbuff_strided variants

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Bug fix

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* AG+AtomicGemm overlap functional but gemm doesnt overlap with comm

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Add userbuffers_sendrecv kernel variants

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* TransformerLayer API changes to enable AtomicGemm+RS overlap

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Code cleanup

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Code cleanup2

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* [UB] AllGather Atomic GEMM overlap using userbuffer_sendrecv kernels

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Code cleanup + bug fix for multiatomic sendrecv kernel

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* cleanup

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Bug fixes

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* [UB] Add shuffling for better AG AtomicGEMM overlap

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Bug fix for AG AtomicGemm overlap

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Bug fix for multiAtomicAG and singleAtomicAG

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Use chunk_i+1 as recv_chunk for multiatomic_AG with shuffling

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Launch AtomicGEMM after first-chunk AG

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Rebase to main

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Add FP8 ReduceScatter kernels, AtomicGEMM+FP8 RS not functional

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Revert "Add FP8 ReduceScatter kernels, AtomicGEMM+FP8 RS not functional"

This reverts commit 80a47a7.

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Add support for NVLS-MC and FP8 Reduce Scatter

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Bug fix

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Atomic and Multiatomic FP8 RS functional

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Remove debug print

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* UB comm initialization hang fix

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Code cleanup

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Create new GEMM API for Atomic GEMM

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* CI ready

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* more fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* license

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Bug fix

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Revert NVLS-MC

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Check cu* versions for running atomic gemms

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* lint

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Cleanup

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Add experimental warning

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Better wording

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Add warning to c api

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix wording

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Vasudevan Rengasamy <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
vasunvidia and ksivaman authored Oct 5, 2023
1 parent be67f21 commit 958e188
Show file tree
Hide file tree
Showing 17 changed files with 3,619 additions and 702 deletions.
4 changes: 2 additions & 2 deletions tests/pytorch/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def forward(self, inp, weight):
self.fp8_tensor_weight,
self.weights_type)

ret = fp8_gemm(
ret, _ = fp8_gemm(
weight_fp8,
self.meta_weight.scale_inv,
self.fp8_tensor_weight,
Expand Down Expand Up @@ -1323,7 +1323,7 @@ def forward(self, inp, weight):
self.fp8_tensor_weight,
self.weights_type)

ret = fp8_gemm(
ret, _ = fp8_gemm(
weight_fp8,
self.meta_weight.scale_inv,
self.fp8_tensor_weight,
Expand Down
108 changes: 107 additions & 1 deletion transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/logging.h>
#include <transformer_engine/gemm.h>
#include <cuda.h>
#include <cublasLt.h>
#include <cublas_v2.h>
#include "../common.h"
Expand Down Expand Up @@ -50,6 +51,10 @@ void cublas_gemm(const Tensor *inputA,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const Tensor *inputCounter,
cudaStream_t stream
) {
void *A = inputA->data.dptr;
Expand All @@ -63,6 +68,10 @@ void cublas_gemm(const Tensor *inputA,
void *bias_ptr = inputBias->data.dptr;
const bool bias = bias_ptr != nullptr;
void *pre_gelu_out = outputPreGelu->data.dptr;
void *counter = nullptr;
if (inputCounter != nullptr) {
counter = inputCounter->data.dptr;
}
const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype);
Expand Down Expand Up @@ -223,6 +232,27 @@ void cublas_gemm(const Tensor *inputA,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, sizeof(epilogue)));
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205
if (counter != nullptr) {
if (m_split == 0) m_split=1;
if (n_split == 0) n_split=1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS,
&m_split, sizeof(m_split)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS,
&n_split, sizeof(n_split)));
if (gemm_producer) {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER,
&counter, sizeof(counter)));
} else {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER,
&counter, sizeof(counter)));
}
}
#endif

NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
Expand Down Expand Up @@ -254,7 +284,6 @@ void cublas_gemm(const Tensor *inputA,
workspaceSize,
stream)); /* stream */


NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
Expand Down Expand Up @@ -320,5 +349,82 @@ void nvte_cublas_gemm(const NVTETensor A,
wspace->data.shape[0],
accumulate, use_split_accumulator,
math_sm_count,
0,
0,
false,
nullptr,
stream);
}

void nvte_cublas_atomic_gemm(const NVTETensor A,
const NVTETensor B,
NVTETensor D,
const NVTETensor bias,
NVTETensor pre_gelu_out,
bool transa,
bool transb,
bool grad,
NVTETensor workspace,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const NVTETensor counter,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_atomic_gemm);

int cudart_version;
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cudart_version));
NVTE_CHECK(cudart_version >= 12020, "Cuda version 12.2 is required for atomic gemm.");
NVTE_CHECK(cublasLtGetVersion() >= 120205, "Cublas version 12.2.5 is required for atomic gemm.");

using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor*>(A);
const Tensor *inputB = reinterpret_cast<const Tensor*>(B);
Tensor *outputD = reinterpret_cast<Tensor*>(D);
const Tensor *biasTensor = reinterpret_cast<const Tensor*>(bias);
Tensor *outputGelu = reinterpret_cast<Tensor*>(pre_gelu_out);
const Tensor *inputCounter = reinterpret_cast<const Tensor*>(counter);
Tensor *wspace = reinterpret_cast<Tensor*>(workspace);

const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
int lda, ldb, ldd;
if (transa && !transb) { // TN
lda = k;
ldb = k;
ldd = m;
} else if (!transa && !transb) { // NN
lda = m;
ldb = k;
ldd = m;
} else if (!transa && transb) { // NT
lda = m;
ldb = n;
ldd = m;
} else { // TT
NVTE_ERROR("TT layout not allowed.");
}

cublas_gemm(inputA,
inputB,
outputD,
biasTensor,
outputGelu,
m, n, k,
lda, ldb, ldd,
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
grad, wspace->data.dptr,
wspace->data.shape[0],
accumulate, use_split_accumulator,
math_sm_count,
m_split,
n_split,
gemm_producer,
inputCounter,
stream);
}
46 changes: 46 additions & 0 deletions transformer_engine/common/include/transformer_engine/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,52 @@ void nvte_cublas_gemm(const NVTETensor A,
cudaStream_t stream
);

/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
*
* \warning Cublas atomic gemm uses a beta API and is not tested for all use cases.
*
* Computes:
* - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors
* - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty
* - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors
*
* \param[in] A The A matrix.
* \param[in] B The B matrix.
* \param[in,out] D Output matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_gelu_out Output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of the
* gradient computation.
* \param[out] workspace Workspace tensor.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM.
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] m_split Number of chunks/splits along m-dimension for Atomic GEMM.
* \param[in] n_split Number of chunks/splits along n-dimension for Atomic GEMM.
* \param[in] gemm_producer Whether Atomic GEMM is the producer or consumer.
* \param[in,out] counter counter[chunk_i]=0 indicates chunk_i has been produced.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cublas_atomic_gemm(const NVTETensor A,
const NVTETensor B,
NVTETensor D,
const NVTETensor bias,
NVTETensor pre_gelu_out,
bool transa,
bool transb,
bool grad,
NVTETensor workspace,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const NVTETensor counter,
cudaStream_t stream
);
#ifdef __cplusplus
} // extern "C"
#endif
Expand Down
6 changes: 6 additions & 0 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2262,6 +2262,8 @@ def __init__(
ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
ub_atomic_gemm_rs: bool = False,
ub_atomic_gemm_ag: bool = False,
bias: bool = True,
normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda",
Expand Down Expand Up @@ -2342,6 +2344,7 @@ def __init__(
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
**common_gemm_kwargs,
)
else:
Expand Down Expand Up @@ -2372,6 +2375,7 @@ def __init__(
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
**common_gemm_kwargs,
)
else:
Expand Down Expand Up @@ -2418,6 +2422,8 @@ def __init__(
parallel_mode="row" if set_parallel_mode else None,
ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag,
ub_atomic_gemm_rs=ub_atomic_gemm_rs,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
**common_gemm_kwargs,
)

Expand Down
26 changes: 22 additions & 4 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,22 +91,40 @@ def fp8_gemm(
assert ub is not None, 'ub object is None!'
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap
args = tuple(args + (1,))
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (1, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
args = tuple(args + (0,))
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (0, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG:
fn = ub.atomic_gemm_overlap_ag
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
fn = ub.split_overlap_rs
assert (
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS:
fn = ub.atomic_gemm_overlap_rs
assert (
extra_output_tensor is not None
), 'ATOMIC_GEMM_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,))
_ = fn(*args)

return out, gelu_input
Expand Down Expand Up @@ -195,10 +213,10 @@ def gemm(
assert ub is not None, 'ub object is None!'
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap
args = tuple(args + (1,))
args = tuple(args + (1, empty_tensor))
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
args = tuple(args + (0,))
args = tuple(args + (0, empty_tensor))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag
extra_output_tensor = (
Expand Down
Loading

0 comments on commit 958e188

Please sign in to comment.