Skip to content

Commit

Permalink
[XLA:GPU] Bail out during SymbolicTileAnalysis if standalone tile d…
Browse files Browse the repository at this point in the history
…erivation is impossible for a reshape.

This is a complement to the previously submitted workaround around power-of-two
tiles in 4aee555.

PiperOrigin-RevId: 675175664
  • Loading branch information
bchetioui authored and Google-ML-Automation committed Sep 16, 2024
1 parent dc5d8b0 commit 0e52ad0
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 17 deletions.
2 changes: 2 additions & 0 deletions xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,7 @@ cc_library(
":symbolic_tiled_hlo_instruction",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/service/gpu:hlo_traversal",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
Expand All @@ -783,6 +784,7 @@ xla_cc_test(
"//xla/hlo/ir:hlo",
"//xla/service:instruction_fusion",
"//xla/service/gpu:gpu_device_info_for_tests",
"//xla/service/gpu:hlo_traversal",
"//xla/stream_executor:device_description",
"//xla/tests:hlo_test_base",
"//xla/tests:verified_hlo_module",
Expand Down
69 changes: 60 additions & 9 deletions xla/service/gpu/model/symbolic_tile_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,26 +217,77 @@ class OrderedUniquePtrValueHashSet {
std::vector<std::unique_ptr<T>> data_;
};

// Detects pathological cases on which symbolic tile derivation should bail out.
// Note that this function bypasses temporary limitations of the infrastructure,
// and not actual fundamental limitations.
FusionDecision ShouldProceedWithSymbolicTileDerivation(
const SymbolicTiledHloInstruction& tiled_hlo_instruction) {
const HloInstruction* hlo = tiled_hlo_instruction.hlo();
const IndexingMap& indexing_map = tiled_hlo_instruction.indexing_map();

// Bail out on instructions that are known to cause problems down the
// line. This is not an inherent limitation of the approach, but simply
// issues to be resolved in the current implementation.
if (hlo->opcode() == HloOpcode::kDot ||
hlo->opcode() == HloOpcode::kConcatenate) {
return FusionDecision{} << "Bailing out on " << hlo->ToString();
}

// Due to the issue highlighted in b/365727080, and the related workaround
// deriving a standalone symbolic tile when constructing Triton-specific
// constraints, reshapes and bitcasts may cause problems down the line.
// The added check here allows us to bail out early when we reach such a
// a problematic.
//
// TODO(b/365727080): get rid of this filter once the issue is properly
// fixed.
if (hlo->opcode() == HloOpcode::kReshape ||
hlo->opcode() == HloOpcode::kBitcast) {
mlir::MLIRContext* ctx = indexing_map.GetMLIRContext();

IndexingMap reshape_indexing_map =
*ComputeOutputToInputIndexing(hlo, /*output_id=*/0, ctx)
.indexing_maps[0]
.begin();

std::optional<SymbolicTile> reshape_symbolic_tile =
SymbolicTile::FromIndexingMap(reshape_indexing_map);

if (!reshape_symbolic_tile.has_value()) {
return FusionDecision{} << "Bailing out on reshape " << hlo->ToString()
<< " with indexing map "
<< reshape_indexing_map.ToString();
}
}

return {};
}

// Sets a SymbolicTile for each tiled hlo instruction and computes their
// combined constraints. Returns a FusionDecision if a SymbolicTile cannot be
// computed for some instruction or if the constraints are unsatisfiable.
// Returns the combined constraints otherwise.
std::variant<ConstraintExpression, FusionDecision>
SetSymbolicTilesAndComputeConstraints(
std::vector<std::unique_ptr<SymbolicTiledHloInstruction>>&
tiled_hlo_instructions) {
tiled_hlo_instructions,
const HloFusionAdaptor& fusion_adaptor) {
ConstraintExpression constraints;
for (const std::unique_ptr<SymbolicTiledHloInstruction>&
tiled_hlo_instruction : tiled_hlo_instructions) {
const HloInstruction* hlo = tiled_hlo_instruction->hlo();
const IndexingMap& indexing_map = tiled_hlo_instruction->indexing_map();

// Bail out on instructions that are known to cause problems down the
// line. This is not an inherent limitation of the approach, but simply
// issues to be resolved in the current implementation.
if (hlo->opcode() == HloOpcode::kDot ||
hlo->opcode() == HloOpcode::kConcatenate) {
return FusionDecision{} << "Bailing out on " << hlo->ToString();
// We first verify some preconditions on the instructions we intend to
// codegen. We first check whether an instruction is part of the fusion
// adaptor, as `tiled_hlo_instructions` may contain instructions that won't
// be codegen'd (the operands to the fusion computation).
if (fusion_adaptor.ContainsInstruction(hlo)) {
FusionDecision should_proceed =
ShouldProceedWithSymbolicTileDerivation(*tiled_hlo_instruction);
if (!should_proceed) {
return should_proceed;
}
}

auto symbolic_tile = SymbolicTile::FromIndexingMap(indexing_map);
Expand Down Expand Up @@ -378,7 +429,7 @@ void SortTiledHloInstructionsInPostOrder(
// Set symbolic tiles for each tiled hlo instruction and compute combined
// constraints.
std::variant<ConstraintExpression, FusionDecision> constraints_or =
SetSymbolicTilesAndComputeConstraints(tiled_hlo_instructions);
SetSymbolicTilesAndComputeConstraints(tiled_hlo_instructions, fusion);
if (std::holds_alternative<FusionDecision>(constraints_or)) {
return std::get<FusionDecision>(constraints_or);
}
Expand All @@ -387,7 +438,7 @@ void SortTiledHloInstructionsInPostOrder(
std::unique_ptr<EmitterSpecificConstraints> emitter_specific_constraints;
if (emitter_specific_constraints_builder != nullptr) {
emitter_specific_constraints =
emitter_specific_constraints_builder(tiled_hlo_instructions);
emitter_specific_constraints_builder(tiled_hlo_instructions, fusion);
}

return SymbolicTileAnalysis(
Expand Down
5 changes: 4 additions & 1 deletion xla/service/gpu/model/symbolic_tile_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,12 @@ class EmitterSpecificConstraints {
absl::Span<const int64_t> tile_parameters) const = 0;
};

// TODO(b/367306544): get rid of the HloFusionAdaptor parameter once the
// abstraction exists.
using EmitterSpecificConstraintsBuilder =
std::function<std::unique_ptr<EmitterSpecificConstraints>(
const std::vector<std::unique_ptr<SymbolicTiledHloInstruction>>&)>;
const std::vector<std::unique_ptr<SymbolicTiledHloInstruction>>&,
const HloFusionAdaptor&)>;

// Constructs and holds symbolic tiles for all the instructions within a
// computation. We may hold several different symbolic tiles for the same
Expand Down
74 changes: 73 additions & 1 deletion xla/service/gpu/model/symbolic_tile_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ class FakeEmitterSpecificConstraints : public EmitterSpecificConstraints {

static EmitterSpecificConstraintsBuilder GetBuilder() {
return [](const std::vector<std::unique_ptr<SymbolicTiledHloInstruction>>&
instructions) {
instructions,
const HloFusionAdaptor&) {
const SymbolicTiledHloInstruction* root = instructions[0].get();
int64_t dim0_size = root->hlo()->shape().dimensions(0);
return std::make_unique<FakeEmitterSpecificConstraints>(
Expand Down Expand Up @@ -948,6 +949,77 @@ ENTRY main {
)"));
}

TEST_F(SymbolicTileAnalysisTest,
BailsOutOnReshapeWhenStandaloneSymbolicTileDerivationFails) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
HloModule m
add_computation {
p0 = f32[] parameter(0)
p1 = f32[] parameter(1)
ROOT add = f32[] add(p0, p1)
}
fused_computation {
p0 = f32[2,128,128] parameter(0)
// We use two successive bitcasts here as a hack to produce the right
// failure---otherwise, the derivation failure may occur on the parameter
// instruction.
bitcast_fix = f32[16384,1,2] bitcast(p0)
bitcast = f32[2,128,128] bitcast(bitcast_fix)
c0 = f32[] constant(0)
reduce = f32[2,128] reduce(bitcast, c0), dimensions={2},
to_apply=add_computation
}
ENTRY main {
p0 = f32[2,128,128] parameter(0)
ROOT fusion = f32[2,128] fusion(p0), kind=kLoop, calls=fused_computation
})"));

SymbolicTileAnalysisOrError analysis_or_error =
SymbolicTileAnalysis::AnalyzeComputation(
*module->entry_computation()
->root_instruction()
->fused_instructions_computation(),
&mlir_context_, /*emitter_specific_constraints_builder=*/nullptr);

EXPECT_TRUE(std::holds_alternative<FusionDecision>(analysis_or_error));
EXPECT_THAT(std::get<FusionDecision>(analysis_or_error).Explain(),
::testing::HasSubstr("Bailing out on reshape"));
}

TEST_F(SymbolicTileAnalysisTest,
DoesNotBailOutOnFilteredOutHloIfThatHloIsOnlyAnOperand) {
// This is a regression test for a bug where we would refuse to tile a
// computation if its operand could not be tiled according to
// `SymbolicTileAnalysis`.
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
HloModule m
fused_computation {
p0 = f32[10,10] parameter(0)
ROOT reshape = f32[100] reshape(p0)
}
ENTRY main {
p0 = f32[10,2] parameter(0)
p1 = f32[2,10] parameter(1)
// Note: this will need upgrading once `SymbolicTileAnalysis` stops filtering
// out dots.
untileable_dot = f32[10,10] dot(p0, p1),
lhs_batch_dims={}, rhs_batch_dims={},
lhs_contracting_dims={1}, rhs_contracting_dims={0}
ROOT fusion = f32[100] fusion(untileable_dot),
kind=kLoop, calls=fused_computation
})"));

std::optional<SymbolicTileAnalysis> analysis = TryAnalyzeModule(module.get());
EXPECT_TRUE(analysis.has_value());
}

} // namespace
} // namespace gpu
} // namespace xla
22 changes: 18 additions & 4 deletions xla/service/gpu/model/triton_emitter_constraints.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "mlir/IR/AffineMap.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/model/affine_map_evaluator.h"
#include "xla/service/gpu/model/indexing_analysis.h"
#include "xla/service/gpu/model/indexing_map.h"
Expand Down Expand Up @@ -64,13 +65,22 @@ llvm::SmallVector<int64_t> GetPaddedTileSizes(
/*static*/ std::vector<TritonEmitterConstraints::CustomConstraints>
TritonEmitterConstraints::DeriveCustomConstraints(
const std::vector<std::unique_ptr<SymbolicTiledHloInstruction>>&
instructions) {
instructions,
const HloFusionAdaptor& fusion_adaptor) {
std::vector<CustomConstraints> result;

for (const auto& instruction : instructions) {
const HloInstruction* hlo = instruction->hlo();
// Construct custom constraints for parameters of bitcasts and reshapes
// within `instructions`. If the operation's parameter is not part of
// `instructions`, then the bitcast/reshape node is an operand of the
// fusion computation, and there is no need to add constraints.
if (hlo->opcode() == HloOpcode::kReshape ||
hlo->opcode() == HloOpcode::kBitcast) {
if (!fusion_adaptor.ContainsInstruction(hlo)) {
continue;
}

mlir::MLIRContext* ctx =
instruction->symbolic_tile().size_map().getContext();

Expand All @@ -81,9 +91,12 @@ TritonEmitterConstraints::DeriveCustomConstraints(

std::optional<SymbolicTile> reshape_symbolic_tile =
SymbolicTile::FromIndexingMap(reshape_indexing_map);

// Since we managed to create a `SymbolicTiledHloInstruction` for this
// instruction, it should never be the case that we fail to derive a
// `SymbolicTile`, so we `CHECK`.
// `SymbolicTile`, so we `CHECK`. This is enforced by checks in
// `SymbolicTileAnalysis`'s internal function
// `ShouldProceedWithSymbolicTileDerivation`.
CHECK(reshape_symbolic_tile.has_value());

ConstraintExpression reshape_constraints =
Expand All @@ -101,15 +114,16 @@ TritonEmitterConstraints::DeriveCustomConstraints(
TritonEmitterConstraints::GetBuilder(
const se::DeviceDescription& device_description) {
return [=](const std::vector<std::unique_ptr<SymbolicTiledHloInstruction>>&
instructions) {
instructions,
const HloFusionAdaptor& fusion_adaptor) {
llvm::DenseSet<mlir::AffineMap> unique_tile_size_maps;
for (const auto& tiled_hlo_instruction : instructions) {
unique_tile_size_maps.insert(
tiled_hlo_instruction->symbolic_tile().size_map());
}

std::vector<CustomConstraints> custom_constraints =
DeriveCustomConstraints(instructions);
DeriveCustomConstraints(instructions, fusion_adaptor);

llvm::SmallVector<mlir::AffineMap, 4> tile_size_maps(
unique_tile_size_maps.begin(), unique_tile_size_maps.end());
Expand Down
9 changes: 7 additions & 2 deletions xla/service/gpu/model/triton_emitter_constraints.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/AffineMap.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/model/symbolic_tile.h"
#include "xla/service/gpu/model/symbolic_tile_analysis.h"
#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h"
Expand All @@ -43,6 +44,8 @@ class TritonEmitterConstraints : public EmitterSpecificConstraints {
absl::StatusOr<bool> ParametersSatisfyConstraints(
absl::Span<const int64_t> tile_parameters) const override;

bool HasCustomConstraints() const { return !custom_constraints_.empty(); }

private:
// Holds a constraint expression over derived parameters (s'0, ..., s'm) where
// (s'0, ..., s'm) = tile_parameters_transform(tile_parameters).
Expand All @@ -63,7 +66,8 @@ class TritonEmitterConstraints : public EmitterSpecificConstraints {
// Derives a vector of `CustomConstraints` to be checked within
// `ParametersSatisfyConstraints` from a vector of
// `SymbolicTiledHloInstruction`s representing a symbolically tiled HLO
// computation.
// computation. The fusion adaptor is used to figure out which instructions
// within the computation are operands of the fusion.
//
// Currently, this is used to work around an issue with reshapes/bitcasts when
// instructions are tiled with non-power-of-2 shapes. The resulting custom
Expand All @@ -78,7 +82,8 @@ class TritonEmitterConstraints : public EmitterSpecificConstraints {
// everywhere, and deprecate this.
static std::vector<CustomConstraints> DeriveCustomConstraints(
const std::vector<std::unique_ptr<SymbolicTiledHloInstruction>>&
instructions);
instructions,
const HloFusionAdaptor& fusion_adaptor);

// A collection of unique size maps from all the SymbolicTiledHloInstructions.
//
Expand Down
31 changes: 31 additions & 0 deletions xla/service/gpu/model/triton_emitter_constraints_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ limitations under the License.
#include <gtest/gtest.h>
#include "absl/log/log.h"
#include "mlir/IR/MLIRContext.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/gpu/gpu_device_info_for_tests.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/model/symbolic_tile_analysis.h"
#include "xla/service/instruction_fusion.h"
#include "xla/stream_executor/device_description.h"
Expand Down Expand Up @@ -196,6 +198,35 @@ ENTRY entry_computation {
IsOkAndHolds(true));
}

TEST_F(TritonEmitterConstraintsTest,
ReshapeConstraintsAreNotDerivedForFusionOperands) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
triton_computation {
p = s8[6,6] parameter(0)
ROOT add = s8[6,6] add(p, p)
}
ENTRY entry_computation {
p = s8[36] parameter(0)
bitcast = s8[6,6] bitcast(p)
ROOT fusion = s8[6,6] fusion(bitcast),
kind=kCustom, calls=triton_computation
})"));
std::optional<SymbolicTileAnalysis> analysis = TryAnalyzeModule(module.get());
ASSERT_TRUE(analysis.has_value());

const HloComputation* triton_computation =
FindComputation(module.get(), "triton_computation");

std::unique_ptr<EmitterSpecificConstraints> constraints =
TritonEmitterConstraints::GetBuilder(device_description_)(
analysis->GetSymbolicTiledHloComputation(),
*HloFusionAdaptor::ForComputation(triton_computation));
EXPECT_FALSE(reinterpret_cast<TritonEmitterConstraints*>(constraints.get())
->HasCustomConstraints());
}

} // namespace
} // namespace gpu
} // namespace xla

0 comments on commit 0e52ad0

Please sign in to comment.