From 45dea7181c0ef3c604fc17de7888330bdb996853 Mon Sep 17 00:00:00 2001 From: Dirk Hornung Date: Thu, 19 Sep 2024 02:52:23 -0700 Subject: [PATCH] Allow CustomKernelFusionRewriter to manualy specify the kernel index. PiperOrigin-RevId: 676336179 --- .../gpu/kernels/cutlass_gemm_fusion_test.cc | 30 +++++++++---------- .../custom_kernel_fusion_rewriter.cc | 19 +++++++----- .../custom_kernel_fusion_rewriter.h | 3 +- .../custom_kernel_fusion_rewriter_test.cc | 24 +++++++++++++-- 4 files changed, 50 insertions(+), 26 deletions(-) diff --git a/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc b/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc index 51524048f476a..5327f234f96be 100644 --- a/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc +++ b/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc @@ -101,7 +101,7 @@ TEST_F(CutlassFusionTest, RowMajorGemm) { patterns.Emplace(); auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - CustomKernelFusionRewriter pass(&device, &patterns); + CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns); RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); } @@ -141,7 +141,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithUpcast) { patterns.Emplace(); auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - CustomKernelFusionRewriter pass(&device, &patterns); + CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns); RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); } @@ -183,7 +183,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithUpcastOfBothOperands) { patterns.Emplace(); auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - CustomKernelFusionRewriter pass(&device, &patterns); + CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns); RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); } @@ -209,7 +209,7 @@ TEST_F(CutlassFusionTest, DoNotPatternMatchNotImplementedKernelTypes) { ParseAndReturnVerifiedModule(hlo); auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - CustomKernelFusionRewriter pass(&device, &patterns); + CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns); ASSERT_FALSE(pass.Run(hlo_module.value().get()).value()); } @@ -260,7 +260,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSlice) { patterns.Emplace(); auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - CustomKernelFusionRewriter pass(&device, &patterns); + CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns); RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); } @@ -319,7 +319,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceMultipleUses) { patterns.Emplace(); auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - CustomKernelFusionRewriter pass(&device, &patterns); + CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns); RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); } @@ -365,7 +365,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceWithoutBitcast) { patterns.Emplace(); auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - CustomKernelFusionRewriter pass(&device, &patterns); + CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns); RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); } @@ -654,7 +654,7 @@ TEST_F(CutlassFusionTest, GemmWithUpcastShouldBeFused) { patterns.Emplace(); auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - CustomKernelFusionRewriter pass(&device, &patterns); + CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns); RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{1e-3, 1e-3})); } @@ -674,7 +674,7 @@ TEST_F(CutlassFusionTest, patterns.Emplace(); auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - CustomKernelFusionRewriter pass(&device, &patterns); + CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns); // Check that hlo is not rewritten after the pass, indicating that we don't // match the upcast pattern. RunAndFilecheckHloRewrite(hlo, std::move(pass), std::nullopt); @@ -695,7 +695,7 @@ TEST_F(CutlassFusionTest, patterns.Emplace(); auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - CustomKernelFusionRewriter pass(&device, &patterns); + CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns); // Check that hlo is not rewritten after the pass, indicating that we don't // match the upcast pattern. RunAndFilecheckHloRewrite(hlo, std::move(pass), std::nullopt); @@ -716,7 +716,7 @@ TEST_F(CutlassFusionTest, patterns.Emplace(); auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - CustomKernelFusionRewriter pass(&device, &patterns); + CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns); // Check that hlo is not rewritten after the pass, indicating that we don't // match the upcast pattern. RunAndFilecheckHloRewrite(hlo, std::move(pass), std::nullopt); @@ -736,7 +736,7 @@ TEST_F(CutlassFusionTest, CustomKernelFusionPatternRegistry patterns; patterns.Emplace(); auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - CustomKernelFusionRewriter pass(&device, &patterns); + CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns); // Check that hlo is not rewritten after the pass, indicating that we don't // match the upcast pattern. RunAndFilecheckHloRewrite(hlo, std::move(pass), std::nullopt); @@ -756,7 +756,7 @@ TEST_F(CutlassFusionTest, CustomKernelFusionPatternRegistry patterns; patterns.Emplace(); auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - CustomKernelFusionRewriter pass(&device, &patterns); + CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns); // Check that hlo is not rewritten after the pass, indicating that we don't // match the upcast pattern. RunAndFilecheckHloRewrite(hlo, std::move(pass), std::nullopt); @@ -776,7 +776,7 @@ TEST_F(CutlassFusionTest, GemmWithUpcastWithBatchDimensionShouldNotBeFused) { CustomKernelFusionPatternRegistry patterns; patterns.Emplace(); auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - CustomKernelFusionRewriter pass(&device, &patterns); + CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns); // Check that hlo is not rewritten after the pass, indicating that we don't // match the upcast pattern. RunAndFilecheckHloRewrite(hlo, std::move(pass), std::nullopt); @@ -795,7 +795,7 @@ TEST_F(CutlassFusionTest, GemmWithUpcastAndColumnMajorOperandsShouldBeFused) { CustomKernelFusionPatternRegistry patterns; patterns.Emplace(); auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - CustomKernelFusionRewriter pass(&device, &patterns); + CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns); std::string expected = "CHECK: cutlass_gemm"; RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{1e-3, 1e-3})); diff --git a/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc b/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc index af9591a16c67a..300d62a896b2f 100644 --- a/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc +++ b/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc @@ -42,9 +42,9 @@ limitations under the License. namespace xla::gpu { CustomKernelFusionRewriter::CustomKernelFusionRewriter( - const se::DeviceDescription* device, + const se::DeviceDescription* device, int kernel_index, const CustomKernelFusionPatternRegistry* patterns) - : device_(device), patterns_(patterns) {} + : device_(device), kernel_index_(kernel_index), patterns_(patterns) {} // Returns a set of instruction that have users outside of a matched pattern // and have a replacement that must be applied after building a new custom @@ -148,9 +148,11 @@ static absl::StatusOr CreateFusionBody( return module->AddComputationAndUnifyNamesAndIds(builder.Build(), false); } -static absl::StatusOr CreateFusionInstruction( +namespace { +absl::StatusOr CreateFusionInstruction( HloModule* module, const CustomKernelFusionPattern::Match& match, - absl::Span captures, HloComputation* body) { + absl::Span captures, HloComputation* body, + int kernel_index) { // We'll be replacing the root operation of a custom kernel fusion with a // fusion instruction calling fusion computation. HloInstruction* root = match.root(); @@ -168,7 +170,7 @@ static absl::StatusOr CreateFusionInstruction( *gpu_config.mutable_fusion_backend_config(); backend_config.set_kind("__custom_fusion"); *backend_config.mutable_custom_fusion_config() = match.config(); - backend_config.mutable_custom_fusion_config()->set_kernel_index(0); + backend_config.mutable_custom_fusion_config()->set_kernel_index(kernel_index); TF_RETURN_IF_ERROR(fusion->set_backend_config(std::move(gpu_config))); // If we don't have workspace we can return constructed fusion instruction. @@ -178,6 +180,7 @@ static absl::StatusOr CreateFusionInstruction( return parent->AddInstruction( HloInstruction::CreateGetTupleElement(fusion, 0)); } +} // namespace absl::StatusOr CustomKernelFusionRewriter::Run( HloModule* module, @@ -205,9 +208,9 @@ absl::StatusOr CustomKernelFusionRewriter::Run( TF_ASSIGN_OR_RETURN(HloComputation * fusion_body, CreateFusionBody(module, match, captures)); - TF_ASSIGN_OR_RETURN( - HloInstruction * fusion, - CreateFusionInstruction(module, match, captures, fusion_body)); + TF_ASSIGN_OR_RETURN(HloInstruction * fusion, + CreateFusionInstruction(module, match, captures, + fusion_body, kernel_index_)); VLOG(2) << "Added a fusion instruction: " << fusion->name() << " for custom kernel fusion " << match.config().name() diff --git a/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h b/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h index c2f59ff0493ad..5738bf5283a34 100644 --- a/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h +++ b/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h @@ -63,7 +63,7 @@ namespace xla::gpu { class CustomKernelFusionRewriter : public HloModulePass { public: explicit CustomKernelFusionRewriter( - const se::DeviceDescription* device, + const se::DeviceDescription* device, int kernel_index = 0, const CustomKernelFusionPatternRegistry* patterns = CustomKernelFusionPatternRegistry::Default()); @@ -78,6 +78,7 @@ class CustomKernelFusionRewriter : public HloModulePass { private: const se::DeviceDescription* device_; + const int kernel_index_; const CustomKernelFusionPatternRegistry* patterns_; }; diff --git a/xla/service/gpu/transforms/custom_kernel_fusion_rewriter_test.cc b/xla/service/gpu/transforms/custom_kernel_fusion_rewriter_test.cc index 235e9ded150bf..4d44d99066e9f 100644 --- a/xla/service/gpu/transforms/custom_kernel_fusion_rewriter_test.cc +++ b/xla/service/gpu/transforms/custom_kernel_fusion_rewriter_test.cc @@ -88,10 +88,30 @@ TEST_F(CustomKernelFusionRewriterTest, SimpleGemm) { patterns.Emplace(); auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - CustomKernelFusionRewriter pass(&device, &patterns); + CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns); RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); } +TEST_F(CustomKernelFusionRewriterTest, SetsKernelIndex) { + const char* hlo = R"( + HloModule test + + ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] { + %p0 = f16[15,19]{1,0} parameter(0) + %p1 = f16[19,17]{1,0} parameter(1) + ROOT %r = f16[15,17]{1,0} dot(%p0, %p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + + CustomKernelFusionPatternRegistry patterns; + patterns.Emplace(); + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + CustomKernelFusionRewriter pass(&device, /*kernel_index=*/1, &patterns); + RunAndFilecheckHloRewrite(hlo, std::move(pass), "CHECK: \"kernel_index\":1"); +} + TEST_F(CustomKernelFusionRewriterTest, SimpleGemmWithWorkspace) { const char* hlo = R"( HloModule test @@ -131,7 +151,7 @@ TEST_F(CustomKernelFusionRewriterTest, SimpleGemmWithWorkspace) { patterns.Emplace(1024); auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - CustomKernelFusionRewriter pass(&device, &patterns); + CustomKernelFusionRewriter pass(&device, /*kernel_index=*/0, &patterns); RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); }