Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Propagate fp8 scale-inverse modification to GroupedLinear #1128

Merged
merged 13 commits into from
Sep 9, 2024
Merged
31 changes: 23 additions & 8 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,12 +1266,15 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False
)
inp_hidden_states.retain_grad()

m = config.seq_len // 16
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
m_splits = m_splits * 16
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
if num_gemms > 1:
m = config.seq_len // 16
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
m_splits = m_splits * 16
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
else:
m_splits = torch.tensor([config.seq_len])

with fp8_autocast(enabled=fp8):
if isinstance(block, GroupedLinear):
Expand Down Expand Up @@ -1353,7 +1356,7 @@ def test_grouped_linear_accuracy(

@pytest.mark.parametrize("parallel_mode", ["column", "row"])
def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
"""Split the tests to reduce CI time"""
"""Split the tests to save CI time"""
test_grouped_linear_accuracy(
dtype=torch.float32,
num_gemms=6,
Expand All @@ -1365,6 +1368,18 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
)


def test_grouped_linear_accuracy_single_gemm():
"""Split the tests to save CI time"""
test_grouped_linear_accuracy(
dtype=torch.float32,
num_gemms=1,
bs=2,
model=list(model_configs.keys())[0],
fp8=True,
fp8_model_params=True,
)


def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False):

def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
Expand Down Expand Up @@ -2034,7 +2049,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):

fp8_grouped_gemm(
A_fp8,
scale_inv,
[scale_inv],
0, # A_offset
tex.DType.kFloat8E4M3,
B_fp8,
Expand Down
129 changes: 87 additions & 42 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from ..utils import assert_dim_for_fp8_exec


__all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"]
__all__ = [
"gemm",
"fp8_gemm",
"grouped_gemm",
"fp8_grouped_gemm",
]


@functools.lru_cache(maxsize=None)
Expand Down Expand Up @@ -313,7 +318,7 @@ def grouped_gemm(
layout: str = "TN",
bias: Optional[List[torch.Tensor]] = None,
use_bias: bool = False,
) -> Tuple[Union[List[torch.Tensor], None], ...]:
) -> Tuple[List[torch.Tensor], ...]:
"""Non FP8 Grouped GEMM."""

assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
Expand Down Expand Up @@ -380,7 +385,7 @@ def grouped_gemm(

def fp8_grouped_gemm(
A: List[torch.Tensor],
A_scale_inv: torch.Tensor,
A_scale_inv: List[torch.Tensor],
A_fp8_tensor_offset: int,
A_dtype: tex.DType,
B: List[torch.Tensor],
Expand All @@ -390,6 +395,7 @@ def fp8_grouped_gemm(
out: List[torch.Tensor],
out_dtype: torch.dtype,
workspaces: List[torch.Tensor],
m_splits: Optional[List[int]] = None,
out_offset: Optional[int] = None,
fp8_meta_tensor: tex.FP8TensorMeta = None,
gelu: bool = False,
Expand All @@ -398,16 +404,25 @@ def fp8_grouped_gemm(
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
) -> Tuple[Union[List[torch.Tensor], None], ...]:
) -> Tuple[List[torch.Tensor], ...]:
"""
TN layout Grouped GEMM with fp8 inputs.
This method assumes the scale/scale_inv/amax of A/B/out is contiguous in the meta tensor.
scale: [ ...A_scale... | ...B_scale... | ...out_scale...]
scale_inv: [ ...A_scale_inv... | ...B_scale_inv... | ...out_scale_inv...]
amax: [ ...A_amax... | ...B_amax... | ...out_amax...]
Input requirements:
1. If len(A_scale_inv) == num_gemms, len(out) must be 1, and m_splits is not None.
This is used for the calculation of output (fwd) and dgrad (bwd).
2. if len(A_scale_inv) == 1, len(out) must be num_gemms. This is used for the
calculation of wgrad.
"""

num_gemms = len(A)
if num_gemms > 1 and len(A_scale_inv) == num_gemms:
assert len(out) == 1 and m_splits is not None
elif num_gemms > 1 and len(A_scale_inv) == 1:
assert len(out) == num_gemms
elif num_gemms == 1:
assert len(A_scale_inv) == 1 and len(out) == 1
else:
raise ValueError("Invalid input combinations of A_scale_inv and out.")

empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
Expand All @@ -420,41 +435,71 @@ def fp8_grouped_gemm(

# Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
]
else:
gelu_input = empty_tensors
bias_dtype = TE_DType[bias_dtype]

gelu_input = empty_tensors
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype

torch.ops.tex_ts.te_grouped_gemm_ts(
A,
A_scale_inv,
A_fp8_tensor_offset,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor_offset,
B_dtype,
False, # transb
out,
0 if out_offset is None else out_offset,
empty_tensor if out_offset is None else fp8_meta_tensor.scale,
out_dtype,
empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
bias if use_bias else empty_tensors,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)
if len(A_scale_inv) == 1:
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
]

torch.ops.tex_ts.te_grouped_gemm_ts(
A,
A_scale_inv[0],
A_fp8_tensor_offset,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor_offset,
B_dtype,
False, # transb
out,
0 if out_offset is None else out_offset,
empty_tensor if out_offset is None else fp8_meta_tensor.scale,
out_dtype,
empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
bias if use_bias else empty_tensors,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)
else:
if gelu:
gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits]

