diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc index 0735d65c6d387..a3cafcdf40628 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc @@ -2211,8 +2211,7 @@ ENTRY main { ; CHECK: ENTRY ; CHECK-DAG: %[[P0:.*]] = f32[125,127]{1,0} parameter(0) ; CHECK-DAG: %[[P1:.*]] = f32[10,125,127]{2,1,0} parameter(1) -; CHECK-DAG: %[[C0:.*]] = f32[] constant(0) -; CHECK: ROOT %[[FUSION:.*]] = f32[125,127]{1,0} fusion(%[[P0]], %[[P1]], %[[C0]]) +; CHECK: ROOT %[[FUSION:.*]] = f32[125,127]{1,0} fusion(%[[P0]], %[[P1]]) ; CHECK-SAME: kind=kCustom ; CHECK-SAME: __triton )"; diff --git a/xla/service/gpu/transforms/priority_fusion.cc b/xla/service/gpu/transforms/priority_fusion.cc index 9d55836861c64..f887161d869fd 100644 --- a/xla/service/gpu/transforms/priority_fusion.cc +++ b/xla/service/gpu/transforms/priority_fusion.cc @@ -1023,7 +1023,8 @@ absl::StatusOr PriorityFusion::Run( for (auto* constant : constants) { auto users = constant->users(); for (auto* user : users) { - if (IsFusible(*user) && CanEmitInputFusedScatter(*constant, *user)) { + if ((IsFusible(*user) || IsGenericTritonFusion(*user)) && + CanEmitInputFusedScatter(*constant, *user)) { Fuse(constant, user); changed = true; } diff --git a/xla/service/gpu/transforms/priority_fusion_test.cc b/xla/service/gpu/transforms/priority_fusion_test.cc index 58c23381b4d8d..5f49fa4fc5c67 100644 --- a/xla/service/gpu/transforms/priority_fusion_test.cc +++ b/xla/service/gpu/transforms/priority_fusion_test.cc @@ -818,6 +818,35 @@ TEST_F(PriorityFusionTest, FuseOnlySmallConstant) { m::Add(m::Parameter(), m::Broadcast(m::Constant()))))); } +TEST_F(PriorityFusionTest, FuseSmallConstantIntoTritonFusion) { + auto module = *ParseAndReturnVerifiedModule(R"( +HloModule module + +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +triton_computation { + param_0 = f32[32,64] parameter(0) + param_1 = f32[] parameter(1) + ROOT reduce = f32[32] reduce(param_0, param_1), dimensions={1}, to_apply=add +} + +ENTRY main { + param_0 = f32[32,64] parameter(0) + c_0 = f32[] constant(0) + ROOT triton_softmax = f32[32] fusion(param_0, c_0), kind=kCustom, calls=triton_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["1"],"num_warps":"1"}}} +})"); + EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true)); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, GmockMatch(m::Fusion(m::Parameter()))); + EXPECT_THAT(root->fused_expression_root(), + GmockMatch(m::Reduce(m::Parameter(), m::Constant()))); +} + TEST_F(PriorityFusionTest, DoNotFuseProducerConsumerMergedTooLarge) { auto module = *ParseAndReturnVerifiedModule(R"( HloModule module