Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Propagate fp8 scale-inverse modification to GroupedLinear #1128

Merged
merged 13 commits into from
Sep 9, 2024
Merged
130 changes: 87 additions & 43 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from ..utils import assert_dim_for_fp8_exec


__all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"]
__all__ = [
"gemm",
"fp8_gemm",
"grouped_gemm",
"fp8_grouped_gemm",
]


@functools.lru_cache(maxsize=None)
Expand Down Expand Up @@ -380,16 +385,17 @@ def grouped_gemm(

def fp8_grouped_gemm(
A: List[torch.Tensor],
A_scale_inv: torch.Tensor,
A_scale_inv: Union[torch.Tensor, List[torch.Tensor]],
A_fp8_tensor_offset: int,
A_dtype: tex.DType,
B: List[torch.Tensor],
B_scale_inv: torch.Tensor,
B_fp8_tensor_offset: int,
B_dtype: tex.DType,
out: List[torch.Tensor],
out: Union[torch.Tensor, List[torch.Tensor]],
yaox12 marked this conversation as resolved.
Show resolved Hide resolved
out_dtype: torch.dtype,
workspaces: List[torch.Tensor],
m_splits: Optional[List[int]] = None,
out_offset: Optional[int] = None,
fp8_meta_tensor: tex.FP8TensorMeta = None,
gelu: bool = False,
Expand All @@ -398,14 +404,21 @@ def fp8_grouped_gemm(
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
) -> Tuple[Union[List[torch.Tensor], None], ...]:
) -> Tuple[Union[List[torch.Tensor], torch.Tensor, None], ...]:
"""
TN layout Grouped GEMM with fp8 inputs.
This method assumes the scale/scale_inv/amax of A/B/out is contiguous in the meta tensor.
scale: [ ...A_scale... | ...B_scale... | ...out_scale...]
scale_inv: [ ...A_scale_inv... | ...B_scale_inv... | ...out_scale_inv...]
amax: [ ...A_amax... | ...B_amax... | ...out_amax...]
This function accepts two combinations of inputs:
1. A_scale_inv is a list of tensors, out is a single tensor, and m_splits is not None.
This is used for the calculation of output (fwd) and dgrad (bwd).
2. A_scale_inv is a single tensor, out is a list of tensors. This is used for the
calculation of wgrad.
"""
if isinstance(A_scale_inv, list):
assert isinstance(out, torch.Tensor) and m_splits is not None
elif isinstance(A_scale_inv, torch.Tensor):
assert isinstance(out, (list, tuple))
else:
raise ValueError("A_scale_inv should be a list of tensors or a single tensor.")

num_gemms = len(A)
empty_tensor = _empty_tensor()
Expand All @@ -420,41 +433,72 @@ def fp8_grouped_gemm(

# Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
]
else:
gelu_input = empty_tensors
bias_dtype = TE_DType[bias_dtype]

out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype

torch.ops.tex_ts.te_grouped_gemm_ts(
A,
A_scale_inv,
A_fp8_tensor_offset,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor_offset,
B_dtype,
False, # transb
out,
0 if out_offset is None else out_offset,
empty_tensor if out_offset is None else fp8_meta_tensor.scale,
out_dtype,
empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
bias if use_bias else empty_tensors,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)
gelu_input = empty_tensors

if not isinstance(out, torch.Tensor):
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
]
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype

torch.ops.tex_ts.te_grouped_gemm_ts(
A,
A_scale_inv,
A_fp8_tensor_offset,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor_offset,
B_dtype,
False, # transb
out,
0 if out_offset is None else out_offset,
empty_tensor if out_offset is None else fp8_meta_tensor.scale,
out_dtype,
empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
bias if use_bias else empty_tensors,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)
else:
if gelu:
gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits]
out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype

torch.ops.tex_ts.te_grouped_gemm_single_output_ts(
A,
A_scale_inv,
A_fp8_tensor_offset,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor_offset,
B_dtype,
False, # transb
m_splits,
out,
0 if out_offset is None else out_offset,
empty_tensor if out_offset is None else fp8_meta_tensor.scale,
out_dtype,
empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
bias if use_bias else empty_tensors,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)

