Skip to content

Commit

Permalink
Add support for kConvolution to FusionWrapper.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 674413074
  • Loading branch information
klucke authored and Google-ML-Automation committed Sep 13, 2024
1 parent 32ebd69 commit 536ba0b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
1 change: 1 addition & 0 deletions xla/service/gpu/transforms/fusion_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ absl::StatusOr<bool> FusionWrapper::Run(
case HloOpcode::kCompare:
case HloOpcode::kComplex:
case HloOpcode::kConcatenate:
case HloOpcode::kConvolution:
case HloOpcode::kConvert:
case HloOpcode::kCopy:
case HloOpcode::kCos:
Expand Down
22 changes: 22 additions & 0 deletions xla/service/gpu/transforms/fusion_wrapper_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,28 @@ namespace {

class FusionWrapperTest : public HloTestBase {};

TEST_F(FusionWrapperTest, ConvolutionWorks) {
RunAndFilecheckHloRewrite(R"(HloModule TestModule
ENTRY TestComputation {
input = f32[1,10,1,10,5,20]{5,4,3,2,1,0} parameter(0)
kernel = f32[20,1,2,1,4,15]{5,4,3,2,1,0} parameter(1)
ROOT conv = f32[15,1,9,1,7,5]{5,4,3,2,1,0} convolution(input, kernel), dim_labels=0123bf_i0123o->f0123b, window={size=1x2x1x4}
})",
FusionWrapper(), R"(
// CHECK: %wrapped_convolution_computation (param_0: f32[1,10,1,10,5,20], param_1: f32[20,1,2,1,4,15]) -> f32[15,1,9,1,7,5] {
// CHECK: %param_0 = f32[1,10,1,10,5,20]{5,4,3,2,1,0} parameter(0)
// CHECK: %param_1 = f32[20,1,2,1,4,15]{5,4,3,2,1,0} parameter(1)
// CHECK: ROOT %conv.1 = f32[15,1,9,1,7,5]{5,4,3,2,1,0} convolution(%param_0, %param_1), window={size=1x2x1x4}, dim_labels=0123bf_i0123o->f0123b
// CHECK: }
// CHECK: ENTRY %TestComputation (input: f32[1,10,1,10,5,20], kernel: f32[20,1,2,1,4,15]) -> f32[15,1,9,1,7,5] {
// CHECK: %input = f32[1,10,1,10,5,20]{5,4,3,2,1,0} parameter(0)
// CHECK: %kernel = f32[20,1,2,1,4,15]{5,4,3,2,1,0} parameter(1)
// CHECK: ROOT %wrapped_convolution = f32[15,1,9,1,7,5]{5,4,3,2,1,0} fusion(%input, %kernel), kind=kLoop, calls=%wrapped_convolution_computation
// CHECK: })");
}

TEST_F(FusionWrapperTest, SimpleOp) {
RunAndFilecheckHloRewrite(R"(
HloModule TestModule
Expand Down

0 comments on commit 536ba0b

Please sign in to comment.