diff --git a/xla/service/gpu/fusions/reduction_mlir.cc b/xla/service/gpu/fusions/reduction_mlir.cc index b2db8fa5cd173..d195d40fe7e0b 100644 --- a/xla/service/gpu/fusions/reduction_mlir.cc +++ b/xla/service/gpu/fusions/reduction_mlir.cc @@ -770,32 +770,11 @@ llvm::SmallVector MlirSmallColumnReductionFusion::EmitReduction( shared_rows_ / 2); } -std::unique_ptr CreateMlirReductionFusion( - const HloFusionAnalysis& analysis) { - auto* hero_reduction = analysis.FindHeroReduction(); - CHECK_NE(hero_reduction, nullptr); - ReductionDimensions reduction_dimensions = - GetReductionKindAndContiguousComponents(*hero_reduction); - if (reduction_dimensions.is_row_reduction) { - if (RowReductionGetRowsPerWarp( - reduction_dimensions.dimensions[kRowMinorReduced]) > 1) { - return std::make_unique(analysis); - } - return std::make_unique(analysis); - } - - if (WarpSize() % reduction_dimensions.dimensions[kColMinorKept] == 0) { - return std::make_unique(analysis); - } - return std::make_unique(analysis); -} - MlirRowReductionFusion::MlirRowReductionFusion( const HloFusionAnalysis& analysis) : MlirReductionFusion(analysis) { CHECK(reduction_dimensions_.is_row_reduction); Vector3 shape = reduction_dimensions_.dimensions; - CHECK_EQ(RowReductionGetRowsPerWarp(shape[kRowMinorReduced]), 1); constexpr int64_t kMinorReducedElementsPerThread = 16; int64_t num_threads_kept = 1; @@ -931,33 +910,28 @@ llvm::SmallVector MlirRowReductionFusion::EmitReduction( } MlirMultiRowReductionFusion::MlirMultiRowReductionFusion( - const HloFusionAnalysis& analysis) + const HloFusionAnalysis& analysis, int vector_size) : MlirReductionFusion(analysis) { CHECK(reduction_dimensions_.is_row_reduction); Vector3 shape = reduction_dimensions_.dimensions; - int64_t rows_per_warp = RowReductionGetRowsPerWarp(shape[kRowMinorReduced]); input_shape_ = {shape[0], shape[1], shape[2]}; - CHECK_GT(rows_per_warp, 1); - - auto compute_block_size = [&](int vector_size) { - int64_t num_threads_reduced = shape[kRowMinorReduced] / vector_size; - - constexpr int64_t kThreadsPerBlockTarget = 256; - int64_t kept_size = reduction_dimensions_.dimensions[kRowKept]; - int64_t num_threads_kept = 1; - if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) { - num_threads_kept = kept_size; - } else { - num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced; - } - num_threads_ = {num_threads_kept, num_threads_reduced}; - tile_sizes_per_thread_ = {shape[0], vector_size}; - num_blocks_ = {CeilOfRatio(input_shape_[kRowKept], num_threads_kept)}; - }; + num_threads_ = GetNumThreads(reduction_dimensions_, vector_size); + num_blocks_ = {GetNumBlocks(reduction_dimensions_, num_threads_)}; + tile_sizes_per_thread_ = {shape[0], vector_size}; +} - // Compute the launch grid without vectorization. We use the results to - // compute the vectorized launch grid. - compute_block_size(1); +std::unique_ptr MlirMultiRowReductionFusion::TryCreate( + const HloFusionAnalysis& analysis) { + auto* hero_reduction = analysis.FindHeroReduction(); + CHECK_NE(hero_reduction, nullptr); + auto reduction_dimensions = + GetReductionKindAndContiguousComponents(*hero_reduction); + auto shape = reduction_dimensions.dimensions; + // This emitter only supports reductions where the reduced dimension is a + // power of 2. + if (shape[kRowMinorReduced] & (shape[kRowMinorReduced] - 1)) { + return nullptr; + } // Normally, we only consider input types for vectorization. However, in // multi-row reductions, the input:output ratio is much higher, so we consider @@ -965,24 +939,75 @@ MlirMultiRowReductionFusion::MlirMultiRowReductionFusion( int smallest_input_or_output_bits = std::min(analysis.input_output_info().smallest_input_dtype_bits, analysis.input_output_info().smallest_output_dtype_bits); + int largest_input_or_output_bits = + std::max(analysis.input_output_info().smallest_input_dtype_bits, + analysis.input_output_info().smallest_output_dtype_bits); - // This vector size is always valid: we know that the reduced dimension is a - // power of 2, since otherwise RowReductionGetRowsPerWarp would have - // returned 1. // Our codegen can't currently deal with vectorization across rows, so we // limit the vector size to the size of the row. Note that this emitter // essentially reverts to the loop emitter in this case, except for side // outputs. - int vector_size = std::min(static_cast(input_shape_[kRowMinorReduced]), - 32 / smallest_input_or_output_bits); - - // We target 8 warps per block, which means there could be up to 8 blocks per - // SM, but we have no good way of knowing. In practice, enabling vectorization - // for decently sized reductions at least does not hurt. - if (num_blocks_.front() > analysis.device_info().core_count() && - vector_size > 1) { - compute_block_size(vector_size); + int vector_size = std::min(static_cast(shape[kRowMinorReduced]), + 64 / smallest_input_or_output_bits); + + // Very large vector sizes for f32 can be detrimental, so we limit the vector + // size to 16 bytes if we have some >= 32 bit inputs or outputs. This is still + // a bit on the high side, but remember that we also have very small inputs + // or outputs. + if (largest_input_or_output_bits >= 32) { + vector_size = std::min(128 / largest_input_or_output_bits, vector_size); + } + + // The reduced dimension must fit into a single warp. + if (shape[kRowMinorReduced] > WarpSize() * vector_size) { + return nullptr; + } + + // At the very least, we want to have work for every SM. + // TODO(jreiffers): This limit is probably too low: if we have as many blocks + // as SMs, we'll only run about 8 warps per SM, so occupancy will be very low. + // Further measurements are needed to refine this heuristic. + int64_t min_desired_blocks = analysis.device_info().core_count(); + while (vector_size > 1 && + GetNumBlocks(reduction_dimensions, + GetNumThreads(reduction_dimensions, vector_size)) < + min_desired_blocks) { + vector_size /= 2; } + + // Check again that the reduced dimension fits after potentially reducing the + // vector size. + if (shape[kRowMinorReduced] > WarpSize() * vector_size) { + return nullptr; + } + + return std::make_unique(analysis, vector_size); +} + +absl::InlinedVector MlirMultiRowReductionFusion::GetNumThreads( + const ReductionDimensions& reduction_dimensions, int vector_size) { + int64_t num_threads_reduced = + reduction_dimensions.dimensions[kRowMinorReduced] / vector_size; + + constexpr int64_t kThreadsPerBlockTarget = 256; + int64_t kept_size = reduction_dimensions.dimensions[kRowKept]; + int64_t num_threads_kept = 1; + if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) { + num_threads_kept = kept_size; + } else { + num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced; + } + return {num_threads_kept, num_threads_reduced}; +} + +int64_t MlirMultiRowReductionFusion::GetNumBlocks( + const ReductionDimensions& reduction_dimensions, + const absl::InlinedVector& num_threads) { + CHECK_EQ(num_threads.size(), 2) + << "Expected num_threads to contain the number of threads in the {kept, " + "reduced} dimensions."; + return CeilOfRatio(reduction_dimensions.dimensions[kRowKept], + num_threads.front()); } IndexingMap MlirMultiRowReductionFusion::ComputeReductionInputIndexing( @@ -1013,8 +1038,7 @@ IndexingMap MlirMultiRowReductionFusion::ComputeReductionOutputIndexing( : mlir::getAffineDimExpr(3, ctx); IndexingMap projected_index = GetIndexingMap(block_id * num_threads_[0] + thread_id[0]); - projected_index.AddConstraint(thread_id[1] % (WarpSize() / GetRowsPerWarp()), - {0, 0}); + projected_index.AddConstraint(thread_id[1] % num_threads_[1], {0, 0}); // We don't need a constraint on the loop dimensions, because they are removed // by GetIndexingMap (since they don't show up in the output index // computation). @@ -1034,10 +1058,30 @@ llvm::SmallVector MlirMultiRowReductionFusion::EmitReduction( auto per_thread = state.EmitPerThreadElements(group_id, inits, state.FusionOutputs()); auto reduced = state.ShuffleReduce(reductions, per_thread.reduction_scalars, - WarpSize() / 2 / GetRowsPerWarp()); + num_threads_[1] / 2); return EvaluateEpilogue(reduced, std::move(per_thread.outputs), state, group_id, /*symbol_values=*/{}); } +std::unique_ptr CreateMlirReductionFusion( + const HloFusionAnalysis& analysis) { + auto* hero_reduction = analysis.FindHeroReduction(); + CHECK_NE(hero_reduction, nullptr); + ReductionDimensions reduction_dimensions = + GetReductionKindAndContiguousComponents(*hero_reduction); + if (reduction_dimensions.is_row_reduction) { + auto multi_row_emitter = MlirMultiRowReductionFusion::TryCreate(analysis); + if (multi_row_emitter != nullptr) { + return multi_row_emitter; + } + return std::make_unique(analysis); + } + + if (WarpSize() % reduction_dimensions.dimensions[kColMinorKept] == 0) { + return std::make_unique(analysis); + } + return std::make_unique(analysis); +} + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/fusions/reduction_mlir.h b/xla/service/gpu/fusions/reduction_mlir.h index 838729254070a..db0fbd2b45c31 100644 --- a/xla/service/gpu/fusions/reduction_mlir.h +++ b/xla/service/gpu/fusions/reduction_mlir.h @@ -16,6 +16,7 @@ limitations under the License. #define XLA_SERVICE_GPU_FUSIONS_REDUCTION_MLIR_H_ #include +#include #include #include #include @@ -168,9 +169,23 @@ class MlirRowReductionFusion : public MlirReductionFusion { class MlirMultiRowReductionFusion : public MlirReductionFusion { public: - explicit MlirMultiRowReductionFusion(const HloFusionAnalysis& analysis); + MlirMultiRowReductionFusion(const HloFusionAnalysis& analysis, + int vector_size); + + // Attempts to create a multi-row reduction emitter for the given analysis. + // Returns nullptr if the fusion is not supported. + static std::unique_ptr TryCreate( + const HloFusionAnalysis& analysis); protected: + // Returns the number of {kept, reduced} threads for the given reduction and + // vector size. + static absl::InlinedVector GetNumThreads( + const ReductionDimensions& reduction_dimensions, int vector_size); + static int64_t GetNumBlocks( + const ReductionDimensions& reduction_dimensions, + const absl::InlinedVector& num_threads); + int GetRowsPerWarp() const; llvm::SmallVector EmitReduction( int group_id, EmitterState& state) const override; diff --git a/xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo b/xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo new file mode 100644 index 0000000000000..d2e3928bdfd56 --- /dev/null +++ b/xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo @@ -0,0 +1,22 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: -xla-gpu-test-transform-loops | FileCheck %s + +// The reference implementation reduces in f64, so we need a larger tolerance. +// RUN: test_correctness %s --bijection_inputs=reduce:0 \ +// RUN: --bijection_outputs=reduce --abs_error_bound=0.005 --rel_error_bound=0.005 + +add { + lhs = f16[] parameter(0) + rhs = f16[] parameter(1) + ROOT add = f16[] add(lhs, rhs) +} + +fusion { + param_0 = f16[2048,64] parameter(0) + c = f16[] constant(0) + ROOT reduce = f16[2048] reduce(param_0, c), dimensions={1}, to_apply=add +} + +// If unvectorized, this would be a regular row reduction. However, since we can +// vectorize to size four, we can emit this as a multi-row reduction. +// CHECK: vector.transfer_read {{.*}} vector<4xf16>