From dec5edc7a80c20427e950f5fea33e50c6b12be21 Mon Sep 17 00:00:00 2001 From: Akhil Goel Date: Wed, 25 Sep 2024 23:44:47 -0700 Subject: [PATCH] Refactor and add new post-op --- xla/service/cpu/onednn_config.proto | 1 + .../cpu/onednn_contraction_rewriter.cc | 29 ++++++-- xla/service/cpu/onednn_matmul.cc | 23 ++---- xla/service/cpu/onednn_memory_util.cc | 18 +++-- xla/service/cpu/tests/onednn_matmul_test.cc | 73 +++++++++++++++++++ 5 files changed, 115 insertions(+), 29 deletions(-) diff --git a/xla/service/cpu/onednn_config.proto b/xla/service/cpu/onednn_config.proto index 9f38673eaaceb..e18b6da541486 100644 --- a/xla/service/cpu/onednn_config.proto +++ b/xla/service/cpu/onednn_config.proto @@ -50,6 +50,7 @@ message OneDnnFusionConfig { ELU = 8; RELU6 = 9; SIGMOID = 10; + SUM = 11; } repeated FusionKind ops = 1; // To avoid protobuf failures for specific decimal values, diff --git a/xla/service/cpu/onednn_contraction_rewriter.cc b/xla/service/cpu/onednn_contraction_rewriter.cc index 19122b393ce23..1e6d68904e160 100644 --- a/xla/service/cpu/onednn_contraction_rewriter.cc +++ b/xla/service/cpu/onednn_contraction_rewriter.cc @@ -640,7 +640,7 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { m::Op(&addend)); if (!Match(addend_intermediate, addend_pattern)) return absl::OkStatus(); - if (optional_addend_broadcast && addend->shape().rank() != 1) { + if (optional_addend_broadcast) { auto new_shape = AdjustBiasShape(optional_addend_broadcast, dot->shape()); if (new_shape.ok()) { @@ -653,17 +653,28 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { } // Validate addend for fusion. + auto addend_user_count = addend->user_count(); + auto addend_idx = -1; if (IsSupportedType(addend->shape().element_type()) && IsOperandFusible(addend, dot)) { + addend_idx = new_operands.size(); new_operands.push_back(addend); } else { return absl::OkStatus(); } + auto matmul_call = Cast(instr->AddInstruction( + dot->CloneWithNewOperands(dot->shape(), new_operands))); + auto backend_config = matmul_call->backend_config(); + bool can_fuse_sum = + (ShapeUtil::Equal(matmul_call->shape(), addend->shape()) && + addend_user_count == 1 && + matmul_call->output_operand_aliasing().empty()); + // TODO(intel-tf): Remove this restriction once oneDNN has an optimized // implementation for broadcasted add across all dimensions. OneDnnFusionConfig_FusionKind kind = OneDnnFusionConfig::UNDEFINED; - kind = (addend->shape().rank() == 1) + kind = (ShapeUtil::TrueRank(addend->shape()) == 1) ? (dot->backend_config() ->mutable_onednn_matmul_config() ->fusions() @@ -671,13 +682,14 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { .empty() ? OneDnnFusionConfig::BIAS : OneDnnFusionConfig::UNDEFINED) - : OneDnnFusionConfig::BINARY_ADD; + : can_fuse_sum ? OneDnnFusionConfig::SUM + : OneDnnFusionConfig::BINARY_ADD; if (kind == OneDnnFusionConfig::UNDEFINED) return absl::OkStatus(); - auto matmul_call = Cast(instr->AddInstruction( - dot->CloneWithNewOperands(dot->shape(), new_operands))); + if (kind == OneDnnFusionConfig::SUM) { + matmul_call->set_output_to_operand_aliasing({{{}, {addend_idx, {}}}}); + } - auto backend_config = matmul_call->backend_config(); backend_config->mutable_onednn_matmul_config() ->mutable_fusions() ->add_ops(kind); @@ -1089,6 +1101,11 @@ class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor { auto scratch_add = AddScratch(custom_call); if (scratch_add.ok()) { custom_call = *scratch_add; + auto aliases = custom_call->output_operand_aliasing(); + if (!aliases.empty()) { + custom_call->set_output_to_operand_aliasing( + {{{0}, {aliases[0].second.first, {}}}}); + } } else { VLOG(2) << scratch_add.status(); } diff --git a/xla/service/cpu/onednn_matmul.cc b/xla/service/cpu/onednn_matmul.cc index 1b2dbee81c661..4dbc687b773ec 100644 --- a/xla/service/cpu/onednn_matmul.cc +++ b/xla/service/cpu/onednn_matmul.cc @@ -145,16 +145,11 @@ std::unique_ptr CreateMatMulPrimDesc( case OneDnnFusionConfig::SIGMOID: post_ops.append_eltwise(dnnl::algorithm::eltwise_logistic, 0.f, 0.f); break; + case OneDnnFusionConfig::SUM: + post_ops.append_sum(); + break; case OneDnnFusionConfig::BIAS: { bias_md = fused_mds.at(fused_operand_idx); - // Extend bias rank to match result rank. - auto missed_rank = output_md.get_ndims() - bias_md.get_ndims(); - XLA_LIGHTWEIGHT_CHECK(missed_rank >= 0); - if (missed_rank > 0) { - auto bias_dims = bias_md.get_dims(); - bias_dims.insert(bias_dims.begin(), missed_rank, 1); - bias_md = bias_md.reshape(bias_dims); - } if (fused_operands_ref) { fused_operands_ref->postop_args.emplace_back( DNNL_ARG_BIAS, @@ -168,14 +163,6 @@ std::unique_ptr CreateMatMulPrimDesc( break; case OneDnnFusionConfig::BINARY_ADD: { auto binary_md = fused_mds.at(fused_operand_idx); - // Extend addend rank to match result rank. - auto missed_rank = output_md.get_ndims() - binary_md.get_ndims(); - XLA_LIGHTWEIGHT_CHECK(missed_rank >= 0); - if (missed_rank > 0) { - auto binary_dims = binary_md.get_dims(); - binary_dims.insert(binary_dims.begin(), missed_rank, 1); - binary_md = binary_md.reshape(binary_dims); - } if (fused_operands_ref) { auto arg_idx = DNNL_ARG_ATTR_MULTIPLE_POST_OP(post_ops.len()) | DNNL_ARG_SRC_1; @@ -309,6 +296,10 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMul( std::vector fused_mds; std::vector fused_bufs; for (int64_t i = 0; i < num_fused_operands; ++i) { + if (matmul_config.fusions().ops(i) == OneDnnFusionConfig::SUM) { + arg_indx++; + continue; + } MemrefInfo operand_minfo(args[arg_indx++]); fused_mds.push_back(operand_minfo.GetOneDnnMemDesc()); fused_bufs.push_back(operand_minfo.Data()); diff --git a/xla/service/cpu/onednn_memory_util.cc b/xla/service/cpu/onednn_memory_util.cc index 6ab913161a7e4..7b3eddf26d653 100644 --- a/xla/service/cpu/onednn_memory_util.cc +++ b/xla/service/cpu/onednn_memory_util.cc @@ -45,9 +45,10 @@ namespace cpu { struct MemrefInfoPOD { int64_t dtype; int64_t rank; + void* data; + int64_t pad; int64_t dims[kOneDnnMaxNDims]; int64_t strides[kOneDnnMaxNDims]; - void* data; }; MemrefInfoHandler CreateMemrefInfoFromLiteral(const Literal* literal) { @@ -91,7 +92,7 @@ StackAlloca GetAllocaAndEmitMemrefInfo(llvm::IRBuilder<>& builder, llvm::ArrayType::get(builder.getInt64Ty(), kOneDnnMaxNDims); llvm::StructType* memref_info_type = llvm::StructType::get( builder.getContext(), - {i64_type, i64_type, i64_array_type, i64_array_type, ptr_type}); + {i64_type, i64_type, ptr_type, i64_type, i64_array_type, i64_array_type}); // Prepare array dims and strides. llvm::Value* dims_val = llvm::UndefValue::get(i64_array_type); @@ -103,16 +104,19 @@ StackAlloca GetAllocaAndEmitMemrefInfo(llvm::IRBuilder<>& builder, strides_val = builder.CreateInsertValue(strides_val, stride_val, i); } - // Prepare values for struct MemrefInfo. + // Prepare values for struct MemrefInfo with padding to align to system + // cacheline llvm::Value* dtype_val = builder.getInt64(shape.element_type()); llvm::Value* rank_val = builder.getInt64(rank); + llvm::Value* pad_val = builder.getInt64(0xff); llvm::Value* data_ptr = ir_array.GetBasePointer(); llvm::Value* memref_info_val = llvm::UndefValue::get(memref_info_type); memref_info_val = builder.CreateInsertValue(memref_info_val, dtype_val, 0); memref_info_val = builder.CreateInsertValue(memref_info_val, rank_val, 1); - memref_info_val = builder.CreateInsertValue(memref_info_val, dims_val, 2); - memref_info_val = builder.CreateInsertValue(memref_info_val, strides_val, 3); - memref_info_val = builder.CreateInsertValue(memref_info_val, data_ptr, 4); + memref_info_val = builder.CreateInsertValue(memref_info_val, data_ptr, 2); + memref_info_val = builder.CreateInsertValue(memref_info_val, pad_val, 3); + memref_info_val = builder.CreateInsertValue(memref_info_val, dims_val, 4); + memref_info_val = builder.CreateInsertValue(memref_info_val, strides_val, 5); // Allocate MemrefInfo on the stack llvm::Value* memref_info_ptr = llvm_ir::EmitAllocaAtFunctionEntry( @@ -141,8 +145,8 @@ dnnl::memory::data_type MemrefInfo::GetOneDnnDataType() const { } dnnl::memory::desc MemrefInfo::GetOneDnnMemDesc() const { - auto dims = GetOneDnnDims(); auto dtype = GetOneDnnDataType(); + auto dims = GetOneDnnDims(); auto strides = GetOneDnnStrides(); return dnnl::memory::desc{dims, dtype, strides}; } diff --git a/xla/service/cpu/tests/onednn_matmul_test.cc b/xla/service/cpu/tests/onednn_matmul_test.cc index 57f7c09aba11e..1a186346b753d 100644 --- a/xla/service/cpu/tests/onednn_matmul_test.cc +++ b/xla/service/cpu/tests/onednn_matmul_test.cc @@ -64,6 +64,17 @@ class MatmulTest : public HloTestBase { ; CHECK-DAG: } ; CHECK: } )"; + const char* fused_matmul_sum_ = R"( + ; CHECK: custom_call_target="__onednn$matmul", + ; CHECK: backend_config={ + ; CHECK-DAG: "outer_dimension_partitions":[], + ; CHECK-DAG: "onednn_matmul_config":{ + ; CHECK-DAG: "fusions":{ + ; CHECK-DAG: "ops":["SUM"] + ; CHECK-DAG: } + ; CHECK-DAG: } + ; CHECK: } + )"; const char* matmul_rewrite_str_ = R"( ; CHECK: custom_call_target="__onednn$matmul", ; CHECK: backend_config={ @@ -267,6 +278,68 @@ TEST_F(MatmulTest, SimpleTestF32WithBiasAsParameter1) { })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, fused_matmul_sum_); +} + +TEST_F(MatmulTest, SimpleTestF32Add2Dots) { + const char* matmul_module_str = R"( + HloModule matmul.biasadd.test.f32 + + ENTRY matmul.biasadd.test.f32 { + arg0.1 = f32[32,32,40,30] parameter(0), parameter_replication={false} + arg0.2 = f32[32,32,30,40] parameter(1), parameter_replication={false} + arg0.3 = f32[32,32,40,40] parameter(2), parameter_replication={false} + arg0.4 = f32[32,32,40,40] parameter(3), parameter_replication={false} + dot.7 = f32[32,32,40,40] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + dot.8 = f32[32,32,40,40] dot(arg0.3, arg0.4), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT add.10 = f32[32,32,40,40] add(dot.7, dot.8) + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, fused_matmul_sum_); +} + +TEST_F(MatmulTest, SimpleTestF16Add2Dots) { + if (!IsSupportedType(PrimitiveType::F16)) { + GTEST_SKIP() << "CPU does not support F16."; + } + + const char* matmul_module_str = R"( + HloModule matmul.biasadd.test.f16 + + ENTRY matmul.biasadd.test.f16 { + arg0.1 = f16[32,32,40,30] parameter(0), parameter_replication={false} + arg0.2 = f16[32,32,30,40] parameter(1), parameter_replication={false} + arg0.3 = f16[32,32,40,40] parameter(2), parameter_replication={false} + arg0.4 = f16[32,32,40,40] parameter(3), parameter_replication={false} + dot.7 = f16[32,32,40,40] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + dot.8 = f16[32,32,40,40] dot(arg0.3, arg0.4), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT add.10 = f16[32,32,40,40] add(dot.7, dot.8) + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-2, 1e-2})); + MatchOptimizedHlo(matmul_module_str, fused_matmul_sum_); +} + +TEST_F(MatmulTest, SimpleTestBF16Add2Dots) { + if (!IsSupportedType(PrimitiveType::BF16)) { + GTEST_SKIP() << "CPU does not support BF16."; + } + + const char* matmul_module_str = R"( + HloModule matmul.biasadd.test.bf16 + + ENTRY matmul.biasadd.test.bf16 { + arg0.1 = bf16[32,32,40,30] parameter(0), parameter_replication={false} + arg0.2 = bf16[32,32,30,40] parameter(1), parameter_replication={false} + arg0.3 = bf16[32,32,40,40] parameter(2), parameter_replication={false} + arg0.4 = bf16[32,32,40,40] parameter(3), parameter_replication={false} + dot.7 = bf16[32,32,40,40] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + dot.8 = bf16[32,32,40,40] dot(arg0.3, arg0.4), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT add.10 = bf16[32,32,40,40] add(dot.7, dot.8) + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-2, 1e-2})); MatchOptimizedHlo(matmul_module_str, fused_matmul_binary_add_); }