Skip to content

Commit

Permalink
optimization for separate scale_inv of weights and single output
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Yao <[email protected]>
  • Loading branch information
yaox12 committed Aug 26, 2024
1 parent c6aff8b commit 117aa98
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 11 deletions.
91 changes: 90 additions & 1 deletion transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
from ..utils import assert_dim_for_fp8_exec


__all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"]
__all__ = [
"gemm",
"fp8_gemm",
"grouped_gemm",
"fp8_grouped_gemm",
"fp8_grouped_gemm_single_output",
]


@functools.lru_cache(maxsize=None)
Expand Down Expand Up @@ -458,3 +464,86 @@ def fp8_grouped_gemm(
)

return out, gelu_input


def fp8_grouped_gemm_single_output(
A: List[torch.Tensor],
A_scale_inv: List[torch.Tensor],
A_fp8_tensor_offset: int,
A_dtype: tex.DType,
B: List[torch.Tensor],
B_scale_inv: torch.Tensor,
B_fp8_tensor_offset: int,
B_dtype: tex.DType,
m_splits: List[int],
out: torch.Tensor,
out_dtype: torch.dtype,
workspaces: List[torch.Tensor],
out_offset: Optional[int] = None,
fp8_meta_tensor: tex.FP8TensorMeta = None,
gelu: bool = False,
accumulate: bool = False,
bias: Optional[List[torch.Tensor]] = None,
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
) -> Tuple[Union[List[torch.Tensor], None], ...]:
"""
TN layout Grouped GEMM with two lists of fp8 inputs, and a single contiguous output implicitly
splitted by m_splits.
This method assumes the scale_inv of A is a list of tensors.
Used for the calculation of output (fwd) and dgrad (bwd).
"""

num_gemms = len(A)
empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
assert fp8_meta_tensor is not None and out_offset is not None
for a, b in zip(A, B):
assert_dim_for_fp8_exec(a)
assert_dim_for_fp8_exec(b)
assert A[0].dtype == torch.uint8
assert B[0].dtype == torch.uint8

# Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype
if gelu:
gelu_input = [
torch.empty((m, A[0].size(0)), dtype=bias_dtype)
for m in m_splits
]
else:
gelu_input = empty_tensors
bias_dtype = TE_DType[bias_dtype]

out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype

torch.ops.tex_ts.te_grouped_gemm_single_output_ts(
A,
A_scale_inv,
A_fp8_tensor_offset,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor_offset,
B_dtype,
False, # transb
m_splits,
out,
0 if out_offset is None else out_offset,
empty_tensor if out_offset is None else fp8_meta_tensor.scale,
out_dtype,
empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
bias if use_bias else empty_tensors,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)

return out, gelu_input
10 changes: 10 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);

void te_grouped_gemm_single_output(
std::vector<at::Tensor> A, std::vector<at::Tensor> A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb,
std::vector<int64_t> m_splits, at::Tensor D, int D_offset, at::Tensor D_scale,
transformer_engine::DType D_type, at::Tensor D_amax, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);

/***************************************************************************************************
* Transpose
**************************************************************************************************/
Expand Down
60 changes: 60 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,63 @@ void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int
te_workspace.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
}


void te_grouped_gemm_single_output(
std::vector<at::Tensor> A, std::vector<at::Tensor> A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb,
std::vector<int64_t> m_splits, at::Tensor D, int D_offset, at::Tensor D_scale,
transformer_engine::DType D_type, at::Tensor D_amax, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count) {
using namespace transformer_engine;
std::vector<NVTETensor> te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape,
transformer_engine::DType dtype, void* amax_dptr,
void* scale_dptr, void* scale_inv_dptr) -> NVTETensor {
tensor_wrappers.emplace_back(
makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr));
return tensor_wrappers.back().data();
};
void* d_i_ptr = reinterpret_cast<void*>(D.data_ptr());
for (size_t i = 0; i < A.size(); i++) {
if (m_splits[i] == 0)
continue;
te_A.emplace_back(make_tensor(
A[i].data_ptr(), {static_cast<size_t>(A[i].size(0)), static_cast<size_t>(A[i].size(1))},
A_type, nullptr, nullptr, getDataPtr(A_scale_inverse[i], A_offset)));
te_B.emplace_back(make_tensor(
B[i].data_ptr(), {static_cast<size_t>(B[i].size(0)), static_cast<size_t>(B[i].size(1))},
B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i)));
te_D.emplace_back(make_tensor(
d_i_ptr, {static_cast<size_t>(m_splits[i]), static_cast<size_t>(A[i].size(0))},
D_type, getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr));
te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast<size_t>(bias[i].size(0))},
bias_type, nullptr, nullptr, nullptr));

