Skip to content

Commit

Permalink
Fix space-to-batch propagation bug on reduce.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 674118535
  • Loading branch information
amitsabne1 authored and Google-ML-Automation committed Sep 13, 2024
1 parent b7de8d2 commit 3633c96
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
5 changes: 2 additions & 3 deletions xla/service/space_to_batch_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2149,15 +2149,14 @@ absl::StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
const int64_t rank = first_operand->shape().rank();

const int64_t output_rank = new_consumer->shape().rank();

// Make a map of each dim in original reduce output to input.
std::vector<int64_t> old_reduce_output_to_input(output_rank);
int dim_number_to_assign_old = 0;
for (int64_t i = 0; i < rank; ++i) {
if (auto it = absl::c_find(reduce_dims, i); it != reduce_dims.end()) {
continue;
}
old_reduce_output_to_input[i] = dim_number_to_assign_old++;
old_reduce_output_to_input[dim_number_to_assign_old++] = i;
}

// Make a map of each dim in new reduce output to the new input.
Expand All @@ -2167,7 +2166,7 @@ absl::StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
if (auto it = absl::c_find(changed_dims, i); it != changed_dims.end()) {
continue;
}
new_reduce_output_to_input[i] = dim_number_to_assign_new++;
new_reduce_output_to_input[dim_number_to_assign_new++] = i;
}

std::vector<int64_t> new_permute_dims(output_rank);
Expand Down
34 changes: 34 additions & 0 deletions xla/service/space_to_batch_converter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -351,5 +351,39 @@ TEST_F(SpaceToBatchConverterTest, DoNotPropagateOnTupleReduce) {
EXPECT_THAT(root, op::Reduce());
}

TEST_F(SpaceToBatchConverterTest, ReduceDegenerateDim) {
std::string hlo_string = R"(
HloModule module
%region_42.4982 {
%Arg_0.38 = f32[] parameter(0)
%Arg_1.39 = f32[] parameter(1)
ROOT %add.40 = f32[] add(f32[] %Arg_0.38, f32[] %Arg_1.39)
}
ENTRY computation {
%p0 = f32[2,1,84,84,3]{4,3,2,1,0} parameter(0)
%p1 = f32[3,3,3,3,32]{4,3,2,1,0} parameter(1)
%constant.10559 = f32[] constant(0)
%convolution.98 = f32[2,1,84,84,32]{4,3,2,1,0} convolution(%p0, %p1),
window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f
ROOT %reduce.2606 = f32[2,84,84]{2,1,0} reduce(f32[2,1,84,84,32]{4,3,2,1,0}
%convolution.98, f32[] %constant.10559), dimensions={1,4}, to_apply=%region_42.4982
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));

auto computation = module->entry_computation();
SpaceToBatchConverter converter(
SpaceToBatchController{true, true, true, true, /*number_of_splits=*/8});
ASSERT_TRUE(converter.Run(module.get()).value());

HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Transpose());
EXPECT_THAT(root->operand(0), op::Slice());
}

} // namespace
} // namespace xla

0 comments on commit 3633c96

Please sign in to comment.