Skip to content

Commit

Permalink
[XLA:GPU] Build a workaround to fix tiling propagation for Triton whe…
Browse files Browse the repository at this point in the history
…n tile sizes

are not set to be a power of two.

Because we currently have cases when tile sizes are not set to be a power of two
(e.g. when capturing the full dimension at the output, or when introducing a
non-power-of-two contracting dimension in the middle of tiling propagation), we
can get into issues with `reshape`s. Previously to this change, we would
verify that constraints coming out of a `reshape` are satisfied by a
tuple of exact tile sizes, but later pad these tile sizes at code generation
time---making the lowering of the `reshape` incorrect.

The proper fix for this issue would be to always propagate tile sizes that are
a power of 2 in fusions that will be code generated using Triton. Alas, this
is easier said than done, and requires introducing a principled way of tiling
along newly introduced contracting dimensions.

We already know how we intend
to solve this, but it'll take a while. For this reason, we instead introduce
here Triton-specific constraints aimed at weeding out improper tile sizes as
they arise.

PiperOrigin-RevId: 673977887
  • Loading branch information
bchetioui authored and Google-ML-Automation committed Sep 12, 2024
1 parent 29048bb commit 4aee555
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 10 deletions.
5 changes: 5 additions & 0 deletions xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -754,10 +754,15 @@ cc_library(
hdrs = ["triton_emitter_constraints.h"],
deps = [
":affine_map_evaluator",
":indexing_analysis",
":symbolic_tile",
":symbolic_tile_analysis",
":symbolic_tiled_hlo_instruction",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
Expand Down
83 changes: 79 additions & 4 deletions xla/service/gpu/model/triton_emitter_constraints.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,25 @@ limitations under the License.

#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
#include <vector>

#include "absl/log/check.h"
#include "absl/memory/memory.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/IR/AffineMap.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/gpu/model/affine_map_evaluator.h"
#include "xla/service/gpu/model/indexing_analysis.h"
#include "xla/service/gpu/model/indexing_map.h"
#include "xla/service/gpu/model/symbolic_tile.h"
#include "xla/service/gpu/model/symbolic_tile_analysis.h"
#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h"
#include "xla/stream_executor/device_description.h"
Expand All @@ -41,8 +49,54 @@ namespace {
// elements, otherwise it will fail to compile.
constexpr int64_t kMaxTensorNumElements = 1048576;

llvm::SmallVector<int64_t> GetPaddedTileSizes(
llvm::SmallVector<int64_t> tile_sizes) {
llvm::SmallVector<int64_t> result;
result.reserve(tile_sizes.size());
for (int64_t value : tile_sizes) {
result.push_back(llvm::PowerOf2Ceil(value));
}
return result;
}

} // namespace

/*static*/ std::vector<TritonEmitterConstraints::CustomConstraints>
TritonEmitterConstraints::DeriveCustomConstraints(
const std::vector<std::unique_ptr<SymbolicTiledHloInstruction>>&
instructions) {
std::vector<CustomConstraints> result;

for (const auto& instruction : instructions) {
const HloInstruction* hlo = instruction->hlo();
if (hlo->opcode() == HloOpcode::kReshape ||
hlo->opcode() == HloOpcode::kBitcast) {
mlir::MLIRContext* ctx =
instruction->symbolic_tile().size_map().getContext();

IndexingMap reshape_indexing_map =
*ComputeOutputToInputIndexing(hlo, /*output_id=*/0, ctx)
.indexing_maps[0]
.begin();

std::optional<SymbolicTile> reshape_symbolic_tile =
SymbolicTile::FromIndexingMap(reshape_indexing_map);
// Since we managed to create a `SymbolicTiledHloInstruction` for this
// instruction, it should never be the case that we fail to derive a
// `SymbolicTile`, so we `CHECK`.
CHECK(reshape_symbolic_tile.has_value());

ConstraintExpression reshape_constraints =
reshape_symbolic_tile->constraints();
result.push_back(
CustomConstraints{instruction->symbolic_tile().size_map(),
std::move(reshape_constraints)});
}
}

return result;
}

/*static*/ EmitterSpecificConstraintsBuilder
TritonEmitterConstraints::GetBuilder(
const se::DeviceDescription& device_description) {
Expand All @@ -54,12 +108,17 @@ TritonEmitterConstraints::GetBuilder(
tiled_hlo_instruction->symbolic_tile().size_map());
}

std::vector<CustomConstraints> custom_constraints =
DeriveCustomConstraints(instructions);

llvm::SmallVector<mlir::AffineMap, 4> tile_size_maps(
unique_tile_size_maps.begin(), unique_tile_size_maps.end());

return std::make_unique<TritonEmitterConstraints>(
std::move(tile_size_maps),
/*root_shape=*/instructions.back()->hlo()->shape(), device_description);
return std::unique_ptr<TritonEmitterConstraints>(
absl::WrapUnique(new TritonEmitterConstraints(
std::move(tile_size_maps), std::move(custom_constraints),
/*root_shape=*/instructions.back()->hlo()->shape(),
device_description)));
};
}

