From 536ba0b7d74f6637a7a772471a99ecf4f578aef2 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Fri, 13 Sep 2024 13:24:51 -0700 Subject: [PATCH] Add support for kConvolution to FusionWrapper. PiperOrigin-RevId: 674413074 --- xla/service/gpu/transforms/fusion_wrapper.cc | 1 + .../gpu/transforms/fusion_wrapper_test.cc | 22 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/xla/service/gpu/transforms/fusion_wrapper.cc b/xla/service/gpu/transforms/fusion_wrapper.cc index d7f8505c420fb..16957f80d370e 100644 --- a/xla/service/gpu/transforms/fusion_wrapper.cc +++ b/xla/service/gpu/transforms/fusion_wrapper.cc @@ -60,6 +60,7 @@ absl::StatusOr FusionWrapper::Run( case HloOpcode::kCompare: case HloOpcode::kComplex: case HloOpcode::kConcatenate: + case HloOpcode::kConvolution: case HloOpcode::kConvert: case HloOpcode::kCopy: case HloOpcode::kCos: diff --git a/xla/service/gpu/transforms/fusion_wrapper_test.cc b/xla/service/gpu/transforms/fusion_wrapper_test.cc index a46338f93ea0a..be5f0d7dfd49c 100644 --- a/xla/service/gpu/transforms/fusion_wrapper_test.cc +++ b/xla/service/gpu/transforms/fusion_wrapper_test.cc @@ -25,6 +25,28 @@ namespace { class FusionWrapperTest : public HloTestBase {}; +TEST_F(FusionWrapperTest, ConvolutionWorks) { + RunAndFilecheckHloRewrite(R"(HloModule TestModule + +ENTRY TestComputation { + input = f32[1,10,1,10,5,20]{5,4,3,2,1,0} parameter(0) + kernel = f32[20,1,2,1,4,15]{5,4,3,2,1,0} parameter(1) + ROOT conv = f32[15,1,9,1,7,5]{5,4,3,2,1,0} convolution(input, kernel), dim_labels=0123bf_i0123o->f0123b, window={size=1x2x1x4} +})", + FusionWrapper(), R"( +// CHECK: %wrapped_convolution_computation (param_0: f32[1,10,1,10,5,20], param_1: f32[20,1,2,1,4,15]) -> f32[15,1,9,1,7,5] { +// CHECK: %param_0 = f32[1,10,1,10,5,20]{5,4,3,2,1,0} parameter(0) +// CHECK: %param_1 = f32[20,1,2,1,4,15]{5,4,3,2,1,0} parameter(1) +// CHECK: ROOT %conv.1 = f32[15,1,9,1,7,5]{5,4,3,2,1,0} convolution(%param_0, %param_1), window={size=1x2x1x4}, dim_labels=0123bf_i0123o->f0123b +// CHECK: } + +// CHECK: ENTRY %TestComputation (input: f32[1,10,1,10,5,20], kernel: f32[20,1,2,1,4,15]) -> f32[15,1,9,1,7,5] { +// CHECK: %input = f32[1,10,1,10,5,20]{5,4,3,2,1,0} parameter(0) +// CHECK: %kernel = f32[20,1,2,1,4,15]{5,4,3,2,1,0} parameter(1) +// CHECK: ROOT %wrapped_convolution = f32[15,1,9,1,7,5]{5,4,3,2,1,0} fusion(%input, %kernel), kind=kLoop, calls=%wrapped_convolution_computation +// CHECK: })"); +} + TEST_F(FusionWrapperTest, SimpleOp) { RunAndFilecheckHloRewrite(R"( HloModule TestModule