Skip to content

Commit

Permalink
[XLA:GPU] Remove the forced cast to f32 when generating Triton redu…
Browse files Browse the repository at this point in the history
…ctions.

Triton can now handle reductions of types other than `f32`. Removing the cast makes a lot of the code simpler and also yields more "correct" numerics - in some cases this means less precise. I had to relax the error tolerance in a couple of `f16` tests because the calculations are now actually done in `f16` unlike the previous `f32`.

Simplifications enabled by this:
- No more casts in the generated code.
- Removed the need for the `is_within_reduction_computation` parameter in `triton_support.cc`.
- Removed a lot of cases that needed `skip_failure_branch_to_avoid_crash`.

PiperOrigin-RevId: 674022630
  • Loading branch information
dimitar-asenov authored and Google-ML-Automation committed Sep 12, 2024
1 parent bce64e5 commit 84ab33a
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 121 deletions.
1 change: 1 addition & 0 deletions xla/service/gpu/fusions/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ xla_test(
tags = ["no_mac"],
deps = [
":triton_support",
":triton_test_utils",
"//xla:comparison_util",
"//xla:error_spec",
"//xla:xla_data_proto_cc",
Expand Down
20 changes: 7 additions & 13 deletions xla/service/gpu/fusions/triton/triton_fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -787,16 +787,13 @@ absl::StatusOr<Value> EmitReduce(
input = b.create<ma::SelectOp>(mask, input, neutral);
}

// Triton actually only performs reductions on float32 inputs, and we must
// thus upcast/downcast our input if its data type is different.
input = Cast(b, input, b.getF32Type());

mt::ReduceOp reduction = b.create<mt::ReduceOp>(input, reduction_dimension);
{
TF_ASSIGN_OR_RETURN(Type result_ty,
TritonType(b, hlo_reduce.shape().element_type()));
mlir::Location loc = b.getLoc();
mlir::Block* reducer =
b.createBlock(&reduction->getRegion(0), {},
{b.getF32Type(), b.getF32Type()}, {loc, loc});
mlir::Block* reducer = b.createBlock(&reduction->getRegion(0), {},
{result_ty, result_ty}, {loc, loc});

HloComputation* reduction_computation = hlo_reduce.to_apply();

Expand Down Expand Up @@ -830,16 +827,13 @@ absl::StatusOr<Value> EmitReduce(

Value result = reduction.getResult().front();

// We want to return a tensor of float32, but the ReturnReduceOp produces an
// f32 constant when reducing a single dim. To convert to a tensor we splat
// the result.
// We want to return a tensor, but the ReturnReduceOp produces a raw scalar
// when reducing a single dim. To convert to a tensor we splat the result.
if (!mlir::dyn_cast<TensorValue>(reduction.getResult().front())) {
result = Splat(b, result, {});
}

TF_ASSIGN_OR_RETURN(Type result_ty,
TritonType(b, hlo_reduce.shape().element_type()));
return Cast(b, result, result_ty);
return result;
}

// Emit code corresponding to a fusion instruction somehow nested within the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/primitive_util.h"
#include "xla/service/gpu/fusions/triton/triton_support_legacy.h"
#include "xla/service/gpu/fusions/triton/triton_test_utils.h"
#include "xla/service/gpu/tests/gpu_codegen_test.h"
#include "xla/stream_executor/device_description.h"
#include "xla/xla.pb.h"
Expand Down Expand Up @@ -731,11 +732,9 @@ ENTRY e {
/*run_hlo_passes=*/false));
}

INSTANTIATE_TEST_SUITE_P(
ConstantTestSuite, ConstantTest, ::testing::ValuesIn(kSupportedDataTypes),
[](const ::testing::TestParamInfo<PrimitiveType> type) {
return primitive_util::LowercasePrimitiveTypeName(type.param);
});
INSTANTIATE_TEST_SUITE_P(ConstantTestSuite, ConstantTest,
::testing::ValuesIn(kSupportedDataTypes),
TritonSupportTestTypeToString);

class ConvertTest : public TritonTest,
public ::testing::WithParamInterface<
Expand Down Expand Up @@ -1161,8 +1160,6 @@ ENTRY main {
tolerance = 1e-6;
break;
case F16:
tolerance = 2e-4;
break;
case BF16:
tolerance = 2e-2;
break;
Expand Down Expand Up @@ -1685,8 +1682,6 @@ ENTRY main {
tolerance = 1e-6;
break;
case F16:
tolerance = 2e-4;
break;
case BF16:
tolerance = 2e-2;
break;
Expand Down Expand Up @@ -2228,6 +2223,47 @@ ENTRY main {
ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance)));
}