Expand All @@ -84,13 +143,29 @@ absl::StatusOr<bool> TritonEmitterConstraints::ParametersSatisfyConstraints(
num_tiles *= (dim_size + tile_size - 1) / tile_size;
}

// Number of blocks will excede the hardware limit. This limitation comes from
// Number of blocks will exceed the hardware limit. This limitation comes from
// the fact that one tile is mapped to one block. This constraint can be
// potentially hoisted to more generic "gpu-specific constraint".
if (num_tiles >= device_info_.block_dim_limit().x) {
return false;
}

// Ensure that we satisfy the custom constraints we derived when padding tile
// sizes to a power of 2. This is a workaround while nested fusions are not
// landed.
//
// TODO(b/365727080): get rid of this once tiling is using power of twos
// everywhere, including when propagating into the prologue of reductions.
for (const auto& custom_constraint : custom_constraints_) {
llvm::SmallVector<int64_t> transformed_tile_parameters =
EvaluateAffineMap(custom_constraint.tile_parameters_transform,
/*dim_values=*/tile_parameters);
if (!custom_constraint.constraints.IsSatisfiedBy(
GetPaddedTileSizes(transformed_tile_parameters))) {
return false;
}
}

return true;
}

Expand Down
43 changes: 40 additions & 3 deletions xla/service/gpu/model/triton_emitter_constraints.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@ limitations under the License.
==============================================================================*/

#include <cstdint>
#include <memory>
#include <utility>
#include <vector>

#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/AffineMap.h"
#include "xla/service/gpu/model/symbolic_tile.h"
#include "xla/service/gpu/model/symbolic_tile_analysis.h"
#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h"
#include "xla/shape.h"
#include "xla/stream_executor/device_description.h"

