From 87e62eef0c582d0a61c7a12338cf99bbb8bd0576 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Mon, 16 Sep 2024 09:20:30 -0700 Subject: [PATCH] Reverts 58cce1ac69eed19488179756419da13a7b8ec470 PiperOrigin-RevId: 675175961 --- xla/service/gpu/fusions/reduction_mlir.cc | 160 +++++++----------- xla/service/gpu/fusions/reduction_mlir.h | 17 +- .../fusions/tests/reduce_multirow/f16_v4.hlo | 22 --- 3 files changed, 59 insertions(+), 140 deletions(-) delete mode 100644 xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo diff --git a/xla/service/gpu/fusions/reduction_mlir.cc b/xla/service/gpu/fusions/reduction_mlir.cc index d195d40fe7e0b..b2db8fa5cd173 100644 --- a/xla/service/gpu/fusions/reduction_mlir.cc +++ b/xla/service/gpu/fusions/reduction_mlir.cc @@ -770,11 +770,32 @@ llvm::SmallVector MlirSmallColumnReductionFusion::EmitReduction( shared_rows_ / 2); } +std::unique_ptr CreateMlirReductionFusion( + const HloFusionAnalysis& analysis) { + auto* hero_reduction = analysis.FindHeroReduction(); + CHECK_NE(hero_reduction, nullptr); + ReductionDimensions reduction_dimensions = + GetReductionKindAndContiguousComponents(*hero_reduction); + if (reduction_dimensions.is_row_reduction) { + if (RowReductionGetRowsPerWarp( + reduction_dimensions.dimensions[kRowMinorReduced]) > 1) { + return std::make_unique(analysis); + } + return std::make_unique(analysis); + } + + if (WarpSize() % reduction_dimensions.dimensions[kColMinorKept] == 0) { + return std::make_unique(analysis); + } + return std::make_unique(analysis); +} + MlirRowReductionFusion::MlirRowReductionFusion( const HloFusionAnalysis& analysis) : MlirReductionFusion(analysis) { CHECK(reduction_dimensions_.is_row_reduction); Vector3 shape = reduction_dimensions_.dimensions; + CHECK_EQ(RowReductionGetRowsPerWarp(shape[kRowMinorReduced]), 1); constexpr int64_t kMinorReducedElementsPerThread = 16; int64_t num_threads_kept = 1; @@ -910,28 +931,33 @@ llvm::SmallVector MlirRowReductionFusion::EmitReduction( } MlirMultiRowReductionFusion::MlirMultiRowReductionFusion( - const HloFusionAnalysis& analysis, int vector_size) + const HloFusionAnalysis& analysis) : MlirReductionFusion(analysis) { CHECK(reduction_dimensions_.is_row_reduction); Vector3 shape = reduction_dimensions_.dimensions; + int64_t rows_per_warp = RowReductionGetRowsPerWarp(shape[kRowMinorReduced]); input_shape_ = {shape[0], shape[1], shape[2]}; - num_threads_ = GetNumThreads(reduction_dimensions_, vector_size); - num_blocks_ = {GetNumBlocks(reduction_dimensions_, num_threads_)}; - tile_sizes_per_thread_ = {shape[0], vector_size}; -} + CHECK_GT(rows_per_warp, 1); -std::unique_ptr MlirMultiRowReductionFusion::TryCreate( - const HloFusionAnalysis& analysis) { - auto* hero_reduction = analysis.FindHeroReduction(); - CHECK_NE(hero_reduction, nullptr); - auto reduction_dimensions = - GetReductionKindAndContiguousComponents(*hero_reduction); - auto shape = reduction_dimensions.dimensions; - // This emitter only supports reductions where the reduced dimension is a - // power of 2. - if (shape[kRowMinorReduced] & (shape[kRowMinorReduced] - 1)) { - return nullptr; - } + auto compute_block_size = [&](int vector_size) { + int64_t num_threads_reduced = shape[kRowMinorReduced] / vector_size; + + constexpr int64_t kThreadsPerBlockTarget = 256; + int64_t kept_size = reduction_dimensions_.dimensions[kRowKept]; + int64_t num_threads_kept = 1; + if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) { + num_threads_kept = kept_size; + } else { + num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced; + } + num_threads_ = {num_threads_kept, num_threads_reduced}; + tile_sizes_per_thread_ = {shape[0], vector_size}; + num_blocks_ = {CeilOfRatio(input_shape_[kRowKept], num_threads_kept)}; + }; + + // Compute the launch grid without vectorization. We use the results to + // compute the vectorized launch grid. + compute_block_size(1); // Normally, we only consider input types for vectorization. However, in // multi-row reductions, the input:output ratio is much higher, so we consider @@ -939,75 +965,24 @@ std::unique_ptr MlirMultiRowReductionFusion::TryCreate( int smallest_input_or_output_bits = std::min(analysis.input_output_info().smallest_input_dtype_bits, analysis.input_output_info().smallest_output_dtype_bits); - int largest_input_or_output_bits = - std::max(analysis.input_output_info().smallest_input_dtype_bits, - analysis.input_output_info().smallest_output_dtype_bits); + // This vector size is always valid: we know that the reduced dimension is a + // power of 2, since otherwise RowReductionGetRowsPerWarp would have + // returned 1. // Our codegen can't currently deal with vectorization across rows, so we // limit the vector size to the size of the row. Note that this emitter // essentially reverts to the loop emitter in this case, except for side // outputs. - int vector_size = std::min(static_cast(shape[kRowMinorReduced]), - 64 / smallest_input_or_output_bits); - - // Very large vector sizes for f32 can be detrimental, so we limit the vector - // size to 16 bytes if we have some >= 32 bit inputs or outputs. This is still - // a bit on the high side, but remember that we also have very small inputs - // or outputs. - if (largest_input_or_output_bits >= 32) { - vector_size = std::min(128 / largest_input_or_output_bits, vector_size); - } - - // The reduced dimension must fit into a single warp. - if (shape[kRowMinorReduced] > WarpSize() * vector_size) { - return nullptr; - } - - // At the very least, we want to have work for every SM. - // TODO(jreiffers): This limit is probably too low: if we have as many blocks - // as SMs, we'll only run about 8 warps per SM, so occupancy will be very low. - // Further measurements are needed to refine this heuristic. - int64_t min_desired_blocks = analysis.device_info().core_count(); - while (vector_size > 1 && - GetNumBlocks(reduction_dimensions, - GetNumThreads(reduction_dimensions, vector_size)) < - min_desired_blocks) { - vector_size /= 2; + int vector_size = std::min(static_cast(input_shape_[kRowMinorReduced]), + 32 / smallest_input_or_output_bits); + + // We target 8 warps per block, which means there could be up to 8 blocks per + // SM, but we have no good way of knowing. In practice, enabling vectorization + // for decently sized reductions at least does not hurt. + if (num_blocks_.front() > analysis.device_info().core_count() && + vector_size > 1) { + compute_block_size(vector_size); } - - // Check again that the reduced dimension fits after potentially reducing the - // vector size. - if (shape[kRowMinorReduced] > WarpSize() * vector_size) { - return nullptr; - } - - return std::make_unique(analysis, vector_size); -} - -absl::InlinedVector MlirMultiRowReductionFusion::GetNumThreads( - const ReductionDimensions& reduction_dimensions, int vector_size) { - int64_t num_threads_reduced = - reduction_dimensions.dimensions[kRowMinorReduced] / vector_size; - - constexpr int64_t kThreadsPerBlockTarget = 256; - int64_t kept_size = reduction_dimensions.dimensions[kRowKept]; - int64_t num_threads_kept = 1; - if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) { - num_threads_kept = kept_size; - } else { - num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced; - } - return {num_threads_kept, num_threads_reduced}; -} - -int64_t MlirMultiRowReductionFusion::GetNumBlocks( - const ReductionDimensions& reduction_dimensions, - const absl::InlinedVector& num_threads) { - CHECK_EQ(num_threads.size(), 2) - << "Expected num_threads to contain the number of threads in the {kept, " - "reduced} dimensions."; - return CeilOfRatio(reduction_dimensions.dimensions[kRowKept], - num_threads.front()); } IndexingMap MlirMultiRowReductionFusion::ComputeReductionInputIndexing( @@ -1038,7 +1013,8 @@ IndexingMap MlirMultiRowReductionFusion::ComputeReductionOutputIndexing( : mlir::getAffineDimExpr(3, ctx); IndexingMap projected_index = GetIndexingMap(block_id * num_threads_[0] + thread_id[0]); - projected_index.AddConstraint(thread_id[1] % num_threads_[1], {0, 0}); + projected_index.AddConstraint(thread_id[1] % (WarpSize() / GetRowsPerWarp()), + {0, 0}); // We don't need a constraint on the loop dimensions, because they are removed // by GetIndexingMap (since they don't show up in the output index // computation). @@ -1058,30 +1034,10 @@ llvm::SmallVector MlirMultiRowReductionFusion::EmitReduction( auto per_thread = state.EmitPerThreadElements(group_id, inits, state.FusionOutputs()); auto reduced = state.ShuffleReduce(reductions, per_thread.reduction_scalars, - num_threads_[1] / 2); + WarpSize() / 2 / GetRowsPerWarp()); return EvaluateEpilogue(reduced, std::move(per_thread.outputs), state, group_id, /*symbol_values=*/{}); } -std::unique_ptr CreateMlirReductionFusion( - const HloFusionAnalysis& analysis) { - auto* hero_reduction = analysis.FindHeroReduction(); - CHECK_NE(hero_reduction, nullptr); - ReductionDimensions reduction_dimensions = - GetReductionKindAndContiguousComponents(*hero_reduction); - if (reduction_dimensions.is_row_reduction) { - auto multi_row_emitter = MlirMultiRowReductionFusion::TryCreate(analysis); - if (multi_row_emitter != nullptr) { - return multi_row_emitter; - } - return std::make_unique(analysis); - } - - if (WarpSize() % reduction_dimensions.dimensions[kColMinorKept] == 0) { - return std::make_unique(analysis); - } - return std::make_unique(analysis); -} - } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/fusions/reduction_mlir.h b/xla/service/gpu/fusions/reduction_mlir.h index db0fbd2b45c31..838729254070a 100644 --- a/xla/service/gpu/fusions/reduction_mlir.h +++ b/xla/service/gpu/fusions/reduction_mlir.h @@ -16,7 +16,6 @@ limitations under the License. #define XLA_SERVICE_GPU_FUSIONS_REDUCTION_MLIR_H_ #include -#include #include #include #include @@ -169,23 +168,9 @@ class MlirRowReductionFusion : public MlirReductionFusion { class MlirMultiRowReductionFusion : public MlirReductionFusion { public: - MlirMultiRowReductionFusion(const HloFusionAnalysis& analysis, - int vector_size); - - // Attempts to create a multi-row reduction emitter for the given analysis. - // Returns nullptr if the fusion is not supported. - static std::unique_ptr TryCreate( - const HloFusionAnalysis& analysis); + explicit MlirMultiRowReductionFusion(const HloFusionAnalysis& analysis); protected: - // Returns the number of {kept, reduced} threads for the given reduction and - // vector size. - static absl::InlinedVector GetNumThreads( - const ReductionDimensions& reduction_dimensions, int vector_size); - static int64_t GetNumBlocks( - const ReductionDimensions& reduction_dimensions, - const absl::InlinedVector& num_threads); - int GetRowsPerWarp() const; llvm::SmallVector EmitReduction( int group_id, EmitterState& state) const override; diff --git a/xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo b/xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo deleted file mode 100644 index d2e3928bdfd56..0000000000000 --- a/xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo +++ /dev/null @@ -1,22 +0,0 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ -// RUN: -xla-gpu-test-transform-loops | FileCheck %s - -// The reference implementation reduces in f64, so we need a larger tolerance. -// RUN: test_correctness %s --bijection_inputs=reduce:0 \ -// RUN: --bijection_outputs=reduce --abs_error_bound=0.005 --rel_error_bound=0.005 - -add { - lhs = f16[] parameter(0) - rhs = f16[] parameter(1) - ROOT add = f16[] add(lhs, rhs) -} - -fusion { - param_0 = f16[2048,64] parameter(0) - c = f16[] constant(0) - ROOT reduce = f16[2048] reduce(param_0, c), dimensions={1}, to_apply=add -} - -// If unvectorized, this would be a regular row reduction. However, since we can -// vectorize to size four, we can emit this as a multi-row reduction. -// CHECK: vector.transfer_read {{.*}} vector<4xf16>