From 0e52ad0db652eadfc519a0d108c62c024c3e5e97 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 16 Sep 2024 09:19:36 -0700 Subject: [PATCH] [XLA:GPU] Bail out during `SymbolicTileAnalysis` if standalone tile derivation is impossible for a reshape. This is a complement to the previously submitted workaround around power-of-two tiles in https://github.com/openxla/xla/commit/4aee555551c2be2e3e7891eab7b4343bf14ab279. PiperOrigin-RevId: 675175664 --- xla/service/gpu/model/BUILD | 2 + .../gpu/model/symbolic_tile_analysis.cc | 69 ++++++++++++++--- .../gpu/model/symbolic_tile_analysis.h | 5 +- .../gpu/model/symbolic_tile_analysis_test.cc | 74 ++++++++++++++++++- .../gpu/model/triton_emitter_constraints.cc | 22 +++++- .../gpu/model/triton_emitter_constraints.h | 9 ++- .../model/triton_emitter_constraints_test.cc | 31 ++++++++ 7 files changed, 195 insertions(+), 17 deletions(-) diff --git a/xla/service/gpu/model/BUILD b/xla/service/gpu/model/BUILD index 3867852a74c4f..ef948e8a43115 100644 --- a/xla/service/gpu/model/BUILD +++ b/xla/service/gpu/model/BUILD @@ -764,6 +764,7 @@ cc_library( ":symbolic_tiled_hlo_instruction", "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/service/gpu:hlo_traversal", "//xla/stream_executor:device_description", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", @@ -783,6 +784,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/service:instruction_fusion", "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_traversal", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", diff --git a/xla/service/gpu/model/symbolic_tile_analysis.cc b/xla/service/gpu/model/symbolic_tile_analysis.cc index f08c52d265b17..1ab971ff2a4a9 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -217,6 +217,52 @@ class OrderedUniquePtrValueHashSet { std::vector> data_; }; +// Detects pathological cases on which symbolic tile derivation should bail out. +// Note that this function bypasses temporary limitations of the infrastructure, +// and not actual fundamental limitations. +FusionDecision ShouldProceedWithSymbolicTileDerivation( + const SymbolicTiledHloInstruction& tiled_hlo_instruction) { + const HloInstruction* hlo = tiled_hlo_instruction.hlo(); + const IndexingMap& indexing_map = tiled_hlo_instruction.indexing_map(); + + // Bail out on instructions that are known to cause problems down the + // line. This is not an inherent limitation of the approach, but simply + // issues to be resolved in the current implementation. + if (hlo->opcode() == HloOpcode::kDot || + hlo->opcode() == HloOpcode::kConcatenate) { + return FusionDecision{} << "Bailing out on " << hlo->ToString(); + } + + // Due to the issue highlighted in b/365727080, and the related workaround + // deriving a standalone symbolic tile when constructing Triton-specific + // constraints, reshapes and bitcasts may cause problems down the line. + // The added check here allows us to bail out early when we reach such a + // a problematic. + // + // TODO(b/365727080): get rid of this filter once the issue is properly + // fixed. + if (hlo->opcode() == HloOpcode::kReshape || + hlo->opcode() == HloOpcode::kBitcast) { + mlir::MLIRContext* ctx = indexing_map.GetMLIRContext(); + + IndexingMap reshape_indexing_map = + *ComputeOutputToInputIndexing(hlo, /*output_id=*/0, ctx) + .indexing_maps[0] + .begin(); + + std::optional reshape_symbolic_tile = + SymbolicTile::FromIndexingMap(reshape_indexing_map); + + if (!reshape_symbolic_tile.has_value()) { + return FusionDecision{} << "Bailing out on reshape " << hlo->ToString() + << " with indexing map " + << reshape_indexing_map.ToString(); + } + } + + return {}; +} + // Sets a SymbolicTile for each tiled hlo instruction and computes their // combined constraints. Returns a FusionDecision if a SymbolicTile cannot be // computed for some instruction or if the constraints are unsatisfiable. @@ -224,19 +270,24 @@ class OrderedUniquePtrValueHashSet { std::variant SetSymbolicTilesAndComputeConstraints( std::vector>& - tiled_hlo_instructions) { + tiled_hlo_instructions, + const HloFusionAdaptor& fusion_adaptor) { ConstraintExpression constraints; for (const std::unique_ptr& tiled_hlo_instruction : tiled_hlo_instructions) { const HloInstruction* hlo = tiled_hlo_instruction->hlo(); const IndexingMap& indexing_map = tiled_hlo_instruction->indexing_map(); - // Bail out on instructions that are known to cause problems down the - // line. This is not an inherent limitation of the approach, but simply - // issues to be resolved in the current implementation. - if (hlo->opcode() == HloOpcode::kDot || - hlo->opcode() == HloOpcode::kConcatenate) { - return FusionDecision{} << "Bailing out on " << hlo->ToString(); + // We first verify some preconditions on the instructions we intend to + // codegen. We first check whether an instruction is part of the fusion + // adaptor, as `tiled_hlo_instructions` may contain instructions that won't + // be codegen'd (the operands to the fusion computation). + if (fusion_adaptor.ContainsInstruction(hlo)) { + FusionDecision should_proceed = + ShouldProceedWithSymbolicTileDerivation(*tiled_hlo_instruction); + if (!should_proceed) { + return should_proceed; + } } auto symbolic_tile = SymbolicTile::FromIndexingMap(indexing_map); @@ -378,7 +429,7 @@ void SortTiledHloInstructionsInPostOrder( // Set symbolic tiles for each tiled hlo instruction and compute combined // constraints. std::variant constraints_or = - SetSymbolicTilesAndComputeConstraints(tiled_hlo_instructions); + SetSymbolicTilesAndComputeConstraints(tiled_hlo_instructions, fusion); if (std::holds_alternative(constraints_or)) { return std::get(constraints_or); } @@ -387,7 +438,7 @@ void SortTiledHloInstructionsInPostOrder( std::unique_ptr emitter_specific_constraints; if (emitter_specific_constraints_builder != nullptr) { emitter_specific_constraints = - emitter_specific_constraints_builder(tiled_hlo_instructions); + emitter_specific_constraints_builder(tiled_hlo_instructions, fusion); } return SymbolicTileAnalysis( diff --git a/xla/service/gpu/model/symbolic_tile_analysis.h b/xla/service/gpu/model/symbolic_tile_analysis.h index 692e88db11b99..58a08afde9ba1 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis.h +++ b/xla/service/gpu/model/symbolic_tile_analysis.h @@ -55,9 +55,12 @@ class EmitterSpecificConstraints { absl::Span tile_parameters) const = 0; }; +// TODO(b/367306544): get rid of the HloFusionAdaptor parameter once the +// abstraction exists. using EmitterSpecificConstraintsBuilder = std::function( - const std::vector>&)>; + const std::vector>&, + const HloFusionAdaptor&)>; // Constructs and holds symbolic tiles for all the instructions within a // computation. We may hold several different symbolic tiles for the same diff --git a/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/xla/service/gpu/model/symbolic_tile_analysis_test.cc index f8200ef2568eb..99f136d3aecae 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -90,7 +90,8 @@ class FakeEmitterSpecificConstraints : public EmitterSpecificConstraints { static EmitterSpecificConstraintsBuilder GetBuilder() { return [](const std::vector>& - instructions) { + instructions, + const HloFusionAdaptor&) { const SymbolicTiledHloInstruction* root = instructions[0].get(); int64_t dim0_size = root->hlo()->shape().dimensions(0); return std::make_unique( @@ -948,6 +949,77 @@ ENTRY main { )")); } +TEST_F(SymbolicTileAnalysisTest, + BailsOutOnReshapeWhenStandaloneSymbolicTileDerivationFails) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule m + +add_computation { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +fused_computation { + p0 = f32[2,128,128] parameter(0) + // We use two successive bitcasts here as a hack to produce the right + // failure---otherwise, the derivation failure may occur on the parameter + // instruction. + bitcast_fix = f32[16384,1,2] bitcast(p0) + bitcast = f32[2,128,128] bitcast(bitcast_fix) + c0 = f32[] constant(0) + reduce = f32[2,128] reduce(bitcast, c0), dimensions={2}, + to_apply=add_computation +} + +ENTRY main { + p0 = f32[2,128,128] parameter(0) + ROOT fusion = f32[2,128] fusion(p0), kind=kLoop, calls=fused_computation +})")); + + SymbolicTileAnalysisOrError analysis_or_error = + SymbolicTileAnalysis::AnalyzeComputation( + *module->entry_computation() + ->root_instruction() + ->fused_instructions_computation(), + &mlir_context_, /*emitter_specific_constraints_builder=*/nullptr); + + EXPECT_TRUE(std::holds_alternative(analysis_or_error)); + EXPECT_THAT(std::get(analysis_or_error).Explain(), + ::testing::HasSubstr("Bailing out on reshape")); +} + +TEST_F(SymbolicTileAnalysisTest, + DoesNotBailOutOnFilteredOutHloIfThatHloIsOnlyAnOperand) { + // This is a regression test for a bug where we would refuse to tile a + // computation if its operand could not be tiled according to + // `SymbolicTileAnalysis`. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule m + +fused_computation { + p0 = f32[10,10] parameter(0) + ROOT reshape = f32[100] reshape(p0) +} + +ENTRY main { + p0 = f32[10,2] parameter(0) + p1 = f32[2,10] parameter(1) + // Note: this will need upgrading once `SymbolicTileAnalysis` stops filtering + // out dots. + untileable_dot = f32[10,10] dot(p0, p1), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT fusion = f32[100] fusion(untileable_dot), + kind=kLoop, calls=fused_computation +})")); + + std::optional analysis = TryAnalyzeModule(module.get()); + EXPECT_TRUE(analysis.has_value()); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/model/triton_emitter_constraints.cc b/xla/service/gpu/model/triton_emitter_constraints.cc index d7a66899840b2..6beffb69badad 100644 --- a/xla/service/gpu/model/triton_emitter_constraints.cc +++ b/xla/service/gpu/model/triton_emitter_constraints.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/AffineMap.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/model/affine_map_evaluator.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" @@ -64,13 +65,22 @@ llvm::SmallVector GetPaddedTileSizes( /*static*/ std::vector TritonEmitterConstraints::DeriveCustomConstraints( const std::vector>& - instructions) { + instructions, + const HloFusionAdaptor& fusion_adaptor) { std::vector result; for (const auto& instruction : instructions) { const HloInstruction* hlo = instruction->hlo(); + // Construct custom constraints for parameters of bitcasts and reshapes + // within `instructions`. If the operation's parameter is not part of + // `instructions`, then the bitcast/reshape node is an operand of the + // fusion computation, and there is no need to add constraints. if (hlo->opcode() == HloOpcode::kReshape || hlo->opcode() == HloOpcode::kBitcast) { + if (!fusion_adaptor.ContainsInstruction(hlo)) { + continue; + } + mlir::MLIRContext* ctx = instruction->symbolic_tile().size_map().getContext(); @@ -81,9 +91,12 @@ TritonEmitterConstraints::DeriveCustomConstraints( std::optional reshape_symbolic_tile = SymbolicTile::FromIndexingMap(reshape_indexing_map); + // Since we managed to create a `SymbolicTiledHloInstruction` for this // instruction, it should never be the case that we fail to derive a - // `SymbolicTile`, so we `CHECK`. + // `SymbolicTile`, so we `CHECK`. This is enforced by checks in + // `SymbolicTileAnalysis`'s internal function + // `ShouldProceedWithSymbolicTileDerivation`. CHECK(reshape_symbolic_tile.has_value()); ConstraintExpression reshape_constraints = @@ -101,7 +114,8 @@ TritonEmitterConstraints::DeriveCustomConstraints( TritonEmitterConstraints::GetBuilder( const se::DeviceDescription& device_description) { return [=](const std::vector>& - instructions) { + instructions, + const HloFusionAdaptor& fusion_adaptor) { llvm::DenseSet unique_tile_size_maps; for (const auto& tiled_hlo_instruction : instructions) { unique_tile_size_maps.insert( @@ -109,7 +123,7 @@ TritonEmitterConstraints::GetBuilder( } std::vector custom_constraints = - DeriveCustomConstraints(instructions); + DeriveCustomConstraints(instructions, fusion_adaptor); llvm::SmallVector tile_size_maps( unique_tile_size_maps.begin(), unique_tile_size_maps.end()); diff --git a/xla/service/gpu/model/triton_emitter_constraints.h b/xla/service/gpu/model/triton_emitter_constraints.h index fa7ac50404614..3c36ba14d24f8 100644 --- a/xla/service/gpu/model/triton_emitter_constraints.h +++ b/xla/service/gpu/model/triton_emitter_constraints.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/AffineMap.h" +#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/model/symbolic_tile.h" #include "xla/service/gpu/model/symbolic_tile_analysis.h" #include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" @@ -43,6 +44,8 @@ class TritonEmitterConstraints : public EmitterSpecificConstraints { absl::StatusOr ParametersSatisfyConstraints( absl::Span tile_parameters) const override; + bool HasCustomConstraints() const { return !custom_constraints_.empty(); } + private: // Holds a constraint expression over derived parameters (s'0, ..., s'm) where // (s'0, ..., s'm) = tile_parameters_transform(tile_parameters). @@ -63,7 +66,8 @@ class TritonEmitterConstraints : public EmitterSpecificConstraints { // Derives a vector of `CustomConstraints` to be checked within // `ParametersSatisfyConstraints` from a vector of // `SymbolicTiledHloInstruction`s representing a symbolically tiled HLO - // computation. + // computation. The fusion adaptor is used to figure out which instructions + // within the computation are operands of the fusion. // // Currently, this is used to work around an issue with reshapes/bitcasts when // instructions are tiled with non-power-of-2 shapes. The resulting custom @@ -78,7 +82,8 @@ class TritonEmitterConstraints : public EmitterSpecificConstraints { // everywhere, and deprecate this. static std::vector DeriveCustomConstraints( const std::vector>& - instructions); + instructions, + const HloFusionAdaptor& fusion_adaptor); // A collection of unique size maps from all the SymbolicTiledHloInstructions. // diff --git a/xla/service/gpu/model/triton_emitter_constraints_test.cc b/xla/service/gpu/model/triton_emitter_constraints_test.cc index cb2c51884ce57..7a171bf1c76a6 100644 --- a/xla/service/gpu/model/triton_emitter_constraints_test.cc +++ b/xla/service/gpu/model/triton_emitter_constraints_test.cc @@ -24,8 +24,10 @@ limitations under the License. #include #include "absl/log/log.h" #include "mlir/IR/MLIRContext.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/model/symbolic_tile_analysis.h" #include "xla/service/instruction_fusion.h" #include "xla/stream_executor/device_description.h" @@ -196,6 +198,35 @@ ENTRY entry_computation { IsOkAndHolds(true)); } +TEST_F(TritonEmitterConstraintsTest, + ReshapeConstraintsAreNotDerivedForFusionOperands) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +triton_computation { + p = s8[6,6] parameter(0) + ROOT add = s8[6,6] add(p, p) +} + +ENTRY entry_computation { + p = s8[36] parameter(0) + bitcast = s8[6,6] bitcast(p) + ROOT fusion = s8[6,6] fusion(bitcast), + kind=kCustom, calls=triton_computation +})")); + std::optional analysis = TryAnalyzeModule(module.get()); + ASSERT_TRUE(analysis.has_value()); + + const HloComputation* triton_computation = + FindComputation(module.get(), "triton_computation"); + + std::unique_ptr constraints = + TritonEmitterConstraints::GetBuilder(device_description_)( + analysis->GetSymbolicTiledHloComputation(), + *HloFusionAdaptor::ForComputation(triton_computation)); + EXPECT_FALSE(reinterpret_cast(constraints.get()) + ->HasCustomConstraints()); +} + } // namespace } // namespace gpu } // namespace xla