Expand All @@ -36,23 +40,56 @@ class TritonEmitterConstraints : public EmitterSpecificConstraints {
static EmitterSpecificConstraintsBuilder GetBuilder(
const se::DeviceDescription& device_description);

absl::StatusOr<bool> ParametersSatisfyConstraints(
absl::Span<const int64_t> tile_parameters) const override;

private:
// Holds a constraint expression over derived parameters (s'0, ..., s'm) where
// (s'0, ..., s'm) = tile_parameters_transform(tile_parameters).
struct CustomConstraints {
mlir::AffineMap tile_parameters_transform;
ConstraintExpression constraints;
};

explicit TritonEmitterConstraints(
llvm::SmallVector<mlir::AffineMap, 4> tile_size_maps,
std::vector<CustomConstraints> custom_constraints,
const Shape& root_shape, const se::DeviceDescription& device_info)
: tile_size_maps_(std::move(tile_size_maps)),
custom_constraints_(std::move(custom_constraints)),
root_shape_(root_shape),
device_info_(device_info) {}

absl::StatusOr<bool> ParametersSatisfyConstraints(
absl::Span<const int64_t> tile_parameters) const override;
// Derives a vector of `CustomConstraints` to be checked within
// `ParametersSatisfyConstraints` from a vector of
// `SymbolicTiledHloInstruction`s representing a symbolically tiled HLO
// computation.
//
// Currently, this is used to work around an issue with reshapes/bitcasts when
// instructions are tiled with non-power-of-2 shapes. The resulting custom
// constraints contain
// * the reshape/bitcast's tile size map; this to allow deriving the
// output tile sizes for the reshape/bitcast instruction;
// * the constraint expression corresponding to the SymbolicTile derived
// from the reshape/bitcast instruction's output-to-input indexing map
// "in a vacuum" (i.e., without composing with any other indexing map).
//
// TODO(b/365727080): move tile derivation to have powers of 2 tiles
// everywhere, and deprecate this.
static std::vector<CustomConstraints> DeriveCustomConstraints(
const std::vector<std::unique_ptr<SymbolicTiledHloInstruction>>&
instructions);

private:
// A collection of unique size maps from all the SymbolicTiledHloInstructions.
//
// Different TiledHloInstructions often have the same size map, so we keep a
// collection of unique maps to improve compilation time.
llvm::SmallVector<mlir::AffineMap, 4> tile_size_maps_;

// Custom emitter-specific constraints to check in
// `ParametersSatisfyConstraints`.
std::vector<CustomConstraints> custom_constraints_;

// Shape of the root instruction.
Shape root_shape_;

Expand Down
57 changes: 54 additions & 3 deletions xla/service/gpu/model/triton_emitter_constraints_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,21 @@ using ::tsl::testing::IsOkAndHolds;

class TritonEmitterConstraintsTest : public HloTestBase {
public:
std::optional<SymbolicTileAnalysis> TryAnalyzeModule(HloModule* module) {
std::optional<SymbolicTileAnalysis> TryAnalyzeModule(
HloModule* module, bool with_triton_emitter_specific_constraints = true) {
EmitterSpecificConstraintsBuilder constraints_builder = nullptr;

if (with_triton_emitter_specific_constraints) {
constraints_builder =
TritonEmitterConstraints::GetBuilder(device_description_);
}

SymbolicTileAnalysisOrError analysis_or_error =
SymbolicTileAnalysis::AnalyzeComputation(
*module->entry_computation()
->root_instruction()
->fused_instructions_computation(),
&mlir_context_,
TritonEmitterConstraints::GetBuilder(device_description_));
&mlir_context_, constraints_builder);

if (std::holds_alternative<SymbolicTileAnalysis>(analysis_or_error)) {
return std::get<SymbolicTileAnalysis>(std::move(analysis_or_error));
Expand Down Expand Up @@ -145,6 +152,50 @@ ENTRY entry_computation {
IsOkAndHolds(false));
}

TEST_F(TritonEmitterConstraintsTest, CustomReshapeConstraintsAreEnforced) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
triton_computation {
p = s8[36] parameter(0)
ROOT bitcast = s8[6,6] bitcast(p)
}
ENTRY entry_computation {
p = s8[36] parameter(0)
ROOT fusion = s8[6,6] fusion(p), kind=kCustom, calls=triton_computation
})"));

std::optional<SymbolicTileAnalysis> analysis_without_triton_constraints =
TryAnalyzeModule(module.get(),
/*with_triton_emitter_specific_constraints=*/false);

ASSERT_TRUE(analysis_without_triton_constraints.has_value());

// (2, 6) is a theoretically valid tiling for this reshape, so
// SymbolicTileAnalysis should allow it.
EXPECT_THAT(
analysis_without_triton_constraints->ParametersSatisfyConstraints({2, 6}),
IsOkAndHolds(true));

std::optional<SymbolicTileAnalysis> analysis_with_triton_constraints =
TryAnalyzeModule(module.get(),
/*with_triton_emitter_specific_constraints=*/true);

ASSERT_TRUE(analysis_with_triton_constraints.has_value());

// (2, 6) is a theoretically valid tiling for this reshape, but it won't
// work because of Triton's power of two restriction. Thus, we should reject
// it here.
EXPECT_THAT(
analysis_with_triton_constraints->ParametersSatisfyConstraints({2, 6}),
IsOkAndHolds(false));

// However, (1, 6) is valid and should still work.
EXPECT_THAT(
analysis_with_triton_constraints->ParametersSatisfyConstraints({1, 6}),
IsOkAndHolds(true));
}

} // namespace
} // namespace gpu
} // namespace xla

0 comments on commit 4aee555

Please sign in to comment.