const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0))}
: std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0)),
static_cast<size_t>(pre_gelu_out[i].size(1))};
te_pre_gelu_out.emplace_back(make_tensor(
pre_gelu_out[i].data_ptr(), gelu_shape,
GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr));
// Move the D pointer to the next split.
char* char_ptr = reinterpret_cast<char*>(d_i_ptr);
char_ptr += m_splits[i] * A[i].size(0) * D.element_size();
d_i_ptr = reinterpret_cast<void*>(char_ptr);
}
for (size_t i = 0; i < workspace.size(); i++) {
te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte,
nullptr, nullptr, nullptr));
}

// For now, we only have multi-stream cublas backend.
nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(),
te_pre_gelu_out.data(), te_A.size(), transa, transb, grad,
te_workspace.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
}
34 changes: 34 additions & 0 deletions transformer_engine/pytorch/csrc/ts_fp8_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,39 @@ std::vector<at::Tensor> te_grouped_gemm_ts(
return D;
}

at::Tensor te_grouped_gemm_single_output_ts(
std::vector<at::Tensor> A, std::vector<at::Tensor> A_scale_inverse, int64_t A_offset, int64_t A_type,
int64_t transa, std::vector<at::Tensor> B, at::Tensor B_scale_inverse, int64_t B_offset,
int64_t B_type, int64_t transb, std::vector<int64_t> m_splits, at::Tensor D, int64_t D_offset, at::Tensor D_scale,
int64_t D_type, at::Tensor D_amax, std::vector<at::Tensor> bias, int64_t bias_type,
std::vector<at::Tensor> pre_gelu_out, int64_t grad, std::vector<at::Tensor> workspace,
int64_t workspaceSize, int64_t accumulate, int64_t use_split_accumulator) {
// cast inputs to types accepted by te_gemm
transformer_engine::DType A_type_arg = reverse_map_dtype(A_type);
bool transa_arg = static_cast<bool>(transa);
transformer_engine::DType B_type_arg = reverse_map_dtype(B_type);
bool transb_arg = static_cast<bool>(transb);
transformer_engine::DType D_type_arg = reverse_map_dtype(D_type);
transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type);
bool grad_arg = static_cast<bool>(grad);
size_t workspaceSize_arg = static_cast<size_t>(workspaceSize);
bool accumulate_arg = static_cast<bool>(accumulate);
bool use_split_accumulator_arg = static_cast<bool>(use_split_accumulator);

// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs

const int device_id = at::cuda::current_device();
const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);

te_grouped_gemm_single_output(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, B_scale_inverse,
B_offset, B_type_arg, transb_arg, m_splits, D, D_offset, D_scale, D_type_arg, D_amax, bias,
bias_type_arg, pre_gelu_out, grad_arg, workspace, workspaceSize_arg,
accumulate_arg, use_split_accumulator_arg, num_math_sms);
return D;
}

at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight,
const at::Tensor &bias, double eps, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor,
Expand Down Expand Up @@ -371,6 +404,7 @@ TORCH_LIBRARY(tex_ts, m) {
m.def("srelu_ts", &srelu_ts);
m.def("te_gemm_ts", &te_gemm_ts);
m.def("te_grouped_gemm_ts", &te_grouped_gemm_ts);
m.def("te_grouped_gemm_single_output_ts", &te_grouped_gemm_single_output_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts);
Expand Down
19 changes: 9 additions & 10 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
fp8_cast_transpose_bgrad_fused,
fp8_multi_cast_transpose_fused,
fp8_grouped_gemm,
fp8_grouped_gemm_single_output,
grouped_gemm,
)
from ..constants import GemmParallelModes, dist_group_type
Expand Down Expand Up @@ -168,18 +169,17 @@ def forward(
device=inputmats[0].device,
)

_ = fp8_grouped_gemm(
_ = fp8_grouped_gemm_single_output(
[w._data for w in weights_fp8],
torch.cat(
[w._scale_inv for w in weights_fp8]
), # avoiding torch.cat requires another interface,
[w._scale_inv for w in weights_fp8],
0, # weight offset is 0 for the newly created _scale_inv
fp8_dtype_forward,
inputmats,
inputmat_scale_inv,
0,
fp8_dtype_forward,
torch.split(out, m_splits),
m_splits,
out,
activation_dtype,
get_multi_stream_cublas_workspace(),
bias=biases,
Expand Down Expand Up @@ -359,18 +359,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
dtype=ctx.activation_dtype,
device=grad_output.device,
)
fp8_grouped_gemm(
fp8_grouped_gemm_single_output(
[w.transpose_2d() for w in weights_fp8],
torch.cat(
[w._scale_inv for w in weights_fp8]
), # avoiding torch.cat requires another interface
[w._scale_inv for w in weights_fp8],
0, # weight offset is 0 for the newly created _scale_inv
weights_fp8[0]._fp8_dtype,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
_GRAD_OUTPUT,
fp8_dtype_backward,
torch.split(dgrad, ctx.m_splits),
ctx.m_splits,
dgrad,
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
Expand Down

0 comments on commit 117aa98

Please sign in to comment.