Skip to content

Commit

Permalink
[XLA:GPU] Allow Priority Fusion to fuse small constants into Triton f…
Browse files Browse the repository at this point in the history
…usions.

PiperOrigin-RevId: 675579950
  • Loading branch information
olegshyshkov authored and Google-ML-Automation committed Sep 17, 2024
1 parent 395005d commit 7b5e183
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2211,8 +2211,7 @@ ENTRY main {
; CHECK: ENTRY
; CHECK-DAG: %[[P0:.*]] = f32[125,127]{1,0} parameter(0)
; CHECK-DAG: %[[P1:.*]] = f32[10,125,127]{2,1,0} parameter(1)
; CHECK-DAG: %[[C0:.*]] = f32[] constant(0)
; CHECK: ROOT %[[FUSION:.*]] = f32[125,127]{1,0} fusion(%[[P0]], %[[P1]], %[[C0]])
; CHECK: ROOT %[[FUSION:.*]] = f32[125,127]{1,0} fusion(%[[P0]], %[[P1]])
; CHECK-SAME: kind=kCustom
; CHECK-SAME: __triton
)";
Expand Down
3 changes: 2 additions & 1 deletion xla/service/gpu/transforms/priority_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,8 @@ absl::StatusOr<bool> PriorityFusion::Run(
for (auto* constant : constants) {
auto users = constant->users();
for (auto* user : users) {
if (IsFusible(*user) && CanEmitInputFusedScatter(*constant, *user)) {
if ((IsFusible(*user) || IsGenericTritonFusion(*user)) &&
CanEmitInputFusedScatter(*constant, *user)) {
Fuse(constant, user);
changed = true;
}
Expand Down
29 changes: 29 additions & 0 deletions xla/service/gpu/transforms/priority_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,35 @@ TEST_F(PriorityFusionTest, FuseOnlySmallConstant) {
m::Add(m::Parameter(), m::Broadcast(m::Constant())))));
}

TEST_F(PriorityFusionTest, FuseSmallConstantIntoTritonFusion) {
auto module = *ParseAndReturnVerifiedModule(R"(
HloModule module
add {
Arg_0 = f32[] parameter(0)
Arg_1 = f32[] parameter(1)
ROOT add = f32[] add(Arg_0, Arg_1)
}
triton_computation {
param_0 = f32[32,64] parameter(0)
param_1 = f32[] parameter(1)
ROOT reduce = f32[32] reduce(param_0, param_1), dimensions={1}, to_apply=add
}
ENTRY main {
param_0 = f32[32,64] parameter(0)
c_0 = f32[] constant(0)
ROOT triton_softmax = f32[32] fusion(param_0, c_0), kind=kCustom, calls=triton_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["1"],"num_warps":"1"}}}
})");
EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true));

HloInstruction* root = module->entry_computation()->root_instruction();
ASSERT_THAT(root, GmockMatch(m::Fusion(m::Parameter())));
EXPECT_THAT(root->fused_expression_root(),
GmockMatch(m::Reduce(m::Parameter(), m::Constant())));
}

TEST_F(PriorityFusionTest, DoNotFuseProducerConsumerMergedTooLarge) {
auto module = *ParseAndReturnVerifiedModule(R"(
HloModule module
Expand Down

0 comments on commit 7b5e183

Please sign in to comment.