class ReductionTypeTest : public TritonTest,
public ::testing::WithParamInterface<PrimitiveType> {
};

TEST_P(ReductionTypeTest, DifferentReductionTypes) {
PrimitiveType data_type = GetParam();

const std::string kHloTestTemplate = R"(
max {
p0 = $0[] parameter(0)
p1 = $0[] parameter(1)
ROOT max = $0[] maximum(p0, p1)
}
triton_computation {
p = $0[400,16] parameter(0)
zero = $0[] constant(0)
ROOT reduce = $0[400] reduce(p, zero), dimensions={1}, to_apply=max
}
ENTRY entry_computation {
p = $0[400,16] parameter(0)
ROOT fusion = $0[400] fusion(p), kind=kCustom, calls=triton_computation,
backend_config={ "operation_queue_id":"0", "wait_on_operation_queues":[],
"fusion_backend_config":{ "kind":"__triton", "block_level_fusion_config":{
"output_tile_sizes":["400"], "num_warps":"1"}},
"force_earliest_schedule":false}
})";
const std::string hlo_test = absl::Substitute(
kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type));
EXPECT_TRUE(
RunAndCompareNoHloPasses(hlo_test, ErrorSpec{/*aabs=*/0, /*arel=*/0}));
}

constexpr std::array<PrimitiveType, 9> kReductionSupportedDataTypes{
PRED, S8, S16, S32, S64, F16, F32, F64, BF16};

INSTANTIATE_TEST_SUITE_P(ReductionTypeTestSuite, ReductionTypeTest,
::testing::ValuesIn(kReductionSupportedDataTypes),
TritonSupportTestTypeToString);

} // namespace
} // namespace gpu
} // namespace xla
102 changes: 30 additions & 72 deletions xla/service/gpu/fusions/triton/triton_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,57 +146,32 @@ CodegenDecision IsTritonSupportedConversion(
}