return out, gelu_input
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/cpp_extensions/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,15 @@ def fp8_multi_cast_transpose_fused(
amax_indices: List[int],
scale_inv_indices: List[int],
otype: tex.DType,
scale_inv: Optional[torch.Tensor] = None,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Cast + Transpose with FP8 output"""

return tex.fused_multi_cast_transpose_alloc(
input_list,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
scale_inv if scale_inv is not None else fp8_meta_tensor.scale_inv,
scale_indices,
amax_indices,
scale_inv_indices,
Expand Down
10 changes: 10 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);

void te_grouped_gemm_single_output(
std::vector<at::Tensor> A, std::vector<at::Tensor> A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb,
std::vector<int64_t> m_splits, at::Tensor D, int D_offset, at::Tensor D_scale,
transformer_engine::DType D_type, at::Tensor D_amax, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);

/***************************************************************************************************
* Transpose
**************************************************************************************************/
Expand Down
61 changes: 61 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,64 @@ void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int
te_workspace.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
}

void te_grouped_gemm_single_output(
std::vector<at::Tensor> A, std::vector<at::Tensor> A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb,
std::vector<int64_t> m_splits, at::Tensor D, int D_offset, at::Tensor D_scale,
transformer_engine::DType D_type, at::Tensor D_amax, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count) {
using namespace transformer_engine;
std::vector<NVTETensor> te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape,
transformer_engine::DType dtype, void* amax_dptr,
void* scale_dptr, void* scale_inv_dptr) -> NVTETensor {
tensor_wrappers.emplace_back(
makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr));
return tensor_wrappers.back().data();
};
NVTE_CHECK(D.is_contiguous(), "D must be contiguous.");
void* d_i_ptr = reinterpret_cast<void*>(D.data_ptr());
for (size_t i = 0; i < A.size(); i++) {
if (m_splits[i] == 0) continue;
NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous.");
NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous.");
te_A.emplace_back(make_tensor(
A[i].data_ptr(), {static_cast<size_t>(A[i].size(0)), static_cast<size_t>(A[i].size(1))},
A_type, nullptr, nullptr, getDataPtr(A_scale_inverse[i], A_offset)));
te_B.emplace_back(make_tensor(
B[i].data_ptr(), {static_cast<size_t>(B[i].size(0)), static_cast<size_t>(B[i].size(1))},
B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i)));
te_D.emplace_back(make_tensor(
d_i_ptr, {static_cast<size_t>(m_splits[i]), static_cast<size_t>(A[i].size(0))}, D_type,
getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr));
te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast<size_t>(bias[i].size(0))},
bias_type, nullptr, nullptr, nullptr));

const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0))}
: std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0)),
static_cast<size_t>(pre_gelu_out[i].size(1))};
te_pre_gelu_out.emplace_back(make_tensor(
pre_gelu_out[i].data_ptr(), gelu_shape,
GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr));
// Move the D pointer to the next split.
char* char_ptr = reinterpret_cast<char*>(d_i_ptr);
char_ptr += m_splits[i] * A[i].size(0) * D.element_size();
d_i_ptr = reinterpret_cast<void*>(char_ptr);
}
for (size_t i = 0; i < workspace.size(); i++) {
te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte,
nullptr, nullptr, nullptr));
}

// For now, we only have multi-stream cublas backend.
nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(),
te_pre_gelu_out.data(), te_A.size(), transa, transb, grad,
te_workspace.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
}
36 changes: 36 additions & 0 deletions transformer_engine/pytorch/csrc/ts_fp8_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,41 @@ std::vector<at::Tensor> te_grouped_gemm_ts(
return D;
}

at::Tensor te_grouped_gemm_single_output_ts(
std::vector<at::Tensor> A, std::vector<at::Tensor> A_scale_inverse, int64_t A_offset,
int64_t A_type, int64_t transa, std::vector<at::Tensor> B, at::Tensor B_scale_inverse,
int64_t B_offset, int64_t B_type, int64_t transb, std::vector<int64_t> m_splits, at::Tensor D,
int64_t D_offset, at::Tensor D_scale, int64_t D_type, at::Tensor D_amax,
std::vector<at::Tensor> bias, int64_t bias_type, std::vector<at::Tensor> pre_gelu_out,
int64_t grad, std::vector<at::Tensor> workspace, int64_t workspaceSize, int64_t accumulate,
int64_t use_split_accumulator) {
// cast inputs to types accepted by te_gemm
transformer_engine::DType A_type_arg = reverse_map_dtype(A_type);
bool transa_arg = static_cast<bool>(transa);
transformer_engine::DType B_type_arg = reverse_map_dtype(B_type);
bool transb_arg = static_cast<bool>(transb);
transformer_engine::DType D_type_arg = reverse_map_dtype(D_type);
transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type);
bool grad_arg = static_cast<bool>(grad);
size_t workspaceSize_arg = static_cast<size_t>(workspaceSize);
bool accumulate_arg = static_cast<bool>(accumulate);
bool use_split_accumulator_arg = static_cast<bool>(use_split_accumulator);

// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs

const int device_id = at::cuda::current_device();
const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);

te_grouped_gemm_single_output(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B,
B_scale_inverse, B_offset, B_type_arg, transb_arg, m_splits, D,
D_offset, D_scale, D_type_arg, D_amax, bias, bias_type_arg,
pre_gelu_out, grad_arg, workspace, workspaceSize_arg,
accumulate_arg, use_split_accumulator_arg, num_math_sms);
return D;
}

at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight,
const at::Tensor &bias, double eps, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor,
Expand Down Expand Up @@ -371,6 +406,7 @@ TORCH_LIBRARY(tex_ts, m) {
m.def("srelu_ts", &srelu_ts);
m.def("te_gemm_ts", &te_gemm_ts);
m.def("te_grouped_gemm_ts", &te_grouped_gemm_ts);
m.def("te_grouped_gemm_single_output_ts", &te_grouped_gemm_single_output_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts);
Expand Down
Loading
Loading