Skip to content

Commit

Permalink
[XLA:GPU][NFC] Move "tile size fits in registers" check to a separate…
Browse files Browse the repository at this point in the history
… function.

A huge comment in the middle of the loop doesn't improve readability.

PiperOrigin-RevId: 674223948
  • Loading branch information
olegshyshkov authored and Google-ML-Automation committed Sep 13, 2024
1 parent 474036b commit bad4cc8
Showing 1 changed file with 35 additions and 27 deletions.
62 changes: 35 additions & 27 deletions xla/service/gpu/model/gpu_indexing_performance_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,44 @@ limitations under the License.
#include "xla/service/instruction_fusion.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_description.h"
#include "xla/util.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace gpu {
namespace {

// Checks if the tile is too large to fit in registers and would result in
// spilling.
//
// Spilling almost always causes significant performance regressions, so this
// heuristic tries to be safe and increase recall at the cost of precision.
bool DoesTileFitsInRegisters(int64_t tile_size,
const se::DeviceDescription& device_info) {
// Register allocation happens at PTX->SASS level, so we can't know the exact
// number of registers used by a kernel. We make a few assumptions about the
// kernel we will generate (this may not hold in the future):
//
// * We'll need at least 1 register to store 1 element of the tile.
// * All values of the tile are live at the same time.
// * If all values don't need to be live at the same time (for example to
// compute a reduction), it will be modeled by an explicit loop with
// smaller tiles inside during tiling propagation.
//
// TODO(b/363194951): Check how many registers we need for scratch memory
// for indexing computation and expensive instructions like exponential or
// cosine.
//
// TODO(b/363194951): Check how the number of registers used depends on the
// data type. `registers_per_block_limit()` returns the number of 32-bit
// registers. Check if 64-bit types need twice as many registers. Check if
// smaller types can fit into one register.
return tile_size <= device_info.registers_per_block_limit();
}

} // namespace

int64_t GpuPerformanceModelWithIndexingAnalysis::FlopsPerElement(
const HloInstruction* instr) {
Expand Down Expand Up @@ -302,33 +334,9 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledHloComputation(

// Check if the tile is too large to fit in registers and would result in
// spilling.
//
// Spilling almost always causes significant performance regressions, so
// this heuristic tries be safe and increase recall at the cost of
// precision.
//
// Register allocation happens at PTX->SASS level, so we can't know the
// exact number of registers used by a kernel. We make a few assumptions
// about the kernel we will generate (this may not hold in the future):
//
// * We'll need to at least 1 register to store 1 element of the tile.
// * All value of the tile are live at the same time.
// * If all values don't need to be live at the same time (for example to
// compute a reduction), it will be modeled by an explicit loop with
// smaller tiles inside during tiling propagation.
//
// TODO(b/363194951): Check how many registers we need for scratch memory
// for indexing computation and expensive instructions like exponential or
// cosine.
//
// TODO(b/363194951): Check how the number of registers used depends on the
// data type. `registers_per_block_limit()` returns the number of 32-bit
// registers. Check if 64-bit types need twice as many registers. Check if
// smaller types can fit into one register.
//
// TODO(b/363194951): Estimate performance regression due to spilling in
// terms of memory bandwidth instead of returning infinite run time.
if (tile_size > device_info_->registers_per_block_limit()) {
if (!DoesTileFitsInRegisters(tile_size, *device_info_)) {
// TODO(b/363194951): Estimate performance regression due to spilling in
// terms of memory bandwidth instead of returning infinite run time.
return EstimateRunTimeData::Infinite();
}

Expand Down

0 comments on commit bad4cc8

Please sign in to comment.