From 59b99ca20cda8c0a47d9ed0381ba64ec73ea3955 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Mon, 12 Aug 2024 20:30:36 -0700 Subject: [PATCH 01/17] fp8 mha with rope Signed-off-by: Xin Yao --- transformer_engine/pytorch/attention.py | 61 ++++++++++--------- .../pytorch/module/layernorm_linear.py | 8 ++- 2 files changed, 39 insertions(+), 30 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index b2fb22c8fc..656d534eac 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3560,19 +3560,19 @@ def forward( fp8_meta, deterministic, ): + is_input_fp8 = False if fp8: - if fp8_meta["recipe"].fp8_mha: - assert isinstance(qkv, Float8Tensor), "qkv must be Float8Tensors for FP8 MHA." + is_input_fp8 = isinstance(qkv, Float8Tensor) + if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) # 1: qkv packed, 2: kv packed, 3: qkv separate qkv_group = len(qkv_layout.split("_")) - assert qkv_group == 1, ( - "qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found" - f" {qkv_layout}." - ) - if fp8_meta["recipe"].fp8_mha: + assert ( + qkv_group == 1 + ), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found {qkv_layout}." + if is_input_fp8: qkv_fp8 = qkv._data else: qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) @@ -3621,7 +3621,7 @@ def forward( qkv_dtype, ).view(out_fp8.shape) out_save = out_ret - if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + if is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) qkv = cast_from_fp8( qkv_c._data, @@ -3672,6 +3672,7 @@ def forward( out_save = out_ret ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.is_input_fp8 = is_input_fp8 qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None) ctx.save_for_backward( *qkvo_tensors, cu_seqlens, cu_seqlens_padded, *fp8_tensors, *aux_ctx_tensors @@ -3793,7 +3794,7 @@ def backward(ctx, d_out): ctx.window_size, ctx.deterministic, ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_input_fp8: dqkv = Float8Tensor( data=dqkv_fp8, fp8_meta=ctx.fp8_meta, @@ -3931,22 +3932,22 @@ def forward( fp8_meta, deterministic, ): + is_input_fp8 = False if fp8: - if fp8_meta["recipe"].fp8_mha: - assert isinstance(q, Float8Tensor) and isinstance( - kv, Float8Tensor - ), "q/kv must be Float8Tensors for FP8 MHA." + assert type(q) == type(kv), "q and kv must have the same type." + is_input_fp8 = isinstance(q, Float8Tensor) + if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if fp8_meta["recipe"].fp8_mha: + if is_input_fp8: q_fp8, kv_fp8 = q._data, kv._data else: # 1: qkv packed, 2: kv packed, 3: qkv separate qkv_group = len(qkv_layout.split("_")) assert qkv_group == 2, ( - "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, " - f" but found {qkv_layout}." + "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, " + f"but found {qkv_layout}." ) q_fp8 = cast_to_fp8(q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward).view( q.shape @@ -4001,7 +4002,7 @@ def forward( qkv_dtype, ).view(out_fp8.shape) out_save = out_ret - if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + if is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q = cast_from_fp8( q._data, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, TE_DType[q.dtype] ).view(q.shape) @@ -4060,6 +4061,7 @@ def forward( fp8_tensors = (None, None, None, None, None) ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.is_input_fp8 = is_input_fp8 qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None) ctx.save_for_backward( *qkvo_tensors, @@ -4196,7 +4198,7 @@ def backward(ctx, d_out): ctx.window_size, ctx.deterministic, ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_input_fp8: dq = Float8Tensor( data=dq_fp8, fp8_meta=ctx.fp8_meta, @@ -4362,15 +4364,13 @@ def forward( fp8_meta, deterministic, ): + is_input_fp8 = False if fp8: fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if fp8_meta["recipe"].fp8_mha: - assert ( - isinstance(q, Float8Tensor) - and isinstance(k, Float8Tensor) - and isinstance(v, Float8Tensor) - ), "q/k/v must be Float8Tensors for FP8 MHA." + assert type(q) == type(k) == type(v), "q, k, and v must have the same type." + is_input_fp8 = isinstance(q, Float8Tensor) + if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data else: @@ -4455,7 +4455,7 @@ def forward( ).view(out_fp8.shape) out_save = out_ret - if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + if is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): # 1: qkv packed, 2: kv packed, 3: qkv separate qkv_group = len(qkv_layout.split("_")) if qkv_group == 1: @@ -4572,6 +4572,7 @@ def forward( tensor.activation_offloading = True ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.is_input_fp8 = is_input_fp8 qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) ctx.save_for_backward( *qkvo_tensors, @@ -4714,7 +4715,7 @@ def backward(ctx, d_out): ctx.deterministic, ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_input_fp8: dq = Float8Tensor( data=dq_fp8, fp8_meta=ctx.fp8_meta, @@ -6592,6 +6593,7 @@ def forward( layernorm_qkv_outputs = self.layernorm_qkv( hidden_states, is_first_microbatch=is_first_microbatch, + is_first_module_in_mha=rotary_pos_emb is None, # specific to FP8 MHA ) if self.return_layernorm_output: mixed_x_layer, layernorm_output = layernorm_qkv_outputs @@ -6601,7 +6603,7 @@ def forward( mixed_x_layer = self.qkv( hidden_states, is_first_microbatch=is_first_microbatch, - is_first_module_in_mha=True, # specific to FP8 MHA + is_first_module_in_mha=rotary_pos_emb is None, # specific to FP8 MHA ) num_queries_per_key_value = ( @@ -6657,7 +6659,7 @@ def forward( mixed_kv_layer = self.key_value( encoder_output, is_first_microbatch=is_first_microbatch, - is_first_module_in_mha=True, # specific to FP8 MHA + is_first_module_in_mha=rotary_pos_emb is None, # specific to FP8 MHA ) if self.qkv_weight_interleaved: @@ -6707,6 +6709,7 @@ def forward( layernorm_query_outputs = self.layernorm_query( hidden_states, is_first_microbatch=is_first_microbatch, + is_first_module_in_mha=rotary_pos_emb is None, # specific to FP8 MHA ) if self.return_layernorm_output: query_layer, layernorm_output = layernorm_query_outputs @@ -6716,7 +6719,7 @@ def forward( query_layer = self.query_layer( hidden_states, is_first_microbatch=is_first_microbatch, - is_first_module_in_mha=True, # specific to FP8 MHA + is_first_module_in_mha=rotary_pos_emb is None, # specific to FP8 MHA ) # [sq, b, hp] --> [sq, b, np, hn] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 10560cdad6..37eb1d84db 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -90,6 +90,7 @@ def forward( ub_overlap_rs_dgrad: bool, ub_overlap_ag: bool, ub_name: str, + is_first_module_in_mha: bool, fsdp_group: Union[dist_group_type, None], ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible @@ -199,7 +200,7 @@ def forward( assert isinstance(weight_fp8, Float8Tensor) - if fp8_meta["recipe"].fp8_mha: + if is_first_module_in_mha: out_index, meta_tensor, output_te_dtype, output_dtype = ( tex.FP8FwdTensors.GEMM1_OUTPUT, fp8_meta["scaling_fwd"], @@ -744,6 +745,7 @@ def backward( None, # ub_overlap_rs_dgrad None, # ub_overlap_ag None, # ub_name + None, # is_first_module_in_mha None, # fsdp_group ) @@ -1096,6 +1098,7 @@ def forward( self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None, + is_first_module_in_mha: Optional[bool] = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply layer normalization to the input followed by a linear transformation. @@ -1125,6 +1128,8 @@ def forward( with self.prepare_forward(inp, is_first_microbatch) as inp: + is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha + # Get concatenated weight and bias tensors unfused_weights = [getattr(self, name) for name in self.weight_names] if any(isinstance(w, Float8Tensor) for w in unfused_weights): @@ -1223,6 +1228,7 @@ def forward( self.ub_overlap_rs_dgrad, self.ub_overlap_ag, self.ub_name, + is_first_module_in_mha, self.fsdp_group, ) out = fwd_fn(*args) From c46f82cc511fc56f445b24562241d5b0719298fc Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Tue, 13 Aug 2024 19:26:33 -0700 Subject: [PATCH 02/17] avoid index select in cast ops Signed-off-by: Xin Yao --- transformer_engine/pytorch/attention.py | 4 +-- transformer_engine/pytorch/csrc/extensions.h | 11 ++++-- .../pytorch/csrc/extensions/cast.cu | 36 +++++++++++++------ .../pytorch/csrc/extensions/pybind.cpp | 12 +++++-- transformer_engine/pytorch/csrc/ts_fp8_op.cpp | 8 ++--- 5 files changed, 48 insertions(+), 23 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 656d534eac..130b8bd224 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3152,8 +3152,8 @@ def run_iteratively(q, k, v): stride = q.stride() check_strides_qkv = all(stride == x.stride() for x in [q, k, v]) stride = k.stride() - check_strides_kv = torch.equal( - torch.Tensor(stride[:-1]) / k.shape[-1], torch.Tensor(v.stride()[:-1]) / v.shape[-1] + check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple( + sv / v.shape[-1] for sv in v.stride()[:-1] ) shape = q.shape diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index cd5bda8b63..64e6909d71 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -315,13 +315,18 @@ at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, fl **************************************************************************************************/ at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, - at::Tensor scale_inv, transformer_engine::DType otype); + at::Tensor scale_inv, transformer_engine::DType otype, + const int scale_offset = 0, const int amax_offset = 0, + const int scale_inv_offset = 0); void cast_to_fp8_noalloc(const at::Tensor &input, const at::Tensor &scale, at::Tensor output, - at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype); + at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, + const int scale_offset = 0, const int amax_offset = 0, + const int scale_inv_offset = 0); at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, - transformer_engine::DType itype, transformer_engine::DType otype); + transformer_engine::DType itype, transformer_engine::DType otype, + const int scale_inv_offset = 0); /*************************************************************************************************** * Softmax diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cu b/transformer_engine/pytorch/csrc/extensions/cast.cu index c783c9d988..47f5825866 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cu +++ b/transformer_engine/pytorch/csrc/extensions/cast.cu @@ -6,8 +6,9 @@ #include "extensions.h" -at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, - at::Tensor scale_inv, transformer_engine::DType otype) { +at::Tensor cast_to_fp8(const at::Tensor& input, const at::Tensor& scale, at::Tensor amax, + at::Tensor scale_inv, transformer_engine::DType otype, + const int scale_offset, const int amax_offset, const int scale_inv_offset) { using namespace transformer_engine; auto input_shape = input.sizes().vec(); std::vector shape{input_shape.begin(), input_shape.end()}; @@ -16,32 +17,45 @@ at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Ten if (input.numel() == 0) return output; + // Get pointers for FP8 scale, amax, scale-inverse + void* scale_dptr = getDataPtr(scale, scale_offset); + void* amax_dptr = getDataPtr(amax, amax_offset); + void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + auto input_cu = makeTransformerEngineTensor(input); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, amax_dptr, + scale_dptr, scale_inv_dptr); nvte_fp8_quantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return output; } -void cast_to_fp8_noalloc(const at::Tensor &input, const at::Tensor &scale, at::Tensor output, - at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype) { +void cast_to_fp8_noalloc(const at::Tensor& input, const at::Tensor& scale, at::Tensor output, + at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, + const int scale_offset, const int amax_offset, + const int scale_inv_offset) { using namespace transformer_engine; size_t N = static_cast(input.size(0)); size_t H = static_cast(input.size(1)); + // Get pointers for FP8 scale, amax, scale-inverse + void* scale_dptr = getDataPtr(scale, scale_offset); + void* amax_dptr = getDataPtr(amax, amax_offset); + void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + auto input_cu = makeTransformerEngineTensor(input); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, amax_dptr, + scale_dptr, scale_inv_dptr); nvte_fp8_quantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return; } -at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, - transformer_engine::DType itype, transformer_engine::DType otype) { +at::Tensor cast_from_fp8(const at::Tensor& input, const at::Tensor& scale_inv, + transformer_engine::DType itype, transformer_engine::DType otype, + const int scale_inv_offset) { using namespace transformer_engine; auto input_shape = input.sizes().vec(); std::vector shape{input_shape.begin(), input_shape.end()}; @@ -49,7 +63,7 @@ at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); auto input_cu = makeTransformerEngineTensor(input.data_ptr(), shape, itype, nullptr, nullptr, - scale_inv.data_ptr()); + getDataPtr(scale_inv, scale_inv_offset)); auto output_cu = makeTransformerEngineTensor(output); nvte_fp8_dequantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index c97c66dd98..d0f470c76a 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -87,10 +87,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc, "Fused Multi-tensor Cast + Transpose with allocating output tensors", py::call_guard()); - m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard()); + m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard(), + py::arg("input"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), + py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8", - py::call_guard()); - m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard()); + py::call_guard(), py::arg("input"), py::arg("scale"), + py::arg("output"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), + py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); + m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard(), + py::arg("input"), py::arg("scale_inv"), py::arg("itype"), py::arg("otype"), + py::arg("scale_inv_offset") = 0); m.def("te_gemm", &te_gemm, "CublasLt GEMM"); /// TODO Think m.def("te_grouped_gemm", &te_grouped_gemm, "Grouped GEMM"); m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked, diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp index 8515092ae0..8c480e8343 100644 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -26,7 +26,7 @@ at::Tensor cast_to_fp8_ts(const at::Tensor &input, const at::Tensor &scale, at:: at::Tensor scale_inv, int64_t fp8_tensor, int64_t otype) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); at::Tensor output = - cast_to_fp8(input, scale[fp8_tensor], amax[0][fp8_tensor], scale_inv[fp8_tensor], otype_arg); + cast_to_fp8(input, scale, amax, scale_inv, otype_arg, fp8_tensor, fp8_tensor, fp8_tensor); return output; } @@ -34,8 +34,8 @@ at::Tensor cast_to_fp8_noalloc_ts(const at::Tensor &input, const at::Tensor &sca at::Tensor output, at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor, int64_t otype) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); - cast_to_fp8_noalloc(input, scale[fp8_tensor], output, amax[0][fp8_tensor], scale_inv[fp8_tensor], - otype_arg); + cast_to_fp8_noalloc(input, scale, output, amax, scale_inv, otype_arg, fp8_tensor, fp8_tensor, + fp8_tensor); return output; } @@ -43,7 +43,7 @@ at::Tensor cast_from_fp8_ts(const at::Tensor &input, const at::Tensor &scale_inv int64_t fp8_tensor, int64_t itype, int64_t otype) { transformer_engine::DType itype_arg = reverse_map_dtype(itype); transformer_engine::DType otype_arg = reverse_map_dtype(otype); - at::Tensor output = cast_from_fp8(input, scale_inv[fp8_tensor], itype_arg, otype_arg); + at::Tensor output = cast_from_fp8(input, scale_inv, itype_arg, otype_arg, fp8_tensor); return output; } From dafd73f0a0e621381df74558dc30eb66eaf1356a Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Tue, 13 Aug 2024 22:11:31 -0700 Subject: [PATCH 03/17] avoid index select in fused_attn_fwd Signed-off-by: Xin Yao --- transformer_engine/pytorch/attention.py | 66 +++---- .../pytorch/cpp_extensions/fused_attn.py | 187 +++++++++--------- transformer_engine/pytorch/csrc/extensions.h | 25 +-- .../pytorch/csrc/extensions/attention.cu | 85 ++++---- 4 files changed, 174 insertions(+), 189 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 130b8bd224..25cbbf5fbb 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -37,6 +37,12 @@ AttnBiasType, AttnMaskType, FusedAttnBackend, + META_QKV, + META_DQKV, + META_O, + META_DO, + META_S, + META_DP, ) from transformer_engine.pytorch.fp8 import get_fp8_te_dtype from transformer_engine.pytorch.float8_tensor import Float8Tensor @@ -87,12 +93,6 @@ from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd -META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT -META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 -META_O = tex.FP8FwdTensors.GEMM2_INPUT -META_DO = tex.FP8BwdTensors.GRAD_INPUT2 -META_S = tex.FP8FwdTensors.GEMM3_OUTPUT -META_DP = tex.FP8BwdTensors.GRAD_INPUT3 # NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) @@ -3588,12 +3588,12 @@ def forward( fused_attention_backend, attn_bias, cu_seqlens_padded, - fp8_meta["scaling_fwd"].scale_inv[META_QKV], - fp8_meta["scaling_fwd"].scale_inv[META_S], - fp8_meta["scaling_fwd"].scale[META_S], - fp8_meta["scaling_fwd"].scale[META_O], - fp8_meta["scaling_fwd"].amax_history[0][META_S], - fp8_meta["scaling_fwd"].amax_history[0][META_O], + fp8_meta["scaling_fwd"].scale_inv, + fp8_meta["scaling_fwd"].scale, + fp8_meta["scaling_fwd"].amax_history, + META_QKV, + META_S, + META_O, attn_scale, dropout_p, fast_zero_fill, @@ -3656,9 +3656,9 @@ def forward( None, None, None, - None, - None, - None, + 0, + 0, + 0, attn_scale, dropout_p, fast_zero_fill, @@ -3969,12 +3969,12 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fp8_meta["scaling_fwd"].scale_inv[META_QKV], - fp8_meta["scaling_fwd"].scale_inv[META_S], - fp8_meta["scaling_fwd"].scale[META_S], - fp8_meta["scaling_fwd"].scale[META_O], - fp8_meta["scaling_fwd"].amax_history[0][META_S], - fp8_meta["scaling_fwd"].amax_history[0][META_O], + fp8_meta["scaling_fwd"].scale_inv, + fp8_meta["scaling_fwd"].scale, + fp8_meta["scaling_fwd"].amax_history, + META_QKV, + META_S, + META_O, attn_scale, dropout_p, fast_zero_fill, @@ -4045,9 +4045,9 @@ def forward( None, None, None, - None, - None, - None, + 0, + 0, + 0, attn_scale, dropout_p, fast_zero_fill, @@ -4421,12 +4421,12 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fp8_meta["scaling_fwd"].scale_inv[META_QKV], - fp8_meta["scaling_fwd"].scale_inv[META_S], - fp8_meta["scaling_fwd"].scale[META_S], - fp8_meta["scaling_fwd"].scale[META_O], - fp8_meta["scaling_fwd"].amax_history[0][META_S], - fp8_meta["scaling_fwd"].amax_history[0][META_O], + fp8_meta["scaling_fwd"].scale_inv, + fp8_meta["scaling_fwd"].scale, + fp8_meta["scaling_fwd"].amax_history, + META_QKV, + META_S, + META_O, attn_scale, dropout_p, fast_zero_fill, @@ -4547,9 +4547,9 @@ def forward( None, None, None, - None, - None, - None, + 0, + 0, + 0, attn_scale, dropout_p, fast_zero_fill, diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index d0ba644621..6c4e01ce45 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -78,6 +78,13 @@ BACKEND_F16m512_FP8_THREADS_PER_CTA = 128 BACKEND_F16arb_ELTS_PER_THREADS = 16 +META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT +META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 +META_O = tex.FP8FwdTensors.GEMM2_INPUT +META_DO = tex.FP8BwdTensors.GRAD_INPUT2 +META_S = tex.FP8FwdTensors.GEMM3_OUTPUT +META_DP = tex.FP8BwdTensors.GRAD_INPUT3 + def fused_attn_fwd_qkvpacked( is_training: bool, @@ -88,12 +95,12 @@ def fused_attn_fwd_qkvpacked( fused_attention_backend: tex.NVTE_Fused_Attn_Backend, attn_bias: torch.Tensor = None, cu_seqlens_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_s: torch.Tensor = None, - q_scale_s: torch.Tensor = None, - q_scale_o: torch.Tensor = None, - amax_s: torch.Tensor = None, - amax_o: torch.Tensor = None, + d_scale: torch.Tensor = None, + q_scale: torch.Tensor = None, + amax: torch.Tensor = None, + offset_QKV: int = META_QKV, + offset_S: int = META_S, + offset_O: int = META_O, attn_scale: float = None, dropout: float = 0.0, fast_zero_fill: bool = True, @@ -126,18 +133,19 @@ def fused_attn_fwd_qkvpacked( shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv cu_seqlens_padded: torch.Tensor, default = None cumulative sequence offsets for QKV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of QKV in FP8 computations - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) - q_scale_o: torch.Tensor, default = None - input tensor for the quantization of O in FP8 computations - amax_s: torch.Tensor, default = None - output tensor, amax of S, used by the next iteration in FP8 computations - amax_o: torch.Tensor, default = None - output tensor, amax of O, used by the next iteration in FP8 computations + d_scale: torch.Tensor, default = None + input tensor for the dequantization of Q, K, V and S in FP8 computations, + S = Softmax(Q * K.T) + q_scale: torch.Tensor, default = None + input tensor for the quantization of S and O in FP8 computations, S = Softmax(Q * K.T) + amax: torch.Tensor, default = None + output tensor, amax of S and O, used by the next iteration in FP8 computations + offset_QKV: int, default = 0 + QKV offset in d_scale + offset_S: int, default = 0 + S offset in d_scale, q_scale and amax + offset_O: int, default = 0 + O offset in q_scale and amax attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; if None, use 1.0/sqrt(head_dim_qk) as the default @@ -223,14 +231,9 @@ def fused_attn_fwd_qkvpacked( max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - assert ( - d_scale_qkv is not None - ), "d_scale_qkv is required as an input for FP8 fused attention." - assert d_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." - assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." - assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." + assert d_scale is not None, "d_scale is required as an input for FP8 fused attention." + assert q_scale is not None, "q_scale is required as an input for FP8 fused attention." + assert amax is not None, "amax is required as an input for FP8 fused attention." # execute kernel output_tensors = tex.fused_attn_fwd_qkvpacked( @@ -247,12 +250,12 @@ def fused_attn_fwd_qkvpacked( qkv, qkv_dtype, cu_seqlens_padded, - d_scale_qkv, - d_scale_s, - q_scale_s, - q_scale_o, - amax_s, - amax_o, + d_scale, + q_scale, + amax, + offset_QKV, + offset_S, + offset_O, attn_bias, rng_gen, rng_elts_per_thread, @@ -447,12 +450,12 @@ def fused_attn_fwd_kvpacked( attn_bias: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_s: torch.Tensor = None, - q_scale_s: torch.Tensor = None, - q_scale_o: torch.Tensor = None, - amax_s: torch.Tensor = None, - amax_o: torch.Tensor = None, + d_scale: torch.Tensor = None, + q_scale: torch.Tensor = None, + amax: torch.Tensor = None, + offset_QKV: int = META_QKV, + offset_S: int = META_S, + offset_O: int = META_O, attn_scale: float = None, dropout: float = 0.0, fast_zero_fill: bool = True, @@ -494,18 +497,19 @@ def fused_attn_fwd_kvpacked( cumulative sequence offsets for Q; shape [batch_size + 1] cu_seqlens_kv_padded: torch.Tensor, default = None cumulative sequence offsets for KV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of QKV in FP8 computations - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) - q_scale_o: torch.Tensor, default = None - input tensor for the quantization of O in FP8 computations - amax_s: torch.Tensor, default = None - output tensor, amax of S, used by the next iteration in FP8 computations - amax_o: torch.Tensor, default = None - output tensor, amax of O, used by the next iteration in FP8 computations + d_scale: torch.Tensor, default = None + input tensor for the dequantization of Q, K, V and S in FP8 computations, + S = Softmax(Q * K.T) + q_scale: torch.Tensor, default = None + input tensor for the quantization of S and O in FP8 computations, S = Softmax(Q * K.T) + amax: torch.Tensor, default = None + output tensor, amax of S and O, used by the next iteration in FP8 computations + offset_QKV: int, default = 0 + QKV offset in d_scale + offset_S: int, default = 0 + S offset in d_scale, q_scale and amax + offset_O: int, default = 0 + O offset in q_scale and amax attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; if None, use 1.0/sqrt(head_dim_qk) as the default @@ -592,14 +596,9 @@ def fused_attn_fwd_kvpacked( max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - assert ( - d_scale_qkv is not None - ), "d_scale_qkv is required as an input for FP8 fused attention." - assert d_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." - assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." - assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." + assert d_scale is not None, "d_scale is required as an input for FP8 fused attention." + assert q_scale is not None, "q_scale is required as an input for FP8 fused attention." + assert amax is not None, "amax is required as an input for FP8 fused attention." # execute kernel output_tensors = tex.fused_attn_fwd_kvpacked( @@ -620,12 +619,12 @@ def fused_attn_fwd_kvpacked( qkv_dtype, cu_seqlens_q_padded, cu_seqlens_kv_padded, - d_scale_qkv, - d_scale_s, - q_scale_s, - q_scale_o, - amax_s, - amax_o, + d_scale, + q_scale, + amax, + offset_QKV, + offset_S, + offset_O, attn_bias, rng_gen, rng_elts_per_thread, @@ -842,12 +841,12 @@ def fused_attn_fwd( attn_bias: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_s: torch.Tensor = None, - q_scale_s: torch.Tensor = None, - q_scale_o: torch.Tensor = None, - amax_s: torch.Tensor = None, - amax_o: torch.Tensor = None, + d_scale: torch.Tensor = None, + q_scale: torch.Tensor = None, + amax: torch.Tensor = None, + offset_QKV: int = META_QKV, + offset_S: int = META_S, + offset_O: int = META_O, attn_scale: float = None, dropout: float = 0.0, fast_zero_fill: bool = True, @@ -893,18 +892,19 @@ def fused_attn_fwd( cumulative sequence offsets for Q; shape [batch_size + 1] cu_seqlens_kv_padded: torch.Tensor, default = None cumulative sequence offsets for KV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of Q, K and V in FP8 computations - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) - q_scale_o: torch.Tensor, default = None - input tensor for the quantization of O in FP8 computations - amax_s: torch.Tensor, default = None - output tensor, amax of S, used by the next iteration in FP8 computations - amax_o: torch.Tensor, default = None - output tensor, amax of O, used by the next iteration in FP8 computations + d_scale: torch.Tensor, default = None + input tensor for the dequantization of Q, K, V and S in FP8 computations, + S = Softmax(Q * K.T) + q_scale: torch.Tensor, default = None + input tensor for the quantization of S and O in FP8 computations, S = Softmax(Q * K.T) + amax: torch.Tensor, default = None + output tensor, amax of S and O, used by the next iteration in FP8 computations + offset_QKV: int, default = 0 + QKV offset in d_scale + offset_S: int, default = 0 + S offset in d_scale, q_scale and amax + offset_O: int, default = 0 + O offset in q_scale and amax attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; if None, use 1.0/sqrt(head_dim_qk) as the default @@ -993,14 +993,9 @@ def fused_attn_fwd( max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - assert ( - d_scale_qkv is not None - ), "d_scale_qkv is required as an input for FP8 fused attention." - assert d_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." - assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." - assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." + assert d_scale is not None, "d_scale is required as an input for FP8 fused attention." + assert q_scale is not None, "q_scale is required as an input for FP8 fused attention." + assert amax is not None, "amax is required as an input for FP8 fused attention." # execute kernel output_tensors = tex.fused_attn_fwd( @@ -1022,12 +1017,12 @@ def fused_attn_fwd( qkv_dtype, cu_seqlens_q_padded, cu_seqlens_kv_padded, - d_scale_qkv, - d_scale_s, - q_scale_s, - q_scale_o, - amax_s, - amax_o, + d_scale, + q_scale, + amax, + offset_QKV, + offset_S, + offset_O, attn_bias, rng_gen, rng_elts_per_thread, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 64e6909d71..dc8bb8b1a3 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -28,9 +28,8 @@ std::vector fused_attn_fwd_qkvpacked( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, + const c10::optional descale, const c10::optional scale, + c10::optional amax, const int offset_QKV, const int offset_S, const int offset_O, const c10::optional Bias, const c10::optional rng_gen, size_t rng_elts_per_thread); @@ -54,12 +53,10 @@ std::vector fused_attn_fwd_kvpacked( const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread); + const c10::optional cu_seqlens_kv_padded, const c10::optional descale, + const c10::optional scale, c10::optional amax, const int offset_QKV, + const int offset_S, const int offset_O, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread); std::vector fused_attn_bwd_kvpacked( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, @@ -83,12 +80,10 @@ std::vector fused_attn_fwd( const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread); + const c10::optional cu_seqlens_kv_padded, const c10::optional descale, + const c10::optional scale, c10::optional amax, const int offset_QKV, + const int offset_S, const int offset_O, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread); std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 50eb7b830f..3daa4c317e 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -83,9 +83,8 @@ std::vector fused_attn_fwd_qkvpacked( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, + const c10::optional descale, const c10::optional scale, + c10::optional amax, const int offset_QKV, const int offset_S, const int offset_O, const c10::optional Bias, const c10::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; @@ -116,17 +115,17 @@ std::vector fused_attn_fwd_qkvpacked( } else { O.fill_(0); } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!scale_S.has_value()) || - (!scale_O.has_value()) || (!amax_S.has_value()) || (!amax_O.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); + if ((!descale.has_value()) || (!scale.has_value()) || (!amax.has_value())) { + NVTE_ERROR("descale, scale, and amax are required for FP8 operation. \n"); } te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), - scale_O.value().data_ptr(), nullptr); + getDataPtr(descale.value(), offset_QKV)); + te_S = makeTransformerEngineTensor( + nullptr, {0}, DType::kFloat32, getDataPtr(amax.value(), offset_S), + getDataPtr(scale.value(), offset_S), getDataPtr(descale.value(), offset_S)); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, + getDataPtr(amax.value(), offset_O), + getDataPtr(scale.value(), offset_O), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { O.fill_(0); @@ -392,12 +391,10 @@ std::vector fused_attn_fwd_kvpacked( const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread) { + const c10::optional cu_seqlens_kv_padded, const c10::optional descale, + const c10::optional scale, c10::optional amax, const int offset_QKV, + const int offset_S, const int offset_O, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; auto q_sizes = Q.sizes().vec(); @@ -423,19 +420,19 @@ std::vector fused_attn_fwd_kvpacked( } else { O.fill_(0); } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!scale_S.has_value()) || - (!scale_O.has_value()) || (!amax_S.has_value()) || (!amax_O.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); + if ((!descale.has_value()) || (!scale.has_value()) || (!amax.has_value())) { + NVTE_ERROR("descale, scale, and amax are required for FP8 operation. \n"); } te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); + getDataPtr(descale.value(), offset_QKV)); te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), - scale_O.value().data_ptr(), nullptr); + getDataPtr(descale.value(), offset_QKV)); + te_S = makeTransformerEngineTensor( + nullptr, {0}, DType::kFloat32, getDataPtr(amax.value(), offset_S), + getDataPtr(scale.value(), offset_S), getDataPtr(descale.value(), offset_S)); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, + getDataPtr(amax.value(), offset_O), + getDataPtr(scale.value(), offset_O), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { O.fill_(0); @@ -746,12 +743,10 @@ std::vector fused_attn_fwd( const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread) { + const c10::optional cu_seqlens_kv_padded, const c10::optional descale, + const c10::optional scale, c10::optional amax, const int offset_QKV, + const int offset_S, const int offset_O, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; auto q_sizes = Q.sizes().vec(); @@ -782,21 +777,21 @@ std::vector fused_attn_fwd( } else { O.fill_(0); } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!scale_S.has_value()) || - (!scale_O.has_value()) || (!amax_S.has_value()) || (!amax_O.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); + if ((!descale.has_value()) || (!scale.has_value()) || (!amax.has_value())) { + NVTE_ERROR("descale, scale, and amax are required for FP8 operation. \n"); } te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); + getDataPtr(descale.value(), offset_QKV)); te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); + getDataPtr(descale.value(), offset_QKV)); te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, amax_O.value().data_ptr(), - scale_O.value().data_ptr(), nullptr); + getDataPtr(descale.value(), offset_QKV)); + te_S = makeTransformerEngineTensor( + nullptr, {0}, DType::kFloat32, getDataPtr(amax.value(), offset_S), + getDataPtr(scale.value(), offset_S), getDataPtr(descale.value(), offset_S)); + te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, + getDataPtr(amax.value(), offset_O), + getDataPtr(scale.value(), offset_O), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { O.fill_(0); From 0d2ff345f6a6c59f410696e6a790b657eddf4ff7 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Tue, 13 Aug 2024 22:33:56 -0700 Subject: [PATCH 04/17] rename is_first_module_in_mha to fp8_output Signed-off-by: Xin Yao --- transformer_engine/pytorch/attention.py | 10 +++++----- .../pytorch/module/layernorm_linear.py | 12 ++++++------ transformer_engine/pytorch/module/linear.py | 14 +++++++------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 25cbbf5fbb..ba7b687b49 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -6593,7 +6593,7 @@ def forward( layernorm_qkv_outputs = self.layernorm_qkv( hidden_states, is_first_microbatch=is_first_microbatch, - is_first_module_in_mha=rotary_pos_emb is None, # specific to FP8 MHA + fp8_output=rotary_pos_emb is None, # specific to FP8 MHA ) if self.return_layernorm_output: mixed_x_layer, layernorm_output = layernorm_qkv_outputs @@ -6603,7 +6603,7 @@ def forward( mixed_x_layer = self.qkv( hidden_states, is_first_microbatch=is_first_microbatch, - is_first_module_in_mha=rotary_pos_emb is None, # specific to FP8 MHA + fp8_output=rotary_pos_emb is None, # specific to FP8 MHA ) num_queries_per_key_value = ( @@ -6659,7 +6659,7 @@ def forward( mixed_kv_layer = self.key_value( encoder_output, is_first_microbatch=is_first_microbatch, - is_first_module_in_mha=rotary_pos_emb is None, # specific to FP8 MHA + fp8_output=rotary_pos_emb is None, # specific to FP8 MHA ) if self.qkv_weight_interleaved: @@ -6709,7 +6709,7 @@ def forward( layernorm_query_outputs = self.layernorm_query( hidden_states, is_first_microbatch=is_first_microbatch, - is_first_module_in_mha=rotary_pos_emb is None, # specific to FP8 MHA + fp8_output=rotary_pos_emb is None, # specific to FP8 MHA ) if self.return_layernorm_output: query_layer, layernorm_output = layernorm_query_outputs @@ -6719,7 +6719,7 @@ def forward( query_layer = self.query_layer( hidden_states, is_first_microbatch=is_first_microbatch, - is_first_module_in_mha=rotary_pos_emb is None, # specific to FP8 MHA + fp8_output=rotary_pos_emb is None, # specific to FP8 MHA ) # [sq, b, hp] --> [sq, b, np, hn] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 37eb1d84db..06acdb032e 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -90,7 +90,7 @@ def forward( ub_overlap_rs_dgrad: bool, ub_overlap_ag: bool, ub_name: str, - is_first_module_in_mha: bool, + fp8_output: bool, fsdp_group: Union[dist_group_type, None], ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible @@ -200,7 +200,7 @@ def forward( assert isinstance(weight_fp8, Float8Tensor) - if is_first_module_in_mha: + if fp8_output: out_index, meta_tensor, output_te_dtype, output_dtype = ( tex.FP8FwdTensors.GEMM1_OUTPUT, fp8_meta["scaling_fwd"], @@ -745,7 +745,7 @@ def backward( None, # ub_overlap_rs_dgrad None, # ub_overlap_ag None, # ub_name - None, # is_first_module_in_mha + None, # fp8_output None, # fsdp_group ) @@ -1098,7 +1098,7 @@ def forward( self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None, - is_first_module_in_mha: Optional[bool] = False, + fp8_output: Optional[bool] = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply layer normalization to the input followed by a linear transformation. @@ -1128,7 +1128,7 @@ def forward( with self.prepare_forward(inp, is_first_microbatch) as inp: - is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha + fp8_output = fp8_output and self.fp8_meta["recipe"].fp8_mha # Get concatenated weight and bias tensors unfused_weights = [getattr(self, name) for name in self.weight_names] @@ -1228,7 +1228,7 @@ def forward( self.ub_overlap_rs_dgrad, self.ub_overlap_ag, self.ub_name, - is_first_module_in_mha, + fp8_output, self.fsdp_group, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 68d333262d..b9fd92f767 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -81,7 +81,7 @@ def forward( ub_overlap_rs: bool, ub_overlap_ag: bool, ub_name: str, - is_first_module_in_mha: bool, + fp8_output: bool, fsdp_group: Union[dist_group_type, None], ) -> torch.Tensor: is_input_fp8 = isinstance(inp, Float8Tensor) @@ -153,7 +153,7 @@ def forward( assert isinstance(weight_fp8, Float8Tensor) - if is_first_module_in_mha: + if fp8_output: proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( tex.FP8FwdTensors.GEMM1_OUTPUT, fp8_meta["scaling_fwd"], @@ -222,7 +222,7 @@ def forward( fp8_meta_tensor=meta_tensor, D_dtype=proj_out_tetype, ) - if is_first_module_in_mha: + if fp8_output: out = Float8Tensor( data=out, fp8_meta=fp8_meta, @@ -621,7 +621,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # ub_overlap_rs None, # ub_overlap_ag None, # ub_name - None, # is_first_module_in_mha + None, # fp8_output None, # fsdp_group ) @@ -899,7 +899,7 @@ def forward( self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None, - is_first_module_in_mha: Optional[bool] = False, + fp8_output: Optional[bool] = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply the linear transformation to the input. @@ -933,7 +933,7 @@ def forward( allow_non_contiguous=isinstance(inp, Float8Tensor), ) as inp: - is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha + fp8_output = fp8_output and self.fp8_meta["recipe"].fp8_mha # Get concatenated weight and bias tensors unfused_weights = [getattr(self, name) for name in self.weight_names] @@ -1019,7 +1019,7 @@ def forward( self.ub_overlap_rs, self.ub_overlap_ag, self.ub_name, - is_first_module_in_mha, + fp8_output, self.fsdp_group, ) out = linear_fn(*args) From 0e837c3580036204b685c0cfdbec24ace5f96e72 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Wed, 14 Aug 2024 20:23:59 -0700 Subject: [PATCH 05/17] resolve comments Signed-off-by: Xin Yao --- transformer_engine/pytorch/attention.py | 22 +++++++++++++------ .../pytorch/module/layernorm_linear.py | 2 -- transformer_engine/pytorch/module/linear.py | 2 -- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index ba7b687b49..9989d41faf 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3934,7 +3934,7 @@ def forward( ): is_input_fp8 = False if fp8: - assert type(q) == type(kv), "q and kv must have the same type." + assert isinstance(kv, q.__class__), "q and kv must have the same type." is_input_fp8 = isinstance(q, Float8Tensor) if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv @@ -4368,7 +4368,9 @@ def forward( if fp8: fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - assert type(q) == type(k) == type(v), "q, k, and v must have the same type." + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, and v must have the same type." is_input_fp8 = isinstance(q, Float8Tensor) if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv @@ -6593,7 +6595,9 @@ def forward( layernorm_qkv_outputs = self.layernorm_qkv( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=rotary_pos_emb is None, # specific to FP8 MHA + fp8_output=( + self.layernorm_qkv.fp8_meta["recipe"].fp8_mha and rotary_pos_emb is None + ), ) if self.return_layernorm_output: mixed_x_layer, layernorm_output = layernorm_qkv_outputs @@ -6603,7 +6607,7 @@ def forward( mixed_x_layer = self.qkv( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=rotary_pos_emb is None, # specific to FP8 MHA + fp8_output=self.qkv.fp8_meta["recipe"].fp8_mha and rotary_pos_emb is None, ) num_queries_per_key_value = ( @@ -6659,7 +6663,7 @@ def forward( mixed_kv_layer = self.key_value( encoder_output, is_first_microbatch=is_first_microbatch, - fp8_output=rotary_pos_emb is None, # specific to FP8 MHA + fp8_output=self.key_value.fp8_meta["recipe"].fp8_mha and rotary_pos_emb is None, ) if self.qkv_weight_interleaved: @@ -6709,7 +6713,9 @@ def forward( layernorm_query_outputs = self.layernorm_query( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=rotary_pos_emb is None, # specific to FP8 MHA + fp8_output=( + self.layernorm_query.fp8_meta["recipe"].fp8_mha and rotary_pos_emb is None + ), ) if self.return_layernorm_output: query_layer, layernorm_output = layernorm_query_outputs @@ -6719,7 +6725,9 @@ def forward( query_layer = self.query_layer( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=rotary_pos_emb is None, # specific to FP8 MHA + fp8_output=( + self.query_layer.fp8_meta["recipe"].fp8_mha and rotary_pos_emb is None + ), ) # [sq, b, hp] --> [sq, b, np, hn] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 06acdb032e..f8b36acb38 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1128,8 +1128,6 @@ def forward( with self.prepare_forward(inp, is_first_microbatch) as inp: - fp8_output = fp8_output and self.fp8_meta["recipe"].fp8_mha - # Get concatenated weight and bias tensors unfused_weights = [getattr(self, name) for name in self.weight_names] if any(isinstance(w, Float8Tensor) for w in unfused_weights): diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b9fd92f767..d6fb6061f1 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -933,8 +933,6 @@ def forward( allow_non_contiguous=isinstance(inp, Float8Tensor), ) as inp: - fp8_output = fp8_output and self.fp8_meta["recipe"].fp8_mha - # Get concatenated weight and bias tensors unfused_weights = [getattr(self, name) for name in self.weight_names] if any(isinstance(w, Float8Tensor) for w in unfused_weights): From 33c3ed66759d0b5939a8d5dc4bb445ce8b865519 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Wed, 14 Aug 2024 22:09:51 -0700 Subject: [PATCH 06/17] resolve comments Signed-off-by: Xin Yao --- transformer_engine/pytorch/attention.py | 108 +++++--- .../pytorch/cpp_extensions/fused_attn.py | 252 ++++++++++++------ transformer_engine/pytorch/csrc/extensions.h | 31 ++- .../pytorch/csrc/extensions/attention.cu | 96 ++++--- 4 files changed, 319 insertions(+), 168 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 9989d41faf..13a6c1c374 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3588,12 +3588,18 @@ def forward( fused_attention_backend, attn_bias, cu_seqlens_padded, - fp8_meta["scaling_fwd"].scale_inv, - fp8_meta["scaling_fwd"].scale, - fp8_meta["scaling_fwd"].amax_history, - META_QKV, - META_S, - META_O, + fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv + META_QKV, # d_scale_qkv_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_s + META_S, # d_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_s + META_S, # q_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_o + META_O, # q_scale_o_offset + fp8_meta["scaling_fwd"].amax_history, # amax_s + META_S, # amax_s_offset + fp8_meta["scaling_fwd"].amax_history, # amax_o + META_O, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -3653,12 +3659,18 @@ def forward( fused_attention_backend, attn_bias, cu_seqlens_padded, - None, - None, - None, - 0, - 0, - 0, + None, # d_scale_qkv + 0, # d_scale_qkv_offset + None, # d_scale_s + 0, # d_scale_s_offset + None, # q_scale_s + 0, # q_scale_s_offset + None, # q_scale_o + 0, # q_scale_o_offset + None, # amax_s + 0, # amax_s_offset + None, # amax_o + 0, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -3969,12 +3981,18 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fp8_meta["scaling_fwd"].scale_inv, - fp8_meta["scaling_fwd"].scale, - fp8_meta["scaling_fwd"].amax_history, - META_QKV, - META_S, - META_O, + fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv + META_QKV, # d_scale_qkv_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_s + META_S, # d_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_s + META_S, # q_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_o + META_O, # q_scale_o_offset + fp8_meta["scaling_fwd"].amax_history, # amax_s + META_S, # amax_s_offset + fp8_meta["scaling_fwd"].amax_history, # amax_o + META_O, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -4042,12 +4060,18 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - None, - None, - None, - 0, - 0, - 0, + None, # d_scale_qkv + 0, # d_scale_qkv_offset + None, # d_scale_s + 0, # d_scale_s_offset + None, # q_scale_s + 0, # q_scale_s_offset + None, # q_scale_o + 0, # q_scale_o_offset + None, # amax_s + 0, # amax_s_offset + None, # amax_o + 0, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -4423,12 +4447,18 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fp8_meta["scaling_fwd"].scale_inv, - fp8_meta["scaling_fwd"].scale, - fp8_meta["scaling_fwd"].amax_history, - META_QKV, - META_S, - META_O, + fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv + META_QKV, # d_scale_qkv_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_s + META_S, # d_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_s + META_S, # q_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_o + META_O, # q_scale_o_offset + fp8_meta["scaling_fwd"].amax_history, # amax_s + META_S, # amax_s_offset + fp8_meta["scaling_fwd"].amax_history, # amax_o + META_O, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -4546,12 +4576,18 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - None, - None, - None, - 0, - 0, - 0, + None, # d_scale_qkv + 0, # d_scale_qkv_offset + None, # d_scale_s + 0, # d_scale_s_offset + None, # q_scale_s + 0, # q_scale_s_offset + None, # q_scale_o + 0, # q_scale_o_offset + None, # amax_s + 0, # amax_s_offset + None, # amax_o + 0, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 6c4e01ce45..ceb90b619e 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -95,12 +95,18 @@ def fused_attn_fwd_qkvpacked( fused_attention_backend: tex.NVTE_Fused_Attn_Backend, attn_bias: torch.Tensor = None, cu_seqlens_padded: torch.Tensor = None, - d_scale: torch.Tensor = None, - q_scale: torch.Tensor = None, - amax: torch.Tensor = None, - offset_QKV: int = META_QKV, - offset_S: int = META_S, - offset_O: int = META_O, + d_scale_qkv: torch.Tensor = None, + d_scale_qkv_offset: int = META_QKV, + d_scale_s: torch.Tensor = None, + d_scale_s_offset: int = META_S, + q_scale_s: torch.Tensor = None, + q_scale_s_offset: int = META_S, + q_scale_o: torch.Tensor = None, + q_scale_o_offset: int = META_O, + amax_s: torch.Tensor = None, + amax_s_offset: int = META_S, + amax_o: torch.Tensor = None, + amax_o_offset: int = META_O, attn_scale: float = None, dropout: float = 0.0, fast_zero_fill: bool = True, @@ -133,19 +139,30 @@ def fused_attn_fwd_qkvpacked( shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv cu_seqlens_padded: torch.Tensor, default = None cumulative sequence offsets for QKV; shape [batch_size + 1] - d_scale: torch.Tensor, default = None - input tensor for the dequantization of Q, K, V and S in FP8 computations, - S = Softmax(Q * K.T) - q_scale: torch.Tensor, default = None - input tensor for the quantization of S and O in FP8 computations, S = Softmax(Q * K.T) - amax: torch.Tensor, default = None - output tensor, amax of S and O, used by the next iteration in FP8 computations - offset_QKV: int, default = 0 - QKV offset in d_scale - offset_S: int, default = 0 - S offset in d_scale, q_scale and amax - offset_O: int, default = 0 - O offset in q_scale and amax + d_scale_qkv: torch.Tensor, default = None + input tensor for the dequantization of QKV in FP8 computations + d_scale_qkv_offset: int, default = META_QKV + offset in d_scale_qkv for QKV + d_scale_s: torch.Tensor, default = None + input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) + d_scale_s_offset: int, default = META_S + offset in d_scale_s for S + q_scale_s: torch.Tensor, default = None + input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) + q_scale_s_offset: int, default = META_S + offset in q_scale_s for S + q_scale_o: torch.Tensor, default = None + input tensor for the quantization of O in FP8 computations + q_scale_o_offset: int, default = META_O + offset in q_scale_o for O + amax_s: torch.Tensor, default = None + output tensor, amax of S, used by the next iteration in FP8 computations + amax_s_offset: int, default = META_S + offset in amax_s for S + amax_o: torch.Tensor, default = None + output tensor, amax of O, used by the next iteration in FP8 computations + amax_o_offset: int, default = META_O + offset in amax_o for O attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; if None, use 1.0/sqrt(head_dim_qk) as the default @@ -231,9 +248,14 @@ def fused_attn_fwd_qkvpacked( max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - assert d_scale is not None, "d_scale is required as an input for FP8 fused attention." - assert q_scale is not None, "q_scale is required as an input for FP8 fused attention." - assert amax is not None, "amax is required as an input for FP8 fused attention." + assert ( + d_scale_qkv is not None + ), "d_scale_qkv is required as an input for FP8 fused attention." + assert d_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." + assert q_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." + assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." + assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." + assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." # execute kernel output_tensors = tex.fused_attn_fwd_qkvpacked( @@ -250,12 +272,18 @@ def fused_attn_fwd_qkvpacked( qkv, qkv_dtype, cu_seqlens_padded, - d_scale, - q_scale, - amax, - offset_QKV, - offset_S, - offset_O, + d_scale_qkv, + d_scale_qkv_offset, + d_scale_s, + d_scale_s_offset, + q_scale_s, + q_scale_s_offset, + q_scale_o, + q_scale_o_offset, + amax_s, + amax_s_offset, + amax_o, + amax_o_offset, attn_bias, rng_gen, rng_elts_per_thread, @@ -450,12 +478,18 @@ def fused_attn_fwd_kvpacked( attn_bias: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None, - d_scale: torch.Tensor = None, - q_scale: torch.Tensor = None, - amax: torch.Tensor = None, - offset_QKV: int = META_QKV, - offset_S: int = META_S, - offset_O: int = META_O, + d_scale_qkv: torch.Tensor = None, + d_scale_qkv_offset: int = META_QKV, + d_scale_s: torch.Tensor = None, + d_scale_s_offset: int = META_S, + q_scale_s: torch.Tensor = None, + q_scale_s_offset: int = META_S, + q_scale_o: torch.Tensor = None, + q_scale_o_offset: int = META_O, + amax_s: torch.Tensor = None, + amax_s_offset: int = META_S, + amax_o: torch.Tensor = None, + amax_o_offset: int = META_O, attn_scale: float = None, dropout: float = 0.0, fast_zero_fill: bool = True, @@ -497,19 +531,30 @@ def fused_attn_fwd_kvpacked( cumulative sequence offsets for Q; shape [batch_size + 1] cu_seqlens_kv_padded: torch.Tensor, default = None cumulative sequence offsets for KV; shape [batch_size + 1] - d_scale: torch.Tensor, default = None - input tensor for the dequantization of Q, K, V and S in FP8 computations, - S = Softmax(Q * K.T) - q_scale: torch.Tensor, default = None - input tensor for the quantization of S and O in FP8 computations, S = Softmax(Q * K.T) - amax: torch.Tensor, default = None - output tensor, amax of S and O, used by the next iteration in FP8 computations - offset_QKV: int, default = 0 - QKV offset in d_scale - offset_S: int, default = 0 - S offset in d_scale, q_scale and amax - offset_O: int, default = 0 - O offset in q_scale and amax + d_scale_qkv: torch.Tensor, default = None + input tensor for the dequantization of QKV in FP8 computations + d_scale_qkv_offset: int, default = META_QKV + offset in d_scale_qkv for QKV + d_scale_s: torch.Tensor, default = None + input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) + d_scale_s_offset: int, default = META_S + offset in d_scale_s for S + q_scale_s: torch.Tensor, default = None + input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) + q_scale_s_offset: int, default = META_S + offset in q_scale_s for S + q_scale_o: torch.Tensor, default = None + input tensor for the quantization of O in FP8 computations + q_scale_o_offset: int, default = META_O + offset in q_scale_o for O + amax_s: torch.Tensor, default = None + output tensor, amax of S, used by the next iteration in FP8 computations + amax_s_offset: int, default = META_S + offset in amax_s for S + amax_o: torch.Tensor, default = None + output tensor, amax of O, used by the next iteration in FP8 computations + amax_o_offset: int, default = META_O + offset in amax_o for O attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; if None, use 1.0/sqrt(head_dim_qk) as the default @@ -596,9 +641,14 @@ def fused_attn_fwd_kvpacked( max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - assert d_scale is not None, "d_scale is required as an input for FP8 fused attention." - assert q_scale is not None, "q_scale is required as an input for FP8 fused attention." - assert amax is not None, "amax is required as an input for FP8 fused attention." + assert ( + d_scale_qkv is not None + ), "d_scale_qkv is required as an input for FP8 fused attention." + assert d_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." + assert q_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." + assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." + assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." + assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." # execute kernel output_tensors = tex.fused_attn_fwd_kvpacked( @@ -619,12 +669,18 @@ def fused_attn_fwd_kvpacked( qkv_dtype, cu_seqlens_q_padded, cu_seqlens_kv_padded, - d_scale, - q_scale, - amax, - offset_QKV, - offset_S, - offset_O, + d_scale_qkv, + d_scale_qkv_offset, + d_scale_s, + d_scale_s_offset, + q_scale_s, + q_scale_s_offset, + q_scale_o, + q_scale_o_offset, + amax_s, + amax_s_offset, + amax_o, + amax_o_offset, attn_bias, rng_gen, rng_elts_per_thread, @@ -841,12 +897,18 @@ def fused_attn_fwd( attn_bias: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None, - d_scale: torch.Tensor = None, - q_scale: torch.Tensor = None, - amax: torch.Tensor = None, - offset_QKV: int = META_QKV, - offset_S: int = META_S, - offset_O: int = META_O, + d_scale_qkv: torch.Tensor = None, + d_scale_qkv_offset: int = META_QKV, + d_scale_s: torch.Tensor = None, + d_scale_s_offset: int = META_S, + q_scale_s: torch.Tensor = None, + q_scale_s_offset: int = META_S, + q_scale_o: torch.Tensor = None, + q_scale_o_offset: int = META_O, + amax_s: torch.Tensor = None, + amax_s_offset: int = META_S, + amax_o: torch.Tensor = None, + amax_o_offset: int = META_O, attn_scale: float = None, dropout: float = 0.0, fast_zero_fill: bool = True, @@ -892,19 +954,30 @@ def fused_attn_fwd( cumulative sequence offsets for Q; shape [batch_size + 1] cu_seqlens_kv_padded: torch.Tensor, default = None cumulative sequence offsets for KV; shape [batch_size + 1] - d_scale: torch.Tensor, default = None - input tensor for the dequantization of Q, K, V and S in FP8 computations, - S = Softmax(Q * K.T) - q_scale: torch.Tensor, default = None - input tensor for the quantization of S and O in FP8 computations, S = Softmax(Q * K.T) - amax: torch.Tensor, default = None - output tensor, amax of S and O, used by the next iteration in FP8 computations - offset_QKV: int, default = 0 - QKV offset in d_scale - offset_S: int, default = 0 - S offset in d_scale, q_scale and amax - offset_O: int, default = 0 - O offset in q_scale and amax + d_scale_qkv: torch.Tensor, default = None + input tensor for the dequantization of QKV in FP8 computations + d_scale_qkv_offset: int, default = META_QKV + offset in d_scale_qkv for QKV + d_scale_s: torch.Tensor, default = None + input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) + d_scale_s_offset: int, default = META_S + offset in d_scale_s for S + q_scale_s: torch.Tensor, default = None + input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) + q_scale_s_offset: int, default = META_S + offset in q_scale_s for S + q_scale_o: torch.Tensor, default = None + input tensor for the quantization of O in FP8 computations + q_scale_o_offset: int, default = META_O + offset in q_scale_o for O + amax_s: torch.Tensor, default = None + output tensor, amax of S, used by the next iteration in FP8 computations + amax_s_offset: int, default = META_S + offset in amax_s for S + amax_o: torch.Tensor, default = None + output tensor, amax of O, used by the next iteration in FP8 computations + amax_o_offset: int, default = META_O + offset in amax_o for O attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; if None, use 1.0/sqrt(head_dim_qk) as the default @@ -993,9 +1066,14 @@ def fused_attn_fwd( max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - assert d_scale is not None, "d_scale is required as an input for FP8 fused attention." - assert q_scale is not None, "q_scale is required as an input for FP8 fused attention." - assert amax is not None, "amax is required as an input for FP8 fused attention." + assert ( + d_scale_qkv is not None + ), "d_scale_qkv is required as an input for FP8 fused attention." + assert d_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." + assert q_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." + assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." + assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." + assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." # execute kernel output_tensors = tex.fused_attn_fwd( @@ -1017,12 +1095,18 @@ def fused_attn_fwd( qkv_dtype, cu_seqlens_q_padded, cu_seqlens_kv_padded, - d_scale, - q_scale, - amax, - offset_QKV, - offset_S, - offset_O, + d_scale_qkv, + d_scale_qkv_offset, + d_scale_s, + d_scale_s_offset, + q_scale_s, + q_scale_s_offset, + q_scale_o, + q_scale_o_offset, + amax_s, + amax_s_offset, + amax_o, + amax_o_offset, attn_bias, rng_gen, rng_elts_per_thread, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index dc8bb8b1a3..a532e194b5 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -28,10 +28,13 @@ std::vector fused_attn_fwd_qkvpacked( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_padded, - const c10::optional descale, const c10::optional scale, - c10::optional amax, const int offset_QKV, const int offset_S, const int offset_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread); + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread); std::vector fused_attn_bwd_qkvpacked( size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, @@ -53,9 +56,13 @@ std::vector fused_attn_fwd_kvpacked( const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, const c10::optional descale, - const c10::optional scale, c10::optional amax, const int offset_QKV, - const int offset_S, const int offset_O, const c10::optional Bias, + const c10::optional cu_seqlens_kv_padded, + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, const c10::optional rng_gen, size_t rng_elts_per_thread); std::vector fused_attn_bwd_kvpacked( @@ -80,9 +87,13 @@ std::vector fused_attn_fwd( const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, const c10::optional descale, - const c10::optional scale, c10::optional amax, const int offset_QKV, - const int offset_S, const int offset_O, const c10::optional Bias, + const c10::optional cu_seqlens_kv_padded, + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, const c10::optional rng_gen, size_t rng_elts_per_thread); std::vector fused_attn_bwd( diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 3daa4c317e..fb1fc97a33 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -83,10 +83,13 @@ std::vector fused_attn_fwd_qkvpacked( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_padded, - const c10::optional descale, const c10::optional scale, - c10::optional amax, const int offset_QKV, const int offset_S, const int offset_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread) { + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; auto qkv_sizes = QKV.sizes().vec(); @@ -115,17 +118,20 @@ std::vector fused_attn_fwd_qkvpacked( } else { O.fill_(0); } - if ((!descale.has_value()) || (!scale.has_value()) || (!amax.has_value())) { - NVTE_ERROR("descale, scale, and amax are required for FP8 operation. \n"); + if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!scale_S.has_value()) || + (!scale_O.has_value()) || (!amax_S.has_value()) || (!amax_O.has_value())) { + std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; + NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale.value(), offset_QKV)); - te_S = makeTransformerEngineTensor( - nullptr, {0}, DType::kFloat32, getDataPtr(amax.value(), offset_S), - getDataPtr(scale.value(), offset_S), getDataPtr(descale.value(), offset_S)); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, + getDataPtr(amax_S.value(), amax_S_offset), + getDataPtr(scale_S.value(), scale_S_offset), + getDataPtr(descale_S.value(), descale_S_offset)); te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, - getDataPtr(amax.value(), offset_O), - getDataPtr(scale.value(), offset_O), nullptr); + getDataPtr(amax_O.value(), amax_O_offset), + getDataPtr(scale_O.value(), scale_O_offset), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { O.fill_(0); @@ -391,9 +397,13 @@ std::vector fused_attn_fwd_kvpacked( const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, const c10::optional descale, - const c10::optional scale, c10::optional amax, const int offset_QKV, - const int offset_S, const int offset_O, const c10::optional Bias, + const c10::optional cu_seqlens_kv_padded, + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, const c10::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; @@ -420,19 +430,22 @@ std::vector fused_attn_fwd_kvpacked( } else { O.fill_(0); } - if ((!descale.has_value()) || (!scale.has_value()) || (!amax.has_value())) { - NVTE_ERROR("descale, scale, and amax are required for FP8 operation. \n"); + if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!scale_S.has_value()) || + (!scale_O.has_value()) || (!amax_S.has_value()) || (!amax_O.has_value())) { + std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; + NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale.value(), offset_QKV)); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale.value(), offset_QKV)); - te_S = makeTransformerEngineTensor( - nullptr, {0}, DType::kFloat32, getDataPtr(amax.value(), offset_S), - getDataPtr(scale.value(), offset_S), getDataPtr(descale.value(), offset_S)); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, + getDataPtr(amax_S.value(), amax_S_offset), + getDataPtr(scale_S.value(), scale_S_offset), + getDataPtr(descale_S.value(), descale_S_offset)); te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, - getDataPtr(amax.value(), offset_O), - getDataPtr(scale.value(), offset_O), nullptr); + getDataPtr(amax_O.value(), amax_O_offset), + getDataPtr(scale_O.value(), scale_O_offset), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { O.fill_(0); @@ -743,9 +756,13 @@ std::vector fused_attn_fwd( const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, const c10::optional descale, - const c10::optional scale, c10::optional amax, const int offset_QKV, - const int offset_S, const int offset_O, const c10::optional Bias, + const c10::optional cu_seqlens_kv_padded, + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, const c10::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; @@ -777,21 +794,24 @@ std::vector fused_attn_fwd( } else { O.fill_(0); } - if ((!descale.has_value()) || (!scale.has_value()) || (!amax.has_value())) { - NVTE_ERROR("descale, scale, and amax are required for FP8 operation. \n"); + if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!scale_S.has_value()) || + (!scale_O.has_value()) || (!amax_S.has_value()) || (!amax_O.has_value())) { + std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; + NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale.value(), offset_QKV)); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale.value(), offset_QKV)); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale.value(), offset_QKV)); - te_S = makeTransformerEngineTensor( - nullptr, {0}, DType::kFloat32, getDataPtr(amax.value(), offset_S), - getDataPtr(scale.value(), offset_S), getDataPtr(descale.value(), offset_S)); - te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, - getDataPtr(amax.value(), offset_O), - getDataPtr(scale.value(), offset_O), nullptr); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, + getDataPtr(amax_S.value(), amax_S_offset), + getDataPtr(scale_S.value(), scale_S_offset), + getDataPtr(descale_S.value(), descale_S_offset)); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, + getDataPtr(amax_O.value(), amax_O_offset), + getDataPtr(scale_O.value(), scale_O_offset), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { O.fill_(0); From 13feabb5933f2dc23c56f3f87514bf82d2b77578 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 Aug 2024 05:10:20 +0000 Subject: [PATCH 07/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/attention.py | 132 ++++++++++++------------ 1 file changed, 66 insertions(+), 66 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 13a6c1c374..b8eee019bf 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3588,18 +3588,18 @@ def forward( fused_attention_backend, attn_bias, cu_seqlens_padded, - fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv - META_QKV, # d_scale_qkv_offset - fp8_meta["scaling_fwd"].scale_inv, # d_scale_s - META_S, # d_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_s - META_S, # q_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_o - META_O, # q_scale_o_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv + META_QKV, # d_scale_qkv_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_s + META_S, # d_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_s + META_S, # q_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_o + META_O, # q_scale_o_offset fp8_meta["scaling_fwd"].amax_history, # amax_s - META_S, # amax_s_offset + META_S, # amax_s_offset fp8_meta["scaling_fwd"].amax_history, # amax_o - META_O, # amax_o_offset + META_O, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -3659,18 +3659,18 @@ def forward( fused_attention_backend, attn_bias, cu_seqlens_padded, - None, # d_scale_qkv - 0, # d_scale_qkv_offset - None, # d_scale_s - 0, # d_scale_s_offset - None, # q_scale_s - 0, # q_scale_s_offset - None, # q_scale_o - 0, # q_scale_o_offset - None, # amax_s - 0, # amax_s_offset - None, # amax_o - 0, # amax_o_offset + None, # d_scale_qkv + 0, # d_scale_qkv_offset + None, # d_scale_s + 0, # d_scale_s_offset + None, # q_scale_s + 0, # q_scale_s_offset + None, # q_scale_o + 0, # q_scale_o_offset + None, # amax_s + 0, # amax_s_offset + None, # amax_o + 0, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -3981,18 +3981,18 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv - META_QKV, # d_scale_qkv_offset - fp8_meta["scaling_fwd"].scale_inv, # d_scale_s - META_S, # d_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_s - META_S, # q_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_o - META_O, # q_scale_o_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv + META_QKV, # d_scale_qkv_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_s + META_S, # d_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_s + META_S, # q_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_o + META_O, # q_scale_o_offset fp8_meta["scaling_fwd"].amax_history, # amax_s - META_S, # amax_s_offset + META_S, # amax_s_offset fp8_meta["scaling_fwd"].amax_history, # amax_o - META_O, # amax_o_offset + META_O, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -4060,18 +4060,18 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - None, # d_scale_qkv - 0, # d_scale_qkv_offset - None, # d_scale_s - 0, # d_scale_s_offset - None, # q_scale_s - 0, # q_scale_s_offset - None, # q_scale_o - 0, # q_scale_o_offset - None, # amax_s - 0, # amax_s_offset - None, # amax_o - 0, # amax_o_offset + None, # d_scale_qkv + 0, # d_scale_qkv_offset + None, # d_scale_s + 0, # d_scale_s_offset + None, # q_scale_s + 0, # q_scale_s_offset + None, # q_scale_o + 0, # q_scale_o_offset + None, # amax_s + 0, # amax_s_offset + None, # amax_o + 0, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -4447,18 +4447,18 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv - META_QKV, # d_scale_qkv_offset - fp8_meta["scaling_fwd"].scale_inv, # d_scale_s - META_S, # d_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_s - META_S, # q_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_o - META_O, # q_scale_o_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv + META_QKV, # d_scale_qkv_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_s + META_S, # d_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_s + META_S, # q_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_o + META_O, # q_scale_o_offset fp8_meta["scaling_fwd"].amax_history, # amax_s - META_S, # amax_s_offset + META_S, # amax_s_offset fp8_meta["scaling_fwd"].amax_history, # amax_o - META_O, # amax_o_offset + META_O, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -4576,18 +4576,18 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - None, # d_scale_qkv - 0, # d_scale_qkv_offset - None, # d_scale_s - 0, # d_scale_s_offset - None, # q_scale_s - 0, # q_scale_s_offset - None, # q_scale_o - 0, # q_scale_o_offset - None, # amax_s - 0, # amax_s_offset - None, # amax_o - 0, # amax_o_offset + None, # d_scale_qkv + 0, # d_scale_qkv_offset + None, # d_scale_s + 0, # d_scale_s_offset + None, # q_scale_s + 0, # q_scale_s_offset + None, # q_scale_o + 0, # q_scale_o_offset + None, # amax_s + 0, # amax_s_offset + None, # amax_o + 0, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, From ae856e48ffb1ec8da085d1346ec2c15f2523b059 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 15 Aug 2024 23:11:24 -0700 Subject: [PATCH 08/17] move transpose to backward for fp8 input Signed-off-by: Xin Yao --- transformer_engine/pytorch/module/linear.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index d6fb6061f1..efd459e1c0 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -106,16 +106,7 @@ def forward( if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if isinstance(inputmat, Float8Tensor): - if ( - not fp8_meta["recipe"].override_linear_precision.wgrad - and is_grad_enabled - and weight.requires_grad - and not sequence_parallel - ): - # FP8 input for forward, FP8 input transpose for backward wgrad - inputmat_t = inputmat.transpose_2d() - else: + if not is_input_fp8: if ( not fp8_meta["recipe"].override_linear_precision.wgrad and is_grad_enabled From 7e26d22499eafa2d512e82491a08d8fd7227c349 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Sun, 18 Aug 2024 20:29:30 -0700 Subject: [PATCH 09/17] fix ut Signed-off-by: Xin Yao --- tests/pytorch/fused_attn/test_fused_attn.py | 18 ++++++++++++------ transformer_engine/pytorch/attention.py | 20 ++++++++------------ 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 82a3c8576b..8d8214cc3d 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -1996,12 +1996,18 @@ def forward( None, None, None, - fp8_meta["scaling_fwd"].scale_inv[META_QKV], - fp8_meta["scaling_fwd"].scale_inv[META_S], - fp8_meta["scaling_fwd"].scale[META_S], - fp8_meta["scaling_fwd"].scale[META_O], - fp8_meta["scaling_fwd"].amax_history[0][META_S], - fp8_meta["scaling_fwd"].amax_history[0][META_O], + fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv + META_QKV, # d_scale_qkv_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_s + META_S, # d_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_s + META_S, # q_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_o + META_O, # q_scale_o_offset + fp8_meta["scaling_fwd"].amax_history, # amax_s + META_S, # amax_s_offset + fp8_meta["scaling_fwd"].amax_history, # amax_o + META_O, # amax_o_offset attn_scale=None, dropout=p_dropout, fast_zero_fill=fast_zero_fill, diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index b8eee019bf..8b774d527f 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -44,7 +44,7 @@ META_S, META_DP, ) -from transformer_engine.pytorch.fp8 import get_fp8_te_dtype +from transformer_engine.pytorch.fp8 import get_fp8_te_dtype, FP8GlobalStateManager from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.module import LayerNormLinear, Linear from transformer_engine.pytorch.module.base import TransformerEngineBaseModule @@ -6625,15 +6625,15 @@ def forward( # Query, Key, and Value # ====================== + fp8_mha = FP8GlobalStateManager.is_fp8_enabled() and FP8GlobalStateManager.get_fp8_recipe().fp8_mha + if self.attention_type == "self": # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn] if self.input_layernorm: layernorm_qkv_outputs = self.layernorm_qkv( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=( - self.layernorm_qkv.fp8_meta["recipe"].fp8_mha and rotary_pos_emb is None - ), + fp8_output=fp8_mha and rotary_pos_emb is None, ) if self.return_layernorm_output: mixed_x_layer, layernorm_output = layernorm_qkv_outputs @@ -6643,7 +6643,7 @@ def forward( mixed_x_layer = self.qkv( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=self.qkv.fp8_meta["recipe"].fp8_mha and rotary_pos_emb is None, + fp8_output=fp8_mha and rotary_pos_emb is None, ) num_queries_per_key_value = ( @@ -6699,7 +6699,7 @@ def forward( mixed_kv_layer = self.key_value( encoder_output, is_first_microbatch=is_first_microbatch, - fp8_output=self.key_value.fp8_meta["recipe"].fp8_mha and rotary_pos_emb is None, + fp8_output=fp8_mha and rotary_pos_emb is None, ) if self.qkv_weight_interleaved: @@ -6749,9 +6749,7 @@ def forward( layernorm_query_outputs = self.layernorm_query( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=( - self.layernorm_query.fp8_meta["recipe"].fp8_mha and rotary_pos_emb is None - ), + fp8_output=fp8_mha and rotary_pos_emb is None, ) if self.return_layernorm_output: query_layer, layernorm_output = layernorm_query_outputs @@ -6761,9 +6759,7 @@ def forward( query_layer = self.query_layer( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=( - self.query_layer.fp8_meta["recipe"].fp8_mha and rotary_pos_emb is None - ), + fp8_output=fp8_mha and rotary_pos_emb is None, ) # [sq, b, hp] --> [sq, b, np, hn] From 521c77a2b5a37b9c69274a49bf15782fbe4dcdf0 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Tue, 20 Aug 2024 20:57:22 -0700 Subject: [PATCH 10/17] resolve comments Signed-off-by: Xin Yao --- tests/pytorch/fused_attn/test_fused_attn.py | 14 +- transformer_engine/pytorch/attention.py | 225 ++++++++++---------- 2 files changed, 128 insertions(+), 111 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 8d8214cc3d..37806ec5e3 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -1320,7 +1320,8 @@ def _rmse(a, b): @pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16) @pytest.mark.parametrize("input_layernorm", [True, False]) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) -def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd): +@pytest.mark.parametrize("RoPE", [True, False]) +def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE): os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" @@ -1332,12 +1333,12 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd): logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( - dtype, config, True, qkv_format, input_layernorm + dtype, config, True, qkv_format, input_layernorm, RoPE ) logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( - dtype, config, False, qkv_format, input_layernorm + dtype, config, False, qkv_format, input_layernorm, RoPE ) tols = dict(atol=5e-1, rtol=5e-1) @@ -1399,7 +1400,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd): ) -def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): +def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE): reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) @@ -1418,6 +1419,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: ) with fp8_model_init(enabled=fp8_mha): + rotary_pos_emb = None + if RoPE: + PE = RotaryPositionEmbedding(dim=config.head_dim_qk) + rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda") mha = MultiheadAttention( hidden_size=config.hidden_size, num_attention_heads=config.num_heads, @@ -1475,6 +1480,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: checkpoint_core_attention=False, core_attention_bias_type=config.attn_bias_type, is_first_microbatch=None, + rotary_pos_emb=rotary_pos_emb, ) out.backward(out_grad) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index e569077fc7..01790fd9af 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3708,7 +3708,6 @@ def run_iteratively(q, k, v): stride = q.stride() check_strides_qkv = all(stride == x.stride() for x in [q, k, v]) - stride = k.stride() check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple( sv / v.shape[-1] for sv in v.stride()[:-1] ) @@ -4117,6 +4116,7 @@ def forward( deterministic, ): is_input_fp8 = False + is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: is_input_fp8 = isinstance(qkv, Float8Tensor) if is_input_fp8: @@ -4165,7 +4165,7 @@ def forward( window_size, rng_gen, ) - if fp8_meta["recipe"].fp8_mha: + if is_output_fp8: out_ret = Float8Tensor( data=out_fp8, fp8_meta=fp8_meta, @@ -4183,22 +4183,24 @@ def forward( qkv_dtype, ).view(out_fp8.shape) out_save = out_ret - if is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv = cast_from_fp8( - qkv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[qkv.dtype], - ).view(qkv.shape) - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) + if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + if is_input_fp8: + qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) + qkv = cast_from_fp8( + qkv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[qkv.dtype], + ).view(qkv.shape) + if is_output_fp8: + out_save = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), + fp8_meta["scaling_fwd"], + META_O, + fp8_dtype_forward, + qkv_dtype, + ).view(out_fp8.shape) fp8_tensors = ( qkv_fp8, out_fp8, @@ -4241,6 +4243,7 @@ def forward( ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None) ctx.save_for_backward( *qkvo_tensors, cu_seqlens, cu_seqlens_padded, *fp8_tensors, *aux_ctx_tensors @@ -4265,7 +4268,7 @@ def forward( @staticmethod def backward(ctx, d_out): - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: assert isinstance( d_out, Float8Tensor ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." @@ -4322,7 +4325,7 @@ def backward(ctx, d_out): fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: d_out_fp8 = d_out ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv else: @@ -4501,6 +4504,7 @@ def forward( deterministic, ): is_input_fp8 = False + is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: assert isinstance(kv, q.__class__), "q and kv must have the same type." is_input_fp8 = isinstance(q, Float8Tensor) @@ -4558,7 +4562,7 @@ def forward( window_size, rng_gen, ) - if fp8_meta["recipe"].fp8_mha: + if is_output_fp8: out_ret = Float8Tensor( data=out_fp8, fp8_meta=fp8_meta, @@ -4576,25 +4580,27 @@ def forward( qkv_dtype, ).view(out_fp8.shape) out_save = out_ret - if is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q = cast_from_fp8( - q._data, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, TE_DType[q.dtype] - ).view(q.shape) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv = cast_from_fp8( - kv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[kv.dtype], - ).view(kv.shape) - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) + if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + if is_input_fp8: + q = cast_from_fp8( + q._data, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, TE_DType[q.dtype] + ).view(q.shape) + kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) + kv = cast_from_fp8( + kv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[kv.dtype], + ).view(kv.shape) + if is_output_fp8: + out_save = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), + fp8_meta["scaling_fwd"], + META_O, + fp8_dtype_forward, + qkv_dtype, + ).view(out_fp8.shape) fp8_tensors = ( q_fp8, kv_fp8, @@ -4642,6 +4648,7 @@ def forward( ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None) ctx.save_for_backward( *qkvo_tensors, @@ -4673,7 +4680,7 @@ def forward( @staticmethod def backward(ctx, d_out): - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: assert isinstance( d_out, Float8Tensor ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." @@ -4734,7 +4741,7 @@ def backward(ctx, d_out): fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: d_out_fp8 = d_out ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv else: @@ -4945,6 +4952,7 @@ def forward( deterministic, ): is_input_fp8 = False + is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) @@ -5024,7 +5032,7 @@ def forward( window_size, rng_gen, ) - if fp8_meta["recipe"].fp8_mha: + if is_output_fp8: out_ret = Float8Tensor( data=out_fp8, fp8_meta=fp8_meta, @@ -5043,71 +5051,73 @@ def forward( ).view(out_fp8.shape) out_save = out_ret - if is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split("_")) - if qkv_group == 1: - dim = qkv_layout.find("3") - qkv = _combine_tensors([q, k, v], dim) - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_no_fp8 = cast_from_fp8( - qkv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[qkv.dtype], - ).view(qkv.shape) - q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1]) - q, k, v = [x.squeeze(dim) for x in [q, k, v]] - if qkv_group == 2: - q = cast_from_fp8( - q._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[q.dtype], - ).view(q.shape) - dim = qkv_layout.split("_")[1].find("2") - kv = _combine_tensors([k, v], dim) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_no_fp8 = cast_from_fp8( - kv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[kv.dtype], - ).view(kv.shape) - k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1]) - k, v = [x.squeeze(dim) for x in [k, v]] - if qkv_group == 3: - q = cast_from_fp8( - q._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[q.dtype], - ).view(q.shape) - k = cast_from_fp8( - k._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[k.dtype], - ).view(k.shape) - v = cast_from_fp8( - v._data, + if is_input_fp8: + qkv_group = len(qkv_layout.split("_")) + if qkv_group == 1: + dim = qkv_layout.find("3") + qkv = _combine_tensors([q, k, v], dim) + qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) + qkv_no_fp8 = cast_from_fp8( + qkv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[qkv.dtype], + ).view(qkv.shape) + q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1]) + q, k, v = [x.squeeze(dim) for x in [q, k, v]] + if qkv_group == 2: + q = cast_from_fp8( + q._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[q.dtype], + ).view(q.shape) + dim = qkv_layout.split("_")[1].find("2") + kv = _combine_tensors([k, v], dim) + kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) + kv_no_fp8 = cast_from_fp8( + kv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[kv.dtype], + ).view(kv.shape) + k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1]) + k, v = [x.squeeze(dim) for x in [k, v]] + if qkv_group == 3: + q = cast_from_fp8( + q._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[q.dtype], + ).view(q.shape) + k = cast_from_fp8( + k._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[k.dtype], + ).view(k.shape) + v = cast_from_fp8( + v._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[v.dtype], + ).view(v.shape) + if is_output_fp8: + out_save = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), fp8_meta["scaling_fwd"], - META_QKV, + META_O, fp8_dtype_forward, - TE_DType[v.dtype], - ).view(v.shape) - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) + qkv_dtype, + ).view(out_fp8.shape) fp8_tensors = ( q_fp8, @@ -5167,6 +5177,7 @@ def forward( ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) ctx.save_for_backward( *qkvo_tensors, @@ -5198,7 +5209,7 @@ def forward( @staticmethod def backward(ctx, d_out): - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: assert isinstance( d_out, Float8Tensor ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." @@ -5263,7 +5274,7 @@ def backward(ctx, d_out): fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: d_out_fp8 = d_out ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv else: From dd30c2d686a39eea3b6f2d8054e576425d849f77 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Wed, 21 Aug 2024 01:41:20 -0700 Subject: [PATCH 11/17] update argument list for CP Signed-off-by: Xin Yao --- transformer_engine/pytorch/attention.py | 18 ++++++++++++------ .../pytorch/cpp_extensions/fused_attn.py | 3 +++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 72bc5a39af..119103a9ce 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1439,10 +1439,14 @@ def forward( for x in [k_f16, v_f16] ] fp8_meta_kwargs = {} - fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv[META_QKV] - fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv[META_S] - fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale[META_S] - fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale[META_O_CP] + fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv + fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV + fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv + fp8_meta_kwargs["d_scale_s_offset"] = META_S + fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale + fp8_meta_kwargs["q_scale_s_offset"] = META_S + fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale + fp8_meta_kwargs["q_scale_o_offset"] = META_O_CP amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) else: assert False, "FP8 is only supported with Fused Attention!" @@ -1494,8 +1498,10 @@ def forward( fp8_dtype_forward, ) if fp8 and use_fused_attention: - fp8_meta_kwargs["amax_s"] = amax_per_step[0][i] - fp8_meta_kwargs["amax_o"] = amax_per_step[1][i] + fp8_meta_kwargs["amax_s"] = amax_per_step + fp8_meta_kwargs["amax_s_offset"] = i + fp8_meta_kwargs["amax_o"] = amax_per_step + fp8_meta_kwargs["amax_o_offset"] = cp_size + i if causal: if i == 0: if pad_between_seqs_q: diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index ceb90b619e..cd0ecbaa6c 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -84,6 +84,9 @@ META_DO = tex.FP8BwdTensors.GRAD_INPUT2 META_S = tex.FP8FwdTensors.GEMM3_OUTPUT META_DP = tex.FP8BwdTensors.GRAD_INPUT3 +# repurpose some unused amax history buffers for partial results of CP fwd and bwd +META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT +META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1 def fused_attn_fwd_qkvpacked( From a94b3ad44790fa653d99b72ba4cf679c99eb6ec8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 21 Aug 2024 08:44:41 +0000 Subject: [PATCH 12/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/attention.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 119103a9ce..d25bd7a399 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -4926,7 +4926,11 @@ def forward( if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): if is_input_fp8: q = cast_from_fp8( - q._data, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, TE_DType[q.dtype] + q._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[q.dtype], ).view(q.shape) kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) kv = cast_from_fp8( @@ -7588,7 +7592,10 @@ def forward( # Query, Key, and Value # ====================== - fp8_mha = FP8GlobalStateManager.is_fp8_enabled() and FP8GlobalStateManager.get_fp8_recipe().fp8_mha + fp8_mha = ( + FP8GlobalStateManager.is_fp8_enabled() + and FP8GlobalStateManager.get_fp8_recipe().fp8_mha + ) if self.attention_type == "self": # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn] From 400d5266f249f2b2c8149cba1b10962f707a4502 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Sun, 25 Aug 2024 19:55:43 -0700 Subject: [PATCH 13/17] fix for FA3 Signed-off-by: Xin Yao --- tests/pytorch/fused_attn/test_fused_attn.py | 4 +++- transformer_engine/pytorch/attention.py | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index aaf370a7e3..f5e8685e5a 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -1355,12 +1355,14 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, global _attention_backends if not is_training: + if RoPE: + pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.") os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16( - dtype, config, True, qkv_format, input_layernorm, is_training + dtype, config, True, qkv_format, input_layernorm, RoPE, is_training ) os.environ["NVTE_FLASH_ATTN"] = "0" diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 06984cf94a..91991ce0ff 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -385,6 +385,10 @@ def get_attention_backend( # Filter: Execution type if fp8 and fp8_meta["recipe"].fp8_dpa: + if use_flash_attention and not _flash_attn_3_plus: + logger.debug("Disabling FlashAttention as FlashAttention 3 is not available for " + "FP8 DPA/FP8 MHA.") + use_flash_attention = False if use_flash_attention and is_training: logger.debug("Disabling FlashAttention as it does not support FP8 training") use_flash_attention = False From b935e13d111e386f17d5cdfd0b4ae56c66e2eab7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Aug 2024 02:56:11 +0000 Subject: [PATCH 14/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/attention.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 91991ce0ff..4a64e5c5f4 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -386,8 +386,9 @@ def get_attention_backend( # Filter: Execution type if fp8 and fp8_meta["recipe"].fp8_dpa: if use_flash_attention and not _flash_attn_3_plus: - logger.debug("Disabling FlashAttention as FlashAttention 3 is not available for " - "FP8 DPA/FP8 MHA.") + logger.debug( + "Disabling FlashAttention as FlashAttention 3 is not available for FP8 DPA/FP8 MHA." + ) use_flash_attention = False if use_flash_attention and is_training: logger.debug("Disabling FlashAttention as it does not support FP8 training") From 9eca36982b65fe8f9b85972b895b85f594fc9683 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Sun, 25 Aug 2024 20:03:17 -0700 Subject: [PATCH 15/17] remove unnecessary copy of scale_inv Signed-off-by: Xin Yao --- transformer_engine/pytorch/module/linear.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index ba50f46842..f92a2db2d9 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -86,8 +86,6 @@ def forward( fsdp_group: Union[dist_group_type, None], ) -> torch.Tensor: is_input_fp8 = isinstance(inp, Float8Tensor) - if is_input_fp8: - fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT] = inp._scale_inv[0] # Make sure input dimensions are compatible in_features = weight.shape[-1] From e3b75db6d45dd34697818684383de4dc8e933638 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Mon, 26 Aug 2024 19:28:50 -0700 Subject: [PATCH 16/17] skip fp8 dpa/mha tests when fa3 is not available Signed-off-by: Xin Yao --- tests/pytorch/fused_attn/test_fused_attn.py | 5 +++++ transformer_engine/pytorch/attention.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index f5e8685e5a..781cc2bdb2 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -22,6 +22,7 @@ get_attention_backend, _flash_attn_2_plus, _flash_attn_2_3_plus, + _flash_attn_3_plus, check_set_window_size, AttentionParams, _attention_backends, @@ -1357,6 +1358,8 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, if not is_training: if RoPE: pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.") + if not _flash_attn_3_plus: + pytest.skip("FP8 MHA requires Flash Attention 3.") os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True @@ -1537,6 +1540,8 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): global _attention_backends if not is_training: + if not _flash_attn_3_plus: + pytest.skip("FP8 DPA requires Flash Attention 3.") os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 4a64e5c5f4..7ca0dec6e4 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -387,7 +387,7 @@ def get_attention_backend( if fp8 and fp8_meta["recipe"].fp8_dpa: if use_flash_attention and not _flash_attn_3_plus: logger.debug( - "Disabling FlashAttention as FlashAttention 3 is not available for FP8 DPA/FP8 MHA." + "Disabling FlashAttention as FlashAttention 2 does not support FP8 DPA/FP8 MHA." ) use_flash_attention = False if use_flash_attention and is_training: From f9da6d79f2aaea841443041203840f7b8b7f3bba Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Wed, 4 Sep 2024 16:47:00 -0700 Subject: [PATCH 17/17] fix a merge bug Signed-off-by: Xin Yao --- transformer_engine/pytorch/attention.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index fcd80f9c56..59bc26140d 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -385,14 +385,14 @@ def get_attention_backend( # Filter: Execution type if fp8 and fp8_meta["recipe"].fp8_dpa: - if use_flash_attention and not _flash_attn_3_plus: + if use_flash_attention and not _use_flash_attn_3: + logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8") + use_flash_attention = False + if use_flash_attention and _use_flash_attn_3 and is_training: logger.debug( - "Disabling FlashAttention as FlashAttention 2 does not support FP8 DPA/FP8 MHA." + "Disabling FlashAttention as FlashAttention 3 does not support FP8 training" ) use_flash_attention = False - if use_flash_attention and is_training: - logger.debug("Disabling FlashAttention as it does not support FP8 training") - use_flash_attention = False if use_unfused_attention: logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8") use_unfused_attention = False