From 047a50722780e7b647f9107783e210021190edc3 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Mon, 9 Sep 2024 22:30:48 +0800 Subject: [PATCH 1/3] [PyTorch] Propagate fp8 scale-inverse modification to `GroupedLinear` (#1128) * propagate scale_inv modification to GroupedLinear Signed-off-by: Xin Yao * optimization for separate scale_inv of weights and single output Signed-off-by: Xin Yao * let grouped gemm support different input combinations Signed-off-by: Xin Yao * fix type Signed-off-by: Xin Yao * add contiguous check Signed-off-by: Xin Yao * use len() instead of isinstance Signed-off-by: Xin Yao * fix ut Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Co-authored-by: Kirthi Shankar Sivamani --- tests/pytorch/test_numerics.py | 31 +++-- .../pytorch/cpp_extensions/gemm.py | 129 ++++++++++++------ .../pytorch/cpp_extensions/transpose.py | 3 +- transformer_engine/pytorch/csrc/extensions.h | 10 ++ .../pytorch/csrc/extensions/gemm.cu | 61 +++++++++ transformer_engine/pytorch/csrc/ts_fp8_op.cpp | 36 +++++ .../pytorch/module/grouped_linear.py | 43 ++++-- 7 files changed, 249 insertions(+), 64 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 723f68369b..ad34b4996f 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1266,12 +1266,15 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False ) inp_hidden_states.retain_grad() - m = config.seq_len // 16 - dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() - dist.append(dist[-1]) # Manually add a zero - m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) - m_splits = m_splits * 16 - assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms + if num_gemms > 1: + m = config.seq_len // 16 + dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() + dist.append(dist[-1]) # Manually add a zero + m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) + m_splits = m_splits * 16 + assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms + else: + m_splits = torch.tensor([config.seq_len]) with fp8_autocast(enabled=fp8): if isinstance(block, GroupedLinear): @@ -1353,7 +1356,7 @@ def test_grouped_linear_accuracy( @pytest.mark.parametrize("parallel_mode", ["column", "row"]) def test_grouped_linear_accuracy_parallel_mode(parallel_mode): - """Split the tests to reduce CI time""" + """Split the tests to save CI time""" test_grouped_linear_accuracy( dtype=torch.float32, num_gemms=6, @@ -1365,6 +1368,18 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode): ) +def test_grouped_linear_accuracy_single_gemm(): + """Split the tests to save CI time""" + test_grouped_linear_accuracy( + dtype=torch.float32, + num_gemms=1, + bs=2, + model=list(model_configs.keys())[0], + fp8=True, + fp8_model_params=True, + ) + + def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False): def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): @@ -2034,7 +2049,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): fp8_grouped_gemm( A_fp8, - scale_inv, + [scale_inv], 0, # A_offset tex.DType.kFloat8E4M3, B_fp8, diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 8502f70491..fd1eb4a810 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -11,7 +11,12 @@ from ..utils import assert_dim_for_fp8_exec -__all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"] +__all__ = [ + "gemm", + "fp8_gemm", + "grouped_gemm", + "fp8_grouped_gemm", +] @functools.lru_cache(maxsize=None) @@ -313,7 +318,7 @@ def grouped_gemm( layout: str = "TN", bias: Optional[List[torch.Tensor]] = None, use_bias: bool = False, -) -> Tuple[Union[List[torch.Tensor], None], ...]: +) -> Tuple[List[torch.Tensor], ...]: """Non FP8 Grouped GEMM.""" assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." @@ -380,7 +385,7 @@ def grouped_gemm( def fp8_grouped_gemm( A: List[torch.Tensor], - A_scale_inv: torch.Tensor, + A_scale_inv: List[torch.Tensor], A_fp8_tensor_offset: int, A_dtype: tex.DType, B: List[torch.Tensor], @@ -390,6 +395,7 @@ def fp8_grouped_gemm( out: List[torch.Tensor], out_dtype: torch.dtype, workspaces: List[torch.Tensor], + m_splits: Optional[List[int]] = None, out_offset: Optional[int] = None, fp8_meta_tensor: tex.FP8TensorMeta = None, gelu: bool = False, @@ -398,16 +404,25 @@ def fp8_grouped_gemm( use_bias: bool = False, use_split_accumulator: bool = False, D_dtype: Optional[tex.DType] = None, -) -> Tuple[Union[List[torch.Tensor], None], ...]: +) -> Tuple[List[torch.Tensor], ...]: """ TN layout Grouped GEMM with fp8 inputs. - This method assumes the scale/scale_inv/amax of A/B/out is contiguous in the meta tensor. - scale: [ ...A_scale... | ...B_scale... | ...out_scale...] - scale_inv: [ ...A_scale_inv... | ...B_scale_inv... | ...out_scale_inv...] - amax: [ ...A_amax... | ...B_amax... | ...out_amax...] + Input requirements: + 1. If len(A_scale_inv) == num_gemms, len(out) must be 1, and m_splits is not None. + This is used for the calculation of output (fwd) and dgrad (bwd). + 2. if len(A_scale_inv) == 1, len(out) must be num_gemms. This is used for the + calculation of wgrad. """ - num_gemms = len(A) + if num_gemms > 1 and len(A_scale_inv) == num_gemms: + assert len(out) == 1 and m_splits is not None + elif num_gemms > 1 and len(A_scale_inv) == 1: + assert len(out) == num_gemms + elif num_gemms == 1: + assert len(A_scale_inv) == 1 and len(out) == 1 + else: + raise ValueError("Invalid input combinations of A_scale_inv and out.") + empty_tensor = _empty_tensor() empty_tensors = [empty_tensor] * num_gemms if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: @@ -420,41 +435,71 @@ def fp8_grouped_gemm( # Use bfloat16 as default bias_dtype bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype - if gelu: - gelu_input = [ - torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) - for o in out - ] - else: - gelu_input = empty_tensors bias_dtype = TE_DType[bias_dtype] - + gelu_input = empty_tensors out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype - torch.ops.tex_ts.te_grouped_gemm_ts( - A, - A_scale_inv, - A_fp8_tensor_offset, - A_dtype, - True, # transa - B, - B_scale_inv, - B_fp8_tensor_offset, - B_dtype, - False, # transb - out, - 0 if out_offset is None else out_offset, - empty_tensor if out_offset is None else fp8_meta_tensor.scale, - out_dtype, - empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, - bias if use_bias else empty_tensors, - bias_dtype, - gelu_input, # this is pre_gelu_out - False, # grad - workspaces, - workspaces[0].shape[0], - accumulate, - use_split_accumulator, - ) + if len(A_scale_inv) == 1: + if gelu: + gelu_input = [ + torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) + for o in out + ] + + torch.ops.tex_ts.te_grouped_gemm_ts( + A, + A_scale_inv[0], + A_fp8_tensor_offset, + A_dtype, + True, # transa + B, + B_scale_inv, + B_fp8_tensor_offset, + B_dtype, + False, # transb + out, + 0 if out_offset is None else out_offset, + empty_tensor if out_offset is None else fp8_meta_tensor.scale, + out_dtype, + empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, + bias if use_bias else empty_tensors, + bias_dtype, + gelu_input, # this is pre_gelu_out + False, # grad + workspaces, + workspaces[0].shape[0], + accumulate, + use_split_accumulator, + ) + else: + if gelu: + gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits] + + torch.ops.tex_ts.te_grouped_gemm_single_output_ts( + A, + A_scale_inv, + A_fp8_tensor_offset, + A_dtype, + True, # transa + B, + B_scale_inv, + B_fp8_tensor_offset, + B_dtype, + False, # transb + m_splits, + out[0], + 0 if out_offset is None else out_offset, + empty_tensor if out_offset is None else fp8_meta_tensor.scale, + out_dtype, + empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, + bias if use_bias else empty_tensors, + bias_dtype, + gelu_input, # this is pre_gelu_out + False, # grad + workspaces, + workspaces[0].shape[0], + accumulate, + use_split_accumulator, + ) return out, gelu_input diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index 37a1b59da2..ddc3b67e9e 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -175,6 +175,7 @@ def fp8_multi_cast_transpose_fused( amax_indices: List[int], scale_inv_indices: List[int], otype: tex.DType, + scale_inv: Optional[torch.Tensor] = None, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """Cast + Transpose with FP8 output""" @@ -182,7 +183,7 @@ def fp8_multi_cast_transpose_fused( input_list, fp8_meta_tensor.scale, fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + scale_inv if scale_inv is not None else fp8_meta_tensor.scale_inv, scale_indices, amax_indices, scale_inv_indices, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 31103cbe8e..c797208e06 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -165,6 +165,16 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); +void te_grouped_gemm_single_output( + std::vector A, std::vector A_scale_inverse, int A_offset, + transformer_engine::DType A_type, bool transa, std::vector B, + at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb, + std::vector m_splits, at::Tensor D, int D_offset, at::Tensor D_scale, + transformer_engine::DType D_type, at::Tensor D_amax, std::vector bias, + transformer_engine::DType bias_type, std::vector pre_gelu_out, bool grad, + std::vector workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count); + /*************************************************************************************************** * Transpose **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu index 7405914a0e..ba9851e7e8 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cu +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -151,3 +151,64 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int te_workspace.data(), accumulate, use_split_accumulator, math_sm_count, at::cuda::getCurrentCUDAStream()); } + +void te_grouped_gemm_single_output( + std::vector A, std::vector A_scale_inverse, int A_offset, + transformer_engine::DType A_type, bool transa, std::vector B, + at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb, + std::vector m_splits, at::Tensor D, int D_offset, at::Tensor D_scale, + transformer_engine::DType D_type, at::Tensor D_amax, std::vector bias, + transformer_engine::DType bias_type, std::vector pre_gelu_out, bool grad, + std::vector workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count) { + using namespace transformer_engine; + std::vector te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace; + std::vector tensor_wrappers; + auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + transformer_engine::DType dtype, void* amax_dptr, + void* scale_dptr, void* scale_inv_dptr) -> NVTETensor { + tensor_wrappers.emplace_back( + makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr)); + return tensor_wrappers.back().data(); + }; + NVTE_CHECK(D.is_contiguous(), "D must be contiguous."); + void* d_i_ptr = reinterpret_cast(D.data_ptr()); + for (size_t i = 0; i < A.size(); i++) { + if (m_splits[i] == 0) continue; + NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous."); + NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous."); + te_A.emplace_back(make_tensor( + A[i].data_ptr(), {static_cast(A[i].size(0)), static_cast(A[i].size(1))}, + A_type, nullptr, nullptr, getDataPtr(A_scale_inverse[i], A_offset))); + te_B.emplace_back(make_tensor( + B[i].data_ptr(), {static_cast(B[i].size(0)), static_cast(B[i].size(1))}, + B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i))); + te_D.emplace_back(make_tensor( + d_i_ptr, {static_cast(m_splits[i]), static_cast(A[i].size(0))}, D_type, + getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr)); + te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast(bias[i].size(0))}, + bias_type, nullptr, nullptr, nullptr)); + + const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr + ? std::vector{static_cast(pre_gelu_out[i].size(0))} + : std::vector{static_cast(pre_gelu_out[i].size(0)), + static_cast(pre_gelu_out[i].size(1))}; + te_pre_gelu_out.emplace_back(make_tensor( + pre_gelu_out[i].data_ptr(), gelu_shape, + GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr)); + // Move the D pointer to the next split. + char* char_ptr = reinterpret_cast(d_i_ptr); + char_ptr += m_splits[i] * A[i].size(0) * D.element_size(); + d_i_ptr = reinterpret_cast(char_ptr); + } + for (size_t i = 0; i < workspace.size(); i++) { + te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte, + nullptr, nullptr, nullptr)); + } + + // For now, we only have multi-stream cublas backend. + nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), + te_pre_gelu_out.data(), te_A.size(), transa, transb, grad, + te_workspace.data(), accumulate, use_split_accumulator, + math_sm_count, at::cuda::getCurrentCUDAStream()); +} diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp index 8c480e8343..9f31dba669 100644 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -305,6 +305,41 @@ std::vector te_grouped_gemm_ts( return D; } +at::Tensor te_grouped_gemm_single_output_ts( + std::vector A, std::vector A_scale_inverse, int64_t A_offset, + int64_t A_type, int64_t transa, std::vector B, at::Tensor B_scale_inverse, + int64_t B_offset, int64_t B_type, int64_t transb, std::vector m_splits, at::Tensor D, + int64_t D_offset, at::Tensor D_scale, int64_t D_type, at::Tensor D_amax, + std::vector bias, int64_t bias_type, std::vector pre_gelu_out, + int64_t grad, std::vector workspace, int64_t workspaceSize, int64_t accumulate, + int64_t use_split_accumulator) { + // cast inputs to types accepted by te_gemm + transformer_engine::DType A_type_arg = reverse_map_dtype(A_type); + bool transa_arg = static_cast(transa); + transformer_engine::DType B_type_arg = reverse_map_dtype(B_type); + bool transb_arg = static_cast(transb); + transformer_engine::DType D_type_arg = reverse_map_dtype(D_type); + transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type); + bool grad_arg = static_cast(grad); + size_t workspaceSize_arg = static_cast(workspaceSize); + bool accumulate_arg = static_cast(accumulate); + bool use_split_accumulator_arg = static_cast(use_split_accumulator); + + // Set an external SM Margin to all the GEMMs. + // This comes in handy when DP is overlapped with GEMMs + + const int device_id = at::cuda::current_device(); + const int sm_count = transformer_engine::cuda::sm_count(device_id); + int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); + + te_grouped_gemm_single_output(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, + B_scale_inverse, B_offset, B_type_arg, transb_arg, m_splits, D, + D_offset, D_scale, D_type_arg, D_amax, bias, bias_type_arg, + pre_gelu_out, grad_arg, workspace, workspaceSize_arg, + accumulate_arg, use_split_accumulator_arg, num_math_sms); + return D; +} + at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, double eps, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor, @@ -371,6 +406,7 @@ TORCH_LIBRARY(tex_ts, m) { m.def("srelu_ts", &srelu_ts); m.def("te_gemm_ts", &te_gemm_ts); m.def("te_grouped_gemm_ts", &te_grouped_gemm_ts); + m.def("te_grouped_gemm_single_output_ts", &te_grouped_gemm_single_output_ts); m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts); m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts); m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts); diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index a91ff5c361..ca100392c7 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -42,6 +42,7 @@ from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor +from ..export import is_in_onnx_export_mode __all__ = ["GroupedLinear"] @@ -102,10 +103,12 @@ def forward( inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats] inputmats = [] inputmats_t = [] + inputmat_scale_inv = None global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + inputmat_scale_inv = torch.empty([num_gemms], dtype=torch.float32, device=inp.device) if ( not fp8_meta["recipe"].override_linear_precision.wgrad and is_grad_enabled @@ -121,6 +124,7 @@ def forward( indices, # amax_indices indices, # scale_inv_indices fp8_dtype_forward, + scale_inv=inputmat_scale_inv, ) else: # FP8 input for forward @@ -130,9 +134,22 @@ def forward( fp8_meta["scaling_fwd"], _GEMM_INPUT + i, fp8_dtype_forward, + scale_inv=inputmat_scale_inv, ) for i in range(num_gemms) ] + + # Hack for ONNX export + # Note: ONNX models are represented as a graph of tensor + # operations, so the in-place scale-inv update doesn't fit + # very well. We work around this by making it look like + # the scale-inv tensor is initialized with a copy. + # Note: ONNX export expects FP8 scales can be represented + # with constant ops. However, copying into a buffer + # involves an expand op for array broadcasting. We work + # around this by filling the buffer instead. + if is_in_onnx_export_mode(): + inputmat_scale_inv.fill_(inputmat_scale_inv.item()) else: inputmats = inputmats_no_fp8 @@ -153,16 +170,17 @@ def forward( _ = fp8_grouped_gemm( [w._data for w in weights_fp8], - fp8_meta["scaling_fwd"].scale_inv, - _GEMM_WEIGHT, + [w._scale_inv for w in weights_fp8], + 0, # weight offset is 0 for the newly created _scale_inv fp8_dtype_forward, inputmats, - fp8_meta["scaling_fwd"].scale_inv, - _GEMM_INPUT, + inputmat_scale_inv, + 0, fp8_dtype_forward, - torch.split(out, m_splits), + [out], activation_dtype, get_multi_stream_cublas_workspace(), + m_splits=m_splits, bias=biases, use_bias=use_bias, use_split_accumulator=_2X_ACC_FPROP, @@ -230,7 +248,7 @@ def forward( t.activation_offloading = True ctx.save_for_backward( - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + inputmat_scale_inv, *saved_inputmats, *saved_inputmats_t, *weights, @@ -270,7 +288,7 @@ def forward( def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: with torch.cuda.nvtx.range("_GroupedLinear_backward"): ( - fwd_scale_inverses, + inputmat_scale_inv, *saved_tensors, ) = ctx.saved_tensors inputmats = saved_tensors[: ctx.num_gemms] @@ -342,18 +360,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) fp8_grouped_gemm( [w.transpose_2d() for w in weights_fp8], - torch.cat( - [w._scale_inv for w in weights_fp8] - ), # avoiding torch.cat requires another interface + [w._scale_inv for w in weights_fp8], 0, # weight offset is 0 for the newly created _scale_inv weights_fp8[0]._fp8_dtype, grad_output_c, ctx.fp8_meta["scaling_bwd"].scale_inv, _GRAD_OUTPUT, fp8_dtype_backward, - torch.split(dgrad, ctx.m_splits), + [dgrad], ctx.activation_dtype, get_multi_stream_cublas_workspace(), + m_splits=ctx.m_splits, use_split_accumulator=_2X_ACC_DGRAD, ) else: @@ -396,8 +413,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inp._data if isinstance(inp, Float8Tensor) else inp for inp in inputmats_t ], - fwd_scale_inverses, - _GEMM_INPUT, + [inputmat_scale_inv], + 0, fp8_dtype_forward, grad_output_t, ctx.fp8_meta["scaling_bwd"].scale_inv, From 2a9845e1d93440d3c0f65427985e66208d09eff8 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Mon, 9 Sep 2024 11:34:45 -0700 Subject: [PATCH 2/3] Added Adobe analytics to the documentation (#1162) Signed-off-by: Przemyslaw Tredak --- docs/_templates/layout.html | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html index a68b4531e3..f94e526f57 100644 --- a/docs/_templates/layout.html +++ b/docs/_templates/layout.html @@ -1,4 +1,11 @@ {% extends "!layout.html" %} + + {% block extrahead %} + + + + {% endblock %} + {% block sidebartitle %} {{ super() }} - {%- if nvidia_analytics_id %} - - {%- endif %} + {% endblock %} + + {% block footer %} + + {% endblock %} From 40dda924a52866c3a5e9b56f1907b4a2602f2fac Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Mon, 9 Sep 2024 13:50:47 -0700 Subject: [PATCH 3/3] Add a context parallelism implementation with QKVO all-to-all (#1160) * clean code for CP function args Signed-off-by: Xiaowei Ren * add a placeholder for Ulysses implementation Signed-off-by: Xiaowei Ren * commit code change to CP+A2A Signed-off-by: Xiaowei Ren * finish the draft fwd implementation of Ulysses Signed-off-by: Xiaowei Ren * add draft bwd implementation of Ulysses Signed-off-by: Xiaowei Ren * make swa work with ulysses Signed-off-by: Xiaowei Ren * commit FP8 code for Ulysses Signed-off-by: Xiaowei Ren * fix qkv type in the bwd of FP8+CP Signed-off-by: Xiaowei Ren * typo fix Signed-off-by: Xiaowei Ren * fix qkv_dtype of FP8+CP Signed-off-by: Xiaowei Ren * code refactoring Signed-off-by: Xiaowei Ren * minor code change Signed-off-by: Xiaowei Ren * config cp correction dtype of FP8+CP Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * code style change Signed-off-by: Xiaowei Ren * save chunk_ids Signed-off-by: Xiaowei Ren * try to make Ulysses A2A async Signed-off-by: Xiaowei Ren * make more a2a async Signed-off-by: Xiaowei Ren * fix a2a_outputs Signed-off-by: Xiaowei Ren * fix chunk_ids generation for A2A Signed-off-by: Xiaowei Ren * avoid code duplication of a2a before attn Signed-off-by: Xiaowei Ren * remove code duplication of a2a after attn Signed-off-by: Xiaowei Ren * add cp_stream in A2A implementation Signed-off-by: Xiaowei Ren * bug fix Signed-off-by: Xiaowei Ren * fix qkv of fp8_fwd + bf16_bwd Signed-off-by: Xiaowei Ren * fix kernel order in cp a2a communication Signed-off-by: Xiaowei Ren * code cleaning for CP a2a Signed-off-by: Xiaowei Ren * fix merging with main Signed-off-by: Xiaowei Ren * fix a2a communication order Signed-off-by: Xiaowei Ren * adjust sequence chunk reordering for a2a Signed-off-by: Xiaowei Ren * add docstring for A2A implementation Signed-off-by: Xiaowei Ren * change an assert info Signed-off-by: Xiaowei Ren * add unit tests of A2A implementation Signed-off-by: Xiaowei Ren * add more A2A unit test Signed-off-by: Xiaowei Ren * fix CP unit tests Signed-off-by: Xiaowei Ren * add more cp unit tests Signed-off-by: Xiaowei Ren * fix window size of no_mask Signed-off-by: Xiaowei Ren * fused attn does not support swa+no_mask Signed-off-by: Xiaowei Ren * change num_gqa_groups to 2 for A2A implementation Signed-off-by: Xiaowei Ren * function and variable renaming Signed-off-by: Xiaowei Ren * code cleaning for CP all-gather implementation Signed-off-by: Xiaowei Ren * some function renaming Signed-off-by: Xiaowei Ren * remove redundant code Signed-off-by: Xiaowei Ren * commit code change for kv all-gather implementation Signed-off-by: Xiaowei Ren * fix all-gather implementation Signed-off-by: Xiaowei Ren * add a window size check Signed-off-by: Xiaowei Ren * code cleaning Signed-off-by: Xiaowei Ren * add unit test of all_gather+no_mask Signed-off-by: Xiaowei Ren * fix all-gather cp implementation Signed-off-by: Xiaowei Ren * code cleaning Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * code format fix Signed-off-by: Xiaowei Ren * code format fix Signed-off-by: Xiaowei Ren * fix FP8 with A2A implementation Signed-off-by: Xiaowei Ren * add paper references to CP implementations with all-gather and all-to-all Signed-off-by: Xiaowei Ren * change pdf to abs Signed-off-by: Xiaowei Ren * elaborate cp_comm_type Signed-off-by: Xiaowei Ren * fix CP docstring Signed-off-by: Xiaowei Ren --------- Signed-off-by: Xiaowei Ren Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../fused_attn/test_fused_attn_with_cp.py | 110 +- transformer_engine/pytorch/attention.py | 1020 +++++++++++++---- transformer_engine/pytorch/transformer.py | 8 +- 3 files changed, 849 insertions(+), 289 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 82875e2791..d6358d1062 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -22,10 +22,16 @@ "cp_1_2": ModelConfig( 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) ), # MHA - "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA - "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA + "cp_1_3": ModelConfig( + 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512) + ), # MHA + "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA + "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA "cp_2_2": ModelConfig( - 2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + ), # GQA + "cp_2_3": ModelConfig( + 2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512) ), # GQA } @@ -45,31 +51,32 @@ def get_bash_arguments(**kwargs): @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"]) +@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"]) def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): config = model_configs_flash_attn[model] + if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1): + pytest.skip("CP implementation with KV P2P does not support sliding window yet!") if cp_comm_type == "all_gather" and qkv_format == "thd": - pytest.skip( - f"CP implementation with KV all-gather does not support {qkv_format} format yet!" - ) - if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type: - pytest.skip( - f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask" - " type yet!" - ) + pytest.skip("CP implementation with KV all-gather does not support THD format yet!") if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": + pytest.skip("CP implementation with KV all-gather does not support bias yet!") + if cp_comm_type == "a2a" and qkv_format == "thd": + pytest.skip("CP implementation with QKVO A2A does not support THD format yet!") + if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias": + pytest.skip("CP implementation with QKVO A2A does not support bias yet!") + if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): pytest.skip( - f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias" - " type yet!" - ) - if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip( - f"CP implementation with KV P2P does not support window size {config.window_size} yet!" + f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" + f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" ) subprocess.run( get_bash_arguments( - dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention" + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FlashAttention", + cp_comm_type=cp_comm_type, ), check=True, ) @@ -81,10 +88,16 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA "cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA "cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA - "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA - "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA - "cp_2_2": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA - "cp_2_3": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA + "cp_1_4": ModelConfig( + 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + ), # MHA + "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA + "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA + "cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA + "cp_2_3": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA + "cp_2_4": ModelConfig( + 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + ), # GQA } @@ -93,37 +106,27 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"]) +@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"]) def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): if qkv_format == "thd" and get_device_compute_capability() < (9, 0): - pytest.skip("THD format is only supported on sm90+.") + pytest.skip("THD format is only supported on sm90+!") if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0): - pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0") + pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") config = model_configs_fused_attn[model] if qkv_format == "thd" and config.num_heads != config.num_gqa_groups: - pytest.skip(f"{qkv_format} format does not support QGA/MQA yet!") + pytest.skip("THD format does not support QGA/MQA yet!") if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": - pytest.skip(f"{qkv_format} format does not support {config.attn_bias_type} bias type yet!") - if cp_comm_type == "all_gather" and qkv_format == "thd": - pytest.skip( - f"CP implementation with KV all-gather does not support {qkv_format} format yet!" - ) - if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type: - pytest.skip( - f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask" - " type yet!" - ) - if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": - pytest.skip( - f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias" - " type yet!" - ) - if config.window_size != (-1, 0) and config.window_size != (-1, -1): + pytest.skip("THD format does not support post_scale_bias yet!") + if qkv_format == "thd" and cp_comm_type == "all_gather": + pytest.skip("CP implementation with KV all-gather does not support THD format yet!") + if qkv_format == "thd" and cp_comm_type == "a2a": + pytest.skip("CP implementation with QKVO A2A does not support THD format yet!") + if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a": pytest.skip( - "Fused attention does not support sliding window attention + context parallelism yet!" + "Sliding window attention only can be supported with the implementation of QKVO A2A!" ) - if cp_comm_type == "all_gather" and dtype == "fp8": + if dtype == "fp8" and cp_comm_type == "all_gather": pytest.skip( "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" ) @@ -131,10 +134,25 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): pytest.skip("FP8 attention cannot work with THD format yet!") if dtype == "fp8" and config.attn_bias_type != "no_bias": pytest.skip("FP8 attention cannot work with bias yet!") + if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1): + pytest.skip("FP8 attention cannot work with sliding window yet!") + if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": + pytest.skip("CP implementation with KV all-gather does not support bias yet!") + if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias": + pytest.skip("CP implementation with QKVO A2A does not support bias yet!") + if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): + pytest.skip( + f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" + f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" + ) subprocess.run( get_bash_arguments( - dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention" + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FusedAttention", + cp_comm_type=cp_comm_type, ), check=True, ) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 91c14899ec..f8ba46b2ea 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -614,12 +614,6 @@ def get_attention_backend( "with causal mask, no dropout, and qkv_format = bshd/sbhd" ) use_fused_attention = False - elif context_parallel: - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention " - "with context parallelism" - ) - use_fused_attention = False elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [ "no_mask", "padding", @@ -1429,9 +1423,6 @@ def forward( cu_seqlens_q_padded, cu_seqlens_kv_padded, dropout_p, - cp_group, - cp_global_ranks, - cp_stream, softmax_scale, qkv_format, attn_mask_type, @@ -1441,6 +1432,9 @@ def forward( use_fused_attention, fp8, fp8_meta, + cp_group, + cp_global_ranks, + cp_stream, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -2946,10 +2940,10 @@ def backward(ctx, dout): None, None, None, + attn_dbias, None, None, None, - attn_dbias, None, None, None, @@ -2958,30 +2952,56 @@ def backward(ctx, dout): @torch.compile -def get_seq_chunk_ids_to_all_gathered_kv( - local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size_left, device +def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous): + """ + Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing. + To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks + before or after CP communications (e.g., all-gather, all-to-all). This function is to compute + sequence chunk ids for reordering. + """ + chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device) + if to_contiguous: + for rank in range(cp_size): + chunk_ids[rank] = 2 * rank + chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1 + else: + for rank in range(cp_size): + chunk_ids[2 * rank] = rank + chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1 + return chunk_ids + + +def get_kv_seq_info_after_all_gather( + local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size, causal ): - """Compute sequence chunk ids to the all-gathered KV.""" - seq_end_idx = (local_chunk_id + 1) * max_seqlen_kv - seq_start_idx = max(0, seq_end_idx - max_seqlen_q - window_size_left) - seqlen = seq_end_idx - seq_start_idx - num_chunks = (seqlen + max_seqlen_kv - 1) // max_seqlen_kv - chunk_ids = torch.arange( - local_chunk_id - num_chunks + 1, - local_chunk_id + 1, - dtype=torch.int32, - device=device, - ) - chunk_ids_to_all_gathered_kv = torch.where( - chunk_ids < cp_size, 2 * chunk_ids, 2 * (2 * cp_size - chunk_ids) - 1 - ) - return chunk_ids_to_all_gathered_kv + """Compute KV sequence index range and update window size after all-gather.""" + local_chunk_end_idx = (local_chunk_id + 1) * max_seqlen_kv + full_seq_end_idx = max_seqlen_kv * cp_size * 2 + + if window_size is None: + window_size = (-1, 0) if causal else (-1, -1) + + if window_size[1] == -1: + seq_end_idx = full_seq_end_idx + window_size_right = -1 + else: + seq_end_idx = min(full_seq_end_idx, local_chunk_end_idx + window_size[1]) + window_size_right = local_chunk_end_idx + window_size[1] - seq_end_idx + + if window_size[0] == -1: + seq_start_idx = 0 + window_size_left = -1 + else: + seq_start_idx = max(0, local_chunk_end_idx - max_seqlen_q - window_size[0]) + window_size_left = window_size[0] + seq_end_idx - local_chunk_end_idx + + return (seq_start_idx, seq_end_idx), (window_size_left, window_size_right) class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): """ - Attention implementation with context parallelism. - KV all-gather between CP ranks is exposed. + Attention implementation with context parallelism. KV all-gather between CP ranks is exposed. + Refer section 3.3.2 of `The Llama 3 Herd of Models `_. """ @staticmethod @@ -2992,14 +3012,10 @@ def forward( k, v, cu_seqlens_q, - cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, cu_seqlens_q_padded, - cu_seqlens_kv_padded, dropout_p, - cp_group, - cp_stream, softmax_scale, qkv_format, attn_mask_type, @@ -3008,6 +3024,8 @@ def forward( deterministic, use_fused_attention, window_size, + cp_group, + cp_stream, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -3017,10 +3035,9 @@ def forward( causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type - assert causal and not padding, f"{attn_mask_type} mask type is not supported!" + assert not padding, f"{attn_mask_type} mask type is not supported!" if use_fused_attention and causal and "bottom_right" not in attn_mask_type: attn_mask_type = attn_mask_type + "_bottom_right" - assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" assert ( @@ -3029,6 +3046,8 @@ def forward( fa_optional_forward_kwargs = {} if _flash_attn_2_4_plus: fa_optional_forward_kwargs["alibi_slopes"] = None + if _flash_attn_2_5_7_plus: + fa_optional_forward_kwargs["block_table"] = None assert qkv_format != "thd", f"{qkv_format} format is not supported!" qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format @@ -3041,31 +3060,35 @@ def forward( max_seqlen_q = max_seqlen_q // (2 * cp_size) max_seqlen_kv = max_seqlen_kv // (2 * cp_size) cu_seqlens_q = cu_seqlens_q // (2 * cp_size) - cu_seqlens_kv = cu_seqlens_kv // (2 * cp_size) cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size) - cu_seqlens_kv_padded = cu_seqlens_kv_padded // (2 * cp_size) - - if causal: - if qkv_format == "bshd": - # [b, s, np, hn] -> [b, 2, s//2, np, hn] - q = q.view(q.shape[0], 2, q.shape[1] // 2, *q.shape[2:]) - # [b, s, np, hn] -> [s, b, np, hn] - k, v = [x.transpose(0, 1).contiguous() for x in [k, v]] - elif qkv_format == "sbhd": - # [s, b, np, hn] -> [2, s//2, b, np, hn] - q = q.view(2, q.shape[0] // 2, *q.shape[1:]) - # create two streams to resolve wave quantization issue of Flash Attn in each step - flash_attn_streams = [torch.cuda.current_stream(), cp_stream] + # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn] + q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) + # [b, s, np, hn] or [s, b, np, hn] -> [s, b, np, hn] + k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] + # [s, b, np, hn] -> [cp, s, b, np, hn] k_ag, _ = gather_along_first_dim(k, cp_group) v_ag, _ = gather_along_first_dim(v, cp_group) - cp_stream.wait_stream(torch.cuda.current_stream()) + + # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True) + k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) + v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + k_ag = k_ag.view(-1, *k.shape[1:]) + v_ag = v_ag.view(-1, *v.shape[1:]) + cp_stream.wait_stream(torch.cuda.current_stream()) + + # create two streams to resolve wave quantization issue of Flash Attn in each step + flash_attn_streams = [torch.cuda.current_stream(), cp_stream] local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] - chunk_ids_to_kv_ag_per_step = [None, None] + kv_seq_range_per_step = [None, None] + window_size_per_step = [None, None] + cu_seqlens_kv_per_step = [None, None] out_per_step = [None, None] softmax_lse_per_step = [None, None] rng_states = [None, None] @@ -3074,53 +3097,36 @@ def forward( for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - chunk_ids_to_kv_ag = get_seq_chunk_ids_to_all_gathered_kv( - local_seq_chunk_ids[i], - cp_size, - max_seqlen_q, - max_seqlen_kv, - ( - max_seqlen_kv * cp_size * 2 - if (window_size is None or window_size[0] == -1) - else window_size[0] - ), - k.device, - ) - chunk_ids_to_kv_ag_per_step[i] = chunk_ids_to_kv_ag - num_kv_chunks = chunk_ids_to_kv_ag.numel() - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_ = q[:, i].contiguous() - # [num_kv_chunks, sq//2, b, np, hn] -> [b, num_kv_chunks*sq//2, np, hn] - k_ = ( - torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag) - .movedim(2, 0) - .contiguous() - .view(k.shape[1], -1, *k.shape[-2:]) - ) - v_ = ( - torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag) - .movedim(2, 0) - .contiguous() - .view(v.shape[1], -1, *v.shape[-2:]) - ) - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_ = q[i].contiguous() - # [num_kv_chunks, sq//2, b, np, hn] -> [num_kv_chunks*sq//2, b, np, hn] - k_ = torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag).view( - -1, *k.shape[-3:] - ) - v_ = torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag).view( - -1, *v.shape[-3:] + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_ = q.select(seq_dim, i).contiguous() + kv_seq_range_per_step[i], window_size_per_step[i] = ( + get_kv_seq_info_after_all_gather( + local_seq_chunk_ids[i], + cp_size, + max_seqlen_q, + max_seqlen_kv, + window_size, + causal, ) + ) + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i][0], + kv_seq_range_per_step[i][1], + ) + max_seqlen_kv_ = seq_end_idx - seq_start_idx + cu_seqlens_kv_per_step[i] = _get_full_cu_seqlens( + k.shape[1], max_seqlen_kv_, k.device + ) + k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn] + k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] if use_fused_attention: out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd( is_training, max_seqlen_q, - max_seqlen_kv * num_kv_chunks, + max_seqlen_kv_, cu_seqlens_q, - cu_seqlens_kv * num_kv_chunks, + cu_seqlens_kv_per_step[i], q_, k_, v_, @@ -3133,8 +3139,8 @@ def forward( attn_bias_type=attn_bias_type, attn_bias=attn_bias, cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded * num_kv_chunks, - window_size=window_size, + cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], + window_size=window_size_per_step[i], ) else: q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]] @@ -3144,14 +3150,14 @@ def forward( k_, v_, cu_seqlens_q, - cu_seqlens_kv * num_kv_chunks, + cu_seqlens_kv_per_step[i], max_seqlen_q, - max_seqlen_kv * num_kv_chunks, + max_seqlen_kv_, dropout_p, softmax_scale, - causal=True, + causal=causal, return_softmax=False, - window_size=window_size, + window_size=window_size_per_step[i], **fa_optional_forward_kwargs, ) ) @@ -3159,9 +3165,9 @@ def forward( if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): if qkv_format == "bshd": - out[:, i - 1].copy_(out_per_step[i - 1].view_as(out[:, i - 1])) + out[:, i - 1].copy_(out_per_step[i - 1].view(out[:, i - 1].shape)) elif qkv_format == "sbhd": - out[i - 1].copy_(out_per_step[i - 1].view_as(out[i - 1])) + out[i - 1].copy_(out_per_step[i - 1].view(out[i - 1].shape)) torch.cuda.current_stream().wait_stream(cp_stream) @@ -3178,26 +3184,24 @@ def forward( k, v, cu_seqlens_q, - cu_seqlens_kv, cu_seqlens_q_padded, - cu_seqlens_kv_padded, - *chunk_ids_to_kv_ag_per_step, + *cu_seqlens_kv_per_step, *out_per_step, *softmax_lse_per_step, *rng_states, ) + ctx.kv_seq_range_per_step = kv_seq_range_per_step + ctx.window_size_per_step = window_size_per_step ctx.cp_group = cp_group ctx.cp_stream = cp_stream ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_kv = max_seqlen_kv ctx.softmax_scale = softmax_scale ctx.qkv_format = qkv_format - ctx.attn_mask_type = attn_mask_type ctx.attn_bias_type = attn_bias_type + ctx.attn_mask_type = attn_mask_type ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention - ctx.window_size = window_size return out @staticmethod @@ -3205,21 +3209,20 @@ def backward(ctx, dout): cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) - (q, k, v, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ( - ctx.saved_tensors[:7] - ) - chunk_ids_to_kv_ag_per_step = ctx.saved_tensors[7:9] - out_per_step = ctx.saved_tensors[9:11] - softmax_lse_per_step = ctx.saved_tensors[11:13] - rng_states = ctx.saved_tensors[13:15] + (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = ctx.saved_tensors[:5] + cu_seqlens_kv_per_step = ctx.saved_tensors[5:7] + out_per_step = ctx.saved_tensors[7:9] + softmax_lse_per_step = ctx.saved_tensors[9:11] + rng_states = ctx.saved_tensors[11:13] + kv_seq_range_per_step = ctx.kv_seq_range_per_step + window_size_per_step = ctx.window_size_per_step + seq_dim = ctx.qkv_format.index("s") qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format - dout = dout.view_as(q) + dout = dout.view(q.shape) dq = torch.empty_like(q) - dk = torch.zeros( - (2 * cp_size, k.shape[0] // 2, *k.shape[1:]), dtype=k.dtype, device=k.device - ) + dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device) dv = torch.zeros_like(dk) dq_per_step = [None, None] dk_per_step = [None, None] @@ -3230,11 +3233,20 @@ def backward(ctx, dout): # synchronize dkv update across steps dkv_update_done = torch.cuda.Event() + # [s, b, np, hn] -> [cp, s, b, np, hn] k_ag, _ = gather_along_first_dim(k, ctx.cp_group) v_ag, _ = gather_along_first_dim(v, ctx.cp_group) - ctx.cp_stream.wait_stream(torch.cuda.current_stream()) + + # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True) + k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) + v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + k_ag = k_ag.view(-1, *k.shape[1:]) + v_ag = v_ag.view(-1, *v.shape[1:]) + ctx.cp_stream.wait_stream(torch.cuda.current_stream()) local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] @@ -3247,66 +3259,46 @@ def backward(ctx, dout): for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - chunk_ids_to_kv_ag = chunk_ids_to_kv_ag_per_step[i] - num_kv_chunks = chunk_ids_to_kv_ag.numel() + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_ = q.select(seq_dim, i).contiguous() + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i][0], + kv_seq_range_per_step[i][1], + ) + max_seqlen_kv = seq_end_idx - seq_start_idx + k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # [cp*s, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn] + k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] out_ = out_per_step[i] - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_ = q[:, i].contiguous() - # [num_kv_chunks, sq//2, b, np, hn] -> [b, num_kv_chunks*sq//2, np, hn] - k_ = ( - torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag) - .movedim(2, 0) - .contiguous() - .view(k.shape[1], -1, *k.shape[-2:]) - ) - v_ = ( - torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag) - .movedim(2, 0) - .contiguous() - .view(v.shape[1], -1, *v.shape[-2:]) - ) - dout_ = dout[:, i].contiguous().view_as(out_) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_ = q[i].contiguous() - # [num_kv_chunks, sq//2, b, np, hn] -> [num_kv_chunks*sq//2, b, np, hn] - k_ = torch.index_select(k_ag, dim=0, index=chunk_ids_to_kv_ag).view( - -1, *k.shape[-3:] - ) - v_ = torch.index_select(v_ag, dim=0, index=chunk_ids_to_kv_ag).view( - -1, *v.shape[-3:] - ) - dout_ = dout[i].contiguous().view_as(out_) + dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) if ctx.use_fused_attention: - dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ - torch.empty_like(x) for x in [q_, k_, v_] - ] aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd( ctx.max_seqlen_q, - ctx.max_seqlen_kv * num_kv_chunks, + max_seqlen_kv, cu_seqlens_q, - cu_seqlens_kv * num_kv_chunks, + cu_seqlens_kv_per_step[i], q_, k_, v_, out_, dout_, TE_DType[q.dtype], - TE_DType[k.dtype], + TE_DType[dout.dtype], aux_ctx_tensors, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded * num_kv_chunks, + cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, qkv_layout=qkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, - window_size=ctx.window_size, + window_size=window_size_per_step[i], + deterministic=ctx.deterministic, ) else: + batch_size = k_.shape[0] q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]] dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ torch.empty_like(x) for x in [q_, k_, v_] @@ -3322,65 +3314,601 @@ def backward(ctx, dout): dk_per_step[i], dv_per_step[i], cu_seqlens_q, - cu_seqlens_kv * num_kv_chunks, + cu_seqlens_kv_per_step[i], ctx.max_seqlen_q, - ctx.max_seqlen_kv * num_kv_chunks, + max_seqlen_kv, ctx.dropout_p, ctx.softmax_scale, - True, - window_size=ctx.window_size, + "causal" in ctx.attn_mask_type, + window_size=window_size_per_step[i], rng_state=rng_states[i], **fa_optional_backward_kwargs, ) + # [b*sq//2, np, hn] -> [b, sq//2, np, hn] + dq_per_step[i] = dq_per_step[i].view(dq[:, i].shape) + # [b*s_range, np, hn] -> [b, s_range, np, hn] + dk_per_step[i], dv_per_step[i] = [ + x.view(batch_size, -1, *x.shape[-2:]) + for x in [dk_per_step[i], dv_per_step[i]] + ] if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): - chunk_ids_to_kv_ag = chunk_ids_to_kv_ag_per_step[i - 1] - num_kv_chunks = chunk_ids_to_kv_ag.numel() if ctx.qkv_format == "bshd": - dq[:, i - 1].copy_(dq_per_step[i - 1].view_as(dq[:, i - 1])) - dk_per_step[i - 1] = ( - dk_per_step[i - 1] - .view(k.shape[1], num_kv_chunks, -1, *k.shape[-2:]) - .movedim(0, 2) - .contiguous() - ) - dv_per_step[i - 1] = ( - dv_per_step[i - 1] - .view(v.shape[1], num_kv_chunks, -1, *v.shape[-2:]) - .movedim(0, 2) - .contiguous() - ) + dq[:, i - 1].copy_(dq_per_step[i - 1]) elif ctx.qkv_format == "sbhd": - dq[i - 1].copy_(dq_per_step[i - 1].view_as(dq[i - 1])) - dk_per_step[i - 1] = dk_per_step[i - 1].view( - num_kv_chunks, -1, *k.shape[-3:] - ) - dv_per_step[i - 1] = dv_per_step[i - 1].view( - num_kv_chunks, -1, *v.shape[-3:] - ) - + dq[i - 1].copy_(dq_per_step[i - 1]) + # [b, s_range, np, hn] or [s_range, b, np, hn] -> [s_range, b, np, hn] + dk_per_step[i - 1], dv_per_step[i - 1] = [ + x.movedim(seq_dim, 0).contiguous() + for x in [dk_per_step[i - 1], dv_per_step[i - 1]] + ] # wait until dkv update of last step is done if i > 1: flash_attn_streams[i - 1].wait_event(dkv_update_done) - dk.index_add_(0, chunk_ids_to_kv_ag, dk_per_step[i - 1]) - dv.index_add_(0, chunk_ids_to_kv_ag, dv_per_step[i - 1]) + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i - 1][0], + kv_seq_range_per_step[i - 1][1], + ) + dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1]) + dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1]) if i < len(local_seq_chunk_ids): flash_attn_streams[i - 1].record_event(dkv_update_done) torch.cuda.current_stream().wait_stream(ctx.cp_stream) + # [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn] + dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) + dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, dk.device, False) + dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) + dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] dk = dk.view(-1, *dk.shape[-3:]) dv = dv.view(-1, *dv.shape[-3:]) dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) + dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) + dk = dk.movedim(0, seq_dim).contiguous() + dv = dv.movedim(0, seq_dim).contiguous() + + return ( + None, + dq, + dk, + dv, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +@torch.compile +def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn): + """Reorder sequence chunk for A2A communication.""" + if before_attn: + # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn] + x = x.movedim(0, seq_dim).contiguous() + # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) + # reorder the sequence chunks + x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) + else: + # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.movedim(seq_dim, 0).contiguous() + # reorder the sequence chunks + x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) + # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn] + x = x.view(cp_size, 2, *x.shape[1:]) + return x + + +def flash_attn_a2a_communicate( + a2a_inputs: Union[torch.Tensor, List[torch.Tensor]], + chunk_ids_for_a2a: torch.Tensor, + seq_dim: int, + cp_size: int, + cp_group: dist_group_type, + cp_stream: torch.cuda.Stream, + before_attn: bool, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """A2A communication for context parallelism.""" + a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs + a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) + if before_attn: + for i in range(len(a2a_inputs) + 2): + if 0 < i < len(a2a_inputs) + 1: + a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1]) + a2a_reqs[i - 1] = torch.distributed.all_to_all_single( + a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True + ) + if i > 1: + with torch.cuda.stream(cp_stream): + a2a_reqs[i - 2].wait() + x = a2a_outputs[i - 2] + # reorder the sequence chunks + x = reorder_seq_chunks_for_a2a( + x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn + ) + # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] + a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) + if i < len(a2a_inputs): + x = a2a_inputs[i] + # [b, s, np, hn] -> [b, s, cp, np//cp, hn] or [s, b, np, hn] -> [s, b, cp, np//cp, hn] + x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1]) + # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] + a2a_inputs[i] = x.movedim(-3, 0).contiguous() + else: + for i in range(len(a2a_inputs) + 2): + if 0 < i < len(a2a_inputs) + 1: + a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1]) + a2a_reqs[i - 1] = torch.distributed.all_to_all_single( + a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True + ) + if i < len(a2a_inputs): + x = a2a_inputs[i] + # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) + # reorder the sequence chunks + a2a_inputs[i] = reorder_seq_chunks_for_a2a( + x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn + ) + if i > 1: + with torch.cuda.stream(cp_stream): + a2a_reqs[i - 2].wait() + x = a2a_outputs[i - 2] + # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] + x = x.movedim(0, -3).movedim(0, seq_dim).contiguous() + # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] + a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) + torch.cuda.current_stream().wait_stream(cp_stream) + return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs + + +class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): + """ + Attention implementation with context parallelism. Like Ulysses, applying A2A to QKVO. + Refer the paper `DeepSpeed Ulysses `_. + """ + + @staticmethod + def forward( + ctx, + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + dropout_p, + softmax_scale, + qkv_format, + attn_mask_type, + attn_bias_type, + attn_bias, + deterministic, + use_fused_attention, + window_size, + fp8, + fp8_meta, + cp_group, + cp_stream, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + cp_size = get_distributed_world_size(cp_group) + + causal = "causal" in attn_mask_type + padding = "padding" in attn_mask_type + assert not padding, f"{attn_mask_type} mask type is not supported!" + assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" + assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" + assert ( + window_size == (-1, 0) + or window_size == (-1, -1) + or use_fused_attention + or _flash_attn_2_3_plus + ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" + fa_optional_forward_kwargs = {} + if _flash_attn_2_3_plus: + fa_optional_forward_kwargs["window_size"] = window_size + if _flash_attn_2_4_plus: + fa_optional_forward_kwargs["alibi_slopes"] = None + if _flash_attn_2_5_7_plus: + fa_optional_forward_kwargs["block_table"] = None + + assert ( + q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0 + ), "The number of attention heads needs to be divisible by CP size!" + + assert qkv_format != "thd", f"{qkv_format} format is not supported!" + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + + batch_dim = qkv_format.index("b") + seq_dim = qkv_format.index("s") + assert ( + q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 + ), "Sequence length per GPU needs to be divisible by 2!" + + if fp8: + if use_fused_attention: + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + fused_attn_qkv_dtype = fp8_dtype_forward + fused_attn_backend = FusedAttnBackend["FP8"] + if fp8_meta["recipe"].fp8_mha: + assert ( + isinstance(q, Float8Tensor) + and isinstance(k, Float8Tensor) + and isinstance(v, Float8Tensor) + ), "q/k/v must be Float8Tensors for FP8 MHA!" + fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv + q_fp8, k_fp8, v_fp8 = q, k, v + q, k, v = q_fp8._data, k_fp8._data, v_fp8._data + elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_f16, k_f16, v_f16 = q, k, v + q, k, v = [ + cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + for x in [q_f16, k_f16, v_f16] + ] + fp8_meta_kwargs = {} + fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv + fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV + fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv + fp8_meta_kwargs["d_scale_s_offset"] = META_S + fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale + fp8_meta_kwargs["q_scale_s_offset"] = META_S + fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale + fp8_meta_kwargs["q_scale_o_offset"] = META_O + fp8_meta_kwargs["amax_s"] = fp8_meta["scaling_fwd"].amax_history + fp8_meta_kwargs["amax_s_offset"] = META_S + fp8_meta_kwargs["amax_o"] = fp8_meta["scaling_fwd"].amax_history + fp8_meta_kwargs["amax_o_offset"] = META_O + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + if use_fused_attention: + fp8_meta_kwargs = {} + fused_attn_qkv_dtype = TE_DType[q.dtype] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, True) + q, k, v = flash_attn_a2a_communicate( + [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True + ) + + if fp8 and not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_f16, k_f16, v_f16 = q, k, v + q, k, v = [ + cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + for x in [q_f16, k_f16, v_f16] + ] + + batch_size = q.shape[batch_dim] + if use_fused_attention: + out, aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q, + k, + v, + fused_attn_qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + window_size=window_size, + **fp8_meta_kwargs, + ) + else: + # [b, cp*s, np//cp, hn] -> [b*cp*s, np//cp, hn] + q, k, v = [x.view(-1, *x.shape[-2:]) for x in [q, k, v]] + ( + _, + _, + _, + _, + out, + softmax_lse, + _, + rng_state, + ) = _flash_attn_forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + dropout_p, + softmax_scale, + causal=causal, + return_softmax=False, + **fa_optional_forward_kwargs, + ) + aux_ctx_tensors = [softmax_lse, rng_state] + # [b*cp*s, np//cp, hn] -> [b, cp*s, np//cp, hn] + out = out.view(batch_size, -1, *out.shape[-2:]) + + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, False) + out = flash_attn_a2a_communicate( + out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False + ) + + if use_fused_attention: + if qkv_format == "bshd": + # [b*s, np, hn] -> [b, s, np, hn] + out = out.view(batch_size, -1, *out.shape[-2:]) + elif qkv_format == "sbhd": + # [s*b, np, hn] -> [s, b, np, hn] + out = out.view(-1, batch_size, *out.shape[-2:]) + + if fp8: + if fp8_meta["recipe"].fp8_mha: + out_fp8 = Float8Tensor( + data=out, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_O, + fp8_dtype=fp8_dtype_forward, + dtype=q_fp8.dtype, + ) + out = out_fp8._data + out_ret = out_fp8 + else: + out_f16 = cast_from_fp8( + out, + fp8_meta["scaling_fwd"], + META_O, + fp8_dtype_forward, + TE_DType[q_f16.dtype], + ) + out_ret = out_f16 + else: + out_ret = out + + if fp8: + if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_save, k_save, v_save, out_save = q, k, v, out + elif fp8_meta["recipe"].fp8_mha: + q_fp8, k_fp8, v_fp8 = [ + Float8Tensor( + data=x, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_QKV, + fp8_dtype=fp8_dtype_forward, + dtype=out_fp8.dtype, + ) + for x in [q, k, v] + ] + q_save, k_save, v_save, out_save = q_fp8, k_fp8, v_fp8, out_fp8 + else: + q_save, k_save, v_save, out_save = q_f16, k_f16, v_f16, out_f16 + else: + q_save, k_save, v_save, out_save = q, k, v, out + + if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone() + fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone() + else: + fp8_fwd_scales, fp8_fwd_scale_invs = None, None + + ctx.save_for_backward( + q_save, + k_save, + v_save, + out_save, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + fp8_fwd_scales, + fp8_fwd_scale_invs, + *aux_ctx_tensors, + ) + ctx.batch_size = batch_size + ctx.cp_group = cp_group + ctx.cp_stream = cp_stream + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_kv = max_seqlen_kv + ctx.softmax_scale = softmax_scale + ctx.qkv_format = qkv_format + ctx.attn_mask_type = attn_mask_type + ctx.attn_bias_type = attn_bias_type + ctx.deterministic = deterministic + ctx.window_size = window_size + ctx.use_fused_attention = use_fused_attention + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.fp8_meta = fp8_meta + return out_ret + + @staticmethod + def backward(ctx, dout): + cp_size = get_distributed_world_size(ctx.cp_group) + + q, k, v, out = ctx.saved_tensors[:4] + cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = ctx.saved_tensors[ + 4:8 + ] + fp8_fwd_scales, fp8_fwd_scale_invs = ctx.saved_tensors[8:10] + aux_ctx_tensors = ctx.saved_tensors[10:] + + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + causal = "causal" in ctx.attn_mask_type + seq_dim = ctx.qkv_format.index("s") + + if ctx.fp8: + if ctx.use_fused_attention: + fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) + fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) + fused_attn_qkv_dtype = fp8_dtype_forward + fused_attn_dqkv_dtype = fp8_dtype_backward + fused_attn_backend = FusedAttnBackend["FP8"] + if ctx.fp8_meta["recipe"].fp8_mha: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv + dout_fp8 = dout + dout = dout_fp8._data + else: + dout_f16 = dout + dout = cast_to_fp8( + dout_f16, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward + ) + fp8_meta_kwargs = {} + fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV] + fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S] + fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O] + fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] + fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP] + fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S] + fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP] + fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV] + fp8_meta_kwargs["amax_dp"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP] + fp8_meta_kwargs["amax_dqkv"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][ + META_DQKV + ] + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + q, k, v, out, dout = [x.from_float8(x.dtype) for x in [q, k, v, out, dout]] + if ctx.use_fused_attention: + fp8_meta_kwargs = {} + fused_attn_qkv_dtype = TE_DType[q.dtype] + fused_attn_dqkv_dtype = TE_DType[dout.dtype] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + + if not ctx.use_fused_attention: + out = out.view(ctx.batch_size, -1, *out.shape[-2:]) + dout = dout.view(*out.shape) + + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, True) + out, dout = flash_attn_a2a_communicate( + [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True + ) + + fa_optional_backward_kwargs = {} + if _flash_attn_2_3_plus: + fa_optional_backward_kwargs["window_size"] = ctx.window_size + if _flash_attn_2_4_plus: + fa_optional_backward_kwargs["alibi_slopes"] = None + if _flash_attn_2_4_1_plus: + fa_optional_backward_kwargs["deterministic"] = ctx.deterministic + + if ctx.use_fused_attention: + dq, dk, dv, _ = fused_attn_bwd( + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q, + k, + v, + out, + dout, + fused_attn_qkv_dtype, + fused_attn_dqkv_dtype, + aux_ctx_tensors, + fused_attn_backend, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + attn_scale=ctx.softmax_scale, + dropout=ctx.dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=ctx.attn_mask_type, + attn_bias_type=ctx.attn_bias_type, + window_size=ctx.window_size, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, + ) + else: + softmax_lse, rng_state = aux_ctx_tensors + out, dout = [x.view(-1, *x.shape[-2:]) for x in [out, dout]] + dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_kv, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + ctx.dropout_p, + ctx.softmax_scale, + causal, + rng_state=rng_state, + **fa_optional_backward_kwargs, + ) + dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] + + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, False) + dq, dk, dv = flash_attn_a2a_communicate( + [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False + ) + if ctx.qkv_format == "bshd": - dq = dq.view(dq.shape[0], -1, *dq.shape[-2:]) - dk = dk.transpose(0, 1).contiguous() - dv = dv.transpose(0, 1).contiguous() + dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] elif ctx.qkv_format == "sbhd": - dq = dq.view(-1, *dq.shape[-3:]) + dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] + + if ctx.fp8: + if ctx.fp8_meta["recipe"].fp8_mha: + dq, dk, dv = [ + Float8Tensor( + data=x, + fp8_meta=ctx.fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=META_DQKV, + fp8_dtype=fp8_dtype_backward, + dtype=dout_fp8.dtype, + ) + for x in [dq, dk, dv] + ] + else: + dq, dk, dv = [ + cast_from_fp8( + x, + ctx.fp8_meta["scaling_bwd"], + META_DQKV, + fp8_dtype_backward, + TE_DType[dout_f16.dtype], + ) + for x in [dq, dk, dv] + ] return ( None, @@ -3404,6 +3932,9 @@ def backward(ctx, dout): None, None, None, + None, + None, + None, ) @@ -3465,57 +3996,44 @@ def attn_forward_func_with_cp( sliding_window_attn = ( window_size is not None and window_size != (-1, 0) and window_size != (-1, -1) ) + assert ( + not sliding_window_attn + or cp_comm_type == "a2a" + or (cp_comm_type == "all_gather" and not use_fused_attention) + ), "The context parallel running configs cannot support sliding window attetnion!" - if sliding_window_attn or cp_comm_type == "all_gather": - out = AttnFuncWithCPAndKVAllGather.apply( - is_training, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - dropout_p, - cp_group, - cp_stream, - softmax_scale, - qkv_format, - attn_mask_type, - attn_bias_type, - attn_bias, - deterministic, - use_fused_attention, - window_size, - ) - elif cp_comm_type == "p2p": - out = AttnFuncWithCPAndKVP2P.apply( - is_training, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - dropout_p, - cp_group, - cp_global_ranks, - cp_stream, - softmax_scale, - qkv_format, - attn_mask_type, - attn_bias_type, - attn_bias, - deterministic, - use_fused_attention, - fp8, - fp8_meta, - ) + args = [ + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + dropout_p, + softmax_scale, + qkv_format, + attn_mask_type, + attn_bias_type, + attn_bias, + deterministic, + use_fused_attention, + ] + + if cp_comm_type == "p2p": + args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream] + out = AttnFuncWithCPAndKVP2P.apply(*args) + elif cp_comm_type == "all_gather": + args.pop(5) + args.pop(8) + args += [window_size, cp_group, cp_stream] + out = AttnFuncWithCPAndKVAllGather.apply(*args) + elif cp_comm_type == "a2a": + args += [window_size, fp8, fp8_meta, cp_group, cp_stream] + out = AttnFuncWithCPAndQKVOA2A.apply(*args) else: raise ValueError(f"Unsupported communication type: {cp_comm_type}!") @@ -6416,7 +6934,13 @@ class DotProductAttention(TransformerEngineBaseModule): can overlap two flash attention kernels. cp_comm_type : str inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather". + Can be "p2p" or "all_gather" or "a2a". + "p2p": Exchange KV chunks with P2P communications in ring topology. + P2P is async and can be overlapped with attention compute. + "all_gather": All-gather to get full sequence of KV before attention. + The all-gather is not async, and cannot be overlapped. + "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP + group, and gather to get full sequence of QKV. """ def __init__( @@ -6608,7 +7132,13 @@ def set_context_parallel_group( cuda stream for context parallel execution. cp_comm_type : str inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather". + Can be "p2p" or "all_gather" or "a2a". + "p2p": Exchange KV chunks with P2P communications in ring topology. + P2P is async and can be overlapped with attention compute. + "all_gather": All-gather to get full sequence of KV before attention. + The all-gather is not async, and cannot be overlapped. + "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP + group, and gather to get full sequence of QKV. """ self.cp_group = cp_group self.cp_global_ranks = cp_global_ranks @@ -7633,7 +8163,13 @@ def set_context_parallel_group( cuda stream for context parallel execution. cp_comm_type : str inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather". + Can be "p2p" or "all_gather" or "a2a". + "p2p": Exchange KV chunks with P2P communications in ring topology. + P2P is async and can be overlapped with attention compute. + "all_gather": All-gather to get full sequence of KV before attention. + The all-gather is not async, and cannot be overlapped. + "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP + group, and gather to get full sequence of QKV. """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index bd6e27594d..958c7019ba 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -503,7 +503,13 @@ def set_context_parallel_group( cuda stream for context parallel execution. cp_comm_type : str inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather". + Can be "p2p" or "all_gather" or "a2a". + "p2p": Exchange KV chunks with P2P communications in ring topology. + P2P is async and can be overlapped with attention compute. + "all_gather": All-gather to get full sequence of KV before attention. + The all-gather is not async, and cannot be overlapped. + "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP + group, and gather to get full sequence of QKV. """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()):