Skip to content

Commit

Permalink
[PyTorch] Propagate fp8 scale-inverse modification to GroupedLinear (
Browse files Browse the repository at this point in the history
…#1128)

* propagate scale_inv modification to GroupedLinear

Signed-off-by: Xin Yao <[email protected]>

* optimization for separate scale_inv of weights and single output

Signed-off-by: Xin Yao <[email protected]>

* let grouped gemm support different input combinations

Signed-off-by: Xin Yao <[email protected]>

* fix type

Signed-off-by: Xin Yao <[email protected]>

* add contiguous check

Signed-off-by: Xin Yao <[email protected]>

* use len() instead of isinstance

Signed-off-by: Xin Yao <[email protected]>

* fix ut

Signed-off-by: Xin Yao <[email protected]>

---------

Signed-off-by: Xin Yao <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
yaox12 and ksivaman committed Sep 9, 2024
1 parent bdea56f commit 047a507
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 64 deletions.
31 changes: 23 additions & 8 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,12 +1266,15 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False
)
inp_hidden_states.retain_grad()

m = config.seq_len // 16
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
m_splits = m_splits * 16
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
if num_gemms > 1:
m = config.seq_len // 16
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
m_splits = m_splits * 16
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
else:
m_splits = torch.tensor([config.seq_len])

with fp8_autocast(enabled=fp8):
if isinstance(block, GroupedLinear):
Expand Down Expand Up @@ -1353,7 +1356,7 @@ def test_grouped_linear_accuracy(

@pytest.mark.parametrize("parallel_mode", ["column", "row"])
def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
"""Split the tests to reduce CI time"""
"""Split the tests to save CI time"""
test_grouped_linear_accuracy(
dtype=torch.float32,
num_gemms=6,
Expand All @@ -1365,6 +1368,18 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
)


def test_grouped_linear_accuracy_single_gemm():
"""Split the tests to save CI time"""
test_grouped_linear_accuracy(
dtype=torch.float32,
num_gemms=1,
bs=2,
model=list(model_configs.keys())[0],
fp8=True,
fp8_model_params=True,
)


def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False):

def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
Expand Down Expand Up @@ -2034,7 +2049,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):

fp8_grouped_gemm(
A_fp8,
scale_inv,
[scale_inv],
0, # A_offset
tex.DType.kFloat8E4M3,
B_fp8,
Expand Down
129 changes: 87 additions & 42 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from ..utils import assert_dim_for_fp8_exec


__all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"]
__all__ = [
"gemm",
"fp8_gemm",
"grouped_gemm",
"fp8_grouped_gemm",
]


@functools.lru_cache(maxsize=None)
Expand Down Expand Up @@ -313,7 +318,7 @@ def grouped_gemm(
layout: str = "TN",
bias: Optional[List[torch.Tensor]] = None,
use_bias: bool = False,
) -> Tuple[Union[List[torch.Tensor], None], ...]:
) -> Tuple[List[torch.Tensor], ...]:
"""Non FP8 Grouped GEMM."""

assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
Expand Down Expand Up @@ -380,7 +385,7 @@ def grouped_gemm(

def fp8_grouped_gemm(
A: List[torch.Tensor],
A_scale_inv: torch.Tensor,
A_scale_inv: List[torch.Tensor],
A_fp8_tensor_offset: int,
A_dtype: tex.DType,
B: List[torch.Tensor],
Expand All @@ -390,6 +395,7 @@ def fp8_grouped_gemm(
out: List[torch.Tensor],
out_dtype: torch.dtype,
workspaces: List[torch.Tensor],
m_splits: Optional[List[int]] = None,
out_offset: Optional[int] = None,
fp8_meta_tensor: tex.FP8TensorMeta = None,
gelu: bool = False,
Expand All @@ -398,16 +404,25 @@ def fp8_grouped_gemm(
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
) -> Tuple[Union[List[torch.Tensor], None], ...]:
) -> Tuple[List[torch.Tensor], ...]:
"""
TN layout Grouped GEMM with fp8 inputs.
This method assumes the scale/scale_inv/amax of A/B/out is contiguous in the meta tensor.
scale: [ ...A_scale... | ...B_scale... | ...out_scale...]
scale_inv: [ ...A_scale_inv... | ...B_scale_inv... | ...out_scale_inv...]
amax: [ ...A_amax... | ...B_amax... | ...out_amax...]
Input requirements:
1. If len(A_scale_inv) == num_gemms, len(out) must be 1, and m_splits is not None.
This is used for the calculation of output (fwd) and dgrad (bwd).
2. if len(A_scale_inv) == 1, len(out) must be num_gemms. This is used for the
calculation of wgrad.
"""

num_gemms = len(A)
if num_gemms > 1 and len(A_scale_inv) == num_gemms:
assert len(out) == 1 and m_splits is not None
elif num_gemms > 1 and len(A_scale_inv) == 1:
assert len(out) == num_gemms
elif num_gemms == 1:
assert len(A_scale_inv) == 1 and len(out) == 1
else:
raise ValueError("Invalid input combinations of A_scale_inv and out.")

empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
Expand All @@ -420,41 +435,71 @@ def fp8_grouped_gemm(

# Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
]
else:
gelu_input = empty_tensors
bias_dtype = TE_DType[bias_dtype]

gelu_input = empty_tensors
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype

torch.ops.tex_ts.te_grouped_gemm_ts(
A,
A_scale_inv,
A_fp8_tensor_offset,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor_offset,
B_dtype,
False, # transb
out,
0 if out_offset is None else out_offset,
empty_tensor if out_offset is None else fp8_meta_tensor.scale,
out_dtype,
empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
bias if use_bias else empty_tensors,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)
if len(A_scale_inv) == 1:
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
]

torch.ops.tex_ts.te_grouped_gemm_ts(
A,
A_scale_inv[0],
A_fp8_tensor_offset,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor_offset,
B_dtype,
False, # transb
out,
0 if out_offset is None else out_offset,
empty_tensor if out_offset is None else fp8_meta_tensor.scale,
out_dtype,
empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
bias if use_bias else empty_tensors,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)
else:
if gelu:
gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits]

torch.ops.tex_ts.te_grouped_gemm_single_output_ts(
A,
A_scale_inv,
A_fp8_tensor_offset,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor_offset,
B_dtype,
False, # transb
m_splits,
out[0],
0 if out_offset is None else out_offset,
empty_tensor if out_offset is None else fp8_meta_tensor.scale,
out_dtype,
empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
bias if use_bias else empty_tensors,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)