// Set of binary element-wise ops that are genuinely supported by Triton.
//
// Note that there is a difference between ops inside a reduction computation
// and "regular" ops. The reason is that ops inside a reduction computation
// operate on "unwrapped" values (e.g. scalars represented as f32 instead of
// tensor<f32>) and that codegen supports a different set of operations.
//
// In principle `is_within_reduction_computation` can be added also to the
// functions that check support for unary and ternary ops, but there was no need
// to do this so far.
absl::flat_hash_set<HloOpcode> TritonSupportedBinaryElementwiseOps(
PrimitiveType element_type, const se::GpuComputeCapability& gpu_version,
bool is_within_reduction_computation) {
absl::flat_hash_set<HloOpcode> ret;

if (!is_within_reduction_computation &&
(element_type == PrimitiveType::F8E5M2 ||
element_type == PrimitiveType::F8E4M3FN)) {
return ret;
PrimitiveType element_type, const se::GpuComputeCapability& gpu_version) {
if (element_type == PrimitiveType::U16 ||
element_type == PrimitiveType::F8E5M2 ||
element_type == PrimitiveType::F8E4M3FN) {
return {};
}

absl::flat_hash_set<HloOpcode> ret{HloOpcode::kAdd, HloOpcode::kCompare,
HloOpcode::kMaximum, HloOpcode::kMinimum,
HloOpcode::kMultiply};

if (element_type == PrimitiveType::PRED) {
ret.insert(HloOpcode::kCompare);
ret.insert(HloOpcode::kAdd);
ret.insert(HloOpcode::kMultiply);
ret.insert(HloOpcode::kMaximum);
ret.insert(HloOpcode::kMinimum);

if (!is_within_reduction_computation) {
ret.insert(HloOpcode::kAnd);
ret.insert(HloOpcode::kOr);
ret.insert(HloOpcode::kXor);
}
ret.insert(HloOpcode::kAnd);
ret.insert(HloOpcode::kOr);
ret.insert(HloOpcode::kXor);
return ret;
}

if (element_type != PrimitiveType::U16 || is_within_reduction_computation) {
ret.insert(HloOpcode::kAdd);
ret.insert(HloOpcode::kCompare);
ret.insert(HloOpcode::kSubtract);
ret.insert(HloOpcode::kMaximum);
ret.insert(HloOpcode::kMinimum);
ret.insert(HloOpcode::kMultiply);

if (primitive_util::IsIntegralType(element_type)) {
ret.insert(HloOpcode::kDivide);
if (!is_within_reduction_computation) {
ret.insert(HloOpcode::kAnd);
ret.insert(HloOpcode::kOr);
ret.insert(HloOpcode::kXor);
}
}
ret.insert(HloOpcode::kSubtract);

if (primitive_util::IsIntegralType(element_type)) {
ret.insert(HloOpcode::kDivide);
ret.insert(HloOpcode::kAnd);
ret.insert(HloOpcode::kOr);
ret.insert(HloOpcode::kXor);
}

if (element_type == PrimitiveType::F32 ||
Expand All @@ -207,11 +182,6 @@ absl::flat_hash_set<HloOpcode> TritonSupportedBinaryElementwiseOps(
ret.insert(HloOpcode::kPower);
}

if (is_within_reduction_computation &&
primitive_util::IsFloatingPointType(element_type)) {
ret.insert(HloOpcode::kDivide);
}

return ret;
}

Expand All @@ -237,40 +207,29 @@ absl::flat_hash_set<HloOpcode> TritonSupportedTernaryElementwiseOps(
// responsible for ensuring that the relevant data type is supported on the
// device of interest.
bool IsTritonSupportedElementwise(HloOpcode opcode, PrimitiveType element_type,
const se::GpuComputeCapability& gpu_version,
bool is_within_reduction_computation) {
const se::GpuComputeCapability& gpu_version) {
return TritonSupportedUnaryElementwiseOps(element_type).contains(opcode) ||
TritonSupportedBinaryElementwiseOps(element_type, gpu_version,
is_within_reduction_computation)
TritonSupportedBinaryElementwiseOps(element_type, gpu_version)
.contains(opcode) ||
TritonSupportedTernaryElementwiseOps(element_type, gpu_version)
.contains(opcode);
}

CodegenDecision IsTritonSupportedInstructionImpl(
const HloInstruction& instr, const se::GpuComputeCapability& gpu_version,
bool is_within_reduction_computation);
const HloInstruction& instr, const se::GpuComputeCapability& gpu_version);

// Filters Reduces which can be handled using Triton.
CodegenDecision CanTritonHandleReduce(
const HloReduceInstruction& reduce,
const se::GpuComputeCapability& gpu_version) {
// The reduction has already passed the generic input/output type checks.
// Now we just need to check specific restrictions for reduce.
if (reduce.shape().element_type() == PrimitiveType::F8E4M3FN) {
if (auto cc = std::get_if<se::CudaComputeCapability>(&gpu_version)) {
if (!cc->IsAtLeastHopper()) {
return "F8E4M3FN is not supported before Hopper.";
}
}
if (reduce.shape().element_type() == PrimitiveType::F8E4M3FN ||
reduce.shape().element_type() == PrimitiveType::F8E5M2) {
return "F8E4M3FN and F8E5M2 are not supported for reductions.";
}

bool is_triton_supported_reduction_computation = absl::c_all_of(
reduce.to_apply()->instructions(), [&](const HloInstruction* instr) {
return IsTritonSupportedInstructionImpl(
*instr, gpu_version,
/*is_within_reduction_computation=*/true)
.CanFuse();
return IsTritonSupportedInstructionImpl(*instr, gpu_version).CanFuse();
});
if (!is_triton_supported_reduction_computation) {
return "Unsupported reduction computation by Triton.";
Expand All @@ -283,8 +242,7 @@ CodegenDecision CanTritonHandleReduce(
}

CodegenDecision IsTritonSupportedInstructionImpl(
const HloInstruction& instr, const se::GpuComputeCapability& gpu_version,
bool is_within_reduction_computation) {
const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) {
if (internal::IsTritonUnsupportedOpcode(instr.opcode())) {
return "Unsupported opcode.";
}
Expand Down Expand Up @@ -330,7 +288,7 @@ CodegenDecision IsTritonSupportedInstructionImpl(
// and `select` which have a fixed PRED type in the output and first
// operand.
instr.operand(instr.operand_count() - 1)->shape().element_type(),
gpu_version, is_within_reduction_computation)) {
gpu_version)) {
return "Unsupported elementwise operation.";
}
return CodegenDecision{};
Expand Down Expand Up @@ -445,8 +403,8 @@ absl::Status EnsureTritonSupportsComputeCapability(

CodegenDecision IsTritonSupportedInstruction(
const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) {
CodegenDecision decision = IsTritonSupportedInstructionImpl(
instr, gpu_version, /*is_within_reduction_computation=*/false);
CodegenDecision decision =
IsTritonSupportedInstructionImpl(instr, gpu_version);
VLOG(2) << "IsTritonSupportedInstruction: " << instr.ToString() << " "
<< bool(decision);
return decision;
Expand Down
32 changes: 5 additions & 27 deletions xla/service/gpu/fusions/triton/triton_support_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -522,14 +522,7 @@ ENTRY triton_computation {
TF_ASSERT_OK_AND_ASSIGN(
TestedInstruction ti,
ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode));

bool skip_failure_branch_to_avoid_crash =
data_type == PrimitiveType::F8E4M3FN &&
std::holds_alternative<se::CudaComputeCapability>(cc) &&
!std::get<se::CudaComputeCapability>(cc).IsAtLeastHopper();

RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc,
skip_failure_branch_to_avoid_crash);
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc);
}

TEST_F(ReduceTest, IsTritonSupportedReductionWithMultidimensionalTile) {
Expand Down Expand Up @@ -601,13 +594,7 @@ ENTRY triton_computation {
TestedInstruction ti,
ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode));

bool skip_failure_branch_to_avoid_crash =
data_type == PrimitiveType::F8E4M3FN &&
std::holds_alternative<se::CudaComputeCapability>(cc) &&
!std::get<se::CudaComputeCapability>(cc).IsAtLeastHopper();

RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc,
skip_failure_branch_to_avoid_crash);
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc);
}

TEST_P(ReduceTest,
Expand Down Expand Up @@ -728,18 +715,9 @@ ENTRY triton_computation {
// TODO(b/361526623): Reduce the cases where setting
// skip_failure_branch_to_avoid_crash is needed.
bool skip_failure_branch_to_avoid_crash =
data_type == PrimitiveType::F8E4M3FN &&
std::holds_alternative<se::CudaComputeCapability>(cc) &&
!std::get<se::CudaComputeCapability>(cc).IsAtLeastHopper();

skip_failure_branch_to_avoid_crash |=
(data_type == S8 || data_type == S16 || data_type == S32 ||
data_type == S64 || data_type == PrimitiveType::F16 ||
data_type == PrimitiveType::BF16 ||
data_type == PrimitiveType::F8E4M3FN ||
data_type == PrimitiveType::F8E5M2) &&
(opcode == HloOpcode::kRemainder || opcode == HloOpcode::kPower ||
opcode == HloOpcode::kAtan2);
opcode == HloOpcode::kDivide &&
(data_type == BF16 || data_type == F16 || data_type == F8E4M3FN ||
data_type == F8E5M2);

RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc,
skip_failure_branch_to_avoid_crash);
Expand Down
5 changes: 5 additions & 0 deletions xla/service/gpu/fusions/triton/triton_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ std::string TritonSupportTestTwoTypesAndDeviceToString(
"_", ComputeCapabilityToString(cc));
}

std::string TritonSupportTestTypeToString(
const ::testing::TestParamInfo<PrimitiveType>& data) {
return primitive_util::LowercasePrimitiveTypeName(data.param);
}

namespace {

// This function does nothing if the input module already has an entry
Expand Down
3 changes: 3 additions & 0 deletions xla/service/gpu/fusions/triton/triton_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ std::string TritonSupportTestTypeAndOpcodeAndDeviceToString(
std::string TritonSupportTestTwoTypesAndDeviceToString(
const ::testing::TestParamInfo<std::tuple<PrimitiveType, PrimitiveType,
se::GpuComputeCapability>>& data);

std::string TritonSupportTestTypeToString(
const ::testing::TestParamInfo<PrimitiveType>& data);
} // namespace xla::gpu

#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_TRITON_TEST_UTILS_H_

0 comments on commit 84ab33a

Please sign in to comment.