diff --git a/xla/service/cpu/onednn_config.proto b/xla/service/cpu/onednn_config.proto index 44829a6857f1f..151f38efe09a4 100644 --- a/xla/service/cpu/onednn_config.proto +++ b/xla/service/cpu/onednn_config.proto @@ -36,7 +36,7 @@ message OneDnnOptimizationConfig { } message OneDnnFusionConfig { - // These enum needs to be mapped to oneDNN enum for post_op algorithm. + // This enum needs to be mapped to oneDNN enum for post_op algorithm. // TODO(intel-tf): Add kinds supported by oneDNN. enum FusionKind { UNDEFINED = 0; @@ -50,6 +50,7 @@ message OneDnnFusionConfig { ELU = 8; RELU6 = 9; SIGMOID = 10; + SUM = 11; // This represents in-place accumulation. } repeated FusionKind ops = 1; // To avoid protobuf failures for specific decimal values, diff --git a/xla/service/cpu/onednn_contraction_rewriter.cc b/xla/service/cpu/onednn_contraction_rewriter.cc index 01ffb340e07c1..cae01ac3a1ca5 100644 --- a/xla/service/cpu/onednn_contraction_rewriter.cc +++ b/xla/service/cpu/onednn_contraction_rewriter.cc @@ -679,8 +679,11 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { } // Validate addend for fusion. + auto addend_user_count = addend->user_count(); + auto addend_idx = -1; if (IsSupportedType(addend->shape().element_type()) && IsOperandFusible(addend, contraction)) { + addend_idx = new_operands.size(); new_operands.push_back(addend); } else { return absl::OkStatus(); @@ -690,6 +693,10 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { contraction->CloneWithNewOperands(contraction->shape(), new_operands))); auto backend_config = custom_call->backend_config(); + bool can_fuse_sum = + (ShapeUtil::Equal(custom_call->shape(), addend->shape()) && + addend_user_count == 1 && + custom_call->output_operand_aliasing().empty()); // TODO(intel-tf): Remove this restriction once oneDNN has an optimized // implementation for broadcasted add across all dimensions. @@ -699,9 +706,15 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { ? (GetKernelConfig(&backend_config)->fusions().ops().empty() ? OneDnnFusionConfig::BIAS : OneDnnFusionConfig::UNDEFINED) - : OneDnnFusionConfig::BINARY_ADD; + : can_fuse_sum ? OneDnnFusionConfig::SUM + : OneDnnFusionConfig::BINARY_ADD; if (kind == OneDnnFusionConfig::UNDEFINED) return absl::OkStatus(); + // Alias output buffers to addend for in-place accumulation + if (kind == OneDnnFusionConfig::SUM) { + custom_call->set_output_to_operand_aliasing({{{}, {addend_idx, {}}}}); + } + GetKernelConfig(&backend_config)->mutable_fusions()->add_ops(kind); if (optional_addend_broadcast) { @@ -1108,6 +1121,11 @@ class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor { auto scratch_add = AddScratch(custom_call); if (scratch_add.ok()) { custom_call = *scratch_add; + auto aliases = custom_call->output_operand_aliasing(); + if (!aliases.empty()) { + custom_call->set_output_to_operand_aliasing( + {{{0}, {aliases[0].second.first, {}}}}); + } } else { VLOG(2) << scratch_add.status(); } diff --git a/xla/service/cpu/onednn_convolution.cc b/xla/service/cpu/onednn_convolution.cc index 30e91fb4aae3e..d39b33507c354 100644 --- a/xla/service/cpu/onednn_convolution.cc +++ b/xla/service/cpu/onednn_convolution.cc @@ -184,6 +184,10 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnConvolution( std::vector fused_mds; std::vector fused_bufs; for (int64_t i = 0; i < num_fused_operands; ++i) { + if (conv_config.fusions().ops(i) == OneDnnFusionConfig::SUM) { + arg_indx++; + continue; + } MemrefInfo operand_minfo(args[arg_indx++]); fused_mds.push_back(operand_minfo.GetOneDnnMemDesc()); fused_bufs.push_back(operand_minfo.Data()); @@ -215,6 +219,9 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnConvolution( post_ops.append_binary(dnnl::algorithm::binary_add, binary_md); fused_operand_idx++; } break; + case OneDnnFusionConfig::SUM: + post_ops.append_sum(); + break; default: LOG(FATAL) << __FILE__ << ":" << __LINE__ diff --git a/xla/service/cpu/onednn_matmul.cc b/xla/service/cpu/onednn_matmul.cc index 77f25f1b17ec5..7cef365a6be21 100644 --- a/xla/service/cpu/onednn_matmul.cc +++ b/xla/service/cpu/onednn_matmul.cc @@ -236,6 +236,10 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMul( std::vector fused_mds; std::vector fused_bufs; for (int64_t i = 0; i < num_fused_operands; ++i) { + if (matmul_config.fusions().ops(i) == OneDnnFusionConfig::SUM) { + arg_indx++; + continue; + } MemrefInfo operand_minfo(args[arg_indx++]); fused_mds.push_back(operand_minfo.GetOneDnnMemDesc()); fused_bufs.push_back(operand_minfo.Data()); diff --git a/xla/service/cpu/onednn_util.cc b/xla/service/cpu/onednn_util.cc index 17d09230ef63a..ee954f6a4df73 100644 --- a/xla/service/cpu/onednn_util.cc +++ b/xla/service/cpu/onednn_util.cc @@ -67,6 +67,9 @@ dnnl::post_ops PopulateOneDnnPostOps( case OneDnnFusionConfig::SIGMOID: post_ops.append_eltwise(dnnl::algorithm::eltwise_logistic, 0.f, 0.f); break; + case OneDnnFusionConfig::SUM: + post_ops.append_sum(); + break; case OneDnnFusionConfig::BIAS: { *bias_md = fused_mds.at(fused_operand_idx); // TODO(intel-tf): Move this check to the rewriter file diff --git a/xla/service/cpu/tests/onednn_convolution_test.cc b/xla/service/cpu/tests/onednn_convolution_test.cc index a710898ca8350..fa32d062c8436 100644 --- a/xla/service/cpu/tests/onednn_convolution_test.cc +++ b/xla/service/cpu/tests/onednn_convolution_test.cc @@ -241,6 +241,26 @@ TEST_P(ConvolutionTest, Conv2DWithBiasAndBinaryAddTest) { RunCompareAndMatchOptimizedHlo(outline, {"BIAS"}); } +TEST_P(ConvolutionTest, Conv2DWithSumTest) { + const absl::string_view outline = R"( + HloModule convolution.test.with.sum + + ENTRY convolution.test.with.sum { + arg0.1 = $dtype[1,22,22,1] parameter(0) + arg0.2 = $dtype[1,11,11,1] parameter(1) + constant.3 = $dtype[] constant(1) + broadcast.4 = $dtype[8,8,1,1] broadcast(constant.3), dimensions={} + convolution.0 = $dtype[1,11,11,1] convolution(arg0.1, broadcast.4), + window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f + ROOT add.10 = $dtype[1,11,11,1] add(convolution.0, arg0.2) + })"; + + // Optimized HLO must match "SUM" only for precisions that support Elementwise + // Add operations + RunCompareAndMatchOptimizedHlo(outline, + {(dtype_ == BF16) ? "BINARY_ADD" : "SUM"}); +} + INSTANTIATE_TEST_SUITE_P( OneDnnConvolutionTestSuite, ConvolutionTest, ::testing::Values(F32, BF16, F16), diff --git a/xla/service/cpu/tests/onednn_matmul_test.cc b/xla/service/cpu/tests/onednn_matmul_test.cc index 9234a56c9dd66..350eb8992fb78 100644 --- a/xla/service/cpu/tests/onednn_matmul_test.cc +++ b/xla/service/cpu/tests/onednn_matmul_test.cc @@ -64,6 +64,17 @@ class MatmulTest : public HloTestBase { ; CHECK-DAG: } ; CHECK: } )"; + const char* fused_matmul_sum_ = R"( + ; CHECK: custom_call_target="__onednn$matmul", + ; CHECK: backend_config={ + ; CHECK-DAG: "outer_dimension_partitions":[], + ; CHECK-DAG: "onednn_matmul_config":{ + ; CHECK-DAG: "fusions":{ + ; CHECK-DAG: "ops":["SUM"] + ; CHECK-DAG: } + ; CHECK-DAG: } + ; CHECK: } + )"; const char* matmul_rewrite_str_ = R"( ; CHECK: custom_call_target="__onednn$matmul", ; CHECK: backend_config={ @@ -267,7 +278,47 @@ TEST_F(MatmulTest, SimpleTestF32WithBiasAsParameter1) { })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); - MatchOptimizedHlo(matmul_module_str, fused_matmul_binary_add_); + MatchOptimizedHlo(matmul_module_str, fused_matmul_sum_); +} + +TEST_F(MatmulTest, SimpleTestF32Add2Dots) { + const char* matmul_module_str = R"( + HloModule matmul.biasadd.test.f32 + + ENTRY matmul.biasadd.test.f32 { + arg0.1 = f32[32,32,40,30] parameter(0) + arg0.2 = f32[32,32,30,40] parameter(1) + arg0.3 = f32[32,32,40,40] parameter(2) + arg0.4 = f32[32,32,40,40] parameter(3) + dot.7 = f32[32,32,40,40] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + dot.8 = f32[32,32,40,40] dot(arg0.3, arg0.4), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT add.10 = f32[32,32,40,40] add(dot.7, dot.8) + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, fused_matmul_sum_); +} + +TEST_F(MatmulTest, SimpleTestF16Add2Dots) { + if (!IsSupportedType(PrimitiveType::F16)) { + GTEST_SKIP() << "CPU does not support F16."; + } + + const char* matmul_module_str = R"( + HloModule matmul.biasadd.test.f16 + + ENTRY matmul.biasadd.test.f16 { + arg0.1 = f16[32,64,128] parameter(0) + arg0.2 = f16[32,128,64] parameter(1) + arg0.3 = f16[32,64,64] parameter(2) + arg0.4 = f16[32,64,64] parameter(3) + dot.7 = f16[32,64,64] dot(arg0.1, arg0.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} + dot.8 = f16[32,64,64] dot(arg0.3, arg0.4), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} + ROOT add.10 = f16[32,64,64] add(dot.7, dot.8) + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-2, 1e-2})); + MatchOptimizedHlo(matmul_module_str, fused_matmul_sum_); } TEST_F(MatmulTest, SimpleTestF32WithBiasAsParameter2) {