diff --git a/xla/service/BUILD b/xla/service/BUILD index 8d5025b0629184..dafb526d8a1110 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -3287,6 +3287,40 @@ cc_library( ], ) +cc_library( + name = "layout_canonicalizer", + srcs = ["layout_canonicalizer.cc"], + hdrs = ["layout_canonicalizer.h"], + deps = [ + ":hlo_pass", + ":layout_assignment", + "//xla:permutation_util", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +xla_cc_test( + name = "layout_canonicalizer_test", + srcs = ["layout_canonicalizer_test.cc"], + deps = [ + ":layout_canonicalizer", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/log", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", + ], +) + xla_cc_test( name = "while_loop_simplifier_test", srcs = ["while_loop_simplifier_test.cc"], diff --git a/xla/service/layout_canonicalizer.cc b/xla/service/layout_canonicalizer.cc new file mode 100644 index 00000000000000..d9117789bae8c9 --- /dev/null +++ b/xla/service/layout_canonicalizer.cc @@ -0,0 +1,145 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/layout_canonicalizer.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/permutation_util.h" +#include "xla/service/layout_assignment.h" +#include "xla/shape.h" +#include "xla/shape_util.h" + +namespace xla { +namespace { + +std::vector CanonicalizeInstructionLayout(HloInstruction* instr, + bool is_entry_root); + +bool IsLayoutDescending(const Shape& shape) { + return absl::c_is_sorted(shape.layout().minor_to_major(), + [](int64_t a, int64_t b) { return a > b; }); +} + +// Given an instruction (with non-tuple output shape), this function updates the +// output shape such that the layout is descending. It returns the +// major-to-minor layout ordering which will be used when instr is used as an +// operand. +std::vector HandleOutput(HloInstruction* instr) { + CHECK(!instr->shape().IsTuple()); + if (IsLayoutDescending(instr->shape())) { + return {}; + } + // Create the major-to-minor ordering to construct the new logical dimensions + std::vector major_to_minor; + absl::c_reverse_copy(instr->shape().layout().minor_to_major(), + std::back_inserter(major_to_minor)); + + // Compose shape's dimensions with the major-to-minor layout + std::vector input_new_logical_dims = + ComposePermutations(instr->shape().dimensions(), major_to_minor); + + // Update the shape + *instr->mutable_shape() = ShapeUtil::MakeShapeWithDescendingLayout( + instr->shape().element_type(), input_new_logical_dims); + return major_to_minor; +} + +std::vector HandleBroadcast(HloInstruction* broadcast, + bool is_entry_root) { + VLOG(3) << "HandleBroadcast: " << broadcast->name(); + // Handle broadcast input + HloInstruction* operand = broadcast->mutable_operand(0); + std::vector operand_map = + CanonicalizeInstructionLayout(operand, false); + VLOG(3) << "operand_map = " << absl::StrJoin(operand_map, ","); + + // Handle output + std::vector output_map; + if (!is_entry_root) { + output_map = HandleOutput(broadcast); + } + VLOG(3) << "output_map = " << absl::StrJoin(output_map, ","); + + // Compose dimension map with the inverse of the output map. + if (!output_map.empty()) { + std::vector inverse_output_map = InversePermutation(output_map); + std::vector new_broadcast_dimensions; + new_broadcast_dimensions.reserve(broadcast->dimensions().size()); + for (int64_t dim : broadcast->dimensions()) { + new_broadcast_dimensions.push_back(inverse_output_map[dim]); + } + VLOG(3) << "dimensions after applying output_map = " + << absl::StrJoin(new_broadcast_dimensions, ","); + *broadcast->mutable_dimensions() = new_broadcast_dimensions; + } + + // Compose dimension map with the operand map. + if (!operand_map.empty()) { + std::vector new_broadcast_dimensions = + ComposePermutations(broadcast->dimensions(), operand_map); + VLOG(3) << "dimensions after applying operand_map = " + << absl::StrJoin(new_broadcast_dimensions, ","); + *broadcast->mutable_dimensions() = new_broadcast_dimensions; + } + VLOG(3) << "Broadcast after: " << broadcast->ToString(); + return output_map; +} + +std::vector CanonicalizeInstructionLayout(HloInstruction* instr, + bool is_entry_root) { + if (!LayoutAssignment::InstructionCanChangeLayout(instr)) { + return {}; + } + // For now, we only handle broadcast and transpose. I will add other ops + // gradually. + switch (instr->opcode()) { + case HloOpcode::kBroadcast: + case HloOpcode::kTranspose: + return HandleBroadcast(instr, is_entry_root); + default: + break; + } + return {}; +} +}; // namespace + +absl::StatusOr LayoutCanonicalizer::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + VLOG(3) << "LayoutCanonicalizer::Run: \n" << module->ToString(); + for (auto* comp : module->MakeNonfusionComputations(execution_threads)) { + // We only canonicalize the entry computation for now. + if (comp->IsEntryComputation()) { + CanonicalizeInstructionLayout(comp->root_instruction(), true); + } + } + return true; +} + +} // namespace xla diff --git a/xla/service/layout_canonicalizer.h b/xla/service/layout_canonicalizer.h new file mode 100644 index 00000000000000..ec5b9e5f8fe992 --- /dev/null +++ b/xla/service/layout_canonicalizer.h @@ -0,0 +1,43 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_LAYOUT_CANONICALIZER_H_ +#define XLA_SERVICE_LAYOUT_CANONICALIZER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { + +// HLO pass that canonicalizes all layouts (except input and output of module) +// to have descending layout by default. This is done by applying the layout +// order to the logical dimension ordering and transform each operation +// attributes according to the new logical shape. +class LayoutCanonicalizer : public HloModulePass { + public: + explicit LayoutCanonicalizer() = default; + ~LayoutCanonicalizer() override = default; + absl::string_view name() const override { return "cononicalize_layout"; } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; +} // namespace xla + +#endif // XLA_SERVICE_LAYOUT_CANONICALIZER_H_ diff --git a/xla/service/layout_canonicalizer_test.cc b/xla/service/layout_canonicalizer_test.cc new file mode 100644 index 00000000000000..c98106bc6e79be --- /dev/null +++ b/xla/service/layout_canonicalizer_test.cc @@ -0,0 +1,134 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/layout_canonicalizer.h" + +#include +#include +#include + +#include +#include "absl/log/log.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +using LayoutCanonicalizerTest = HloTestBase; + +TEST_F(LayoutCanonicalizerTest, CanonicalizeBroadcast) { + const std::string hlo_string = R"( + HloModule broadcast_module + ENTRY %main { + %p0 = f32[2,6]{0,1} parameter(0) + %broadcast = f32[3,2,1,6]{0,1,2,3} broadcast(%p0), dimensions={1,3} + ROOT %output = f32[3,2,1,6]{3,2,1,0} broadcast(%broadcast), dimensions={0,1,2,3} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + LayoutCanonicalizer canonicalizer; + TF_ASSERT_OK_AND_ASSIGN(bool changed, canonicalizer.Run(m.get())); + ASSERT_TRUE(changed); + + // Layout should be descending. + HloInstruction* output = m->entry_computation()->root_instruction(); + HloInstruction* broadcast = output->mutable_operand(0); + EXPECT_EQ(broadcast->shape().layout().minor_to_major(), + std::vector({3, 2, 1, 0})); + + // Logical dimensions should be as follows. + EXPECT_EQ(broadcast->shape().dimensions(), + std::vector({6, 1, 2, 3})); + + // Dimensions should change according to the new descending layout. + EXPECT_EQ(broadcast->dimensions(), std::vector({2, 0})); + EXPECT_EQ(output->dimensions(), std::vector({3, 2, 1, 0})); + VLOG(3) << "module after:\n" << m->ToString(); +} + +TEST_F(LayoutCanonicalizerTest, CanonicalizeBroadcast2) { + const std::string hlo_string = R"( + HloModule broadcast_module + ENTRY %main { + %p0 = f32[2,6]{0,1} parameter(0) + %broadcast = f32[3,2,1,6]{2,3,1,0} broadcast(%p0), dimensions={1,3} + ROOT %output = f32[3,5,2,1,6]{3,4,2,1,0} broadcast(%broadcast), dimensions={0,2,3,4} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + LayoutCanonicalizer canonicalizer; + TF_ASSERT_OK_AND_ASSIGN(bool changed, canonicalizer.Run(m.get())); + ASSERT_TRUE(changed); + + // Layout should be descending. + HloInstruction* output = m->entry_computation()->root_instruction(); + HloInstruction* broadcast = output->mutable_operand(0); + EXPECT_EQ(broadcast->shape().layout().minor_to_major(), + std::vector({3, 2, 1, 0})); + + // Logical dimensions should be as follows. + EXPECT_EQ(broadcast->shape().dimensions(), + std::vector({3, 2, 6, 1})); + + // Dimensions should change according to the new descending layout. + EXPECT_EQ(broadcast->dimensions(), std::vector({1, 2})); + EXPECT_EQ(output->dimensions(), std::vector({0, 2, 4, 3})); + VLOG(3) << "module after:\n" << m->ToString(); +} + +TEST_F(LayoutCanonicalizerTest, CanonicalizeBroadcast3) { + const std::string hlo_string = R"( + HloModule broadcast_module + ENTRY %main { + %p0 = f32[2,6]{0,1} parameter(0) + %broadcast = f32[3,2,1,6]{2,3,0,1} broadcast(%p0), dimensions={1,3} + %broadcast2 = f32[3,5,2,1,6]{3,4,0,2,1} broadcast(f32[3,2,1,6]{2,3,0,1} %broadcast), dimensions={0,2,3,4} + ROOT %output = f32[3,5,2,1,6]{3,4,0,1,2} broadcast(f32[3,5,2,1,6]{3,4,0,2,1} %broadcast2), dimensions={0,1,2,3,4} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + LayoutCanonicalizer canonicalizer; + TF_ASSERT_OK_AND_ASSIGN(bool changed, canonicalizer.Run(m.get())); + ASSERT_TRUE(changed); + + // Layout should be descending. + HloInstruction* root = m->entry_computation()->root_instruction(); + HloInstruction* broadcast2 = root->mutable_operand(0); + HloInstruction* broadcast = broadcast2->mutable_operand(0); + EXPECT_EQ(broadcast->shape().layout().minor_to_major(), + std::vector({3, 2, 1, 0})); + EXPECT_EQ(broadcast2->shape().layout().minor_to_major(), + std::vector({4, 3, 2, 1, 0})); + + // Logical dimensions should be as follows. + EXPECT_EQ(broadcast->shape().dimensions(), + std::vector({2, 3, 6, 1})); + EXPECT_EQ(broadcast2->shape().dimensions(), + std::vector({5, 2, 3, 6, 1})); + + // Dimensions should change according to the new descending layout. + EXPECT_EQ(broadcast->dimensions(), std::vector({0, 2})); + EXPECT_EQ(broadcast2->dimensions(), std::vector({1, 2, 3, 4})); + EXPECT_EQ(root->dimensions(), std::vector({1, 2, 0, 4, 3})); + VLOG(3) << "module after:\n" << m->ToString(); +} + +} // namespace +} // namespace xla