From dedab4f8cf4a6f892311d331ea0a6b93568ff2b0 Mon Sep 17 00:00:00 2001 From: Amit Sabne Date: Sat, 14 Sep 2024 07:02:08 -0700 Subject: [PATCH] Allow propagations on reduce to occur PiperOrigin-RevId: 674646492 --- xla/service/space_to_batch_converter.cc | 9 +++--- xla/service/space_to_batch_converter_test.cc | 33 ++++++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/xla/service/space_to_batch_converter.cc b/xla/service/space_to_batch_converter.cc index ea2c6a7f5df71..5a6b8c5e0627b 100644 --- a/xla/service/space_to_batch_converter.cc +++ b/xla/service/space_to_batch_converter.cc @@ -1747,11 +1747,12 @@ bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer, const int64_t space_dim = result[DimMapper(SpaceToBatchDimMap::kSpace0)]; // Support the trivial case where none of the batch and split spatial dim // are being reduced. - return !absl::c_linear_search(reduce_dims, batch_dim) && - !absl::c_linear_search(reduce_dims, space_dim); + if (!absl::c_linear_search(reduce_dims, batch_dim) && + !absl::c_linear_search(reduce_dims, space_dim)) { + return true; + } - // Support only the trivial case where both batch and split spatial dim are - // being reduced + // If both batch and space dim are being reduced, propagate. return absl::c_linear_search(reduce_dims, batch_dim) && absl::c_linear_search(reduce_dims, space_dim); } diff --git a/xla/service/space_to_batch_converter_test.cc b/xla/service/space_to_batch_converter_test.cc index 315320627de1d..6f1d86b618216 100644 --- a/xla/service/space_to_batch_converter_test.cc +++ b/xla/service/space_to_batch_converter_test.cc @@ -385,5 +385,38 @@ TEST_F(SpaceToBatchConverterTest, ReduceDegenerateDim) { EXPECT_THAT(root->operand(0), op::Slice()); } +TEST_F(SpaceToBatchConverterTest, PropagateOnReduce) { + std::string hlo_string = R"( +HloModule xla_computation_unknown.14 + +region_0.134 { + Arg_0.135 = f32[] parameter(0) + Arg_1.136 = f32[] parameter(1) + ROOT add.137 = f32[] add(Arg_0.135, Arg_1.136) +} + +ENTRY main.140 { + p0 = bf16[1,512,32,128]{3,2,1,0} parameter(0) + p1 = f32[3,3,128,128]{3,2,1,0} parameter(1) + %convolution.755 = f32[1,512,32,128]{3,2,1,0} + convolution(p0, p1), + window={size=3x3 pad=1_1x1_1 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f + %constant.19458 = f32[] constant(0) + ROOT %reduce.1354 = f32[128]{0} reduce(%convolution.755, %constant.19458), + dimensions={0,1,2}, to_apply=%region_0.134 +} + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + auto computation = module->entry_computation(); + SpaceToBatchConverter converter( + SpaceToBatchController{true, true, true, true, /*number_of_splits=*/8}); + ASSERT_TRUE(converter.Run(module.get()).value()); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Reduce()); +} + } // namespace } // namespace xla