diff --git a/xla/service/gpu/fusions/triton/BUILD b/xla/service/gpu/fusions/triton/BUILD index b68ef80040178..f2eebf5c7dcf6 100644 --- a/xla/service/gpu/fusions/triton/BUILD +++ b/xla/service/gpu/fusions/triton/BUILD @@ -391,6 +391,7 @@ xla_test( tags = ["no_mac"], deps = [ ":triton_support", + ":triton_test_utils", "//xla:comparison_util", "//xla:error_spec", "//xla:xla_data_proto_cc", diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index e286aaa568714..688be3893072e 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -787,16 +787,13 @@ absl::StatusOr EmitReduce( input = b.create(mask, input, neutral); } - // Triton actually only performs reductions on float32 inputs, and we must - // thus upcast/downcast our input if its data type is different. - input = Cast(b, input, b.getF32Type()); - mt::ReduceOp reduction = b.create(input, reduction_dimension); { + TF_ASSIGN_OR_RETURN(Type result_ty, + TritonType(b, hlo_reduce.shape().element_type())); mlir::Location loc = b.getLoc(); - mlir::Block* reducer = - b.createBlock(&reduction->getRegion(0), {}, - {b.getF32Type(), b.getF32Type()}, {loc, loc}); + mlir::Block* reducer = b.createBlock(&reduction->getRegion(0), {}, + {result_ty, result_ty}, {loc, loc}); HloComputation* reduction_computation = hlo_reduce.to_apply(); @@ -830,16 +827,13 @@ absl::StatusOr EmitReduce( Value result = reduction.getResult().front(); - // We want to return a tensor of float32, but the ReturnReduceOp produces an - // f32 constant when reducing a single dim. To convert to a tensor we splat - // the result. + // We want to return a tensor, but the ReturnReduceOp produces a raw scalar + // when reducing a single dim. To convert to a tensor we splat the result. if (!mlir::dyn_cast(reduction.getResult().front())) { result = Splat(b, result, {}); } - TF_ASSIGN_OR_RETURN(Type result_ty, - TritonType(b, hlo_reduce.shape().element_type())); - return Cast(b, result, result_ty); + return result; } // Emit code corresponding to a fusion instruction somehow nested within the diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc index 99195059d6959..0735d65c6d387 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" #include "xla/service/gpu/fusions/triton/triton_support_legacy.h" +#include "xla/service/gpu/fusions/triton/triton_test_utils.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/stream_executor/device_description.h" #include "xla/xla.pb.h" @@ -731,11 +732,9 @@ ENTRY e { /*run_hlo_passes=*/false)); } -INSTANTIATE_TEST_SUITE_P( - ConstantTestSuite, ConstantTest, ::testing::ValuesIn(kSupportedDataTypes), - [](const ::testing::TestParamInfo type) { - return primitive_util::LowercasePrimitiveTypeName(type.param); - }); +INSTANTIATE_TEST_SUITE_P(ConstantTestSuite, ConstantTest, + ::testing::ValuesIn(kSupportedDataTypes), + TritonSupportTestTypeToString); class ConvertTest : public TritonTest, public ::testing::WithParamInterface< @@ -1161,8 +1160,6 @@ ENTRY main { tolerance = 1e-6; break; case F16: - tolerance = 2e-4; - break; case BF16: tolerance = 2e-2; break; @@ -1685,8 +1682,6 @@ ENTRY main { tolerance = 1e-6; break; case F16: - tolerance = 2e-4; - break; case BF16: tolerance = 2e-2; break; @@ -2228,6 +2223,47 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } +class ReductionTypeTest : public TritonTest, + public ::testing::WithParamInterface { +}; + +TEST_P(ReductionTypeTest, DifferentReductionTypes) { + PrimitiveType data_type = GetParam(); + + const std::string kHloTestTemplate = R"( +max { + p0 = $0[] parameter(0) + p1 = $0[] parameter(1) + ROOT max = $0[] maximum(p0, p1) +} + +triton_computation { + p = $0[400,16] parameter(0) + zero = $0[] constant(0) + ROOT reduce = $0[400] reduce(p, zero), dimensions={1}, to_apply=max +} + +ENTRY entry_computation { + p = $0[400,16] parameter(0) + ROOT fusion = $0[400] fusion(p), kind=kCustom, calls=triton_computation, + backend_config={ "operation_queue_id":"0", "wait_on_operation_queues":[], + "fusion_backend_config":{ "kind":"__triton", "block_level_fusion_config":{ + "output_tile_sizes":["400"], "num_warps":"1"}}, + "force_earliest_schedule":false} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type)); + EXPECT_TRUE( + RunAndCompareNoHloPasses(hlo_test, ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +constexpr std::array kReductionSupportedDataTypes{ + PRED, S8, S16, S32, S64, F16, F32, F64, BF16}; + +INSTANTIATE_TEST_SUITE_P(ReductionTypeTestSuite, ReductionTypeTest, + ::testing::ValuesIn(kReductionSupportedDataTypes), + TritonSupportTestTypeToString); + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/fusions/triton/triton_support.cc b/xla/service/gpu/fusions/triton/triton_support.cc index f8ca2b080e3b0..69eeb7461bba6 100644 --- a/xla/service/gpu/fusions/triton/triton_support.cc +++ b/xla/service/gpu/fusions/triton/triton_support.cc @@ -146,57 +146,32 @@ CodegenDecision IsTritonSupportedConversion( } // Set of binary element-wise ops that are genuinely supported by Triton. -// -// Note that there is a difference between ops inside a reduction computation -// and "regular" ops. The reason is that ops inside a reduction computation -// operate on "unwrapped" values (e.g. scalars represented as f32 instead of -// tensor) and that codegen supports a different set of operations. -// -// In principle `is_within_reduction_computation` can be added also to the -// functions that check support for unary and ternary ops, but there was no need -// to do this so far. absl::flat_hash_set TritonSupportedBinaryElementwiseOps( - PrimitiveType element_type, const se::GpuComputeCapability& gpu_version, - bool is_within_reduction_computation) { - absl::flat_hash_set ret; - - if (!is_within_reduction_computation && - (element_type == PrimitiveType::F8E5M2 || - element_type == PrimitiveType::F8E4M3FN)) { - return ret; + PrimitiveType element_type, const se::GpuComputeCapability& gpu_version) { + if (element_type == PrimitiveType::U16 || + element_type == PrimitiveType::F8E5M2 || + element_type == PrimitiveType::F8E4M3FN) { + return {}; } + absl::flat_hash_set ret{HloOpcode::kAdd, HloOpcode::kCompare, + HloOpcode::kMaximum, HloOpcode::kMinimum, + HloOpcode::kMultiply}; + if (element_type == PrimitiveType::PRED) { - ret.insert(HloOpcode::kCompare); - ret.insert(HloOpcode::kAdd); - ret.insert(HloOpcode::kMultiply); - ret.insert(HloOpcode::kMaximum); - ret.insert(HloOpcode::kMinimum); - - if (!is_within_reduction_computation) { - ret.insert(HloOpcode::kAnd); - ret.insert(HloOpcode::kOr); - ret.insert(HloOpcode::kXor); - } + ret.insert(HloOpcode::kAnd); + ret.insert(HloOpcode::kOr); + ret.insert(HloOpcode::kXor); return ret; } - if (element_type != PrimitiveType::U16 || is_within_reduction_computation) { - ret.insert(HloOpcode::kAdd); - ret.insert(HloOpcode::kCompare); - ret.insert(HloOpcode::kSubtract); - ret.insert(HloOpcode::kMaximum); - ret.insert(HloOpcode::kMinimum); - ret.insert(HloOpcode::kMultiply); - - if (primitive_util::IsIntegralType(element_type)) { - ret.insert(HloOpcode::kDivide); - if (!is_within_reduction_computation) { - ret.insert(HloOpcode::kAnd); - ret.insert(HloOpcode::kOr); - ret.insert(HloOpcode::kXor); - } - } + ret.insert(HloOpcode::kSubtract); + + if (primitive_util::IsIntegralType(element_type)) { + ret.insert(HloOpcode::kDivide); + ret.insert(HloOpcode::kAnd); + ret.insert(HloOpcode::kOr); + ret.insert(HloOpcode::kXor); } if (element_type == PrimitiveType::F32 || @@ -207,11 +182,6 @@ absl::flat_hash_set TritonSupportedBinaryElementwiseOps( ret.insert(HloOpcode::kPower); } - if (is_within_reduction_computation && - primitive_util::IsFloatingPointType(element_type)) { - ret.insert(HloOpcode::kDivide); - } - return ret; } @@ -237,40 +207,29 @@ absl::flat_hash_set TritonSupportedTernaryElementwiseOps( // responsible for ensuring that the relevant data type is supported on the // device of interest. bool IsTritonSupportedElementwise(HloOpcode opcode, PrimitiveType element_type, - const se::GpuComputeCapability& gpu_version, - bool is_within_reduction_computation) { + const se::GpuComputeCapability& gpu_version) { return TritonSupportedUnaryElementwiseOps(element_type).contains(opcode) || - TritonSupportedBinaryElementwiseOps(element_type, gpu_version, - is_within_reduction_computation) + TritonSupportedBinaryElementwiseOps(element_type, gpu_version) .contains(opcode) || TritonSupportedTernaryElementwiseOps(element_type, gpu_version) .contains(opcode); } CodegenDecision IsTritonSupportedInstructionImpl( - const HloInstruction& instr, const se::GpuComputeCapability& gpu_version, - bool is_within_reduction_computation); + const HloInstruction& instr, const se::GpuComputeCapability& gpu_version); // Filters Reduces which can be handled using Triton. CodegenDecision CanTritonHandleReduce( const HloReduceInstruction& reduce, const se::GpuComputeCapability& gpu_version) { - // The reduction has already passed the generic input/output type checks. - // Now we just need to check specific restrictions for reduce. - if (reduce.shape().element_type() == PrimitiveType::F8E4M3FN) { - if (auto cc = std::get_if(&gpu_version)) { - if (!cc->IsAtLeastHopper()) { - return "F8E4M3FN is not supported before Hopper."; - } - } + if (reduce.shape().element_type() == PrimitiveType::F8E4M3FN || + reduce.shape().element_type() == PrimitiveType::F8E5M2) { + return "F8E4M3FN and F8E5M2 are not supported for reductions."; } bool is_triton_supported_reduction_computation = absl::c_all_of( reduce.to_apply()->instructions(), [&](const HloInstruction* instr) { - return IsTritonSupportedInstructionImpl( - *instr, gpu_version, - /*is_within_reduction_computation=*/true) - .CanFuse(); + return IsTritonSupportedInstructionImpl(*instr, gpu_version).CanFuse(); }); if (!is_triton_supported_reduction_computation) { return "Unsupported reduction computation by Triton."; @@ -283,8 +242,7 @@ CodegenDecision CanTritonHandleReduce( } CodegenDecision IsTritonSupportedInstructionImpl( - const HloInstruction& instr, const se::GpuComputeCapability& gpu_version, - bool is_within_reduction_computation) { + const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { if (internal::IsTritonUnsupportedOpcode(instr.opcode())) { return "Unsupported opcode."; } @@ -330,7 +288,7 @@ CodegenDecision IsTritonSupportedInstructionImpl( // and `select` which have a fixed PRED type in the output and first // operand. instr.operand(instr.operand_count() - 1)->shape().element_type(), - gpu_version, is_within_reduction_computation)) { + gpu_version)) { return "Unsupported elementwise operation."; } return CodegenDecision{}; @@ -445,8 +403,8 @@ absl::Status EnsureTritonSupportsComputeCapability( CodegenDecision IsTritonSupportedInstruction( const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { - CodegenDecision decision = IsTritonSupportedInstructionImpl( - instr, gpu_version, /*is_within_reduction_computation=*/false); + CodegenDecision decision = + IsTritonSupportedInstructionImpl(instr, gpu_version); VLOG(2) << "IsTritonSupportedInstruction: " << instr.ToString() << " " << bool(decision); return decision; diff --git a/xla/service/gpu/fusions/triton/triton_support_test.cc b/xla/service/gpu/fusions/triton/triton_support_test.cc index 067282630f670..6b21a8c7c82ac 100644 --- a/xla/service/gpu/fusions/triton/triton_support_test.cc +++ b/xla/service/gpu/fusions/triton/triton_support_test.cc @@ -522,14 +522,7 @@ ENTRY triton_computation { TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); - - bool skip_failure_branch_to_avoid_crash = - data_type == PrimitiveType::F8E4M3FN && - std::holds_alternative(cc) && - !std::get(cc).IsAtLeastHopper(); - - RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc, - skip_failure_branch_to_avoid_crash); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); } TEST_F(ReduceTest, IsTritonSupportedReductionWithMultidimensionalTile) { @@ -601,13 +594,7 @@ ENTRY triton_computation { TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); - bool skip_failure_branch_to_avoid_crash = - data_type == PrimitiveType::F8E4M3FN && - std::holds_alternative(cc) && - !std::get(cc).IsAtLeastHopper(); - - RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc, - skip_failure_branch_to_avoid_crash); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); } TEST_P(ReduceTest, @@ -728,18 +715,9 @@ ENTRY triton_computation { // TODO(b/361526623): Reduce the cases where setting // skip_failure_branch_to_avoid_crash is needed. bool skip_failure_branch_to_avoid_crash = - data_type == PrimitiveType::F8E4M3FN && - std::holds_alternative(cc) && - !std::get(cc).IsAtLeastHopper(); - - skip_failure_branch_to_avoid_crash |= - (data_type == S8 || data_type == S16 || data_type == S32 || - data_type == S64 || data_type == PrimitiveType::F16 || - data_type == PrimitiveType::BF16 || - data_type == PrimitiveType::F8E4M3FN || - data_type == PrimitiveType::F8E5M2) && - (opcode == HloOpcode::kRemainder || opcode == HloOpcode::kPower || - opcode == HloOpcode::kAtan2); + opcode == HloOpcode::kDivide && + (data_type == BF16 || data_type == F16 || data_type == F8E4M3FN || + data_type == F8E5M2); RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc, skip_failure_branch_to_avoid_crash); diff --git a/xla/service/gpu/fusions/triton/triton_test_utils.cc b/xla/service/gpu/fusions/triton/triton_test_utils.cc index 43b9d881fc08b..a42ac3d1b7ecc 100644 --- a/xla/service/gpu/fusions/triton/triton_test_utils.cc +++ b/xla/service/gpu/fusions/triton/triton_test_utils.cc @@ -189,6 +189,11 @@ std::string TritonSupportTestTwoTypesAndDeviceToString( "_", ComputeCapabilityToString(cc)); } +std::string TritonSupportTestTypeToString( + const ::testing::TestParamInfo& data) { + return primitive_util::LowercasePrimitiveTypeName(data.param); +} + namespace { // This function does nothing if the input module already has an entry diff --git a/xla/service/gpu/fusions/triton/triton_test_utils.h b/xla/service/gpu/fusions/triton/triton_test_utils.h index 39e8eb7af637e..0a7ec78bc0043 100644 --- a/xla/service/gpu/fusions/triton/triton_test_utils.h +++ b/xla/service/gpu/fusions/triton/triton_test_utils.h @@ -146,6 +146,9 @@ std::string TritonSupportTestTypeAndOpcodeAndDeviceToString( std::string TritonSupportTestTwoTypesAndDeviceToString( const ::testing::TestParamInfo>& data); + +std::string TritonSupportTestTypeToString( + const ::testing::TestParamInfo& data); } // namespace xla::gpu #endif // XLA_SERVICE_GPU_FUSIONS_TRITON_TRITON_TEST_UTILS_H_