From aec5fce2f940b2c97b988c6a0719a536f6f6c6e9 Mon Sep 17 00:00:00 2001 From: shw Date: Thu, 25 Apr 2024 11:12:21 +0800 Subject: [PATCH 01/14] fused_rope start --- oneflow/core/functional/functional_api.yaml | 10 +- .../impl/fused_attention_functor.cpp | 55 +++++- oneflow/ir/include/OneFlow/OneFlowUserOps.td | 52 ++--- oneflow/user/ops/fused_attention_ops.cpp | 185 ++++++++++++++++++ 4 files changed, 270 insertions(+), 32 deletions(-) diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 5829165e708..d646ed021ad 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -1279,6 +1279,10 @@ signature: 'Tensor (Tensor x, *, Tensor cos=None, Tensor sin=None, Tensor position_ids=None, String x_layout="BHMK", String output_layout=None, String mode="plane", Int64 tensor_index=None, Int64 k_size=None, Float base=1e4, Int64 rotary_size=None) => FusedApplyRotaryEmb' bind_python: True +- name: "fused_apply_rotary_emb_grad" + signature: 'Tensor (Tensor x, Tensor dy, Tensor cos=None, Tensor sin=None, Tensor position_ids=None, String x_layout="BHMK", String output_layout=None, String mode="plane", Int64 tensor_index=None, Int64 k_size=None, Float base=1e4, Int64 rotary_size=None) => FusedApplyRotaryEmbGrad' + bind_python: False + - name: "fused_relu_dropout_grad" signature: "Tensor (Tensor dy, Tensor mask, Float scale) => FusedReluDropoutGrad" bind_python: False @@ -2404,7 +2408,7 @@ bind_python: False - name: "to_global" - signature: "Tensor (Tensor x, Placement placement, SbpList sbp, SbpList grad_sbp, Bool check_meta, Bool copy=False) => ToGlobal" + signature: "Tensor (Tensor x, Placement placement, SbpList sbp, SbpList grad_sbp, Bool check_meta, Bool sync_data, Bool copy=False) => ToGlobal" bind_python: True - name: "to_local" @@ -2684,10 +2688,6 @@ signature: "TensorTuple (Tensor x, Tensor bias, Tensor mask, *, Float fill_value=0.0, Float scale=1.0, Float p=0.5, Bool training=True, Generator generator=None) => FusedBiasAddScaleMaskSoftmaxDropout" bind_python: True -- name: "scaled_dot_product_attention" - signature: "Tensor (Tensor query, Tensor key, Tensor value, Tensor attn_mask=None, Float dropout_p=0.0, Bool is_causal=False, Float scale=None, Int64 seed=0) => ScaledDotProductFlashAttention" - bind_python: True - - name: "fused_multi_head_attention_inference" signature: "Tensor (Tensor query, Tensor key, Tensor value, Int64 num_heads, Bool causal=False, Int64 query_hidden_slice_start=0, Int64 query_hidden_slice_end=-1, Int64 key_hidden_slice_start=0, Int64 key_hidden_slice_end=-1, Int64 value_hidden_slice_start=0, Int64 value_hidden_slice_end=-1, Tensor attn_bias=None, Int64 causal_diagonal_offset=0) => FusedMultiHeadAttentionInference" bind_python: True diff --git a/oneflow/core/functional/impl/fused_attention_functor.cpp b/oneflow/core/functional/impl/fused_attention_functor.cpp index 413caf69e5f..4c7c9dd97f1 100644 --- a/oneflow/core/functional/impl/fused_attention_functor.cpp +++ b/oneflow/core/functional/impl/fused_attention_functor.cpp @@ -644,6 +644,7 @@ class FusedApplyRotaryEmbFunctor { const Optional& tensor_index, const Optional& k_size, const float base, const Optional& rotary_size) const { int64_t b = 0, m = 0, h = 0, k = 0; + // std::cout << "go here!" << std::endl; if (tensor_index) { CHECK_OR_RETURN((JUST(tensor_index) >= 0) && (JUST(tensor_index) <= 2)) @@ -653,7 +654,8 @@ class FusedApplyRotaryEmbFunctor { << "mode should be \"intervel\" or \"plane\""; ParseDims("x", *x->shape(), x_layout, Optional(), k_size, &b, &m, &h, &k); - + // std::cout << "cpp head_size: " << h << std::endl; + // printf("k_size: %ld\n", k_size); if (k_size) { CHECK_EQ_OR_RETURN(JUST(k_size), k) << "k_size if given should be equal to K of cos, sin and x."; @@ -709,6 +711,7 @@ class FusedApplyRotaryEmbFunctor { attrs.SetAllAttrs(x_layout, output_layout.value_or(x_layout), mode, tensor_index.value_or(0), k_size.value_or(k), base, rotary_size.value_or(k)); + // std::cout << "dispatch!" << std::endl; if (position_ids) { if (cos && sin) { return OpInterpUtil::Dispatch(*op_with_position_sinuous_, @@ -733,6 +736,55 @@ class FusedApplyRotaryEmbFunctor { std::shared_ptr op_without_position_sinuous_; }; +class FusedApplyRotaryEmbGradFunctor { + public: + FusedApplyRotaryEmbGradFunctor() { + op_with_position_sinuous_ = CHECK_JUST(one::OpBuilder("fused_apply_rotary_emb_grad") + .Input("x") + .Input("dy") + .Input("cos") + .Input("sin") + .Input("position_ids") + .Output("out") + .Build()); + op_with_position_ = CHECK_JUST(one::OpBuilder("fused_apply_rotary_emb_grad") + .Input("x") + .Input("dy") + .Input("position_ids") + .Output("out") + .Build()); + op_without_position_ = CHECK_JUST(one::OpBuilder("fused_apply_rotary_emb_grad") + .Input("x") + .Input("dy") + .Input("cos") + .Input("sin") + .Output("out") + .Build()); + op_without_position_sinuous_ = + CHECK_JUST(one::OpBuilder("fused_apply_rotary_emb_grad").Input("x") + .Input("dy") + .Output("out") + .Build()); + } + Maybe operator()(const std::shared_ptr& x, + const std::shared_ptr& dy, + const Optional& cos, + const Optional& sin, + const Optional& position_ids, const std::string& x_layout, + const Optional& output_layout, const std::string& mode, + const Optional& tensor_index, const Optional& k_size, + const float base, const Optional& rotary_size) const { + std::cout << "FusedApplyRotaryEmbGradFunctor" << std::endl; + return dy; + } + + private: + std::shared_ptr op_with_position_; + std::shared_ptr op_with_position_sinuous_; + std::shared_ptr op_without_position_; + std::shared_ptr op_without_position_sinuous_; +}; + } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { @@ -741,6 +793,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { "FusedMultiHeadAttentionInferenceV2"); m.add_functor("FusedAttentionConcatPastKeyValue"); m.add_functor("FusedApplyRotaryEmb"); + m.add_functor("FusedApplyRotaryEmbGrad"); } } // namespace functional diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index eb7c6da6e58..7f51d4dfffe 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -2877,32 +2877,6 @@ def OneFlow_FusedCrossFeatureInteractionV2GradOp : OneFlow_BaseOp<"fused_cross_f let has_data_type_infer_fn = 1; } -def OneFlow_ScaledDotProductFlashAttentionOp : OneFlow_BaseOp<"scaled_dot_product_flash_attention", [NoMemoryEffect, DeclareOpInterfaceMethods]> { - let input = (ins - OneFlow_Tensor:$query, - OneFlow_Tensor:$key, - OneFlow_Tensor:$value, - Optional:$alibi_slopes_ - ); - let output = (outs - OneFlow_Tensor:$out, - OneFlow_Tensor:$softmax_lse, - OneFlow_Tensor:$rng_state - ); - let attrs = (ins - DefaultValuedAttr:$p_dropout, - DefaultValuedAttr:$softmax_scale, - DefaultValuedAttr:$is_causal, - SI32Attr:$window_size_left, - SI32Attr:$window_size_right, - DefaultValuedAttr:$seed - ); - let has_logical_tensor_desc_infer_fn = 1; - let has_physical_tensor_desc_infer_fn = 1; - let has_get_sbp_fn = 1; - let has_data_type_infer_fn = 1; -} - def OneFlow_FusedMultiHeadAttentionInferenceOp : OneFlow_BaseOp<"fused_multi_head_attention_inference", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$query, @@ -3950,6 +3924,32 @@ def OneFlow_FusedApplyRotaryEmbOp : OneFlow_BaseOp<"fused_apply_rotary_emb", [No let has_data_type_infer_fn = 1; } +def OneFlow_FusedApplyRotaryEmbGradOp : OneFlow_BaseOp<"fused_apply_rotary_emb_grad", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy, + Optional:$cos, + Optional:$sin, + Optional:$position_ids + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$x_layout, + DefaultValuedAttr:$output_layout, + DefaultValuedAttr:$mode, + DefaultValuedAttr:$tensor_index, + DefaultValuedAttr:$base, + DefaultValuedAttr:$k_size, + DefaultValuedAttr:$rotary_size + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + def OneFlow_EmbeddingGradOp : OneFlow_BaseOp<"embedding_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, diff --git a/oneflow/user/ops/fused_attention_ops.cpp b/oneflow/user/ops/fused_attention_ops.cpp index 123c09e16fc..d02cf2e1950 100644 --- a/oneflow/user/ops/fused_attention_ops.cpp +++ b/oneflow/user/ops/fused_attention_ops.cpp @@ -806,4 +806,189 @@ Maybe ParseSplitAxis(const std::string& layout, bool can_hk_split, int64_t return Maybe::Ok(); } +/* static */ Maybe FusedApplyRotaryEmbGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + const std::string& x_layout = ctx->Attr("x_layout"); + const std::string& output_layout = ctx->Attr("output_layout"); + const std::string& mode = ctx->Attr("mode"); + const int64_t rotary_size = ctx->Attr("rotary_size"); + const int64_t k_size = ctx->Attr("k_size"); + const int64_t tensor_index = ctx->Attr("tensor_index"); + + CHECK_OR_RETURN((tensor_index >= 0) && (tensor_index <= 2)) + << "tensor_index should be in range [0, 2]."; + CHECK_OR_RETURN((mode == "interval") || (mode == "plane")) + << "mode should be either \"interval\" or \"plane\"."; + + CHECK_OR_RETURN(output_layout != "BM(H2K)" && output_layout != "BM(H3K)" + && output_layout != "MB(H2K)" && output_layout != "MB(H3K)") + << "output_layout should not be \"BM(H2k)\", \"BM(H3K)\", \"MB(H2K)\", \"MB(H3K)\"."; + + int64_t b = 0, m = 0, h = 0, k = 0; + + JUST(ParseDims(x_desc.shape(), x_layout, Optional(), Optional(k_size), &b, &m, + &h, &k)); // 这里需要检查是否正确; + + CHECK_LE_OR_RETURN(rotary_size, k) << "rotary_size should be no more than K of input x."; + + int64_t rotary_emb_dim = 1; + + if (ctx->has_input("position_ids", 0)) { + const user_op::TensorDesc& position_ids_desc = ctx->InputTensorDesc("position_ids", 0); + CHECK_EQ_OR_RETURN(position_ids_desc.shape().NumAxes(), 3) + << "ndims of position_ids should be equal to 3, either in form of B1M or B2M."; + CHECK_EQ_OR_RETURN(position_ids_desc.shape().At(0), b) + << "1st dim of position_ids should be equal to B."; + CHECK_EQ_OR_RETURN(position_ids_desc.shape().At(2), m) + << "3rd dim of position_ids should be equal to M."; + rotary_emb_dim = position_ids_desc.shape().At(1); + CHECK_OR_RETURN(rotary_emb_dim == 1 || rotary_emb_dim == 2) + << "2nd dim of position_ids should be 1 or 2."; + } + + const int64_t actual_rotary_size = rotary_size / rotary_emb_dim; + CHECK_EQ_OR_RETURN(actual_rotary_size % 2, 0) + << "rotary_size should be a multiple of 2 * rotary_encoding_dim."; + + bool has_cos = ctx->has_input("cos", 0); + bool has_sin = ctx->has_input("sin", 0); + // TODO: fused_apply_rotary_emb have same logic no matter name + if (has_cos && has_sin) { + const user_op::TensorDesc& cos_desc = ctx->InputTensorDesc("cos", 0); + const user_op::TensorDesc& sin_desc = ctx->InputTensorDesc("sin", 0); + CHECK_EQ_OR_RETURN(cos_desc.shape().NumAxes(), 2) + << "The number of dimensions of cos should be equal to 2."; + CHECK_OR_RETURN(cos_desc.shape() == sin_desc.shape()) + << "The dimensions of cos & sin should be the same."; + CHECK_EQ_OR_RETURN(cos_desc.shape().At(1), actual_rotary_size) + << "The 1st dimension of cos & sin should equal to rotary_size // " + "rotary_embedding_dimension."; + } else if (!has_cos && !has_sin) { + // Do nothing + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + + if (!ctx->has_input("position_ids", 0)) { + if (has_cos && has_sin) { + const user_op::TensorDesc& cos_desc = ctx->InputTensorDesc("cos", 0); + CHECK_GE_OR_RETURN(cos_desc.shape().At(0), m) + << "M of cos should be no less than M of x if position_ids is not given."; + // K of cos & sin is checked inside ParseDims + } + } + + Shape out_shape = *JUST(LayoutToShape(b, m, h, k, x_layout)); + ctx->SetOutputShape("dx", 0, out_shape); + return Maybe::Ok(); +} + +/* static */ Maybe FusedApplyRotaryEmbGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + + +/* static */ Maybe FusedApplyRotaryEmbGradOp::GetSbp(user_op::SbpContext* ctx) { + /* + 1. 获取layout; + 2. check dy shape; + 3. 获取split_axis; + 4. 设置sbp; + 反向算子中的split应该与前向一致; + */ + const user_op::TensorDesc& x_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + int num_heads = -1; + const int64_t k_size = ctx->Attr("k_size"); + const std::string& x_layout = ctx->Attr("x_layout"); + const std::string& output_layout = ctx->Attr("output_layout"); + + if (x_desc.shape().NumAxes() == 2) { + if (x_layout == "(BM)(HK)") { + CHECK_EQ_OR_RETURN(x_desc.shape().At(1) % k_size, 0); + num_heads = x_desc.shape().At(1) / k_size; + } else if (x_layout == "(BM)(H3K)") { + CHECK_EQ_OR_RETURN(x_desc.shape().At(1) % (k_size * 3), 0); + num_heads = x_desc.shape().At(1) / (k_size * 3); + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + } else if (x_desc.shape().NumAxes() == 3) { + if (x_layout == "BM(HK)" || x_layout == "MB(HK)") { + CHECK_EQ_OR_RETURN(x_desc.shape().At(2) % k_size, 0); + num_heads = x_desc.shape().At(2) / k_size; + } else if (x_layout == "BM(H3K)" || x_layout == "MB(H3K)") { + CHECK_EQ_OR_RETURN(x_desc.shape().At(2) % (k_size * 3), 0); + num_heads = x_desc.shape().At(2) / (k_size * 3); + } else if (x_layout == "(BM)HK") { + num_heads = x_desc.shape().At(1); + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + } else if (x_desc.shape().NumAxes() == 4) { + if (x_layout == "BMHK") { + num_heads = x_desc.shape().At(2); + } else if (x_layout == "BHMK") { + num_heads = x_desc.shape().At(1); + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + + const bool can_hk_split = num_heads % ctx->parallel_num() == 0; + int64_t x_b_split_axis = -1; + int64_t x_h_split_axis = -1; + JUST(ParseSplitAxis(x_layout, can_hk_split, &x_b_split_axis, &x_h_split_axis)); + int64_t o_b_split_axis = -1; + int64_t o_h_split_axis = -1; + JUST(ParseSplitAxis(output_layout, can_hk_split, &o_b_split_axis, &o_h_split_axis)); + + if (x_b_split_axis >= 0 && o_b_split_axis >= 0) { + auto builder = ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), o_b_split_axis) + .Split(user_op::OpArg("dx", 0), x_b_split_axis); + if (ctx->user_op_conf().has_input("cos", 0)) + builder = builder.Broadcast(user_op::OpArg("cos", 0)).Broadcast(user_op::OpArg("sin", 0)); + if (ctx->user_op_conf().has_input("position_ids", 0)) + builder = builder.Split(user_op::OpArg("position_ids", 0), 0); // 这里怎么split? + builder.Build(); + } + if (x_h_split_axis >= 0 && o_h_split_axis >= 0) { + auto builder = ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), o_h_split_axis) + .Split(user_op::OpArg("dx", 0), x_h_split_axis); + if (ctx->user_op_conf().has_input("cos", 0)) + builder = builder.Broadcast(user_op::OpArg("cos", 0)).Broadcast(user_op::OpArg("sin", 0)); + if (ctx->user_op_conf().has_input("position_ids", 0)) + builder = builder.Broadcast(user_op::OpArg("position_ids", 0)); + builder.Build(); + } + + return Maybe::Ok(); +} + +/* static */ Maybe FusedApplyRotaryEmbGradOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc("dy", 0); + + bool has_sinuous = ctx->has_input("cos", 0); + + if (has_sinuous) { + const user_op::TensorDesc& cos_desc = ctx->InputTensorDesc("cos", 0); + const user_op::TensorDesc& sin_desc = ctx->InputTensorDesc("sin", 0); + + CHECK_EQ_OR_RETURN(cos_desc.data_type(), first_in_desc.data_type()) + << "InferDataType Failed. Expected " << DataType_Name(first_in_desc.data_type()) + << ", but got " << DataType_Name(cos_desc.data_type()); + CHECK_EQ_OR_RETURN(sin_desc.data_type(), first_in_desc.data_type()) + << "InferDataType Failed. Expected " << DataType_Name(first_in_desc.data_type()) + << ", but got " << DataType_Name(sin_desc.data_type()); + } + + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("dx", 0); + out_desc->set_data_type(first_in_desc.data_type()); + + return Maybe::Ok(); +} + } // namespace oneflow From 357fcaa8adb1b95d769614c1b512fdb3eecbe470 Mon Sep 17 00:00:00 2001 From: shw Date: Thu, 25 Apr 2024 14:57:38 +0800 Subject: [PATCH 02/14] test fused_rope --- oneflow/core/functional/functional_api.yaml | 6 ++++- oneflow/ir/include/OneFlow/OneFlowUserOps.td | 26 +++++++++++++++++++ .../modules/test_fused_rotary_embedding.py | 10 +++---- 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index d646ed021ad..ea9bb659220 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -2408,7 +2408,7 @@ bind_python: False - name: "to_global" - signature: "Tensor (Tensor x, Placement placement, SbpList sbp, SbpList grad_sbp, Bool check_meta, Bool sync_data, Bool copy=False) => ToGlobal" + signature: "Tensor (Tensor x, Placement placement, SbpList sbp, SbpList grad_sbp, Bool check_meta, Bool copy=False) => ToGlobal" bind_python: True - name: "to_local" @@ -2688,6 +2688,10 @@ signature: "TensorTuple (Tensor x, Tensor bias, Tensor mask, *, Float fill_value=0.0, Float scale=1.0, Float p=0.5, Bool training=True, Generator generator=None) => FusedBiasAddScaleMaskSoftmaxDropout" bind_python: True +- name: "scaled_dot_product_attention" + signature: "Tensor (Tensor query, Tensor key, Tensor value, Tensor attn_mask=None, Float dropout_p=0.0, Bool is_causal=False, Float scale=None, Int64 seed=0) => ScaledDotProductFlashAttention" + bind_python: True + - name: "fused_multi_head_attention_inference" signature: "Tensor (Tensor query, Tensor key, Tensor value, Int64 num_heads, Bool causal=False, Int64 query_hidden_slice_start=0, Int64 query_hidden_slice_end=-1, Int64 key_hidden_slice_start=0, Int64 key_hidden_slice_end=-1, Int64 value_hidden_slice_start=0, Int64 value_hidden_slice_end=-1, Tensor attn_bias=None, Int64 causal_diagonal_offset=0) => FusedMultiHeadAttentionInference" bind_python: True diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index 7f51d4dfffe..9042aac9422 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -2877,6 +2877,32 @@ def OneFlow_FusedCrossFeatureInteractionV2GradOp : OneFlow_BaseOp<"fused_cross_f let has_data_type_infer_fn = 1; } +def OneFlow_ScaledDotProductFlashAttentionOp : OneFlow_BaseOp<"scaled_dot_product_flash_attention", [NoMemoryEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$query, + OneFlow_Tensor:$key, + OneFlow_Tensor:$value, + Optional:$alibi_slopes_ + ); + let output = (outs + OneFlow_Tensor:$out, + OneFlow_Tensor:$softmax_lse, + OneFlow_Tensor:$rng_state + ); + let attrs = (ins + DefaultValuedAttr:$p_dropout, + DefaultValuedAttr:$softmax_scale, + DefaultValuedAttr:$is_causal, + SI32Attr:$window_size_left, + SI32Attr:$window_size_right, + DefaultValuedAttr:$seed + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + def OneFlow_FusedMultiHeadAttentionInferenceOp : OneFlow_BaseOp<"fused_multi_head_attention_inference", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$query, diff --git a/python/oneflow/test/modules/test_fused_rotary_embedding.py b/python/oneflow/test/modules/test_fused_rotary_embedding.py index 68e8838fddb..a7677c718cc 100644 --- a/python/oneflow/test/modules/test_fused_rotary_embedding.py +++ b/python/oneflow/test/modules/test_fused_rotary_embedding.py @@ -289,7 +289,7 @@ def _test_without_position( for m in range(M) ] ).reshape(M, rotary_size // rotary_ndims) - fused_x = flow.tensor(x, dtype=dtype, device=device) + fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) fused_cos = flow.tensor(fused_cos, dtype=dtype, device=device) fused_sin = flow.tensor(fused_sin, dtype=dtype, device=device) @@ -435,7 +435,7 @@ def _test_without_position_sinuous( mode, ) - fused_x = flow.tensor(x, dtype=dtype, device=device) + fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) if x_layout == "BM(H3K)": out0 = flow._C.fused_apply_rotary_emb( @@ -627,7 +627,7 @@ def _test_with_position_sinuous( ] ) - fused_x = flow.tensor(x, dtype=dtype, device=device) + fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) fused_cos = flow.tensor(fused_cos, dtype=dtype, device=device) fused_sin = flow.tensor(fused_sin, dtype=dtype, device=device) fused_position_ids = flow.tensor(position_ids, dtype=flow.int32, device=device) @@ -778,7 +778,7 @@ def _test_with_position( mode, ) - fused_x = flow.tensor(x, dtype=dtype, device=device) + fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) fused_position_ids = flow.tensor(position_ids, dtype=flow.int32, device=device) if x_layout == "BM(H3K)": @@ -935,7 +935,7 @@ def _test_plane( mode, ) - fused_x = flow.tensor(x, dtype=dtype, device=device) + fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) fused_position_ids = flow.tensor(position_ids, dtype=flow.int32, device=device) if x_layout == "MB(H3K)": From 9fe84e16190d825ebbf82826a9cc60d50da5e574 Mon Sep 17 00:00:00 2001 From: shw Date: Thu, 25 Apr 2024 20:19:58 +0800 Subject: [PATCH 03/14] add test_fused_rope --- .../modules/test_fused_rotary_embedding.py | 167 +++++++++++++++++- 1 file changed, 163 insertions(+), 4 deletions(-) diff --git a/python/oneflow/test/modules/test_fused_rotary_embedding.py b/python/oneflow/test/modules/test_fused_rotary_embedding.py index a7677c718cc..f44e28d43bd 100644 --- a/python/oneflow/test/modules/test_fused_rotary_embedding.py +++ b/python/oneflow/test/modules/test_fused_rotary_embedding.py @@ -22,7 +22,115 @@ import numpy as np import math +# tensor version: +def plane_shuffle_tensor(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return flow.cat((-x2, x1), dim=-1) + +def shuffle_adjacent_two_elem_tensor(x): + y = x.clone() # 使用.clone()以确保我们不会在原始输入上进行操作 + for i in range(x.shape[-1] // 2): + y[..., 2*i] = -x[..., 2*i + 1] + y[..., 2*i + 1] = x[..., 2*i] + return y + + +def parseDims_tensor(dims, x_layout): + B, M, H, K = 1, 1, 1, 1 # 初始化维度 + merged_dims = dims + if x_layout == "BHMK": + B, H, M, K = dims + elif x_layout == "BMHK": + B, M, H, K = dims + elif x_layout == "MBHK": + M, B, H, K = dims + elif x_layout == "BM(HK)": + B, M, H, K = dims + merged_dims = [dims[0], dims[1], dims[2] * dims[3]] # merge H and K + elif x_layout == "MB(HK)": + M, B, H, K = dims + merged_dims = [dims[0], dims[1], dims[2] * dims[3]] + elif x_layout == "BM(H3K)": + B, M, H, K = dims + merged_dims = [dims[0], dims[1], 3 * dims[2] * dims[3]] # merge and scale + elif x_layout == "MB(H3K)": + M, B, H, K = dims + merged_dims = [dims[0], dims[1], 3 * dims[2] * dims[3]] + + return B, M, H, K, merged_dims + + +def naive_embedding_tensor(x, cos, sin, x_layout, B, M, H, K, dims, merged_dims, rotary_size, rotary_ndims, mode): + naive_out = None + if mode == "plane": + if rotary_ndims == 2: + y1 = plane_shuffle_tensor(x[..., : rotary_size // 2]) + y2 = plane_shuffle_tensor(x[..., rotary_size // 2 : rotary_size]) + y3 = x[..., rotary_size:] + y = flow.cat((y1, y2, y3), dim=-1) + else: + y1 = plane_shuffle_tensor(x[..., :rotary_size]) + y2 = x[..., rotary_size:] + y = flow.cat((y1, y2), dim=-1) + else: + y = shuffle_adjacent_two_elem_tensor(x) + + if x_layout == "BHMK": + naive_out = x * cos + y * sin + elif x_layout == "BMHK": + naive_out = x.reshape(dims) * cos.reshape([B, M, 1, K]) + y.reshape( + dims + ) * sin.reshape( + [B, M, 1, K] + ) # un-merge + elif x_layout == "MBHK" or x_layout == "MB(HK)": + naive_out = x.reshape(dims) * cos.transpose([2, 0, 1, 3]).reshape( + [M, B, 1, K] + ) + y.reshape(dims) * sin.transpose([2, 0, 1, 3]).reshape( + [M, B, 1, K] + ) # un-merge + elif x_layout == "BM(HK)": + naive_out = x.reshape(dims) * cos.reshape([B, M, 1, K]) + y.reshape( + dims + ) * sin.reshape( + [B, M, 1, K] + ) # un-merge + elif x_layout == "BM(H3K)": + out0 = x[..., 0, :].reshape(dims) * cos.reshape([B, M, 1, K]) + y[ + ..., 0, : + ].reshape(dims) * sin.reshape([B, M, 1, K]) + out1 = x[..., 1, :].reshape(dims) * cos.reshape([B, M, 1, K]) + y[ + ..., 1, : + ].reshape(dims) * sin.reshape([B, M, 1, K]) + out2 = x[..., 2, :].reshape(dims) * cos.reshape([B, M, 1, K]) + y[ + ..., 2, : + ].reshape(dims) * sin.reshape([B, M, 1, K]) + + naive_out = flow.cat((out0, out1, out2), axis=-1) + elif x_layout == "MB(H3K)": + out0 = x[..., 0, :].reshape(dims) * cos.transpose([2, 0, 1, 3]).reshape( + [M, B, 1, K] + ) + y[..., 0, :].reshape(dims) * sin.transpose([2, 0, 1, 3]).reshape( + [M, B, 1, K] + ) + out1 = x[..., 1, :].reshape(dims) * cos.transpose([2, 0, 1, 3]).reshape( + [M, B, 1, K] + ) + y[..., 1, :].reshape(dims) * sin.transpose([2, 0, 1, 3]).reshape( + [M, B, 1, K] + ) + out2 = x[..., 2, :].reshape(dims) * cos.transpose([2, 0, 1, 3]).reshape( + [M, B, 1, K] + ) + y[..., 2, :].reshape(dims) * sin.transpose([2, 0, 1, 3]).reshape( + [M, B, 1, K] + ) + + naive_out = flow.cat((out0, out1, out2), axis=-1) + + return naive_out + + +# numpy version: def plane_shuffle(x): x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return np.concatenate((-x2, x1), axis=-1) @@ -571,7 +679,7 @@ def _test_with_position_sinuous( naive_x = x.reshape([B, M, H, -1, K]) elif x_layout == "MB(HK)" or x_layout == "MB(H2K)" or x_layout == "MB(H3K)": naive_x = x.reshape([M, B, H, -1, K]) - + naive_out = naive_embedding( naive_x, naive_cos, @@ -588,6 +696,41 @@ def _test_with_position_sinuous( mode, ) + naive_x_tensor = flow.tensor(naive_x, dtype=dtype, device=device, requires_grad=True) + naive_cos_tensor = flow.tensor(naive_cos, dtype=dtype, device=device) # 不用grad + naive_sin_tensor = flow.tensor(naive_sin, dtype=dtype, device=device) + naive_out_tensor = naive_embedding_tensor( + naive_x_tensor, + naive_cos_tensor, + naive_sin_tensor, + x_layout, + B, + M, + H, + K, + dims, + merged_dims, + rotary_size, + rotary_ndims, + mode, + ) + #print('naive_out_tensor') + #print(naive_out_tensor) # 验证这里的naive_out_tensor与naive_out是否有精度误差; + # get grad + test_case.assertTrue( + np.allclose( + naive_out.reshape(merged_dims), + naive_out_tensor.numpy().reshape(merged_dims), + atol=5e-2, + rtol=5e-3, + ) + ) + + naive_out_backward = naive_out_tensor.sum() + naive_out_backward.backward() + naive_out_grad = naive_x_tensor.grad + #print(naive_out_grad) + fused_cos = np.array( [ [ @@ -673,7 +816,8 @@ def _test_with_position_sinuous( tensor_index=2, ) - fused_out = np.concatenate((out0, out1, out2), axis=-1) + #fused_out = np.concatenate((out0, out1, out2), axis=-1) + fused_out = flow.cat((out0, out1, out2), dim = -1) else: fused_out = flow._C.fused_apply_rotary_emb( fused_x, @@ -685,12 +829,27 @@ def _test_with_position_sinuous( base=base, rotary_size=rotary_size, mode=mode, - ).numpy() + ) + fused_out_backward = fused_out.sum() + fused_out_backward.backward() + fused_out_grad = fused_out_backward.grad + #print(fused_out_grad) + # test forward test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), - fused_out.reshape(merged_dims), + fused_out.numpy().reshape(merged_dims), + atol=5e-2, + rtol=5e-3, + ) + ) + + # test backward + test_case.assertTrue( + np.allclose( + naive_out_grad.numpy().reshape(merged_dims), + fused_out_grad.numpy().reshape(merged_dims), atol=5e-2, rtol=5e-3, ) From 897d772da06985b88e2fbbb6e897cc14bb5170f1 Mon Sep 17 00:00:00 2001 From: shw Date: Sun, 28 Apr 2024 16:29:20 +0800 Subject: [PATCH 04/14] rope grad kernel start --- .../gradient_funcs/fused_apply_rotary_emb.cpp | 136 +++++++++ .../impl/fused_attention_functor.cpp | 98 ++++++- .../user/kernels/fused_attention_kernels.cu | 41 +++ oneflow/user/ops/fused_attention_ops.cpp | 32 ++- .../modules/test_fused_rotary_embedding.py | 272 +++++++++++++++--- 5 files changed, 522 insertions(+), 57 deletions(-) create mode 100644 oneflow/core/autograd/gradient_funcs/fused_apply_rotary_emb.cpp diff --git a/oneflow/core/autograd/gradient_funcs/fused_apply_rotary_emb.cpp b/oneflow/core/autograd/gradient_funcs/fused_apply_rotary_emb.cpp new file mode 100644 index 00000000000..2afc8f3d6d0 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/fused_apply_rotary_emb.cpp @@ -0,0 +1,136 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct FusedApplyRotaryEmbCaptureState : public AutoGradCaptureState { + bool requires_grad; // 输入x是否需要梯度 只有一个输入x; + std::string x_layout{}; + std::string output_layout{}; + std::string mode{}; + int64_t tensor_index{}; + int64_t k_size{}; + float base; + int64_t rotary_size{}; +}; + +class FusedApplyRotaryEmb : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override; + Maybe Capture(FusedApplyRotaryEmbCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override; + Maybe Apply(const FusedApplyRotaryEmbCaptureState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override; + + private: + AttrMap base_attrs_; +}; + +Maybe FusedApplyRotaryEmb::Init(const OpExpr& op) { // 是否需要实现存疑; + const UserOpExpr* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe::Ok(); +} + +Maybe FusedApplyRotaryEmb::Capture(FusedApplyRotaryEmbCaptureState* ctx, + const TensorTuple& inputs, const TensorTuple& outputs, + const AttrMap& attrs) const { + // 这里需要检查sin和cos同时出现,或同时不出现; + CHECK_OR_RETURN((inputs.size() >= 1) && (inputs.size() <= 4)); // 这里的输入应该是 1 - 4; + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (!ctx->requires_grad) { return Maybe::Ok(); } // 如果不需要梯度,也就不需要求导了,直接返回 Maybe::Ok() + + ComposedAttrMap composed_attrs(attrs, base_attrs_); // 写法确认; + ctx->SaveTensorForBackward(inputs.at(0)); + if (inputs.size() == 2) // position_ids + ctx->SaveTensorForBackward(inputs.at(1)); + if (inputs.size() == 3) { // cos, sin + ctx->SaveTensorForBackward(inputs.at(1)); + ctx->SaveTensorForBackward(inputs.at(2)); + } + + if (inputs.size() == 4) { // cos, sin, position_ids; + ctx->SaveTensorForBackward(inputs.at(1)); + ctx->SaveTensorForBackward(inputs.at(2)); + ctx->SaveTensorForBackward(inputs.at(3)); + } + + ctx->x_layout = JUST(composed_attrs.GetAttr("x_layout")); + ctx->output_layout = JUST(composed_attrs.GetAttr("output_layout")); + ctx->mode = JUST(composed_attrs.GetAttr("mode")); + ctx->tensor_index = JUST(composed_attrs.GetAttr("tensor_index")); + ctx->k_size = JUST(composed_attrs.GetAttr("k_size")); + ctx->base = JUST(composed_attrs.GetAttr("base")); + ctx->rotary_size = JUST(composed_attrs.GetAttr("rotary_size")); + + return Maybe::Ok(); +} + +Maybe FusedApplyRotaryEmb::Apply(const FusedApplyRotaryEmbCaptureState* ctx, + const TensorTuple& out_grads, TensorTuple* in_grads) const { + CHECK_EQ_OR_RETURN(out_grads.size(), 1); // 检查梯度 Tensor 个数是否为 1 TODO: 不确定是否输入为1 -- (dy) + // in_grads->resize(1); // 这里不能resize, 需要和input_meta_data_.size() 一致; + const auto& saved_tensors = ctx->SavedTensors(); + + CHECK_OR_RETURN((saved_tensors.size() >= 1) && (saved_tensors.size() <= 4)); + // 输出backward拿到的参数 + if (ctx->requires_grad) { + if (saved_tensors.size() == 1) { + const auto& x = ctx->SavedTensors().at(0); + in_grads->at(0) = JUST(functional::FusedApplyRotaryEmbGrad(x, out_grads.at(0), NullOpt/*cos*/, + NullOpt/*sin*/, NullOpt/*position_ids*/, + ctx->x_layout, ctx->output_layout, ctx->mode, ctx->tensor_index, ctx->k_size, ctx->base, ctx->rotary_size)); + } + + if (saved_tensors.size() == 2) { + const auto& x = ctx->SavedTensors().at(0); + const auto& position_ids = ctx->SavedTensors().at(1); + in_grads->at(0) = JUST(functional::FusedApplyRotaryEmbGrad(x, out_grads.at(0), NullOpt/*cos*/, + NullOpt/*sin*/, position_ids, + ctx->x_layout, ctx->output_layout, ctx->mode, ctx->tensor_index, ctx->k_size, ctx->base, ctx->rotary_size)); + } + + if (saved_tensors.size() == 3) { + const auto& x = ctx->SavedTensors().at(0); + const auto& cos = ctx->SavedTensors().at(1); + const auto& sin = ctx->SavedTensors().at(2); + + in_grads->at(0) = JUST(functional::FusedApplyRotaryEmbGrad(x, out_grads.at(0), cos, sin, NullOpt/*position_ids*/, + ctx->x_layout, ctx->output_layout, ctx->mode, ctx->tensor_index, ctx->k_size, ctx->base, ctx->rotary_size)); + } + + if (saved_tensors.size() == 4) { + const auto& x = ctx->SavedTensors().at(0); // 调用 SavedTensors 接口,拿到先前通过 SaveTensorForBackward 接口保存的 Tensor + const auto& cos = ctx->SavedTensors().at(1); + const auto& sin = ctx->SavedTensors().at(2); + const auto& position_ids = ctx->SavedTensors().at(3); + in_grads->at(0) = JUST(functional::FusedApplyRotaryEmbGrad(x, out_grads.at(0), cos, sin, position_ids, + ctx->x_layout, ctx->output_layout, ctx->mode, ctx->tensor_index, ctx->k_size, ctx->base, ctx->rotary_size)); + } + } + + return Maybe::Ok(); +} + +REGISTER_OP_EXPR_GRAD_FUNCTION("fused_apply_rotary_emb", FusedApplyRotaryEmb); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/functional/impl/fused_attention_functor.cpp b/oneflow/core/functional/impl/fused_attention_functor.cpp index 4c7c9dd97f1..7caf7606137 100644 --- a/oneflow/core/functional/impl/fused_attention_functor.cpp +++ b/oneflow/core/functional/impl/fused_attention_functor.cpp @@ -644,7 +644,6 @@ class FusedApplyRotaryEmbFunctor { const Optional& tensor_index, const Optional& k_size, const float base, const Optional& rotary_size) const { int64_t b = 0, m = 0, h = 0, k = 0; - // std::cout << "go here!" << std::endl; if (tensor_index) { CHECK_OR_RETURN((JUST(tensor_index) >= 0) && (JUST(tensor_index) <= 2)) @@ -654,8 +653,7 @@ class FusedApplyRotaryEmbFunctor { << "mode should be \"intervel\" or \"plane\""; ParseDims("x", *x->shape(), x_layout, Optional(), k_size, &b, &m, &h, &k); - // std::cout << "cpp head_size: " << h << std::endl; - // printf("k_size: %ld\n", k_size); + if (k_size) { CHECK_EQ_OR_RETURN(JUST(k_size), k) << "k_size if given should be equal to K of cos, sin and x."; @@ -711,7 +709,6 @@ class FusedApplyRotaryEmbFunctor { attrs.SetAllAttrs(x_layout, output_layout.value_or(x_layout), mode, tensor_index.value_or(0), k_size.value_or(k), base, rotary_size.value_or(k)); - // std::cout << "dispatch!" << std::endl; if (position_ids) { if (cos && sin) { return OpInterpUtil::Dispatch(*op_with_position_sinuous_, @@ -745,25 +742,25 @@ class FusedApplyRotaryEmbGradFunctor { .Input("cos") .Input("sin") .Input("position_ids") - .Output("out") + .Output("dx") .Build()); op_with_position_ = CHECK_JUST(one::OpBuilder("fused_apply_rotary_emb_grad") .Input("x") .Input("dy") .Input("position_ids") - .Output("out") + .Output("dx") .Build()); op_without_position_ = CHECK_JUST(one::OpBuilder("fused_apply_rotary_emb_grad") .Input("x") .Input("dy") .Input("cos") .Input("sin") - .Output("out") + .Output("dx") .Build()); op_without_position_sinuous_ = CHECK_JUST(one::OpBuilder("fused_apply_rotary_emb_grad").Input("x") .Input("dy") - .Output("out") + .Output("dx") .Build()); } Maybe operator()(const std::shared_ptr& x, @@ -774,8 +771,89 @@ class FusedApplyRotaryEmbGradFunctor { const Optional& output_layout, const std::string& mode, const Optional& tensor_index, const Optional& k_size, const float base, const Optional& rotary_size) const { - std::cout << "FusedApplyRotaryEmbGradFunctor" << std::endl; - return dy; + int64_t b = 0, m = 0, h = 0, k = 0; + + if (tensor_index) { + CHECK_OR_RETURN((JUST(tensor_index) >= 0) && (JUST(tensor_index) <= 2)) + << "tensor_index should be set between [0, 2]"; + } + CHECK_OR_RETURN((mode == "interval") || (mode == "plane")) + << "mode should be \"intervel\" or \"plane\""; + + ParseDims("x", *x->shape(), x_layout, Optional(), k_size, &b, &m, &h, &k); + + if (k_size) { + CHECK_EQ_OR_RETURN(JUST(k_size), k) + << "k_size if given should be equal to K of cos, sin and x."; + } + if (rotary_size) { + CHECK_LE_OR_RETURN(JUST(rotary_size), k) << "rotary_size should be no more than k."; + } + + int64_t rotary_emd_dim = 1; + + if (position_ids) { + CHECK_EQ_OR_RETURN(JUST(position_ids)->shape()->NumAxes(), 3) + << "ndims of position_ids should be equal to 3, either in form of B1M or B2M."; + CHECK_EQ_OR_RETURN(JUST(position_ids)->shape()->At(0), b) + << "1st dim of position_ids should be equal to B."; + CHECK_EQ_OR_RETURN(JUST(position_ids)->shape()->At(2), m) + << "3rd dim of position_ids should be equal to M."; + rotary_emd_dim = JUST(position_ids)->shape()->At(1); + CHECK_OR_RETURN(rotary_emd_dim == 1 || rotary_emd_dim == 2) + << "2nd dim of position_ids should be 1 or 2."; + } + + const int64_t actual_rotary_size = rotary_size.value_or(k) / rotary_emd_dim; + CHECK_EQ_OR_RETURN(actual_rotary_size % 2, 0) + << "k ,or rotary_size if given, should be a multiple of 2 * rotary_encoding_dim."; + + if (cos && sin) { + CHECK_EQ_OR_RETURN(JUST(cos)->shape()->NumAxes(), 2) + << "The number of dimensions of cos should be equal to 2."; + CHECK_OR_RETURN(JUST(cos)->shape() == JUST(sin)->shape()) + << "Each dimension of cos & sin should be the same."; + CHECK_EQ_OR_RETURN(JUST(cos)->shape()->At(1), actual_rotary_size) + << "The 1st dimension of cos & sin should equal to rotary_size // " + "rotary_embedding_dimension."; + } else if (!cos && !sin) { + // do nothing + } else { + UNIMPLEMENTED_THEN_RETURN() << "cos & sin should both be given or not given."; + } + + if (!position_ids) { + if (cos && sin) { + CHECK_GE_OR_RETURN(JUST(cos)->shape()->At(0), m) + << "M of cos & sin should be to no less than " + "M of x when position_ids is not " + "given."; // K of cos & sin is checked + // inside ParseDims + } + } + + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("x_layout", "output_layout", "mode", + "tensor_index", "k_size", "base", "rotary_size"); + attrs.SetAllAttrs(x_layout, output_layout.value_or(x_layout), mode, tensor_index.value_or(0), + k_size.value_or(k), base, rotary_size.value_or(k)); + + if (position_ids) { + if (cos && sin) { + return OpInterpUtil::Dispatch(*op_with_position_sinuous_, + {x, dy, JUST(cos), JUST(sin), JUST(position_ids)}, attrs); + } else { + return OpInterpUtil::Dispatch(*op_with_position_, {x, dy, JUST(position_ids)}, attrs); + } + } else { + if (cos && sin) { + return OpInterpUtil::Dispatch(*op_without_position_, {x, dy, JUST(cos), JUST(sin)}, + attrs); + } else { + return OpInterpUtil::Dispatch(*op_without_position_sinuous_, {x, dy}, attrs); + } + } + + return dy; } private: diff --git a/oneflow/user/kernels/fused_attention_kernels.cu b/oneflow/user/kernels/fused_attention_kernels.cu index d432a1f30a8..dc7d1e8d303 100644 --- a/oneflow/user/kernels/fused_attention_kernels.cu +++ b/oneflow/user/kernels/fused_attention_kernels.cu @@ -1405,6 +1405,22 @@ class FusedApplyRotaryEmbKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; +template +class FusedApplyRotaryEmbGradKernel final : public user_op::OpKernel { + public: + FusedApplyRotaryEmbGradKernel() = default; + ~FusedApplyRotaryEmbGradKernel() override = default; + + private: + using user_op::OpKernel::Compute; + void Compute(user_op::KernelComputeContext* ctx) const override { + printf("%s\n", "FusedApplyRotaryEmbGradKernel"); + } + + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + + #define REGISTER_FUSED_APPLY_ROTARY_EMB_GPU(dtype, position_type) \ REGISTER_USER_KERNEL("fused_apply_rotary_emb") \ .SetCreateFn>() \ @@ -1429,6 +1445,31 @@ REGISTER_FUSED_APPLY_ROTARY_EMB_GPU_DTYPE(half); REGISTER_FUSED_APPLY_ROTARY_EMB_GPU_DTYPE(nv_bfloat16); #endif // CUDA_VERSION >= 11000 + +#define REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU(dtype, position_type) \ + REGISTER_USER_KERNEL("fused_apply_rotary_emb_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob( \ + (user_op::HobDeviceType() == DeviceType::kCUDA) \ + && (user_op::HobDataType("dx", 0) == GetDataType::value) \ + && (user_op::HobInputSize("position_ids") == 1) \ + && (user_op::HobDataType("position_ids", 0) == GetDataType::value)); + +#define REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU_DTYPE(dtype) \ + REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU(dtype, int64_t); \ + REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU(dtype, int32_t); \ + REGISTER_USER_KERNEL("fused_apply_rotary_emb_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ + && (user_op::HobDataType("dx", 0) == GetDataType::value) \ + && (user_op::HobInputSize("position_ids") == 0)); + +REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU_DTYPE(float); +REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU_DTYPE(half); +#if CUDA_VERSION >= 11000 +REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU_DTYPE(nv_bfloat16); +#endif // CUDA_VERSION >= 11000 + } // namespace } // namespace user_op diff --git a/oneflow/user/ops/fused_attention_ops.cpp b/oneflow/user/ops/fused_attention_ops.cpp index d02cf2e1950..a1b65464756 100644 --- a/oneflow/user/ops/fused_attention_ops.cpp +++ b/oneflow/user/ops/fused_attention_ops.cpp @@ -815,6 +815,8 @@ Maybe ParseSplitAxis(const std::string& layout, bool can_hk_split, int64_t const int64_t k_size = ctx->Attr("k_size"); const int64_t tensor_index = ctx->Attr("tensor_index"); + auto Inputs = ctx->inputs(); + CHECK_OR_RETURN((tensor_index >= 0) && (tensor_index <= 2)) << "tensor_index should be in range [0, 2]."; CHECK_OR_RETURN((mode == "interval") || (mode == "plane")) @@ -833,18 +835,20 @@ Maybe ParseSplitAxis(const std::string& layout, bool can_hk_split, int64_t int64_t rotary_emb_dim = 1; + if (ctx->has_input("position_ids", 0)) { - const user_op::TensorDesc& position_ids_desc = ctx->InputTensorDesc("position_ids", 0); - CHECK_EQ_OR_RETURN(position_ids_desc.shape().NumAxes(), 3) + const Shape& position_id_shape = ctx->InputShape("position_ids", 0); + CHECK_EQ_OR_RETURN(position_id_shape.NumAxes(), 3) << "ndims of position_ids should be equal to 3, either in form of B1M or B2M."; - CHECK_EQ_OR_RETURN(position_ids_desc.shape().At(0), b) + CHECK_EQ_OR_RETURN(position_id_shape.At(0), b) << "1st dim of position_ids should be equal to B."; - CHECK_EQ_OR_RETURN(position_ids_desc.shape().At(2), m) + CHECK_EQ_OR_RETURN(position_id_shape.At(2), m) << "3rd dim of position_ids should be equal to M."; - rotary_emb_dim = position_ids_desc.shape().At(1); + rotary_emb_dim = position_id_shape.At(1); CHECK_OR_RETURN(rotary_emb_dim == 1 || rotary_emb_dim == 2) << "2nd dim of position_ids should be 1 or 2."; } + // 这里是重复检查,且会报错; const int64_t actual_rotary_size = rotary_size / rotary_emb_dim; CHECK_EQ_OR_RETURN(actual_rotary_size % 2, 0) @@ -854,13 +858,14 @@ Maybe ParseSplitAxis(const std::string& layout, bool can_hk_split, int64_t bool has_sin = ctx->has_input("sin", 0); // TODO: fused_apply_rotary_emb have same logic no matter name if (has_cos && has_sin) { - const user_op::TensorDesc& cos_desc = ctx->InputTensorDesc("cos", 0); - const user_op::TensorDesc& sin_desc = ctx->InputTensorDesc("sin", 0); - CHECK_EQ_OR_RETURN(cos_desc.shape().NumAxes(), 2) + const Shape& cos_shape = ctx->InputShape("cos", 0); + const Shape& sin_shape = ctx->InputShape("sin", 0); + CHECK_EQ_OR_RETURN(cos_shape.NumAxes(), 2) << "The number of dimensions of cos should be equal to 2."; - CHECK_OR_RETURN(cos_desc.shape() == sin_desc.shape()) + + CHECK_OR_RETURN(cos_shape == sin_shape) << "The dimensions of cos & sin should be the same."; - CHECK_EQ_OR_RETURN(cos_desc.shape().At(1), actual_rotary_size) + CHECK_EQ_OR_RETURN(cos_shape.At(1), actual_rotary_size) << "The 1st dimension of cos & sin should equal to rotary_size // " "rotary_embedding_dimension."; } else if (!has_cos && !has_sin) { @@ -871,8 +876,9 @@ Maybe ParseSplitAxis(const std::string& layout, bool can_hk_split, int64_t if (!ctx->has_input("position_ids", 0)) { if (has_cos && has_sin) { - const user_op::TensorDesc& cos_desc = ctx->InputTensorDesc("cos", 0); - CHECK_GE_OR_RETURN(cos_desc.shape().At(0), m) + // const user_op::TensorDesc& cos_desc = ctx->InputTensorDesc("cos", 0); + const Shape& cos_shape = ctx->InputShape("cos", 0); + CHECK_GE_OR_RETURN(cos_shape.At(0), m) << "M of cos should be no less than M of x if position_ids is not given."; // K of cos & sin is checked inside ParseDims } @@ -969,7 +975,7 @@ Maybe ParseSplitAxis(const std::string& layout, bool can_hk_split, int64_t } /* static */ Maybe FusedApplyRotaryEmbGradOp::InferDataType(user_op::InferContext* ctx) { - const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc("x", 0); bool has_sinuous = ctx->has_input("cos", 0); diff --git a/python/oneflow/test/modules/test_fused_rotary_embedding.py b/python/oneflow/test/modules/test_fused_rotary_embedding.py index f44e28d43bd..7b9431b9f5b 100644 --- a/python/oneflow/test/modules/test_fused_rotary_embedding.py +++ b/python/oneflow/test/modules/test_fused_rotary_embedding.py @@ -85,9 +85,9 @@ def naive_embedding_tensor(x, cos, sin, x_layout, B, M, H, K, dims, merged_dims, [B, M, 1, K] ) # un-merge elif x_layout == "MBHK" or x_layout == "MB(HK)": - naive_out = x.reshape(dims) * cos.transpose([2, 0, 1, 3]).reshape( + naive_out = x.reshape(dims) * cos.permute([2, 0, 1, 3]).reshape( [M, B, 1, K] - ) + y.reshape(dims) * sin.transpose([2, 0, 1, 3]).reshape( + ) + y.reshape(dims) * sin.permute([2, 0, 1, 3]).reshape( [M, B, 1, K] ) # un-merge elif x_layout == "BM(HK)": @@ -109,23 +109,25 @@ def naive_embedding_tensor(x, cos, sin, x_layout, B, M, H, K, dims, merged_dims, naive_out = flow.cat((out0, out1, out2), axis=-1) elif x_layout == "MB(H3K)": - out0 = x[..., 0, :].reshape(dims) * cos.transpose([2, 0, 1, 3]).reshape( + #test = cos.transpose([2, 0, 1, 3]).reshape( [M, B, 1, K] ) + #test = cos.permute([2, 0, 1, 3]) + out0 = x[..., 0, :].reshape(dims) * cos.permute([2, 0, 1, 3]).reshape( [M, B, 1, K] - ) + y[..., 0, :].reshape(dims) * sin.transpose([2, 0, 1, 3]).reshape( + ) + y[..., 0, :].reshape(dims) * sin.permute([2, 0, 1, 3]).reshape( [M, B, 1, K] ) - out1 = x[..., 1, :].reshape(dims) * cos.transpose([2, 0, 1, 3]).reshape( + out1 = x[..., 1, :].reshape(dims) * cos.permute([2, 0, 1, 3]).reshape( [M, B, 1, K] - ) + y[..., 1, :].reshape(dims) * sin.transpose([2, 0, 1, 3]).reshape( + ) + y[..., 1, :].reshape(dims) * sin.permute([2, 0, 1, 3]).reshape( [M, B, 1, K] ) - out2 = x[..., 2, :].reshape(dims) * cos.transpose([2, 0, 1, 3]).reshape( + out2 = x[..., 2, :].reshape(dims) * cos.permute([2, 0, 1, 3]).reshape( [M, B, 1, K] - ) + y[..., 2, :].reshape(dims) * sin.transpose([2, 0, 1, 3]).reshape( + ) + y[..., 2, :].reshape(dims) * sin.permute([2, 0, 1, 3]).reshape( [M, B, 1, K] ) - naive_out = flow.cat((out0, out1, out2), axis=-1) + naive_out = flow.cat((out0, out1, out2), dim=-1) return naive_out @@ -359,6 +361,39 @@ def _test_without_position( mode, ) + naive_x_tensor = flow.tensor(naive_x, dtype=dtype, device=device, requires_grad=True) + naive_cos_tensor = flow.tensor(naive_cos, dtype=dtype, device=device) # 不用grad + naive_sin_tensor = flow.tensor(naive_sin, dtype=dtype, device=device) + naive_out_tensor = naive_embedding_tensor( + naive_x_tensor, + naive_cos_tensor, + naive_sin_tensor, + x_layout, + B, + M, + H, + K, + dims, + merged_dims, + rotary_size, + rotary_ndims, + mode, + ) + + # 验证这里的naive_out_tensor与naive_out是否有精度误差; + test_case.assertTrue( + np.allclose( + naive_out.reshape(merged_dims), + naive_out_tensor.numpy().reshape(merged_dims), + atol=5e-2, + rtol=5e-3, + ) + ) + + naive_out_backward = naive_out_tensor.sum() + naive_out_backward.backward() + naive_out_grad = naive_x_tensor.grad + fused_cos = np.array( [ [ @@ -442,7 +477,8 @@ def _test_without_position( tensor_index=2, ) - fused_out = np.concatenate((out0, out1, out2), axis=-1) + # fused_out = np.concatenate((out0, out1, out2), axis=-1) + fused_out = flow.cat((out0, out1, out2), dim = -1) else: fused_out = flow._C.fused_apply_rotary_emb( fused_x, @@ -454,12 +490,27 @@ def _test_without_position( base=base, rotary_size=rotary_size, mode=mode, - ).numpy() - + ) + + fused_out_backward = fused_out.sum() + fused_out_backward.backward() + fused_out_grad = fused_out_backward.grad + #print(fused_out_grad) + # test forward test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), - fused_out.reshape(merged_dims), + fused_out.numpy().reshape(merged_dims), + atol=5e-2, + rtol=5e-3, + ) + ) + + # test backward + test_case.assertTrue( + np.allclose( + naive_out_grad.numpy().reshape(merged_dims), + fused_out_grad.numpy().reshape(merged_dims), atol=5e-2, rtol=5e-3, ) @@ -542,9 +593,41 @@ def _test_without_position_sinuous( rotary_ndims, mode, ) + naive_x_tensor = flow.tensor(naive_x, dtype=dtype, device=device, requires_grad=True) + naive_cos_tensor = flow.tensor(naive_cos, dtype=dtype, device=device) # 不用grad + naive_sin_tensor = flow.tensor(naive_sin, dtype=dtype, device=device) + naive_out_tensor = naive_embedding_tensor( + naive_x_tensor, + naive_cos_tensor, + naive_sin_tensor, + x_layout, + B, + M, + H, + K, + dims, + merged_dims, + rotary_size, + rotary_ndims, + mode, + ) - fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) + # 验证这里的naive_out_tensor与naive_out是否有精度误差; + test_case.assertTrue( + np.allclose( + naive_out.reshape(merged_dims), + naive_out_tensor.numpy().reshape(merged_dims), + atol=5e-2, + rtol=5e-3, + ) + ) + + naive_out_backward = naive_out_tensor.sum() + naive_out_backward.backward() + naive_out_grad = naive_x_tensor.grad + fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) + if x_layout == "BM(H3K)": out0 = flow._C.fused_apply_rotary_emb( fused_x, @@ -586,7 +669,8 @@ def _test_without_position_sinuous( tensor_index=2, ) - fused_out = np.concatenate((out0, out1, out2), axis=-1) + #fused_out = np.concatenate((out0, out1, out2), axis=-1) + fused_out = flow.cat((out0, out1, out2), dim = -1) else: fused_out = flow._C.fused_apply_rotary_emb( fused_x, @@ -598,12 +682,27 @@ def _test_without_position_sinuous( base=base, rotary_size=rotary_size, mode=mode, - ).numpy() - + ) + + fused_out_backward = fused_out.sum() + fused_out_backward.backward() + fused_out_grad = fused_out_backward.grad + #print(fused_out_grad) + # test forward test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), - fused_out.reshape(merged_dims), + fused_out.numpy().reshape(merged_dims), + atol=5e-2, + rtol=5e-3, + ) + ) + + # test backward + test_case.assertTrue( + np.allclose( + naive_out_grad.numpy().reshape(merged_dims), + fused_out_grad.numpy().reshape(merged_dims), atol=5e-2, rtol=5e-3, ) @@ -714,9 +813,8 @@ def _test_with_position_sinuous( rotary_ndims, mode, ) - #print('naive_out_tensor') - #print(naive_out_tensor) # 验证这里的naive_out_tensor与naive_out是否有精度误差; - # get grad + + # 验证这里的naive_out_tensor与naive_out是否有精度误差; test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), @@ -937,6 +1035,39 @@ def _test_with_position( mode, ) + naive_x_tensor = flow.tensor(naive_x, dtype=dtype, device=device, requires_grad=True) + naive_cos_tensor = flow.tensor(naive_cos, dtype=dtype, device=device) # 不用grad + naive_sin_tensor = flow.tensor(naive_sin, dtype=dtype, device=device) + naive_out_tensor = naive_embedding_tensor( + naive_x_tensor, + naive_cos_tensor, + naive_sin_tensor, + x_layout, + B, + M, + H, + K, + dims, + merged_dims, + rotary_size, + rotary_ndims, + mode, + ) + + # 验证这里的naive_out_tensor与naive_out是否有精度误差; + test_case.assertTrue( + np.allclose( + naive_out.reshape(merged_dims), + naive_out_tensor.numpy().reshape(merged_dims), + atol=5e-2, + rtol=5e-3, + ) + ) + + naive_out_backward = naive_out_tensor.sum() + naive_out_backward.backward() + naive_out_grad = naive_x_tensor.grad + fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) fused_position_ids = flow.tensor(position_ids, dtype=flow.int32, device=device) @@ -981,7 +1112,8 @@ def _test_with_position( tensor_index=2, ) - fused_out = np.concatenate((out0, out1, out2), axis=-1) + #fused_out = np.concatenate((out0, out1, out2), axis=-1) + fused_out = flow.cat((out0, out1, out2), dim=-1) else: fused_out = flow._C.fused_apply_rotary_emb( fused_x, @@ -993,12 +1125,27 @@ def _test_with_position( base=base, rotary_size=rotary_size, mode=mode, - ).numpy() + ) + fused_out_backward = fused_out.sum() + fused_out_backward.backward() + fused_out_grad = fused_out_backward.grad + #print(fused_out_grad) + # test forward test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), - fused_out.reshape(merged_dims), + fused_out.numpy().reshape(merged_dims), + atol=5e-2, + rtol=5e-3, + ) + ) + + # test backward + test_case.assertTrue( + np.allclose( + naive_out_grad.numpy().reshape(merged_dims), + fused_out_grad.numpy().reshape(merged_dims), atol=5e-2, rtol=5e-3, ) @@ -1094,9 +1241,43 @@ def _test_plane( mode, ) + naive_x_tensor = flow.tensor(naive_x, dtype=dtype, device=device, requires_grad=True) + naive_cos_tensor = flow.tensor(naive_cos, dtype=dtype, device=device) # 不用grad + naive_sin_tensor = flow.tensor(naive_sin, dtype=dtype, device=device) + naive_out_tensor = naive_embedding_tensor( + naive_x_tensor, + naive_cos_tensor, + naive_sin_tensor, + x_layout, + B, + M, + H, + K, + dims, + merged_dims, + rotary_size, + rotary_ndims, + mode, + ) + + # 验证这里的naive_out_tensor与naive_out是否有精度误差; + test_case.assertTrue( + np.allclose( + naive_out.reshape(merged_dims), + naive_out_tensor.numpy().reshape(merged_dims), + atol=5e-2, + rtol=5e-3, + ) + ) + + naive_out_backward = naive_out_tensor.sum() + naive_out_backward.backward() + naive_out_grad = naive_x_tensor.grad + fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) fused_position_ids = flow.tensor(position_ids, dtype=flow.int32, device=device) + if x_layout == "MB(H3K)": out0 = flow._C.fused_apply_rotary_emb( fused_x, @@ -1138,7 +1319,7 @@ def _test_plane( tensor_index=2, ) - fused_out = np.concatenate((out0, out1, out2), axis=-1) + fused_out = flow.cat((out0, out1, out2), dim=-1) else: fused_out = flow._C.fused_apply_rotary_emb( fused_x, @@ -1150,18 +1331,34 @@ def _test_plane( base=base, rotary_size=rotary_size, mode=mode, - ).numpy() + ) + fused_out_backward = fused_out.sum() + fused_out_backward.backward() + fused_out_grad = fused_out_backward.grad + #print(fused_out_grad) + # test forward test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), - fused_out.reshape(merged_dims), + fused_out.numpy().reshape(merged_dims), + atol=5e-2, + rtol=5e-3, + ) + ) + + # test backward + test_case.assertTrue( + np.allclose( + naive_out_grad.numpy().reshape(merged_dims), + fused_out_grad.numpy().reshape(merged_dims), atol=5e-2, rtol=5e-3, ) ) + """ 1. if cos&sin is given, then base will not be used 2. if cos&sin is not given, then any form of x_layout which cannot infer the dimension of k is not allowed, e.g. BM(HK) @@ -1173,10 +1370,12 @@ def _test_plane( @flow.unittest.skip_unless_1n1d() class TestFusedRotaryEmbedding(flow.unittest.TestCase): # because rule no.2, kernels without cos&sin cannot work under specific x_layout + def test_fused_rotary_embedding_op_plane(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_plane] - args_dict["x_layout"] = ["MB(H3K)", "MB(HK)"] + #args_dict["x_layout"] = ["MB(H3K)", "MB(HK)"] + args_dict["x_layout"] = ["BMHK"] args_dict["mode"] = ["plane"] args_dict["base"] = [1e1] args_dict["rotary_size"] = [4, 8] @@ -1189,10 +1388,14 @@ def test_fused_rotary_embedding_op_plane(test_case): for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) + + def test_fused_rotary_embedding_op_interval_2d(test_case): args_dict = OrderedDict() - args_dict["test_fun"] = [_test_with_position, _test_with_position_sinuous] + args_dict["test_fun"] = [_test_with_position, + _test_with_position_sinuous + ] args_dict["x_layout"] = ["BMHK"] args_dict["mode"] = ["interval"] args_dict["base"] = [1e1] @@ -1206,17 +1409,18 @@ def test_fused_rotary_embedding_op_interval_2d(test_case): for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) + def test_fused_rotary_embedding_op_interval_1d(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [ - _test_without_position_sinuous, + #_test_without_position_sinuous, _test_without_position, - _test_with_position, - _test_with_position_sinuous, + #_test_with_position, + #_test_with_position_sinuous, ] args_dict["x_layout"] = ["BMHK"] - args_dict["mode"] = ["interval"] + args_dict["mode"] = ["plane"] args_dict["base"] = [1e1] args_dict["rotary_size"] = [4] args_dict["dims"] = [(3, 2, 5, 8)] @@ -1228,7 +1432,7 @@ def test_fused_rotary_embedding_op_interval_1d(test_case): for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) - + if __name__ == "__main__": unittest.main() From ef72669433ac41ea969f573e4b24b469f44a6e16 Mon Sep 17 00:00:00 2001 From: shw Date: Tue, 30 Apr 2024 20:34:40 +0800 Subject: [PATCH 05/14] plane grad done! --- .../impl/fused_attention_functor.cpp | 2 - .../user/kernels/fused_attention_kernels.cu | 183 ++++++++++++++++-- .../modules/test_fused_rotary_embedding.py | 23 ++- 3 files changed, 181 insertions(+), 27 deletions(-) diff --git a/oneflow/core/functional/impl/fused_attention_functor.cpp b/oneflow/core/functional/impl/fused_attention_functor.cpp index 7caf7606137..d8cf614f8d0 100644 --- a/oneflow/core/functional/impl/fused_attention_functor.cpp +++ b/oneflow/core/functional/impl/fused_attention_functor.cpp @@ -852,8 +852,6 @@ class FusedApplyRotaryEmbGradFunctor { return OpInterpUtil::Dispatch(*op_without_position_sinuous_, {x, dy}, attrs); } } - - return dy; } private: diff --git a/oneflow/user/kernels/fused_attention_kernels.cu b/oneflow/user/kernels/fused_attention_kernels.cu index dc7d1e8d303..0cb49437c95 100644 --- a/oneflow/user/kernels/fused_attention_kernels.cu +++ b/oneflow/user/kernels/fused_attention_kernels.cu @@ -1128,6 +1128,15 @@ __global__ void IntervalKernel( } } +template +__global__ void IntervalGradKernel( + FusedApplyRotaryEmbParam param) { + printf("IntervalGradKernel TODO!\n"); + } +} + + template __global__ void PlaneKernel( @@ -1199,6 +1208,86 @@ __global__ void PlaneKernel( } } +template +__global__ void PlaneGradKernel( + FusedApplyRotaryEmbParam param) { + for (IndexType offset = threadIdx.x + blockIdx.x * blockDim.x; offset < param.num_elements; + offset += blockDim.x * gridDim.x) { + using LoadPack = cuda::elementwise::Packed; + IndexType temp_offset = offset; + IndexType index[num_dims]; +#pragma unroll + for (int i = 0; i < num_dims - 1; i++) { + IndexType ref_stride = param.ref_stride[i]; + IndexType idx = temp_offset / ref_stride; + index[i] = idx; + temp_offset = temp_offset - idx * ref_stride; + } + index[num_dims - 1] = temp_offset; + + const IndexType b_index = index[0], m_index = index[1], k_index = index[num_dims - 1]; + const IndexType position_rotate_index = (k_index >= param.k0) ? 1 : 0; + const IndexType position_id_offset = b_index * param.position_b_stride + + position_rotate_index * param.position_rotate_stride + + m_index; + + const PositionType position = + param.position_ids ? param.position_ids[position_id_offset] : m_index; + const IndexType actual_k_index = k_index % param.actual_rotary_size; + const IndexType sinuous_offset = position * param.k + actual_k_index; + + T cos_val, sin_val, out_val; + + if (param.cos && param.sin) { + cos_val = *(param.cos + sinuous_offset); + IndexType offset_; // 针对grad, sin_val需要有 size / 2的偏移; + if (k_index < param.k0) { + offset_ = (param.k0 - k_index > param.rotate_stride) ? param.rotate_stride : -param.rotate_stride; + } else if (k_index < param.k1) { + offset_ = (param.k1 - k_index > param.rotate_stride) ? param.rotate_stride : -param.rotate_stride; + } + + sin_val = *(param.sin + sinuous_offset + offset_); + } else { + // TODO: plane grad kernel without sin & cos; + + T val = position + * expf(2.0f * static_cast(k_index % (param.actual_rotary_size >> 1)) + * param.inv_actual_rotary_size * logf(param.theta)); + cos_val = cosf(val); + sin_val = sinf(val); + } + + LoadPack x_vec; + IndexType x_offset = param.x_offset; + IndexType out_offset = 0; +#pragma unroll + for (int i = 0; i < num_dims; i++) { + x_offset = x_offset + param.x_stride[i] * index[i]; + out_offset = out_offset + param.out_stride[i] * index[i]; + } + + if (k_index < param.k0) { + x_vec.elem[0] = *(param.x + x_offset); + x_vec.elem[1] = (param.k0 - k_index > param.rotate_stride) + ? static_cast(*(param.x + x_offset + param.rotate_stride)) + : -*(param.x + x_offset - param.rotate_stride); + out_val = cos_val * x_vec.elem[0] + sin_val * x_vec.elem[1]; + } else if (k_index < param.k1) { + x_vec.elem[0] = *(param.x + x_offset); + x_vec.elem[1] = (param.k1 - k_index > param.rotate_stride) + ? static_cast(*(param.x + x_offset + param.rotate_stride)) + : -*(param.x + x_offset - param.rotate_stride); + out_val = cos_val * x_vec.elem[0] + sin_val * x_vec.elem[1]; + } else { + out_val = *(param.x + x_offset); + } + + *(param.out + out_offset) = out_val; + } +} + template void LaunchKernel(ep::CudaStream* stream, const T* x, const T* cos, const T* sin, @@ -1209,7 +1298,7 @@ void LaunchKernel(ep::CudaStream* stream, const T* x, const T* cos, const T* sin const IndexType x_b_stride, const IndexType x_m_stride, const IndexType x_h_stride, const IndexType x_offset, const IndexType out_b_stride, const IndexType out_m_stride, - const IndexType out_h_stride, IndexType num_elements) { + const IndexType out_h_stride, IndexType num_elements, const bool is_forward) { const IndexType k0 = rotary_size / rotary_emb_dim, k1 = rotary_size; // TODO: this only support 1d, 2d, rotary postional encoding @@ -1243,16 +1332,31 @@ void LaunchKernel(ep::CudaStream* stream, const T* x, const T* cos, const T* sin constexpr size_t blk_size = 128; - if (mode == "plane") { + if (is_forward) { + if (mode == "plane") { param.num_elements = param.num_elements * PackSize; PlaneKernel <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>( param); - } else { + } else { IntervalKernel <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>( param); + } + } + else { + if (mode == "plane") { + param.num_elements = param.num_elements * PackSize; + PlaneGradKernel + <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>( + param); + } else { + IntervalGradKernel + <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>( + param); + } } + } template(x) % (sizeof(T) * PackSize)) == 0) && (((rotary_size / rotary_emb_dim) % PackSize) == 0) @@ -1278,19 +1382,19 @@ void DispatchPackSize(ep::CudaStream* stream, const T* x, const T* cos, const T* LaunchKernel( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, - out_m_stride, out_h_stride, num_elements); + out_m_stride, out_h_stride, num_elements, is_forward); } else if (CheckPackSize(4)) { num_elements /= 4; LaunchKernel( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, - out_m_stride, out_h_stride, num_elements); + out_m_stride, out_h_stride, num_elements, is_forward); } else { num_elements /= 2; LaunchKernel( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, - out_m_stride, out_h_stride, num_elements); + out_m_stride, out_h_stride, num_elements, is_forward); } } @@ -1302,8 +1406,9 @@ void DispatchIndex(ep::CudaStream* stream, const T* x, const T* cos, const T* si const int64_t b, const int64_t m, const int64_t h, const int64_t k, const int64_t x_b_stride, const int64_t x_m_stride, const int64_t x_h_stride, const int64_t x_offset, const int64_t out_b_stride, const int64_t out_m_stride, - const int64_t out_h_stride) { + const int64_t out_h_stride, const bool is_forward) { int64_t num_elements = b * m * h * k; + if (num_elements < (1 << 30)) { DispatchPackSize( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, @@ -1312,12 +1417,12 @@ void DispatchIndex(ep::CudaStream* stream, const T* x, const T* cos, const T* si static_cast(x_m_stride), static_cast(x_h_stride), static_cast(x_offset), static_cast(out_b_stride), static_cast(out_m_stride), static_cast(out_h_stride), - static_cast(num_elements)); + static_cast(num_elements), is_forward); } else { DispatchPackSize( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, - out_m_stride, out_h_stride, num_elements); + out_m_stride, out_h_stride, num_elements, is_forward); } } @@ -1331,17 +1436,17 @@ void DispatchRotaryEmbeddingDimension(ep::CudaStream* stream, const T* x, const const int64_t h, const int64_t k, const int64_t x_b_stride, const int64_t x_m_stride, const int64_t x_h_stride, const int64_t x_offset, const int64_t out_b_stride, - const int64_t out_m_stride, const int64_t out_h_stride) { + const int64_t out_m_stride, const int64_t out_h_stride, const bool is_forward) { if (rotary_emb_dim == 1) { DispatchIndex( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, - out_m_stride, out_h_stride); + out_m_stride, out_h_stride, is_forward); } else if (rotary_emb_dim == 2) { DispatchIndex( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, - out_m_stride, out_h_stride); + out_m_stride, out_h_stride, is_forward); } } @@ -1389,7 +1494,7 @@ class FusedApplyRotaryEmbKernel final : public user_op::OpKernel { &out_b_stride, &out_m_stride, &out_h_stride, &out_offset); ParseDims(x->shape_view(), x_layout, Optional(), k_size, tensor_index, &b, &m, &h, &k, &x_b_stride, &x_m_stride, &x_h_stride, &x_offset); - + bool is_forward = true; // TODO: hard code num_dims & seems redundant template problem... DispatchRotaryEmbeddingDimension( ctx->stream()->As(), reinterpret_cast(x->dptr()), @@ -1399,7 +1504,7 @@ class FusedApplyRotaryEmbKernel final : public user_op::OpKernel { reinterpret_cast(out->mut_dptr()), position_ids ? position_ids->shape_view().data() : nullptr, x_layout, output_layout, mode, static_cast(theta), rotary_size, rotary_emb_dim, b, m, h, k, x_b_stride, x_m_stride, - x_h_stride, x_offset, out_b_stride, out_m_stride, out_h_stride); + x_h_stride, x_offset, out_b_stride, out_m_stride, out_h_stride, is_forward); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } @@ -1414,7 +1519,53 @@ class FusedApplyRotaryEmbGradKernel final : public user_op::OpKernel { private: using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx) const override { - printf("%s\n", "FusedApplyRotaryEmbGradKernel"); + const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); + const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); + user_op::Tensor* cos = nullptr; + user_op::Tensor* sin = nullptr; + user_op::Tensor* position_ids = nullptr; + user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); + const std::string& x_layout = ctx->Attr("x_layout"); + const std::string& output_layout = ctx->Attr("output_layout"); + const std::string& mode = ctx->Attr("mode"); + const int64_t tensor_index = ctx->Attr("tensor_index"); + const int64_t k_size = ctx->Attr("k_size"); + const int64_t rotary_size = ctx->Attr("rotary_size"); + const float theta = 1.0f / ctx->Attr("base"); + int rotary_emb_dim = 1; + + if (ctx->has_input("cos", 0)) { cos = ctx->Tensor4ArgNameAndIndex("cos", 0); } + + if (ctx->has_input("sin", 0)) { sin = ctx->Tensor4ArgNameAndIndex("sin", 0); } + + if (ctx->has_input("position_ids", 0)) { + position_ids = ctx->Tensor4ArgNameAndIndex("position_ids", 0); + rotary_emb_dim = position_ids->shape_view().At(1); + } + + constexpr size_t ndims = 4; + int64_t b = 0; + int64_t m = 0; + int64_t h = 0; + int64_t k = 0; + int64_t out_b_stride = 0, out_m_stride = 0, out_h_stride = 0, out_offset = 0; + int64_t x_b_stride = 0, x_m_stride = 0, x_h_stride = 0, x_offset = 0; + + ParseDims(dx->shape_view(), output_layout, Optional(), k_size, 0, &b, &m, &h, &k, + &out_b_stride, &out_m_stride, &out_h_stride, &out_offset); + ParseDims(dy->shape_view(), x_layout, Optional(), k_size, tensor_index, &b, &m, &h, &k, + &x_b_stride, &x_m_stride, &x_h_stride, &x_offset); + bool is_forward = false; + // TODO: hard code num_dims & seems redundant template problem... + DispatchRotaryEmbeddingDimension( + ctx->stream()->As(), reinterpret_cast(dy->dptr()), + cos ? reinterpret_cast(cos->dptr()) : nullptr, + sin ? reinterpret_cast(sin->dptr()) : nullptr, + position_ids ? reinterpret_cast(position_ids->dptr()) : nullptr, + reinterpret_cast(dx->mut_dptr()), + position_ids ? position_ids->shape_view().data() : nullptr, x_layout, output_layout, mode, + static_cast(theta), rotary_size, rotary_emb_dim, b, m, h, k, x_b_stride, x_m_stride, + x_h_stride, x_offset, out_b_stride, out_m_stride, out_h_stride, is_forward); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } diff --git a/python/oneflow/test/modules/test_fused_rotary_embedding.py b/python/oneflow/test/modules/test_fused_rotary_embedding.py index 7b9431b9f5b..bf3a1bafea0 100644 --- a/python/oneflow/test/modules/test_fused_rotary_embedding.py +++ b/python/oneflow/test/modules/test_fused_rotary_embedding.py @@ -480,6 +480,7 @@ def _test_without_position( # fused_out = np.concatenate((out0, out1, out2), axis=-1) fused_out = flow.cat((out0, out1, out2), dim = -1) else: + print("apply!") fused_out = flow._C.fused_apply_rotary_emb( fused_x, cos=fused_cos, @@ -491,10 +492,12 @@ def _test_without_position( rotary_size=rotary_size, mode=mode, ) + print("end!") fused_out_backward = fused_out.sum() fused_out_backward.backward() - fused_out_grad = fused_out_backward.grad + fused_out_grad = fused_x.grad + print("get grad!") #print(fused_out_grad) # test forward test_case.assertTrue( @@ -686,7 +689,7 @@ def _test_without_position_sinuous( fused_out_backward = fused_out.sum() fused_out_backward.backward() - fused_out_grad = fused_out_backward.grad + fused_out_grad = fused_x.grad #print(fused_out_grad) # test forward test_case.assertTrue( @@ -931,7 +934,7 @@ def _test_with_position_sinuous( fused_out_backward = fused_out.sum() fused_out_backward.backward() - fused_out_grad = fused_out_backward.grad + fused_out_grad = fused_x.grad #print(fused_out_grad) # test forward test_case.assertTrue( @@ -1129,7 +1132,7 @@ def _test_with_position( fused_out_backward = fused_out.sum() fused_out_backward.backward() - fused_out_grad = fused_out_backward.grad + fused_out_grad = fused_x.grad #print(fused_out_grad) # test forward test_case.assertTrue( @@ -1335,7 +1338,7 @@ def _test_plane( fused_out_backward = fused_out.sum() fused_out_backward.backward() - fused_out_grad = fused_out_backward.grad + fused_out_grad = fused_x.grad #print(fused_out_grad) # test forward test_case.assertTrue( @@ -1356,6 +1359,7 @@ def _test_plane( rtol=5e-3, ) ) + print("tset plane done") @@ -1390,7 +1394,7 @@ def test_fused_rotary_embedding_op_plane(test_case): arg[0](test_case, *arg[1:]) - + ''' def test_fused_rotary_embedding_op_interval_2d(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_with_position, @@ -1409,8 +1413,9 @@ def test_fused_rotary_embedding_op_interval_2d(test_case): for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) - + ''' + ''' def test_fused_rotary_embedding_op_interval_1d(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [ @@ -1420,7 +1425,7 @@ def test_fused_rotary_embedding_op_interval_1d(test_case): #_test_with_position_sinuous, ] args_dict["x_layout"] = ["BMHK"] - args_dict["mode"] = ["plane"] + args_dict["mode"] = ["interval"] args_dict["base"] = [1e1] args_dict["rotary_size"] = [4] args_dict["dims"] = [(3, 2, 5, 8)] @@ -1432,7 +1437,7 @@ def test_fused_rotary_embedding_op_interval_1d(test_case): for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) - + ''' if __name__ == "__main__": unittest.main() From 8746815329a95ccb9d08c7ef2fb59b2cb1ad2cbd Mon Sep 17 00:00:00 2001 From: shw Date: Thu, 2 May 2024 15:31:10 +0800 Subject: [PATCH 06/14] check format --- .../gradient_funcs/fused_apply_rotary_emb.cpp | 120 ++++++++-------- .../impl/fused_attention_functor.cpp | 21 ++- oneflow/user/ops/fused_attention_ops.cpp | 26 ++-- .../modules/test_fused_rotary_embedding.py | 133 +++++++++--------- 4 files changed, 151 insertions(+), 149 deletions(-) diff --git a/oneflow/core/autograd/gradient_funcs/fused_apply_rotary_emb.cpp b/oneflow/core/autograd/gradient_funcs/fused_apply_rotary_emb.cpp index 2afc8f3d6d0..412ee63acd9 100644 --- a/oneflow/core/autograd/gradient_funcs/fused_apply_rotary_emb.cpp +++ b/oneflow/core/autograd/gradient_funcs/fused_apply_rotary_emb.cpp @@ -21,7 +21,7 @@ namespace oneflow { namespace one { struct FusedApplyRotaryEmbCaptureState : public AutoGradCaptureState { - bool requires_grad; // 输入x是否需要梯度 只有一个输入x; + bool requires_grad; std::string x_layout{}; std::string output_layout{}; std::string mode{}; @@ -43,35 +43,37 @@ class FusedApplyRotaryEmb : public OpExprGradFunction FusedApplyRotaryEmb::Init(const OpExpr& op) { // 是否需要实现存疑; +Maybe FusedApplyRotaryEmb::Init(const OpExpr& op) { const UserOpExpr* fw_op_expr = dynamic_cast(&op); - CHECK_NOTNULL_OR_RETURN(fw_op_expr); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } Maybe FusedApplyRotaryEmb::Capture(FusedApplyRotaryEmbCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, - const AttrMap& attrs) const { - // 这里需要检查sin和cos同时出现,或同时不出现; - CHECK_OR_RETURN((inputs.size() >= 1) && (inputs.size() <= 4)); // 这里的输入应该是 1 - 4; - ctx->requires_grad = inputs.at(0)->requires_grad(); - if (!ctx->requires_grad) { return Maybe::Ok(); } // 如果不需要梯度,也就不需要求导了,直接返回 Maybe::Ok() - - ComposedAttrMap composed_attrs(attrs, base_attrs_); // 写法确认; - ctx->SaveTensorForBackward(inputs.at(0)); - if (inputs.size() == 2) // position_ids + const AttrMap& attrs) const { + CHECK_OR_RETURN((inputs.size() >= 1) && (inputs.size() <= 4)) + << Error::RuntimeError() << "the inputs size of fusedapplyrotaryembgrad\ + should between 1 and 4"; + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (!ctx->requires_grad) { return Maybe::Ok(); } + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->SaveTensorForBackward(inputs.at(0)); + if (inputs.size() == 2) // position_ids ctx->SaveTensorForBackward(inputs.at(1)); - if (inputs.size() == 3) { // cos, sin + + if (inputs.size() == 3) { // cos, sin ctx->SaveTensorForBackward(inputs.at(1)); ctx->SaveTensorForBackward(inputs.at(2)); } - - if (inputs.size() == 4) { // cos, sin, position_ids; + + if (inputs.size() == 4) { // cos, sin, position_ids; ctx->SaveTensorForBackward(inputs.at(1)); ctx->SaveTensorForBackward(inputs.at(2)); ctx->SaveTensorForBackward(inputs.at(3)); - } + } ctx->x_layout = JUST(composed_attrs.GetAttr("x_layout")); ctx->output_layout = JUST(composed_attrs.GetAttr("output_layout")); @@ -86,48 +88,54 @@ Maybe FusedApplyRotaryEmb::Capture(FusedApplyRotaryEmbCaptureState* ctx, Maybe FusedApplyRotaryEmb::Apply(const FusedApplyRotaryEmbCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { - CHECK_EQ_OR_RETURN(out_grads.size(), 1); // 检查梯度 Tensor 个数是否为 1 TODO: 不确定是否输入为1 -- (dy) - // in_grads->resize(1); // 这里不能resize, 需要和input_meta_data_.size() 一致; - const auto& saved_tensors = ctx->SavedTensors(); - - CHECK_OR_RETURN((saved_tensors.size() >= 1) && (saved_tensors.size() <= 4)); - // 输出backward拿到的参数 - if (ctx->requires_grad) { - if (saved_tensors.size() == 1) { - const auto& x = ctx->SavedTensors().at(0); - in_grads->at(0) = JUST(functional::FusedApplyRotaryEmbGrad(x, out_grads.at(0), NullOpt/*cos*/, - NullOpt/*sin*/, NullOpt/*position_ids*/, - ctx->x_layout, ctx->output_layout, ctx->mode, ctx->tensor_index, ctx->k_size, ctx->base, ctx->rotary_size)); - } - - if (saved_tensors.size() == 2) { - const auto& x = ctx->SavedTensors().at(0); - const auto& position_ids = ctx->SavedTensors().at(1); - in_grads->at(0) = JUST(functional::FusedApplyRotaryEmbGrad(x, out_grads.at(0), NullOpt/*cos*/, - NullOpt/*sin*/, position_ids, - ctx->x_layout, ctx->output_layout, ctx->mode, ctx->tensor_index, ctx->k_size, ctx->base, ctx->rotary_size)); - } - - if (saved_tensors.size() == 3) { - const auto& x = ctx->SavedTensors().at(0); - const auto& cos = ctx->SavedTensors().at(1); - const auto& sin = ctx->SavedTensors().at(2); - - in_grads->at(0) = JUST(functional::FusedApplyRotaryEmbGrad(x, out_grads.at(0), cos, sin, NullOpt/*position_ids*/, - ctx->x_layout, ctx->output_layout, ctx->mode, ctx->tensor_index, ctx->k_size, ctx->base, ctx->rotary_size)); - } - - if (saved_tensors.size() == 4) { - const auto& x = ctx->SavedTensors().at(0); // 调用 SavedTensors 接口,拿到先前通过 SaveTensorForBackward 接口保存的 Tensor - const auto& cos = ctx->SavedTensors().at(1); - const auto& sin = ctx->SavedTensors().at(2); - const auto& position_ids = ctx->SavedTensors().at(3); - in_grads->at(0) = JUST(functional::FusedApplyRotaryEmbGrad(x, out_grads.at(0), cos, sin, position_ids, - ctx->x_layout, ctx->output_layout, ctx->mode, ctx->tensor_index, ctx->k_size, ctx->base, ctx->rotary_size)); - } + CHECK_EQ_OR_RETURN(out_grads.size(), 1) + << Error::RuntimeError() << "fusedapplyrotaryembgrad outgrad size should be 1"; + const auto& saved_tensors = ctx->SavedTensors(); + + CHECK_OR_RETURN((saved_tensors.size() >= 1) && (saved_tensors.size() <= 4)) + << Error::RuntimeError() << "the saved_tensors of fusedapplyrotaryembgrad\ + should between 1 and 4"; + + if (ctx->requires_grad) { + if (saved_tensors.size() == 1) { // x + const auto& x = ctx->SavedTensors().at(0); + in_grads->at(0) = JUST(functional::FusedApplyRotaryEmbGrad( + x, out_grads.at(0), NullOpt /*cos*/, NullOpt /*sin*/, NullOpt /*position_ids*/, + ctx->x_layout, ctx->output_layout, ctx->mode, ctx->tensor_index, ctx->k_size, ctx->base, + ctx->rotary_size)); + } + + if (saved_tensors.size() == 2) { // x, position_ids + const auto& x = ctx->SavedTensors().at(0); + const auto& position_ids = ctx->SavedTensors().at(1); + in_grads->at(0) = JUST(functional::FusedApplyRotaryEmbGrad( + x, out_grads.at(0), NullOpt /*cos*/, NullOpt /*sin*/, position_ids, ctx->x_layout, + ctx->output_layout, ctx->mode, ctx->tensor_index, ctx->k_size, ctx->base, + ctx->rotary_size)); + } + + if (saved_tensors.size() == 3) { // x, cos, sin, position_ids + const auto& x = ctx->SavedTensors().at(0); + const auto& cos = ctx->SavedTensors().at(1); + const auto& sin = ctx->SavedTensors().at(2); + + in_grads->at(0) = JUST(functional::FusedApplyRotaryEmbGrad( + x, out_grads.at(0), cos, sin, NullOpt /*position_ids*/, ctx->x_layout, ctx->output_layout, + ctx->mode, ctx->tensor_index, ctx->k_size, ctx->base, ctx->rotary_size)); + } + + if (saved_tensors.size() == 4) { + const auto& x = ctx->SavedTensors().at(0); + const auto& cos = ctx->SavedTensors().at(1); + const auto& sin = ctx->SavedTensors().at(2); + const auto& position_ids = ctx->SavedTensors().at(3); + in_grads->at(0) = JUST(functional::FusedApplyRotaryEmbGrad( + x, out_grads.at(0), cos, sin, position_ids, ctx->x_layout, ctx->output_layout, ctx->mode, + ctx->tensor_index, ctx->k_size, ctx->base, ctx->rotary_size)); } + } - return Maybe::Ok(); + return Maybe::Ok(); } REGISTER_OP_EXPR_GRAD_FUNCTION("fused_apply_rotary_emb", FusedApplyRotaryEmb); diff --git a/oneflow/core/functional/impl/fused_attention_functor.cpp b/oneflow/core/functional/impl/fused_attention_functor.cpp index d8cf614f8d0..4e57253c229 100644 --- a/oneflow/core/functional/impl/fused_attention_functor.cpp +++ b/oneflow/core/functional/impl/fused_attention_functor.cpp @@ -757,22 +757,18 @@ class FusedApplyRotaryEmbGradFunctor { .Input("sin") .Output("dx") .Build()); - op_without_position_sinuous_ = - CHECK_JUST(one::OpBuilder("fused_apply_rotary_emb_grad").Input("x") - .Input("dy") - .Output("dx") - .Build()); + op_without_position_sinuous_ = CHECK_JUST( + one::OpBuilder("fused_apply_rotary_emb_grad").Input("x").Input("dy").Output("dx").Build()); } Maybe operator()(const std::shared_ptr& x, - const std::shared_ptr& dy, - const Optional& cos, + const std::shared_ptr& dy, const Optional& cos, const Optional& sin, const Optional& position_ids, const std::string& x_layout, const Optional& output_layout, const std::string& mode, const Optional& tensor_index, const Optional& k_size, const float base, const Optional& rotary_size) const { int64_t b = 0, m = 0, h = 0, k = 0; - + if (tensor_index) { CHECK_OR_RETURN((JUST(tensor_index) >= 0) && (JUST(tensor_index) <= 2)) << "tensor_index should be set between [0, 2]"; @@ -780,7 +776,7 @@ class FusedApplyRotaryEmbGradFunctor { CHECK_OR_RETURN((mode == "interval") || (mode == "plane")) << "mode should be \"intervel\" or \"plane\""; - ParseDims("x", *x->shape(), x_layout, Optional(), k_size, &b, &m, &h, &k); + JUST(ParseDims("x", *x->shape(), x_layout, Optional(), k_size, &b, &m, &h, &k)); if (k_size) { CHECK_EQ_OR_RETURN(JUST(k_size), k) @@ -839,10 +835,11 @@ class FusedApplyRotaryEmbGradFunctor { if (position_ids) { if (cos && sin) { - return OpInterpUtil::Dispatch(*op_with_position_sinuous_, - {x, dy, JUST(cos), JUST(sin), JUST(position_ids)}, attrs); + return OpInterpUtil::Dispatch( + *op_with_position_sinuous_, {x, dy, JUST(cos), JUST(sin), JUST(position_ids)}, attrs); } else { - return OpInterpUtil::Dispatch(*op_with_position_, {x, dy, JUST(position_ids)}, attrs); + return OpInterpUtil::Dispatch(*op_with_position_, {x, dy, JUST(position_ids)}, + attrs); } } else { if (cos && sin) { diff --git a/oneflow/user/ops/fused_attention_ops.cpp b/oneflow/user/ops/fused_attention_ops.cpp index a1b65464756..d731e9893a2 100644 --- a/oneflow/user/ops/fused_attention_ops.cpp +++ b/oneflow/user/ops/fused_attention_ops.cpp @@ -806,7 +806,8 @@ Maybe ParseSplitAxis(const std::string& layout, bool can_hk_split, int64_t return Maybe::Ok(); } -/* static */ Maybe FusedApplyRotaryEmbGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { +/* static */ Maybe FusedApplyRotaryEmbGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); const std::string& x_layout = ctx->Attr("x_layout"); const std::string& output_layout = ctx->Attr("output_layout"); @@ -829,13 +830,12 @@ Maybe ParseSplitAxis(const std::string& layout, bool can_hk_split, int64_t int64_t b = 0, m = 0, h = 0, k = 0; JUST(ParseDims(x_desc.shape(), x_layout, Optional(), Optional(k_size), &b, &m, - &h, &k)); // 这里需要检查是否正确; + &h, &k)); CHECK_LE_OR_RETURN(rotary_size, k) << "rotary_size should be no more than K of input x."; int64_t rotary_emb_dim = 1; - if (ctx->has_input("position_ids", 0)) { const Shape& position_id_shape = ctx->InputShape("position_ids", 0); CHECK_EQ_OR_RETURN(position_id_shape.NumAxes(), 3) @@ -848,7 +848,6 @@ Maybe ParseSplitAxis(const std::string& layout, bool can_hk_split, int64_t CHECK_OR_RETURN(rotary_emb_dim == 1 || rotary_emb_dim == 2) << "2nd dim of position_ids should be 1 or 2."; } - // 这里是重复检查,且会报错; const int64_t actual_rotary_size = rotary_size / rotary_emb_dim; CHECK_EQ_OR_RETURN(actual_rotary_size % 2, 0) @@ -856,15 +855,14 @@ Maybe ParseSplitAxis(const std::string& layout, bool can_hk_split, int64_t bool has_cos = ctx->has_input("cos", 0); bool has_sin = ctx->has_input("sin", 0); - // TODO: fused_apply_rotary_emb have same logic no matter name + // TODO: fused_apply_rotary_emb_grad have same logic no matter name if (has_cos && has_sin) { const Shape& cos_shape = ctx->InputShape("cos", 0); const Shape& sin_shape = ctx->InputShape("sin", 0); CHECK_EQ_OR_RETURN(cos_shape.NumAxes(), 2) << "The number of dimensions of cos should be equal to 2."; - CHECK_OR_RETURN(cos_shape == sin_shape) - << "The dimensions of cos & sin should be the same."; + CHECK_OR_RETURN(cos_shape == sin_shape) << "The dimensions of cos & sin should be the same."; CHECK_EQ_OR_RETURN(cos_shape.At(1), actual_rotary_size) << "The 1st dimension of cos & sin should equal to rotary_size // " "rotary_embedding_dimension."; @@ -876,7 +874,6 @@ Maybe ParseSplitAxis(const std::string& layout, bool can_hk_split, int64_t if (!ctx->has_input("position_ids", 0)) { if (has_cos && has_sin) { - // const user_op::TensorDesc& cos_desc = ctx->InputTensorDesc("cos", 0); const Shape& cos_shape = ctx->InputShape("cos", 0); CHECK_GE_OR_RETURN(cos_shape.At(0), m) << "M of cos should be no less than M of x if position_ids is not given."; @@ -889,19 +886,12 @@ Maybe ParseSplitAxis(const std::string& layout, bool can_hk_split, int64_t return Maybe::Ok(); } -/* static */ Maybe FusedApplyRotaryEmbGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { +/* static */ Maybe FusedApplyRotaryEmbGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } - /* static */ Maybe FusedApplyRotaryEmbGradOp::GetSbp(user_op::SbpContext* ctx) { - /* - 1. 获取layout; - 2. check dy shape; - 3. 获取split_axis; - 4. 设置sbp; - 反向算子中的split应该与前向一致; - */ const user_op::TensorDesc& x_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); int num_heads = -1; const int64_t k_size = ctx->Attr("k_size"); @@ -957,7 +947,7 @@ Maybe ParseSplitAxis(const std::string& layout, bool can_hk_split, int64_t if (ctx->user_op_conf().has_input("cos", 0)) builder = builder.Broadcast(user_op::OpArg("cos", 0)).Broadcast(user_op::OpArg("sin", 0)); if (ctx->user_op_conf().has_input("position_ids", 0)) - builder = builder.Split(user_op::OpArg("position_ids", 0), 0); // 这里怎么split? + builder = builder.Split(user_op::OpArg("position_ids", 0), 0); builder.Build(); } if (x_h_split_axis >= 0 && o_h_split_axis >= 0) { diff --git a/python/oneflow/test/modules/test_fused_rotary_embedding.py b/python/oneflow/test/modules/test_fused_rotary_embedding.py index bf3a1bafea0..d3db8c1246a 100644 --- a/python/oneflow/test/modules/test_fused_rotary_embedding.py +++ b/python/oneflow/test/modules/test_fused_rotary_embedding.py @@ -24,20 +24,20 @@ # tensor version: def plane_shuffle_tensor(x): - x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return flow.cat((-x2, x1), dim=-1) def shuffle_adjacent_two_elem_tensor(x): - y = x.clone() # 使用.clone()以确保我们不会在原始输入上进行操作 + y = x.clone() for i in range(x.shape[-1] // 2): - y[..., 2*i] = -x[..., 2*i + 1] - y[..., 2*i + 1] = x[..., 2*i] + y[..., 2 * i] = -x[..., 2 * i + 1] + y[..., 2 * i + 1] = x[..., 2 * i] return y def parseDims_tensor(dims, x_layout): - B, M, H, K = 1, 1, 1, 1 # 初始化维度 + B, M, H, K = 1, 1, 1, 1 merged_dims = dims if x_layout == "BHMK": B, H, M, K = dims @@ -61,7 +61,21 @@ def parseDims_tensor(dims, x_layout): return B, M, H, K, merged_dims -def naive_embedding_tensor(x, cos, sin, x_layout, B, M, H, K, dims, merged_dims, rotary_size, rotary_ndims, mode): +def naive_embedding_tensor( + x, + cos, + sin, + x_layout, + B, + M, + H, + K, + dims, + merged_dims, + rotary_size, + rotary_ndims, + mode, +): naive_out = None if mode == "plane": if rotary_ndims == 2: @@ -109,23 +123,15 @@ def naive_embedding_tensor(x, cos, sin, x_layout, B, M, H, K, dims, merged_dims, naive_out = flow.cat((out0, out1, out2), axis=-1) elif x_layout == "MB(H3K)": - #test = cos.transpose([2, 0, 1, 3]).reshape( [M, B, 1, K] ) - #test = cos.permute([2, 0, 1, 3]) out0 = x[..., 0, :].reshape(dims) * cos.permute([2, 0, 1, 3]).reshape( [M, B, 1, K] - ) + y[..., 0, :].reshape(dims) * sin.permute([2, 0, 1, 3]).reshape( - [M, B, 1, K] - ) + ) + y[..., 0, :].reshape(dims) * sin.permute([2, 0, 1, 3]).reshape([M, B, 1, K]) out1 = x[..., 1, :].reshape(dims) * cos.permute([2, 0, 1, 3]).reshape( [M, B, 1, K] - ) + y[..., 1, :].reshape(dims) * sin.permute([2, 0, 1, 3]).reshape( - [M, B, 1, K] - ) + ) + y[..., 1, :].reshape(dims) * sin.permute([2, 0, 1, 3]).reshape([M, B, 1, K]) out2 = x[..., 2, :].reshape(dims) * cos.permute([2, 0, 1, 3]).reshape( [M, B, 1, K] - ) + y[..., 2, :].reshape(dims) * sin.permute([2, 0, 1, 3]).reshape( - [M, B, 1, K] - ) + ) + y[..., 2, :].reshape(dims) * sin.permute([2, 0, 1, 3]).reshape([M, B, 1, K]) naive_out = flow.cat((out0, out1, out2), dim=-1) @@ -361,7 +367,9 @@ def _test_without_position( mode, ) - naive_x_tensor = flow.tensor(naive_x, dtype=dtype, device=device, requires_grad=True) + naive_x_tensor = flow.tensor( + naive_x, dtype=dtype, device=device, requires_grad=True + ) naive_cos_tensor = flow.tensor(naive_cos, dtype=dtype, device=device) # 不用grad naive_sin_tensor = flow.tensor(naive_sin, dtype=dtype, device=device) naive_out_tensor = naive_embedding_tensor( @@ -380,7 +388,7 @@ def _test_without_position( mode, ) - # 验证这里的naive_out_tensor与naive_out是否有精度误差; + # check naive_out_tensor and naive_out; test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), @@ -390,6 +398,7 @@ def _test_without_position( ) ) + # get naive_out_grad naive_out_backward = naive_out_tensor.sum() naive_out_backward.backward() naive_out_grad = naive_x_tensor.grad @@ -477,10 +486,8 @@ def _test_without_position( tensor_index=2, ) - # fused_out = np.concatenate((out0, out1, out2), axis=-1) - fused_out = flow.cat((out0, out1, out2), dim = -1) + fused_out = flow.cat((out0, out1, out2), dim=-1) else: - print("apply!") fused_out = flow._C.fused_apply_rotary_emb( fused_x, cos=fused_cos, @@ -492,13 +499,12 @@ def _test_without_position( rotary_size=rotary_size, mode=mode, ) - print("end!") - + + # get fused_out_grad fused_out_backward = fused_out.sum() fused_out_backward.backward() fused_out_grad = fused_x.grad - print("get grad!") - #print(fused_out_grad) + # test forward test_case.assertTrue( np.allclose( @@ -519,7 +525,6 @@ def _test_without_position( ) ) - # this assume that rotary_ndims is by default 1 def _test_without_position_sinuous( test_case, x_layout, mode, base, rotary_size, dims, rotary_ndims, dtype, device @@ -596,8 +601,10 @@ def _test_without_position_sinuous( rotary_ndims, mode, ) - naive_x_tensor = flow.tensor(naive_x, dtype=dtype, device=device, requires_grad=True) - naive_cos_tensor = flow.tensor(naive_cos, dtype=dtype, device=device) # 不用grad + naive_x_tensor = flow.tensor( + naive_x, dtype=dtype, device=device, requires_grad=True + ) + naive_cos_tensor = flow.tensor(naive_cos, dtype=dtype, device=device) naive_sin_tensor = flow.tensor(naive_sin, dtype=dtype, device=device) naive_out_tensor = naive_embedding_tensor( naive_x_tensor, @@ -615,7 +622,7 @@ def _test_without_position_sinuous( mode, ) - # 验证这里的naive_out_tensor与naive_out是否有精度误差; + # check naive_out_tensor and naive_out; test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), @@ -625,12 +632,13 @@ def _test_without_position_sinuous( ) ) + # get naive_out_grad naive_out_backward = naive_out_tensor.sum() naive_out_backward.backward() naive_out_grad = naive_x_tensor.grad fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) - + if x_layout == "BM(H3K)": out0 = flow._C.fused_apply_rotary_emb( fused_x, @@ -672,8 +680,7 @@ def _test_without_position_sinuous( tensor_index=2, ) - #fused_out = np.concatenate((out0, out1, out2), axis=-1) - fused_out = flow.cat((out0, out1, out2), dim = -1) + fused_out = flow.cat((out0, out1, out2), dim=-1) else: fused_out = flow._C.fused_apply_rotary_emb( fused_x, @@ -686,11 +693,11 @@ def _test_without_position_sinuous( rotary_size=rotary_size, mode=mode, ) - + # get fused_out_grad fused_out_backward = fused_out.sum() fused_out_backward.backward() fused_out_grad = fused_x.grad - #print(fused_out_grad) + # test forward test_case.assertTrue( np.allclose( @@ -781,7 +788,7 @@ def _test_with_position_sinuous( naive_x = x.reshape([B, M, H, -1, K]) elif x_layout == "MB(HK)" or x_layout == "MB(H2K)" or x_layout == "MB(H3K)": naive_x = x.reshape([M, B, H, -1, K]) - + naive_out = naive_embedding( naive_x, naive_cos, @@ -798,8 +805,10 @@ def _test_with_position_sinuous( mode, ) - naive_x_tensor = flow.tensor(naive_x, dtype=dtype, device=device, requires_grad=True) - naive_cos_tensor = flow.tensor(naive_cos, dtype=dtype, device=device) # 不用grad + naive_x_tensor = flow.tensor( + naive_x, dtype=dtype, device=device, requires_grad=True + ) + naive_cos_tensor = flow.tensor(naive_cos, dtype=dtype, device=device) naive_sin_tensor = flow.tensor(naive_sin, dtype=dtype, device=device) naive_out_tensor = naive_embedding_tensor( naive_x_tensor, @@ -817,7 +826,7 @@ def _test_with_position_sinuous( mode, ) - # 验证这里的naive_out_tensor与naive_out是否有精度误差; + # check naive_out_tensor and naive_out; test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), @@ -827,10 +836,10 @@ def _test_with_position_sinuous( ) ) + # get naive_out_grad; naive_out_backward = naive_out_tensor.sum() naive_out_backward.backward() naive_out_grad = naive_x_tensor.grad - #print(naive_out_grad) fused_cos = np.array( [ @@ -917,8 +926,7 @@ def _test_with_position_sinuous( tensor_index=2, ) - #fused_out = np.concatenate((out0, out1, out2), axis=-1) - fused_out = flow.cat((out0, out1, out2), dim = -1) + fused_out = flow.cat((out0, out1, out2), dim=-1) else: fused_out = flow._C.fused_apply_rotary_emb( fused_x, @@ -931,11 +939,11 @@ def _test_with_position_sinuous( rotary_size=rotary_size, mode=mode, ) - + # get fused_out_grad; fused_out_backward = fused_out.sum() fused_out_backward.backward() fused_out_grad = fused_x.grad - #print(fused_out_grad) + # test forward test_case.assertTrue( np.allclose( @@ -1038,7 +1046,9 @@ def _test_with_position( mode, ) - naive_x_tensor = flow.tensor(naive_x, dtype=dtype, device=device, requires_grad=True) + naive_x_tensor = flow.tensor( + naive_x, dtype=dtype, device=device, requires_grad=True + ) naive_cos_tensor = flow.tensor(naive_cos, dtype=dtype, device=device) # 不用grad naive_sin_tensor = flow.tensor(naive_sin, dtype=dtype, device=device) naive_out_tensor = naive_embedding_tensor( @@ -1057,7 +1067,7 @@ def _test_with_position( mode, ) - # 验证这里的naive_out_tensor与naive_out是否有精度误差; + # check naive_out_tensor and naive_out; test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), @@ -1066,7 +1076,7 @@ def _test_with_position( rtol=5e-3, ) ) - + # get naive_out_grad naive_out_backward = naive_out_tensor.sum() naive_out_backward.backward() naive_out_grad = naive_x_tensor.grad @@ -1115,7 +1125,6 @@ def _test_with_position( tensor_index=2, ) - #fused_out = np.concatenate((out0, out1, out2), axis=-1) fused_out = flow.cat((out0, out1, out2), dim=-1) else: fused_out = flow._C.fused_apply_rotary_emb( @@ -1129,11 +1138,10 @@ def _test_with_position( rotary_size=rotary_size, mode=mode, ) - + # get fused_out_grad fused_out_backward = fused_out.sum() fused_out_backward.backward() fused_out_grad = fused_x.grad - #print(fused_out_grad) # test forward test_case.assertTrue( np.allclose( @@ -1244,7 +1252,9 @@ def _test_plane( mode, ) - naive_x_tensor = flow.tensor(naive_x, dtype=dtype, device=device, requires_grad=True) + naive_x_tensor = flow.tensor( + naive_x, dtype=dtype, device=device, requires_grad=True + ) naive_cos_tensor = flow.tensor(naive_cos, dtype=dtype, device=device) # 不用grad naive_sin_tensor = flow.tensor(naive_sin, dtype=dtype, device=device) naive_out_tensor = naive_embedding_tensor( @@ -1263,7 +1273,7 @@ def _test_plane( mode, ) - # 验证这里的naive_out_tensor与naive_out是否有精度误差; + # check naive_out_tensor and naive_out; test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), @@ -1276,11 +1286,10 @@ def _test_plane( naive_out_backward = naive_out_tensor.sum() naive_out_backward.backward() naive_out_grad = naive_x_tensor.grad - + fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) fused_position_ids = flow.tensor(position_ids, dtype=flow.int32, device=device) - if x_layout == "MB(H3K)": out0 = flow._C.fused_apply_rotary_emb( fused_x, @@ -1339,7 +1348,6 @@ def _test_plane( fused_out_backward = fused_out.sum() fused_out_backward.backward() fused_out_grad = fused_x.grad - #print(fused_out_grad) # test forward test_case.assertTrue( np.allclose( @@ -1362,7 +1370,6 @@ def _test_plane( print("tset plane done") - """ 1. if cos&sin is given, then base will not be used 2. if cos&sin is not given, then any form of x_layout which cannot infer the dimension of k is not allowed, e.g. BM(HK) @@ -1374,11 +1381,11 @@ def _test_plane( @flow.unittest.skip_unless_1n1d() class TestFusedRotaryEmbedding(flow.unittest.TestCase): # because rule no.2, kernels without cos&sin cannot work under specific x_layout - + def test_fused_rotary_embedding_op_plane(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_plane] - #args_dict["x_layout"] = ["MB(H3K)", "MB(HK)"] + # args_dict["x_layout"] = ["MB(H3K)", "MB(HK)"] args_dict["x_layout"] = ["BMHK"] args_dict["mode"] = ["plane"] args_dict["base"] = [1e1] @@ -1392,9 +1399,8 @@ def test_fused_rotary_embedding_op_plane(test_case): for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) - - ''' + """ TODO: interval mode grad kernel def test_fused_rotary_embedding_op_interval_2d(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_with_position, @@ -1413,9 +1419,9 @@ def test_fused_rotary_embedding_op_interval_2d(test_case): for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) - ''' + """ - ''' + """ TODO: interval mode grad kernel def test_fused_rotary_embedding_op_interval_1d(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [ @@ -1437,7 +1443,8 @@ def test_fused_rotary_embedding_op_interval_1d(test_case): for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) - ''' + """ + if __name__ == "__main__": unittest.main() From 6154888da9de664de839365988f303342a2c3293 Mon Sep 17 00:00:00 2001 From: shw Date: Thu, 2 May 2024 16:47:39 +0800 Subject: [PATCH 07/14] fix some bugs --- oneflow/user/kernels/fused_attention_kernels.cu | 2 -- python/oneflow/test/modules/test_fused_rotary_embedding.py | 6 +++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/oneflow/user/kernels/fused_attention_kernels.cu b/oneflow/user/kernels/fused_attention_kernels.cu index 0cb49437c95..dfabc62e32b 100644 --- a/oneflow/user/kernels/fused_attention_kernels.cu +++ b/oneflow/user/kernels/fused_attention_kernels.cu @@ -1133,7 +1133,6 @@ template param) { printf("IntervalGradKernel TODO!\n"); - } } @@ -1356,7 +1355,6 @@ void LaunchKernel(ep::CudaStream* stream, const T* x, const T* cos, const T* sin param); } } - } template Date: Sun, 5 May 2024 22:18:24 +0800 Subject: [PATCH 08/14] fix MB(H3k) --- .../user/kernels/fused_attention_kernels.cu | 90 +++++++++---------- .../modules/test_fused_rotary_embedding.py | 5 +- 2 files changed, 47 insertions(+), 48 deletions(-) diff --git a/oneflow/user/kernels/fused_attention_kernels.cu b/oneflow/user/kernels/fused_attention_kernels.cu index dfabc62e32b..8ec00615d6d 100644 --- a/oneflow/user/kernels/fused_attention_kernels.cu +++ b/oneflow/user/kernels/fused_attention_kernels.cu @@ -1132,10 +1132,9 @@ template __global__ void IntervalGradKernel( FusedApplyRotaryEmbParam param) { - printf("IntervalGradKernel TODO!\n"); + printf("IntervalGradKernel TODO!\n"); } - template __global__ void PlaneKernel( @@ -1210,9 +1209,9 @@ __global__ void PlaneKernel( template __global__ void PlaneGradKernel( - FusedApplyRotaryEmbParam param) { - for (IndexType offset = threadIdx.x + blockIdx.x * blockDim.x; offset < param.num_elements; - offset += blockDim.x * gridDim.x) { + FusedApplyRotaryEmbParam param) { + for (IndexType offset = threadIdx.x + blockIdx.x * blockDim.x; offset < param.num_elements; + offset += blockDim.x * gridDim.x) { using LoadPack = cuda::elementwise::Packed; IndexType temp_offset = offset; IndexType index[num_dims]; @@ -1240,24 +1239,26 @@ __global__ void PlaneGradKernel( if (param.cos && param.sin) { cos_val = *(param.cos + sinuous_offset); - IndexType offset_; // 针对grad, sin_val需要有 size / 2的偏移; + IndexType offset_; // 针对grad, sin_val需要有 size / 2的偏移; if (k_index < param.k0) { - offset_ = (param.k0 - k_index > param.rotate_stride) ? param.rotate_stride : -param.rotate_stride; + offset_ = + (param.k0 - k_index > param.rotate_stride) ? param.rotate_stride : -param.rotate_stride; } else if (k_index < param.k1) { - offset_ = (param.k1 - k_index > param.rotate_stride) ? param.rotate_stride : -param.rotate_stride; - } + offset_ = + (param.k1 - k_index > param.rotate_stride) ? param.rotate_stride : -param.rotate_stride; + } sin_val = *(param.sin + sinuous_offset + offset_); } else { // TODO: plane grad kernel without sin & cos; - + T val = position * expf(2.0f * static_cast(k_index % (param.actual_rotary_size >> 1)) * param.inv_actual_rotary_size * logf(param.theta)); cos_val = cosf(val); sin_val = sinf(val); } - + LoadPack x_vec; IndexType x_offset = param.x_offset; IndexType out_offset = 0; @@ -1333,26 +1334,25 @@ void LaunchKernel(ep::CudaStream* stream, const T* x, const T* cos, const T* sin if (is_forward) { if (mode == "plane") { - param.num_elements = param.num_elements * PackSize; - PlaneKernel - <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>( - param); + param.num_elements = param.num_elements * PackSize; + PlaneKernel + <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>( + param); } else { - IntervalKernel - <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>( - param); + IntervalKernel + <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>( + param); } - } - else { + } else { if (mode == "plane") { - param.num_elements = param.num_elements * PackSize; - PlaneGradKernel - <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>( - param); + param.num_elements = param.num_elements * PackSize; + PlaneGradKernel + <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>( + param); } else { - IntervalGradKernel - <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>( - param); + IntervalGradKernel + <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>( + param); } } } @@ -1434,7 +1434,8 @@ void DispatchRotaryEmbeddingDimension(ep::CudaStream* stream, const T* x, const const int64_t h, const int64_t k, const int64_t x_b_stride, const int64_t x_m_stride, const int64_t x_h_stride, const int64_t x_offset, const int64_t out_b_stride, - const int64_t out_m_stride, const int64_t out_h_stride, const bool is_forward) { + const int64_t out_m_stride, const int64_t out_h_stride, + const bool is_forward) { if (rotary_emb_dim == 1) { DispatchIndex( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, @@ -1487,7 +1488,6 @@ class FusedApplyRotaryEmbKernel final : public user_op::OpKernel { int64_t k = 0; int64_t out_b_stride = 0, out_m_stride = 0, out_h_stride = 0, out_offset = 0; int64_t x_b_stride = 0, x_m_stride = 0, x_h_stride = 0, x_offset = 0; - ParseDims(out->shape_view(), output_layout, Optional(), k_size, 0, &b, &m, &h, &k, &out_b_stride, &out_m_stride, &out_h_stride, &out_offset); ParseDims(x->shape_view(), x_layout, Optional(), k_size, tensor_index, &b, &m, &h, &k, @@ -1549,10 +1549,10 @@ class FusedApplyRotaryEmbGradKernel final : public user_op::OpKernel { int64_t out_b_stride = 0, out_m_stride = 0, out_h_stride = 0, out_offset = 0; int64_t x_b_stride = 0, x_m_stride = 0, x_h_stride = 0, x_offset = 0; - ParseDims(dx->shape_view(), output_layout, Optional(), k_size, 0, &b, &m, &h, &k, + ParseDims(dx->shape_view(), x_layout, Optional(), k_size, 0, &b, &m, &h, &k, &out_b_stride, &out_m_stride, &out_h_stride, &out_offset); - ParseDims(dy->shape_view(), x_layout, Optional(), k_size, tensor_index, &b, &m, &h, &k, - &x_b_stride, &x_m_stride, &x_h_stride, &x_offset); + ParseDims(dy->shape_view(), output_layout, Optional(), k_size, tensor_index, &b, &m, + &h, &k, &x_b_stride, &x_m_stride, &x_h_stride, &x_offset); bool is_forward = false; // TODO: hard code num_dims & seems redundant template problem... DispatchRotaryEmbeddingDimension( @@ -1569,7 +1569,6 @@ class FusedApplyRotaryEmbGradKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; - #define REGISTER_FUSED_APPLY_ROTARY_EMB_GPU(dtype, position_type) \ REGISTER_USER_KERNEL("fused_apply_rotary_emb") \ .SetCreateFn>() \ @@ -1594,22 +1593,21 @@ REGISTER_FUSED_APPLY_ROTARY_EMB_GPU_DTYPE(half); REGISTER_FUSED_APPLY_ROTARY_EMB_GPU_DTYPE(nv_bfloat16); #endif // CUDA_VERSION >= 11000 - -#define REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU(dtype, position_type) \ - REGISTER_USER_KERNEL("fused_apply_rotary_emb_grad") \ - .SetCreateFn>() \ - .SetIsMatchedHob( \ - (user_op::HobDeviceType() == DeviceType::kCUDA) \ +#define REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU(dtype, position_type) \ + REGISTER_USER_KERNEL("fused_apply_rotary_emb_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob( \ + (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dx", 0) == GetDataType::value) \ - && (user_op::HobInputSize("position_ids") == 1) \ + && (user_op::HobInputSize("position_ids") == 1) \ && (user_op::HobDataType("position_ids", 0) == GetDataType::value)); -#define REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU_DTYPE(dtype) \ - REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU(dtype, int64_t); \ - REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU(dtype, int32_t); \ - REGISTER_USER_KERNEL("fused_apply_rotary_emb_grad") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ +#define REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU_DTYPE(dtype) \ + REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU(dtype, int64_t); \ + REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU(dtype, int32_t); \ + REGISTER_USER_KERNEL("fused_apply_rotary_emb_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("dx", 0) == GetDataType::value) \ && (user_op::HobInputSize("position_ids") == 0)); diff --git a/python/oneflow/test/modules/test_fused_rotary_embedding.py b/python/oneflow/test/modules/test_fused_rotary_embedding.py index 537fc3ec5e8..d698623534f 100644 --- a/python/oneflow/test/modules/test_fused_rotary_embedding.py +++ b/python/oneflow/test/modules/test_fused_rotary_embedding.py @@ -525,6 +525,7 @@ def _test_without_position( ) ) + # this assume that rotary_ndims is by default 1 def _test_without_position_sinuous( test_case, x_layout, mode, base, rotary_size, dims, rotary_ndims, dtype, device @@ -1385,8 +1386,8 @@ class TestFusedRotaryEmbedding(flow.unittest.TestCase): def test_fused_rotary_embedding_op_plane(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_plane] - #args_dict["x_layout"] = ["MB(H3K)", "MB(HK)"] - args_dict["x_layout"] = ["BMHK", "MB(H3K)"] # TODO: MB(H3K) paramdims bug; + # args_dict["x_layout"] = ["MB(H3K)"] + args_dict["x_layout"] = ["BMHK", "MB(HK)"] # TODO: MB(H3K) paramdims bug; args_dict["mode"] = ["plane"] args_dict["base"] = [1e1] args_dict["rotary_size"] = [4, 8] From 4179b3d8bd656f715bb66b803d84ad312267185c Mon Sep 17 00:00:00 2001 From: shw Date: Sun, 5 May 2024 22:19:15 +0800 Subject: [PATCH 09/14] MB(H3K) todo --- python/oneflow/test/modules/test_fused_rotary_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/oneflow/test/modules/test_fused_rotary_embedding.py b/python/oneflow/test/modules/test_fused_rotary_embedding.py index d698623534f..efa9fb90750 100644 --- a/python/oneflow/test/modules/test_fused_rotary_embedding.py +++ b/python/oneflow/test/modules/test_fused_rotary_embedding.py @@ -1387,7 +1387,7 @@ def test_fused_rotary_embedding_op_plane(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_plane] # args_dict["x_layout"] = ["MB(H3K)"] - args_dict["x_layout"] = ["BMHK", "MB(HK)"] # TODO: MB(H3K) paramdims bug; + args_dict["x_layout"] = ["BMHK", "MB(HK)"] # TODO: MB(H3K) bug; args_dict["mode"] = ["plane"] args_dict["base"] = [1e1] args_dict["rotary_size"] = [4, 8] From 52874a92fa9f3e4625ec08321e6a94c63f0e2798 Mon Sep 17 00:00:00 2001 From: chende Date: Sat, 11 May 2024 05:16:06 +0000 Subject: [PATCH 10/14] finish interval grad kernel. --- .../user/kernels/fused_attention_kernels.cu | 82 +++++++++++++++++-- 1 file changed, 77 insertions(+), 5 deletions(-) diff --git a/oneflow/user/kernels/fused_attention_kernels.cu b/oneflow/user/kernels/fused_attention_kernels.cu index 8ec00615d6d..f54964c52cf 100644 --- a/oneflow/user/kernels/fused_attention_kernels.cu +++ b/oneflow/user/kernels/fused_attention_kernels.cu @@ -1132,7 +1132,79 @@ template __global__ void IntervalGradKernel( FusedApplyRotaryEmbParam param) { - printf("IntervalGradKernel TODO!\n"); + // printf("IntervalGradKernel TODO!\n"); + for (IndexType packed_offset = threadIdx.x + blockIdx.x * blockDim.x; + packed_offset < param.num_elements; packed_offset += blockDim.x * gridDim.x) { + using LoadPack = cuda::elementwise::Packed; + IndexType offset = packed_offset * PackSize; + IndexType index[num_dims]; // b, m, h, k + + IndexType temp_offset = offset; + + for (int i = 0; i < num_dims - 1; i++) { + IndexType ref_stride = param.ref_stride[i]; + IndexType idx = temp_offset / ref_stride; + index[i] = idx; + temp_offset = temp_offset - idx * ref_stride; + } + index[num_dims - 1] = temp_offset; + + IndexType x_offset = param.x_offset; + IndexType out_offset = 0; +#pragma unroll + for (int i = 0; i < num_dims; i++) { + x_offset = x_offset + param.x_stride[i] * index[i]; + out_offset = out_offset + param.out_stride[i] * index[i]; + } + const LoadPack x_vec = *reinterpret_cast(param.x + x_offset); + + const IndexType k_index = index[num_dims - 1]; + if (k_index < param.rotary_size) { + const IndexType position_rotate_index = (k_index >= param.k0) ? 1 : 0; + const IndexType b_index = index[0], m_index = index[1]; + const IndexType position_id_offset = b_index * param.position_b_stride + + position_rotate_index * param.position_rotate_stride + + m_index; + + const PositionType position = + param.position_ids ? param.position_ids[position_id_offset] : m_index; + const IndexType actual_k_index = k_index % param.actual_rotary_size; + const IndexType sinuous_offset = position * param.sinuous_m_stride + actual_k_index; + + LoadPack cos_vec, sin_vec, out_vec; + + if (param.cos && param.sin) { + cos_vec = *reinterpret_cast(param.cos + sinuous_offset); + sin_vec = *reinterpret_cast(param.sin + sinuous_offset); + } else { + const IndexType actual_ndim = param.rotary_size / rotary_emb_dim; +#pragma unroll + for (int i = 0; i < PackSize / 2; i++) { + T val = position + * expf(2.0f * static_cast(((actual_k_index >> 1) + i)) + * param.inv_actual_rotary_size * logf(param.theta)); + T cos_val = cosf(val); + T sin_val = sinf(val); + cos_vec.elem[i * 2] = cos_val; + cos_vec.elem[i * 2 + 1] = cos_val; + sin_vec.elem[i * 2] = sin_val; + sin_vec.elem[i * 2 + 1] = sin_val; + } + } + +#pragma unroll + for (int i = 0; i < PackSize / 2; i++) { + out_vec.elem[i * 2] = + x_vec.elem[i * 2] * cos_vec.elem[i * 2] + x_vec.elem[i * 2 + 1] * sin_vec.elem[i * 2 + 1]; + out_vec.elem[i * 2 + 1] = x_vec.elem[i * 2 + 1] * cos_vec.elem[i * 2 + 1] + - x_vec.elem[i * 2] * sin_vec.elem[i * 2]; + } + + *(reinterpret_cast(param.out + out_offset)) = out_vec; + } else { + *(reinterpret_cast(param.out + out_offset)) = x_vec; + } + } } template param.rotate_stride) - ? static_cast(*(param.x + x_offset + param.rotate_stride)) - : -*(param.x + x_offset - param.rotate_stride); + ? *(param.x + x_offset + param.rotate_stride) + : static_cast(-*(param.x + x_offset - param.rotate_stride)); out_val = cos_val * x_vec.elem[0] + sin_val * x_vec.elem[1]; } else if (k_index < param.k1) { x_vec.elem[0] = *(param.x + x_offset); x_vec.elem[1] = (param.k1 - k_index > param.rotate_stride) - ? static_cast(*(param.x + x_offset + param.rotate_stride)) - : -*(param.x + x_offset - param.rotate_stride); + ? *(param.x + x_offset + param.rotate_stride) + : static_cast(-*(param.x + x_offset - param.rotate_stride)); out_val = cos_val * x_vec.elem[0] + sin_val * x_vec.elem[1]; } else { out_val = *(param.x + x_offset); From 278295d2afe73a118a5ad56cfa3a50f709e7e580 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Sat, 11 May 2024 05:21:18 +0000 Subject: [PATCH 11/14] auto format by CI --- oneflow/user/kernels/fused_attention_kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/oneflow/user/kernels/fused_attention_kernels.cu b/oneflow/user/kernels/fused_attention_kernels.cu index f54964c52cf..66e59190a18 100644 --- a/oneflow/user/kernels/fused_attention_kernels.cu +++ b/oneflow/user/kernels/fused_attention_kernels.cu @@ -1194,8 +1194,8 @@ __global__ void IntervalGradKernel( #pragma unroll for (int i = 0; i < PackSize / 2; i++) { - out_vec.elem[i * 2] = - x_vec.elem[i * 2] * cos_vec.elem[i * 2] + x_vec.elem[i * 2 + 1] * sin_vec.elem[i * 2 + 1]; + out_vec.elem[i * 2] = x_vec.elem[i * 2] * cos_vec.elem[i * 2] + + x_vec.elem[i * 2 + 1] * sin_vec.elem[i * 2 + 1]; out_vec.elem[i * 2 + 1] = x_vec.elem[i * 2 + 1] * cos_vec.elem[i * 2 + 1] - x_vec.elem[i * 2] * sin_vec.elem[i * 2]; } From d79cda1d595e8724a9584595c7a03de44ef58929 Mon Sep 17 00:00:00 2001 From: chende Date: Sat, 11 May 2024 05:58:33 +0000 Subject: [PATCH 12/14] change test file. --- .../test/modules/test_fused_rotary_embedding.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/python/oneflow/test/modules/test_fused_rotary_embedding.py b/python/oneflow/test/modules/test_fused_rotary_embedding.py index efa9fb90750..d671ac6422e 100644 --- a/python/oneflow/test/modules/test_fused_rotary_embedding.py +++ b/python/oneflow/test/modules/test_fused_rotary_embedding.py @@ -1387,7 +1387,7 @@ def test_fused_rotary_embedding_op_plane(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_plane] # args_dict["x_layout"] = ["MB(H3K)"] - args_dict["x_layout"] = ["BMHK", "MB(HK)"] # TODO: MB(H3K) bug; + args_dict["x_layout"] = ["MB(HK)"] # TODO: MB(H3K) bug; args_dict["mode"] = ["plane"] args_dict["base"] = [1e1] args_dict["rotary_size"] = [4, 8] @@ -1401,7 +1401,6 @@ def test_fused_rotary_embedding_op_plane(test_case): for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) - """ TODO: interval mode grad kernel def test_fused_rotary_embedding_op_interval_2d(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_with_position, @@ -1420,16 +1419,14 @@ def test_fused_rotary_embedding_op_interval_2d(test_case): for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) - """ - """ TODO: interval mode grad kernel def test_fused_rotary_embedding_op_interval_1d(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [ - #_test_without_position_sinuous, + _test_without_position_sinuous, _test_without_position, - #_test_with_position, - #_test_with_position_sinuous, + _test_with_position, + _test_with_position_sinuous, ] args_dict["x_layout"] = ["BMHK"] args_dict["mode"] = ["interval"] @@ -1444,7 +1441,6 @@ def test_fused_rotary_embedding_op_interval_1d(test_case): for arg in GenArgList(args_dict): arg[0](test_case, *arg[1:]) - """ if __name__ == "__main__": From d43b03ce1299b6838b3b834d8c9f446c2570e860 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Sat, 11 May 2024 06:00:29 +0000 Subject: [PATCH 13/14] auto format by CI --- python/oneflow/test/modules/test_fused_rotary_embedding.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/oneflow/test/modules/test_fused_rotary_embedding.py b/python/oneflow/test/modules/test_fused_rotary_embedding.py index d671ac6422e..2f154986aed 100644 --- a/python/oneflow/test/modules/test_fused_rotary_embedding.py +++ b/python/oneflow/test/modules/test_fused_rotary_embedding.py @@ -1403,9 +1403,7 @@ def test_fused_rotary_embedding_op_plane(test_case): def test_fused_rotary_embedding_op_interval_2d(test_case): args_dict = OrderedDict() - args_dict["test_fun"] = [_test_with_position, - _test_with_position_sinuous - ] + args_dict["test_fun"] = [_test_with_position, _test_with_position_sinuous] args_dict["x_layout"] = ["BMHK"] args_dict["mode"] = ["interval"] args_dict["base"] = [1e1] From 2eb8a20babaee1c9765f1db93807004d863fc8ec Mon Sep 17 00:00:00 2001 From: chende Date: Mon, 13 May 2024 02:16:06 +0000 Subject: [PATCH 14/14] fix MB(H3K) bug. --- .../gradient_funcs/fused_apply_rotary_emb.cpp | 4 +- .../user/kernels/fused_attention_kernels.cu | 79 ++++++++++++------- oneflow/user/ops/fused_attention_ops.cpp | 12 ++- .../modules/test_fused_rotary_embedding.py | 9 +-- 4 files changed, 63 insertions(+), 41 deletions(-) diff --git a/oneflow/core/autograd/gradient_funcs/fused_apply_rotary_emb.cpp b/oneflow/core/autograd/gradient_funcs/fused_apply_rotary_emb.cpp index 412ee63acd9..8d6b219b148 100644 --- a/oneflow/core/autograd/gradient_funcs/fused_apply_rotary_emb.cpp +++ b/oneflow/core/autograd/gradient_funcs/fused_apply_rotary_emb.cpp @@ -21,13 +21,13 @@ namespace oneflow { namespace one { struct FusedApplyRotaryEmbCaptureState : public AutoGradCaptureState { - bool requires_grad; + bool requires_grad = false; std::string x_layout{}; std::string output_layout{}; std::string mode{}; int64_t tensor_index{}; int64_t k_size{}; - float base; + float base = 0.0f; int64_t rotary_size{}; }; diff --git a/oneflow/user/kernels/fused_attention_kernels.cu b/oneflow/user/kernels/fused_attention_kernels.cu index 66e59190a18..4e4a4f396c6 100644 --- a/oneflow/user/kernels/fused_attention_kernels.cu +++ b/oneflow/user/kernels/fused_attention_kernels.cu @@ -24,6 +24,7 @@ limitations under the License. #include "cutlass/gemm/warp/mma.h" #include "kernel_forward.h" #include "oneflow/core/kernel/cuda_graph_support.h" +#include "oneflow/core/kernel/kernel_util.h" #include "trt_flash_attention/fmha.h" #include "trt_flash_attention/fmha_flash_attention.h" @@ -1017,6 +1018,7 @@ struct FusedApplyRotaryEmbParam { IndexType num_elements; const IndexType k; const IndexType x_offset; + const IndexType out_offset; IndexType ref_stride[num_dims]; // b, m, h, k IndexType out_stride[num_dims]; // ordered descendingly by stride @@ -1032,7 +1034,7 @@ struct FusedApplyRotaryEmbParam { const IndexType actual_rotary_size, const IndexType rotary_size, const IndexType rotate_stride, const IndexType num_elements, const IndexType k, const IndexType k0, const IndexType k1, - const IndexType x_offset) + const IndexType x_offset, const IndexType out_offset) : x(x), cos(cos), sin(sin), @@ -1047,7 +1049,8 @@ struct FusedApplyRotaryEmbParam { k(k), k0(k0), k1(k1), - x_offset(x_offset) {} + x_offset(x_offset), + out_offset(out_offset) {} }; template __global__ void IntervalGradKernel( FusedApplyRotaryEmbParam param) { - // printf("IntervalGradKernel TODO!\n"); for (IndexType packed_offset = threadIdx.x + blockIdx.x * blockDim.x; packed_offset < param.num_elements; packed_offset += blockDim.x * gridDim.x) { using LoadPack = cuda::elementwise::Packed; @@ -1150,7 +1152,7 @@ __global__ void IntervalGradKernel( index[num_dims - 1] = temp_offset; IndexType x_offset = param.x_offset; - IndexType out_offset = 0; + IndexType out_offset = param.out_offset; #pragma unroll for (int i = 0; i < num_dims; i++) { x_offset = x_offset + param.x_stride[i] * index[i]; @@ -1251,7 +1253,7 @@ __global__ void PlaneKernel( LoadPack x_vec; IndexType x_offset = param.x_offset; - IndexType out_offset = 0; + IndexType out_offset = param.out_offset; #pragma unroll for (int i = 0; i < num_dims; i++) { x_offset = x_offset + param.x_stride[i] * index[i]; @@ -1322,18 +1324,30 @@ __global__ void PlaneGradKernel( sin_val = *(param.sin + sinuous_offset + offset_); } else { - // TODO: plane grad kernel without sin & cos; + T val_cos = position + * expf(2.0f * static_cast(k_index % (param.actual_rotary_size >> 1)) + * param.inv_actual_rotary_size * logf(param.theta)); - T val = position - * expf(2.0f * static_cast(k_index % (param.actual_rotary_size >> 1)) - * param.inv_actual_rotary_size * logf(param.theta)); - cos_val = cosf(val); - sin_val = sinf(val); + IndexType offset_; // 针对grad, sin_val需要有 size / 2的偏移; + if (k_index < param.k0) { + offset_ = + (param.k0 - k_index > param.rotate_stride) ? param.rotate_stride : -param.rotate_stride; + } else if (k_index < param.k1) { + offset_ = + (param.k1 - k_index > param.rotate_stride) ? param.rotate_stride : -param.rotate_stride; + } + T val_sin = + position + * expf(2.0f * static_cast((k_index + offset_) % (param.actual_rotary_size >> 1)) + * param.inv_actual_rotary_size * logf(param.theta)); + + cos_val = cosf(val_cos); + sin_val = sinf(val_sin); } LoadPack x_vec; IndexType x_offset = param.x_offset; - IndexType out_offset = 0; + IndexType out_offset = param.out_offset; #pragma unroll for (int i = 0; i < num_dims; i++) { x_offset = x_offset + param.x_stride[i] * index[i]; @@ -1370,7 +1384,8 @@ void LaunchKernel(ep::CudaStream* stream, const T* x, const T* cos, const T* sin const IndexType x_b_stride, const IndexType x_m_stride, const IndexType x_h_stride, const IndexType x_offset, const IndexType out_b_stride, const IndexType out_m_stride, - const IndexType out_h_stride, IndexType num_elements, const bool is_forward) { + const IndexType out_h_stride, const IndexType out_offset, IndexType num_elements, + const bool is_forward) { const IndexType k0 = rotary_size / rotary_emb_dim, k1 = rotary_size; // TODO: this only support 1d, 2d, rotary postional encoding @@ -1381,7 +1396,7 @@ void LaunchKernel(ep::CudaStream* stream, const T* x, const T* cos, const T* sin struct FusedApplyRotaryEmbParam param( x, cos, sin, position_ids, out, theta, inv_actual_rotary_size, actual_rotary_size, - rotary_size, rotate_stride, num_elements, k, k0, k1, x_offset); + rotary_size, rotate_stride, num_elements, k, k0, k1, x_offset, out_offset); const IndexType ref_strides[num_dims] = {m * h * k, h * k, k, 1}; const IndexType out_strides[num_dims] = {out_b_stride, out_m_stride, out_h_stride, 1}; @@ -1439,7 +1454,8 @@ void DispatchPackSize(ep::CudaStream* stream, const T* x, const T* cos, const T* const IndexType x_b_stride, const IndexType x_m_stride, const IndexType x_h_stride, const IndexType x_offset, const IndexType out_b_stride, const IndexType out_m_stride, - const IndexType out_h_stride, IndexType num_elements, const bool is_forward) { + const IndexType out_h_stride, const IndexType out_offset, + IndexType num_elements, const bool is_forward) { const auto CheckPackSize = [&](const size_t PackSize) { bool r = (((reinterpret_cast(x) % (sizeof(T) * PackSize)) == 0) && (((rotary_size / rotary_emb_dim) % PackSize) == 0) @@ -1452,19 +1468,19 @@ void DispatchPackSize(ep::CudaStream* stream, const T* x, const T* cos, const T* LaunchKernel( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, - out_m_stride, out_h_stride, num_elements, is_forward); + out_m_stride, out_h_stride, out_offset, num_elements, is_forward); } else if (CheckPackSize(4)) { num_elements /= 4; LaunchKernel( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, - out_m_stride, out_h_stride, num_elements, is_forward); + out_m_stride, out_h_stride, out_offset, num_elements, is_forward); } else { num_elements /= 2; LaunchKernel( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, - out_m_stride, out_h_stride, num_elements, is_forward); + out_m_stride, out_h_stride, out_offset, num_elements, is_forward); } } @@ -1476,7 +1492,7 @@ void DispatchIndex(ep::CudaStream* stream, const T* x, const T* cos, const T* si const int64_t b, const int64_t m, const int64_t h, const int64_t k, const int64_t x_b_stride, const int64_t x_m_stride, const int64_t x_h_stride, const int64_t x_offset, const int64_t out_b_stride, const int64_t out_m_stride, - const int64_t out_h_stride, const bool is_forward) { + const int64_t out_h_stride, const int64_t out_offset, const bool is_forward) { int64_t num_elements = b * m * h * k; if (num_elements < (1 << 30)) { @@ -1487,12 +1503,12 @@ void DispatchIndex(ep::CudaStream* stream, const T* x, const T* cos, const T* si static_cast(x_m_stride), static_cast(x_h_stride), static_cast(x_offset), static_cast(out_b_stride), static_cast(out_m_stride), static_cast(out_h_stride), - static_cast(num_elements), is_forward); + static_cast(out_offset), static_cast(num_elements), is_forward); } else { DispatchPackSize( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, - out_m_stride, out_h_stride, num_elements, is_forward); + out_m_stride, out_h_stride, out_offset, num_elements, is_forward); } } @@ -1507,17 +1523,17 @@ void DispatchRotaryEmbeddingDimension(ep::CudaStream* stream, const T* x, const const int64_t x_m_stride, const int64_t x_h_stride, const int64_t x_offset, const int64_t out_b_stride, const int64_t out_m_stride, const int64_t out_h_stride, - const bool is_forward) { + const int64_t out_offset, bool is_forward) { if (rotary_emb_dim == 1) { DispatchIndex( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, - out_m_stride, out_h_stride, is_forward); + out_m_stride, out_h_stride, out_offset, is_forward); } else if (rotary_emb_dim == 2) { DispatchIndex( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, - out_m_stride, out_h_stride, is_forward); + out_m_stride, out_h_stride, out_offset, is_forward); } } @@ -1574,7 +1590,7 @@ class FusedApplyRotaryEmbKernel final : public user_op::OpKernel { reinterpret_cast(out->mut_dptr()), position_ids ? position_ids->shape_view().data() : nullptr, x_layout, output_layout, mode, static_cast(theta), rotary_size, rotary_emb_dim, b, m, h, k, x_b_stride, x_m_stride, - x_h_stride, x_offset, out_b_stride, out_m_stride, out_h_stride, is_forward); + x_h_stride, x_offset, out_b_stride, out_m_stride, out_h_stride, out_offset, is_forward); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } @@ -1604,6 +1620,9 @@ class FusedApplyRotaryEmbGradKernel final : public user_op::OpKernel { const float theta = 1.0f / ctx->Attr("base"); int rotary_emb_dim = 1; + size_t dx_byte_size = dx->shape_view().elem_cnt() * sizeof(T); + Memset(ctx->stream(), dx->mut_dptr(), 0, dx_byte_size); + if (ctx->has_input("cos", 0)) { cos = ctx->Tensor4ArgNameAndIndex("cos", 0); } if (ctx->has_input("sin", 0)) { sin = ctx->Tensor4ArgNameAndIndex("sin", 0); } @@ -1621,10 +1640,10 @@ class FusedApplyRotaryEmbGradKernel final : public user_op::OpKernel { int64_t out_b_stride = 0, out_m_stride = 0, out_h_stride = 0, out_offset = 0; int64_t x_b_stride = 0, x_m_stride = 0, x_h_stride = 0, x_offset = 0; - ParseDims(dx->shape_view(), x_layout, Optional(), k_size, 0, &b, &m, &h, &k, + ParseDims(dx->shape_view(), x_layout, Optional(), k_size, tensor_index, &b, &m, &h, &k, &out_b_stride, &out_m_stride, &out_h_stride, &out_offset); - ParseDims(dy->shape_view(), output_layout, Optional(), k_size, tensor_index, &b, &m, - &h, &k, &x_b_stride, &x_m_stride, &x_h_stride, &x_offset); + ParseDims(dy->shape_view(), output_layout, Optional(), k_size, 0, &b, &m, &h, &k, + &x_b_stride, &x_m_stride, &x_h_stride, &x_offset); bool is_forward = false; // TODO: hard code num_dims & seems redundant template problem... DispatchRotaryEmbeddingDimension( @@ -1635,7 +1654,7 @@ class FusedApplyRotaryEmbGradKernel final : public user_op::OpKernel { reinterpret_cast(dx->mut_dptr()), position_ids ? position_ids->shape_view().data() : nullptr, x_layout, output_layout, mode, static_cast(theta), rotary_size, rotary_emb_dim, b, m, h, k, x_b_stride, x_m_stride, - x_h_stride, x_offset, out_b_stride, out_m_stride, out_h_stride, is_forward); + x_h_stride, x_offset, out_b_stride, out_m_stride, out_h_stride, out_offset, is_forward); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } diff --git a/oneflow/user/ops/fused_attention_ops.cpp b/oneflow/user/ops/fused_attention_ops.cpp index d731e9893a2..d0b0687689b 100644 --- a/oneflow/user/ops/fused_attention_ops.cpp +++ b/oneflow/user/ops/fused_attention_ops.cpp @@ -900,20 +900,24 @@ Maybe ParseSplitAxis(const std::string& layout, bool can_hk_split, int64_t if (x_desc.shape().NumAxes() == 2) { if (x_layout == "(BM)(HK)") { - CHECK_EQ_OR_RETURN(x_desc.shape().At(1) % k_size, 0); + CHECK_EQ_OR_RETURN(x_desc.shape().At(1) % k_size, 0) + << "shape can not be divided by head dimension size."; num_heads = x_desc.shape().At(1) / k_size; } else if (x_layout == "(BM)(H3K)") { - CHECK_EQ_OR_RETURN(x_desc.shape().At(1) % (k_size * 3), 0); + CHECK_EQ_OR_RETURN(x_desc.shape().At(1) % (k_size * 3), 0) + << "shape can not be divided by head dimension size."; num_heads = x_desc.shape().At(1) / (k_size * 3); } else { UNIMPLEMENTED_THEN_RETURN(); } } else if (x_desc.shape().NumAxes() == 3) { if (x_layout == "BM(HK)" || x_layout == "MB(HK)") { - CHECK_EQ_OR_RETURN(x_desc.shape().At(2) % k_size, 0); + CHECK_EQ_OR_RETURN(x_desc.shape().At(2) % k_size, 0) + << "shape can not be divided by head dimension size."; num_heads = x_desc.shape().At(2) / k_size; } else if (x_layout == "BM(H3K)" || x_layout == "MB(H3K)") { - CHECK_EQ_OR_RETURN(x_desc.shape().At(2) % (k_size * 3), 0); + CHECK_EQ_OR_RETURN(x_desc.shape().At(2) % (k_size * 3), 0) + << "shape can not be divided by head dimension size."; num_heads = x_desc.shape().At(2) / (k_size * 3); } else if (x_layout == "(BM)HK") { num_heads = x_desc.shape().At(1); diff --git a/python/oneflow/test/modules/test_fused_rotary_embedding.py b/python/oneflow/test/modules/test_fused_rotary_embedding.py index 2f154986aed..dc70cb24ec9 100644 --- a/python/oneflow/test/modules/test_fused_rotary_embedding.py +++ b/python/oneflow/test/modules/test_fused_rotary_embedding.py @@ -121,7 +121,7 @@ def naive_embedding_tensor( ..., 2, : ].reshape(dims) * sin.reshape([B, M, 1, K]) - naive_out = flow.cat((out0, out1, out2), axis=-1) + naive_out = flow.cat((out0, out1, out2), dim=-1) elif x_layout == "MB(H3K)": out0 = x[..., 0, :].reshape(dims) * cos.permute([2, 0, 1, 3]).reshape( [M, B, 1, K] @@ -1386,8 +1386,7 @@ class TestFusedRotaryEmbedding(flow.unittest.TestCase): def test_fused_rotary_embedding_op_plane(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_plane] - # args_dict["x_layout"] = ["MB(H3K)"] - args_dict["x_layout"] = ["MB(HK)"] # TODO: MB(H3K) bug; + args_dict["x_layout"] = ["MB(HK)", "MB(H3K)"] args_dict["mode"] = ["plane"] args_dict["base"] = [1e1] args_dict["rotary_size"] = [4, 8] @@ -1404,7 +1403,7 @@ def test_fused_rotary_embedding_op_plane(test_case): def test_fused_rotary_embedding_op_interval_2d(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_with_position, _test_with_position_sinuous] - args_dict["x_layout"] = ["BMHK"] + args_dict["x_layout"] = ["BMHK", "BM(H3K)"] args_dict["mode"] = ["interval"] args_dict["base"] = [1e1] args_dict["rotary_size"] = [4] @@ -1426,7 +1425,7 @@ def test_fused_rotary_embedding_op_interval_1d(test_case): _test_with_position, _test_with_position_sinuous, ] - args_dict["x_layout"] = ["BMHK"] + args_dict["x_layout"] = ["BMHK", "BM(H3K)"] args_dict["mode"] = ["interval"] args_dict["base"] = [1e1] args_dict["rotary_size"] = [4]