Skip to content

Commit

Permalink
[XLA:GPU] Use padded tile size to estimate FLOPs and choose the numbe…
Browse files Browse the repository at this point in the history
…r of warps.

Triton has a requirement that all tiles should be padded to the next power of 2 for each dimensions. The emitted kernel will perform computation on the padded value, so it affect the amount of computation and the number of threads in a block. However, padded values are set directly in registed, so it doesn't affect memory access time.

PiperOrigin-RevId: 674274206
  • Loading branch information
olegshyshkov authored and Google-ML-Automation committed Sep 13, 2024
1 parent 8e91a0a commit bff7fd7
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 21 deletions.
4 changes: 4 additions & 0 deletions xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:statusor",
Expand All @@ -388,12 +389,15 @@ xla_cc_test(
":gpu_hlo_cost_analysis",
":gpu_indexing_performance_model",
":gpu_performance_model_base",
":symbolic_tile_analysis",
":tiled_hlo_instruction_or_computation",
"//xla:shape_util",
"//xla:test_helpers",
"//xla/hlo/ir:hlo",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:gpu_device_info_for_tests",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:launch_dimensions",
"//xla/stream_executor:device_description",
"//xla/tests:hlo_test_base",
Expand Down
72 changes: 51 additions & 21 deletions xla/service/gpu/model/gpu_indexing_performance_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "llvm/Support/MathExtras.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/gpu/backend_configs.pb.h"
Expand Down Expand Up @@ -57,6 +58,18 @@ namespace xla {
namespace gpu {
namespace {

// Returns the number of elements in the tile after each dimension is padded to
// the next power of 2.
// TODO(b/353484968): Delete this function once we have constraints to only
// propagate tile sizes that are a power of 2.
int64_t GetPaddedTileSize(absl::Span<int64_t const> tile_sizes) {
int64_t result = 1;
for (int64_t tile_size : tile_sizes) {
result *= llvm::PowerOf2Ceil(tile_size);
}
return result;
}

// Checks if the tile is too large to fit in registers and would result in
// spilling.
//
Expand Down Expand Up @@ -85,6 +98,18 @@ bool DoesTileFitsInRegisters(int64_t tile_size,
return tile_size <= device_info.registers_per_block_limit();
}

// Returns the number of warps to use based on the tile size. The numbers were
// originally selected from Triton SoftMax reduction row length.
// TODO(b/332714755): Make it smarter.
int64_t GetNumWarps(int64_t tile_size) {
if (tile_size <= 512) return 1;
if (tile_size <= 1024) return 2;
if (tile_size <= 16384) return 4;
if (tile_size <= 32768) return 8;
if (tile_size <= 65536) return 16;
return 32;
}

} // namespace

int64_t GpuPerformanceModelWithIndexingAnalysis::FlopsPerElement(
Expand Down Expand Up @@ -329,21 +354,17 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledHloComputation(
int64_t num_blocks = launch_dimensions.num_blocks();

for (const auto& tiled_hlo : tiled_hlo_computation.instructions()) {
// Number of elements in the tile.
int64_t tile_size = Product(tiled_hlo->tile_sizes());
// Number of elements in the tile after padding.
int64_t padded_tile_size = GetPaddedTileSize(tiled_hlo->tile_sizes());

// Check if the tile is too large to fit in registers and would result in
// spilling.
if (!DoesTileFitsInRegisters(tile_size, *device_info_)) {
if (!DoesTileFitsInRegisters(padded_tile_size, *device_info_)) {
// TODO(b/363194951): Estimate performance regression due to spilling in
// terms of memory bandwidth instead of returning infinite run time.
return EstimateRunTimeData::Infinite();
}

// Total number of elements that are read from memory or computed for this
// tile across all blocks.
int64_t num_elements = num_blocks * tile_size;

const HloInstruction* hlo = tiled_hlo->hlo();

if (hlo->opcode() == HloOpcode::kConcatenate) {
Expand All @@ -354,9 +375,28 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledHloComputation(
}

if (fusion_adaptor.ContainsInstruction(hlo)) {
// Total number of elements computed for this tile across all blocks.
//
// Even if real `tile_size` is smaller than `padded_tile_size`, SM will
// still perform calculations on masked values, so they should count
// towards FLOPs.
int64_t num_elements = num_blocks * padded_tile_size;

// Tiles inside the computation contribute to the total FLOPs count.
flops += FlopsPerElement(hlo) * num_elements;
} else {
// Number of elements in the tile.
int64_t tile_size = Product(tiled_hlo->tile_sizes());

// Total number of elements that are read from memory across all blocks.
//
// Triton requires that all tiles have dimensions that are padded to the
// next power of 2. However, the load masks the padded elements, so they
// are not read from memory, but set directly in registers. As a result,
// the number of elements read from memory is equal to the size of the
// original tile.
int64_t num_elements = num_blocks * tile_size;

// Tiles of the operands of the fusion contribute to the total memory
// read time.
int64_t element_type_size =
Expand Down Expand Up @@ -443,23 +483,13 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTriton(
launch_config->block_level_parameters.output_tile_sizes);
}

// Returns the number of warps to use based on the tile size. The numbers were
// originally selected from Triton SoftMax reduction row length.
// TODO(b/332714755): Make it smarter.
int64_t GetNumWarps(int64_t tile_size) {
if (tile_size <= 512) return 1;
if (tile_size <= 1024) return 2;
if (tile_size <= 16384) return 4;
if (tile_size <= 32768) return 8;
if (tile_size <= 65536) return 16;
return 32;
}

LaunchDimensions GetLaunchDimensionsForTiledFusion(
/*static*/
LaunchDimensions
GpuPerformanceModelWithIndexingAnalysis::GetLaunchDimensionsForTiledFusion(
const TiledHloComputation& tiled_hlo_computation) {
const auto* tiled_root = tiled_hlo_computation.GetRoot();
int64_t num_blocks = tiled_hlo_computation.num_output_tiles();
int64_t num_warps = GetNumWarps(Product(tiled_root->tile_sizes()));
int64_t num_warps = GetNumWarps(GetPaddedTileSize(tiled_root->tile_sizes()));

return {static_cast<uint64_t>(num_blocks),
static_cast<uint64_t>(num_warps * WarpSize())};
Expand Down
4 changes: 4 additions & 0 deletions xla/service/gpu/model/gpu_indexing_performance_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase {
*device_info_),
mlir_context_(mlir_context) {}

// Returns the launch dimensions for the given tiled HLO computation.
static LaunchDimensions GetLaunchDimensionsForTiledFusion(
const TiledHloComputation& tiled_hlo_computation);

EstimateRunTimeData EstimateRunTimeForFusion(
const HloFusionAnalysis& fusion_analysis, bool is_coalesced = true);

Expand Down
85 changes: 85 additions & 0 deletions xla/service/gpu/model/gpu_indexing_performance_model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ limitations under the License.
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/gpu_device_info_for_tests.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/gpu/model/fusion_analysis_cache.h"
#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
#include "xla/service/gpu/model/gpu_performance_model_base.h"
#include "xla/service/gpu/model/symbolic_tile_analysis.h"
#include "xla/service/gpu/model/tiled_hlo_computation.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_description.h"
Expand Down Expand Up @@ -459,6 +462,88 @@ ENTRY main {
EXPECT_TRUE(res2.IsInfinite());
}

TEST_F(GpuIndexingPerformanceModelTest,
EstimateRunTimeForTiledFusion_UsesPaddedTileSizeForMemoryAccessTime) {
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
HloModule m
triton_softmax_computation {
param_0 = f32[65,65] parameter(0)
param_1 = f32[65,65] parameter(1)
ROOT add = f32[65,65] add(param_0, param_1)
}
ENTRY main {
param_0 = f32[65,65] parameter(0)
param_1 = f32[65,65] parameter(1)
ROOT triton_softmax = f32[65,65] fusion(param_0, param_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}}
}
)"));
auto fusion_adaptor = HloFusionAdaptor::ForInstruction(
module->entry_computation()->root_instruction());

TF_ASSERT_OK_AND_ASSIGN(
auto tiling_result,
indexing_cost_model_.TryFindBestTilingForFusion(*fusion_adaptor));

TF_ASSERT_OK_AND_ASSIGN(
auto res, indexing_cost_model_.EstimateRunTimeForTiledFusion(
*fusion_adaptor, /*launch_dimensions=*/{1, 2 * WarpSize()},
/*output_tile_sizes=*/{65, 65}));

constexpr int64_t kParamSizeBytes = 65 * 65 * 4;
constexpr int64_t kPaddedOutputTileSize = 128 * 128;
constexpr int64_t kAddFlops = 3;

// Memory access time is estimated for the tile without padding to the power
// of 2, because padded values are set directly in registers.
EXPECT_EQ(res.bytes_read, 2 * kParamSizeBytes);

// Compute happens on all value in the tile, including padded ones.
EXPECT_EQ(res.flops, kPaddedOutputTileSize * kAddFlops);
}

TEST_F(GpuIndexingPerformanceModelTest,
GetLaunchDimensionsForTiledFusion_IsSupported) {
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
HloModule m
triton_softmax_computation {
param_0 = f32[9,9,9] parameter(0)
param_1 = f32[9,9,9] parameter(1)
ROOT multiply = f32[9,9,9] multiply(param_0, param_1)
}
ENTRY main {
param_0 = f32[9,9,9] parameter(0)
param_1 = f32[9,9,9] parameter(1)
ROOT fusion = f32[9,9,9] fusion(param_0, param_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}}
}
)"));
auto fusion_adaptor = HloFusionAdaptor::ForInstruction(
module->entry_computation()->root_instruction());

SymbolicTileAnalysisOrError analysis_or_error =
SymbolicTileAnalysis::AnalyzeFusion(
*fusion_adaptor, &mlir_context_,
/*emitter_specific_constraints_builder=*/nullptr);
ASSERT_TRUE(std::holds_alternative<SymbolicTileAnalysis>(analysis_or_error));

TF_ASSERT_OK_AND_ASSIGN(
TiledHloComputation tiled_hlo_computation,
std::get<SymbolicTileAnalysis>(analysis_or_error)
.ComputeTiledHloInstructions(/*tile_parameters=*/{9, 9, 9}));

LaunchDimensions launch_dimensions = GpuPerformanceModelWithIndexingAnalysis::
GetLaunchDimensionsForTiledFusion(tiled_hlo_computation);
EXPECT_EQ(launch_dimensions.num_blocks(), 1);

// Tile size is 9 * 9 * 9 = 729 that corresponds to 2 warps. But we estimate
// the number of warps for padded tile that has size of 16 * 16 * 16 = 4096
// and corresponds to 4 warps.
EXPECT_EQ(launch_dimensions.num_threads_per_block(), 4 * WarpSize());
}

class FlopsPerElementTest : public GpuIndexingPerformanceModelTest {
public:
void CompareFlopsModels(absl::string_view hlo_module_string) {
Expand Down

0 comments on commit bff7fd7

Please sign in to comment.