torch.ops.tex_ts.te_grouped_gemm_single_output_ts(
A,
A_scale_inv,
A_fp8_tensor_offset,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor_offset,
B_dtype,
False, # transb
m_splits,
out[0],
0 if out_offset is None else out_offset,
empty_tensor if out_offset is None else fp8_meta_tensor.scale,
out_dtype,
empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
bias if use_bias else empty_tensors,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)

return out, gelu_input
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/cpp_extensions/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,15 @@ def fp8_multi_cast_transpose_fused(
amax_indices: List[int],
scale_inv_indices: List[int],
otype: tex.DType,
scale_inv: Optional[torch.Tensor] = None,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Cast + Transpose with FP8 output"""

return tex.fused_multi_cast_transpose_alloc(
input_list,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
scale_inv if scale_inv is not None else fp8_meta_tensor.scale_inv,
scale_indices,
amax_indices,
scale_inv_indices,
Expand Down
10 changes: 10 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@ void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);

void te_grouped_gemm_single_output(
std::vector<at::Tensor> A, std::vector<at::Tensor> A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb,
std::vector<int64_t> m_splits, at::Tensor D, int D_offset, at::Tensor D_scale,
transformer_engine::DType D_type, at::Tensor D_amax, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);

/***************************************************************************************************
* Transpose
**************************************************************************************************/
Expand Down
61 changes: 61 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,64 @@ void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int
te_workspace.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
}

void te_grouped_gemm_single_output(
std::vector<at::Tensor> A, std::vector<at::Tensor> A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb,
std::vector<int64_t> m_splits, at::Tensor D, int D_offset, at::Tensor D_scale,
transformer_engine::DType D_type, at::Tensor D_amax, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count) {
using namespace transformer_engine;
std::vector<NVTETensor> te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape,
transformer_engine::DType dtype, void* amax_dptr,
void* scale_dptr, void* scale_inv_dptr) -> NVTETensor {
tensor_wrappers.emplace_back(
makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr));
return tensor_wrappers.back().data();
};
NVTE_CHECK(D.is_contiguous(), "D must be contiguous.");
void* d_i_ptr = reinterpret_cast<void*>(D.data_ptr());
for (size_t i = 0; i < A.size(); i++) {
if (m_splits[i] == 0) continue;
NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous.");
NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous.");
te_A.emplace_back(make_tensor(
A[i].data_ptr(), {static_cast<size_t>(A[i].size(0)), static_cast<size_t>(A[i].size(1))},
A_type, nullptr, nullptr, getDataPtr(A_scale_inverse[i], A_offset)));
te_B.emplace_back(make_tensor(
B[i].data_ptr(), {static_cast<size_t>(B[i].size(0)), static_cast<size_t>(B[i].size(1))},
B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i)));
te_D.emplace_back(make_tensor(
d_i_ptr, {static_cast<size_t>(m_splits[i]), static_cast<size_t>(A[i].size(0))}, D_type,
getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr));
te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast<size_t>(bias[i].size(0))},
bias_type, nullptr, nullptr, nullptr));

const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0))}
: std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0)),
static_cast<size_t>(pre_gelu_out[i].size(1))};
te_pre_gelu_out.emplace_back(make_tensor(
pre_gelu_out[i].data_ptr(), gelu_shape,
GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr));
// Move the D pointer to the next split.
char* char_ptr = reinterpret_cast<char*>(d_i_ptr);
char_ptr += m_splits[i] * A[i].size(0) * D.element_size();
d_i_ptr = reinterpret_cast<void*>(char_ptr);
}
for (size_t i = 0; i < workspace.size(); i++) {
te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte,
nullptr, nullptr, nullptr));
}

// For now, we only have multi-stream cublas backend.
nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(),
te_pre_gelu_out.data(), te_A.size(), transa, transb, grad,
te_workspace.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
}
36 changes: 36 additions & 0 deletions transformer_engine/pytorch/csrc/ts_fp8_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,41 @@ std::vector<at::Tensor> te_grouped_gemm_ts(
return D;
}

at::Tensor te_grouped_gemm_single_output_ts(
std::vector<at::Tensor> A, std::vector<at::Tensor> A_scale_inverse, int64_t A_offset,
int64_t A_type, int64_t transa, std::vector<at::Tensor> B, at::Tensor B_scale_inverse,
int64_t B_offset, int64_t B_type, int64_t transb, std::vector<int64_t> m_splits, at::Tensor D,
int64_t D_offset, at::Tensor D_scale, int64_t D_type, at::Tensor D_amax,
std::vector<at::Tensor> bias, int64_t bias_type, std::vector<at::Tensor> pre_gelu_out,
int64_t grad, std::vector<at::Tensor> workspace, int64_t workspaceSize, int64_t accumulate,
int64_t use_split_accumulator) {
// cast inputs to types accepted by te_gemm
transformer_engine::DType A_type_arg = reverse_map_dtype(A_type);
bool transa_arg = static_cast<bool>(transa);
transformer_engine::DType B_type_arg = reverse_map_dtype(B_type);
bool transb_arg = static_cast<bool>(transb);
transformer_engine::DType D_type_arg = reverse_map_dtype(D_type);
transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type);
bool grad_arg = static_cast<bool>(grad);
size_t workspaceSize_arg = static_cast<size_t>(workspaceSize);
bool accumulate_arg = static_cast<bool>(accumulate);
bool use_split_accumulator_arg = static_cast<bool>(use_split_accumulator);

// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs

const int device_id = at::cuda::current_device();
const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);

te_grouped_gemm_single_output(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B,
B_scale_inverse, B_offset, B_type_arg, transb_arg, m_splits, D,
D_offset, D_scale, D_type_arg, D_amax, bias, bias_type_arg,
pre_gelu_out, grad_arg, workspace, workspaceSize_arg,
accumulate_arg, use_split_accumulator_arg, num_math_sms);
return D;
}

at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight,
const at::Tensor &bias, double eps, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor,
Expand Down Expand Up @@ -371,6 +406,7 @@ TORCH_LIBRARY(tex_ts, m) {
m.def("srelu_ts", &srelu_ts);
m.def("te_gemm_ts", &te_gemm_ts);
m.def("te_grouped_gemm_ts", &te_grouped_gemm_ts);
m.def("te_grouped_gemm_single_output_ts", &te_grouped_gemm_single_output_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts);
Expand Down
Loading
Loading