return out, gelu_input
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/cpp_extensions/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,15 @@ def fp8_multi_cast_transpose_fused(
amax_indices: List[int],
scale_inv_indices: List[int],
otype: tex.DType,
scale_inv: Optional[torch.Tensor] = None,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Cast + Transpose with FP8 output"""

return tex.fused_multi_cast_transpose_alloc(
input_list,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
scale_inv if scale_inv is not None else fp8_meta_tensor.scale_inv,
scale_indices,
amax_indices,
scale_inv_indices,
Expand Down
10 changes: 10 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@ void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);

void te_grouped_gemm_single_output(
std::vector<at::Tensor> A, std::vector<at::Tensor> A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb,
std::vector<int64_t> m_splits, at::Tensor D, int D_offset, at::Tensor D_scale,
transformer_engine::DType D_type, at::Tensor D_amax, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);

/***************************************************************************************************
* Transpose
**************************************************************************************************/
Expand Down
61 changes: 61 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,64 @@ void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int
te_workspace.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
}

void te_grouped_gemm_single_output(
std::vector<at::Tensor> A, std::vector<at::Tensor> A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb,
std::vector<int64_t> m_splits, at::Tensor D, int D_offset, at::Tensor D_scale,
transformer_engine::DType D_type, at::Tensor D_amax, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, std::vector<at::Tensor> pre_gelu_out, bool grad,
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count) {
using namespace transformer_engine;
std::vector<NVTETensor> te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape,
transformer_engine::DType dtype, void* amax_dptr,
void* scale_dptr, void* scale_inv_dptr) -> NVTETensor {
tensor_wrappers.emplace_back(
makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr));
return tensor_wrappers.back().data();
};
NVTE_CHECK(D.is_contiguous(), "D must be contiguous.");
void* d_i_ptr = reinterpret_cast<void*>(D.data_ptr());
for (size_t i = 0; i < A.size(); i++) {
if (m_splits[i] == 0) continue;
NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous.");
NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous.");
te_A.emplace_back(make_tensor(
A[i].data_ptr(), {static_cast<size_t>(A[i].size(0)), static_cast<size_t>(A[i].size(1))},
A_type, nullptr, nullptr, getDataPtr(A_scale_inverse[i], A_offset)));
te_B.emplace_back(make_tensor(
B[i].data_ptr(), {static_cast<size_t>(B[i].size(0)), static_cast<size_t>(B[i].size(1))},
B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i)));
te_D.emplace_back(make_tensor(
d_i_ptr, {static_cast<size_t>(m_splits[i]), static_cast<size_t>(A[i].size(0))}, D_type,
getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr));
te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast<size_t>(bias[i].size(0))},
bias_type, nullptr, nullptr, nullptr));

const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0))}
: std::vector<size_t>{static_cast<size_t>(pre_gelu_out[i].size(0)),
static_cast<size_t>(pre_gelu_out[i].size(1))};
te_pre_gelu_out.emplace_back(make_tensor(
pre_gelu_out[i].data_ptr(), gelu_shape,
GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr));
// Move the D pointer to the next split.
char* char_ptr = reinterpret_cast<char*>(d_i_ptr);
char_ptr += m_splits[i] * A[i].size(0) * D.element_size();
d_i_ptr = reinterpret_cast<void*>(char_ptr);
}
for (size_t i = 0; i < workspace.size(); i++) {
te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte,
nullptr, nullptr, nullptr));
}

// For now, we only have multi-stream cublas backend.
nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(),
te_pre_gelu_out.data(), te_A.size(), transa, transb, grad,
te_workspace.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
}
36 changes: 36 additions & 0 deletions transformer_engine/pytorch/csrc/ts_fp8_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,41 @@ std::vector<at::Tensor> te_grouped_gemm_ts(
return D;
}

at::Tensor te_grouped_gemm_single_output_ts(
std::vector<at::Tensor> A, std::vector<at::Tensor> A_scale_inverse, int64_t A_offset,
int64_t A_type, int64_t transa, std::vector<at::Tensor> B, at::Tensor B_scale_inverse,
int64_t B_offset, int64_t B_type, int64_t transb, std::vector<int64_t> m_splits, at::Tensor D,
int64_t D_offset, at::Tensor D_scale, int64_t D_type, at::Tensor D_amax,
std::vector<at::Tensor> bias, int64_t bias_type, std::vector<at::Tensor> pre_gelu_out,
int64_t grad, std::vector<at::Tensor> workspace, int64_t workspaceSize, int64_t accumulate,
int64_t use_split_accumulator) {
// cast inputs to types accepted by te_gemm
transformer_engine::DType A_type_arg = reverse_map_dtype(A_type);
bool transa_arg = static_cast<bool>(transa);
transformer_engine::DType B_type_arg = reverse_map_dtype(B_type);
bool transb_arg = static_cast<bool>(transb);
transformer_engine::DType D_type_arg = reverse_map_dtype(D_type);
transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type);
bool grad_arg = static_cast<bool>(grad);
size_t workspaceSize_arg = static_cast<size_t>(workspaceSize);
bool accumulate_arg = static_cast<bool>(accumulate);
bool use_split_accumulator_arg = static_cast<bool>(use_split_accumulator);

// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs

const int device_id = at::cuda::current_device();
const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);

te_grouped_gemm_single_output(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B,
B_scale_inverse, B_offset, B_type_arg, transb_arg, m_splits, D,
D_offset, D_scale, D_type_arg, D_amax, bias, bias_type_arg,
pre_gelu_out, grad_arg, workspace, workspaceSize_arg,
accumulate_arg, use_split_accumulator_arg, num_math_sms);
return D;
}

at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight,
const at::Tensor &bias, double eps, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor,
Expand Down Expand Up @@ -371,6 +406,7 @@ TORCH_LIBRARY(tex_ts, m) {
m.def("srelu_ts", &srelu_ts);
m.def("te_gemm_ts", &te_gemm_ts);
m.def("te_grouped_gemm_ts", &te_grouped_gemm_ts);
m.def("te_grouped_gemm_single_output_ts", &te_grouped_gemm_single_output_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts);
Expand Down
Loading

0 comments on commit 047a507

Please sign in to comment.