From 67325a0f76c0e989c2648dc57b13fbf2884f995b Mon Sep 17 00:00:00 2001 From: Dirk Hornung Date: Fri, 13 Sep 2024 06:00:16 -0700 Subject: [PATCH] Add custom kernel fusion to gemm fusion autotuner. The GemmFusionAutotuner currently takes a fusion and compares its runtime on different backends (Triton, CuBLAS and CuDNN). We add CustomKernelFusions (mostly Cutlass kernels) to the autotuner. PiperOrigin-RevId: 674266098 --- xla/autotuning.proto | 7 +- xla/service/gpu/BUILD | 1 + xla/service/gpu/autotuning/BUILD | 9 +- .../custom_kernel_fusion_autotuner.cc | 18 +- .../gpu/autotuning/gemm_fusion_autotuner.cc | 247 ++++++++++++------ .../gpu/autotuning/gemm_fusion_autotuner.h | 20 +- .../autotuning/gemm_fusion_autotuner_test.cc | 130 ++++++++- xla/service/gpu/gpu_compiler.cc | 11 +- xla/service/gpu/gpu_compiler_test.cc | 9 +- .../custom_kernel_fusion_rewriter.cc | 7 + 10 files changed, 362 insertions(+), 97 deletions(-) diff --git a/xla/autotuning.proto b/xla/autotuning.proto index a7ffcbb57ae6e..4cadf6dbb250e 100644 --- a/xla/autotuning.proto +++ b/xla/autotuning.proto @@ -83,6 +83,10 @@ message AutotuneResult { int64 num_ctas = 7; } + message CustomKernelFusionKey { + int64 kernel_index = 1; + } + int64 scratch_bytes = 8; google.protobuf.Duration run_time = 9; @@ -93,10 +97,11 @@ message AutotuneResult { GemmKey gemm = 6; TritonGemmKey triton = 17; CudaConvPlanKey cuda_conv_plan = 15; + CustomKernelFusionKey custom_kernel_fusion = 18; stream_executor.dnn.AlgorithmProto algorithm = 16; } - // Next ID: 17 + // Next ID: 19 } message AutotuningLog { diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 151417710d85c..fbe509210668e 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -1627,6 +1627,7 @@ xla_test( "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", + "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", diff --git a/xla/service/gpu/autotuning/BUILD b/xla/service/gpu/autotuning/BUILD index 487da266d968e..0f287c6ec3095 100644 --- a/xla/service/gpu/autotuning/BUILD +++ b/xla/service/gpu/autotuning/BUILD @@ -58,12 +58,15 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:buffer_comparator", "//xla/service/gpu:gpu_float_support", - "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:split_k_gemm_rewriter", "//xla/service/gpu:stream_executor_util", + "//xla/service/gpu/kernels:custom_kernel", + "//xla/service/gpu/kernels:custom_kernel_fusion", + "//xla/service/gpu/kernels:custom_kernel_fusion_pattern", "//xla/service/gpu/transforms:cudnn_fusion_compiler", + "//xla/service/gpu/transforms:custom_kernel_fusion_rewriter", "//xla/service/gpu/transforms:fusion_wrapper", "//xla/service/gpu/transforms:gemm_rewriter", "//xla/service/gpu/transforms:priority_fusion", @@ -72,11 +75,9 @@ cc_library( "//xla/stream_executor:device_memory", "//xla/stream_executor:semantic_version", "//xla/stream_executor:stream_executor_memory_allocator", - "//xla/stream_executor/gpu:redzone_allocator", "//xla/tools:hlo_decomposer_lib", "//xla/tsl/lib/core:bits", "//xla/tsl/util/proto:proto_utils", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -137,6 +138,8 @@ xla_test( "//xla/stream_executor:device_description", "//xla/stream_executor:device_description_proto_cc", "//xla/stream_executor:semantic_version", + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor/gpu:gpu_executor_header", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:test_utils", diff --git a/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc b/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc index 164252eb83312..4a02c599cb987 100644 --- a/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc +++ b/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc @@ -192,7 +192,7 @@ absl::StatusOr AutotuneCustomKernelFusion( return previous_kernel_index != fastest_kernel_index; } -bool IsCustomFusion(const HloComputation* computation) { +bool IsCutlassCustomFusion(const HloComputation* computation) { if (!computation->IsFusionComputation()) { return false; } @@ -212,8 +212,18 @@ bool IsCustomFusion(const HloComputation* computation) { return false; } - return gpu_backend_config->fusion_backend_config().kind() == - kCustomFusionKind; + if (gpu_backend_config->fusion_backend_config().kind() != kCustomFusionKind) { + return false; + } + + if (gpu_backend_config->fusion_backend_config() + .custom_fusion_config() + .name() + .rfind("cutlass", 0) != 0) { + return false; + } + + return true; } } // namespace @@ -231,7 +241,7 @@ absl::StatusOr CustomKernelFusionAutotuner::Run( bool hlo_changed = false; for (const HloComputation* computation : module->computations()) { - if (IsCustomFusion(computation)) { + if (IsCutlassCustomFusion(computation)) { TF_ASSIGN_OR_RETURN( bool instruction_changed, AutotuneCustomKernelFusion(computation->FusionInstruction(), config_, diff --git a/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index c12c45a4abfd8..a0d98975bf440 100644 --- a/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -27,7 +26,6 @@ limitations under the License. #include #include -#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" @@ -56,19 +54,21 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/service/algorithm_util.h" #include "xla/service/dump.h" -#include "xla/service/executable.h" #include "xla/service/float_normalization.h" #include "xla/service/gpu/autotuning/autotuner_compile_util.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/buffer_comparator.h" #include "xla/service/gpu/gpu_float_support.h" -#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/kernels/custom_kernel_fusion.h" +#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/split_k_gemm_rewriter.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/transforms/cudnn_fusion_compiler.h" +#include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h" #include "xla/service/gpu/transforms/fusion_wrapper.h" #include "xla/service/gpu/transforms/gemm_rewriter.h" #include "xla/service/gpu/transforms/priority_fusion.h" @@ -82,7 +82,6 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" -#include "xla/stream_executor/gpu/redzone_allocator.h" #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" @@ -140,76 +139,6 @@ constexpr std::array kNumCtas = {1, 2, 4, 8, 16}; using AutoTuneCacheKeyCount = absl::flat_hash_map; -class GemmFusionAutotunerVisitor : public DfsHloRewriteVisitor { - public: - explicit GemmFusionAutotunerVisitor(const AutotuneConfig& config) - : config_(config) {} - - absl::Status HandleFusion(HloInstruction* hlo) override { - TF_ASSIGN_OR_RETURN(auto gpu_config, - hlo->backend_config()); - FusionBackendConfig& backend_config = - *gpu_config.mutable_fusion_backend_config(); - if (backend_config.kind() != kTritonGemmFusionKind && - backend_config.kind() != kCuDnnFusionKind) { - return absl::OkStatus(); - } - - VLOG(4) << "Processing " << hlo->ToString(); - if (!backend_config.has_triton_gemm_config() && - !backend_config.has_cudnn_fusion_config()) { - TF_ASSIGN_OR_RETURN( - AutotuneResult autotune_result, - AutotunerUtil::Autotune( - hlo, config_, [&]() -> absl::StatusOr { - if (config_.IsDeviceless()) { - return absl::InternalError(absl::StrCat( - "Expect autotune result cache hit for deviceless " - "compilation (HLO: ", - hlo->ToString(), ")")); - } - return absl::InternalError("Expect autotune result cache hit."); - })); - VLOG(4) << "Result: " << autotune_result.ShortDebugString(); - - if (autotune_result.has_triton()) { - *backend_config.mutable_triton_gemm_config() = autotune_result.triton(); - TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); - } else if (autotune_result.has_gemm()) { - // Falling back to cuBLAS: Converting the fusion to a Call, so that it - // can be inlined back again. - HloComputation* const computation = hlo->parent(); - HloInstruction* const call = computation->AddInstruction( - HloInstruction::CreateCall(hlo->shape(), hlo->operands(), - hlo->fused_instructions_computation())); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, call)); - hlo = call; - } else { - CHECK(autotune_result.has_algorithm()); - backend_config.set_kind(std::string(kCuDnnFusionKind)); - backend_config.mutable_cudnn_fusion_config()->set_plan_id( - autotune_result.algorithm().algo_id()); - TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); - } - } - - if (backend_config.has_triton_gemm_config()) { - TF_ASSIGN_OR_RETURN( - const TritonGemmConfig config, - TritonGemmConfig::FromProto(backend_config.triton_gemm_config())); - if (config.split_k > 1) { - TF_RETURN_IF_ERROR(MakeDotSplitKBatch(hlo, config)); - } - } - - MarkAsChanged(); - return absl::OkStatus(); - } - - private: - AutotuneConfig config_; -}; - class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault { public: explicit GemmConfigSetCollector(GemmFusionAutotunerImpl* impl) @@ -259,7 +188,9 @@ class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault { bool missing_config = (backend_config.kind() == kTritonGemmFusionKind && !backend_config.has_triton_gemm_config()) || (backend_config.kind() == kCuDnnFusionKind && - !backend_config.has_cudnn_fusion_config()); + !backend_config.has_cudnn_fusion_config()) || + (backend_config.kind() == kCustomFusionKind && + !backend_config.has_custom_fusion_config()); if (missing_config) { if (error_out_on_cache_miss_) { return absl::NotFoundError(absl::StrCat( @@ -426,6 +357,46 @@ absl::StatusOr> CublasGemmAutotuneExtractor( return new_module; } +absl::Status UpdateFusionInstructionKernelIndex( + HloInstruction* fusion_instruction, int kernel_index) { + GpuBackendConfig gpu_config = + fusion_instruction->backend_config().value(); + gpu_config.mutable_fusion_backend_config() + ->mutable_custom_fusion_config() + ->set_kernel_index(kernel_index); + TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(gpu_config)); + + return absl::OkStatus(); +} + +absl::StatusOr> CutlassGemmAutotuneExtractor( + const GemmFusionAutotunerImpl::CustomKernelFusionConfig& cutlass_config, + const AutotuneConfig& config, const se::SemanticVersion& toolkit_version, + const HloFusionInstruction* fusion, const DebugOptions& debug_opts) { + const HloComputation* fusion_computation = fusion->called_computation(); + std::unique_ptr new_module = + ExtractComputationIntoNewModule(*fusion_computation); + new_module->mutable_config().set_debug_options(debug_opts); + + CustomKernelFusionRewriter rewriter( + &config.GetExecutor()->GetDeviceDescription()); + PriorityFusion fusion_pass( + /*thread_pool=*/nullptr, config.GetExecutor()->GetDeviceDescription(), + PriorityFusionOptions()); + TF_RETURN_IF_ERROR(rewriter.Run(new_module.get()).status()); + TF_RETURN_IF_ERROR(fusion_pass.Run(new_module.get()).status()); + + // Select custom kernel fusion kernel. + HloInstruction* custom_kernel_fusion = + hlo_query::GetFirstInstructionWithOpcode(*new_module->entry_computation(), + HloOpcode::kFusion); + int64_t kernel_index = cutlass_config.kernel_index; + TF_RETURN_IF_ERROR( + UpdateFusionInstructionKernelIndex(custom_kernel_fusion, kernel_index)); + + return new_module; +} + absl::StatusOr> FusionExtractor( const HloFusionInstruction& fusion, const DebugOptions& debug_opts) { std::unique_ptr module = ExtractInstructionIntoNewModule(fusion); @@ -474,6 +445,11 @@ AutotuneResult FromConfig(const BackendConfig& config) { AutotuneResult res; if (std::holds_alternative(config)) { res.mutable_gemm()->set_algorithm(CUBLAS_GEMM_DEFAULT); + } else if (std::holds_alternative< + GemmFusionAutotunerImpl::CustomKernelFusionConfig>(config)) { + res.mutable_custom_kernel_fusion()->set_kernel_index( + std::get(config) + .kernel_index); } else if (std::holds_alternative( config)) { res.mutable_algorithm()->set_algo_id( @@ -573,6 +549,75 @@ std::string Serialize(const BackendConfig& config) { } // anonymous namespace +absl::Status GemmFusionAutotunerVisitor::HandleFusion(HloInstruction* hlo) { + TF_ASSIGN_OR_RETURN(auto gpu_config, hlo->backend_config()); + FusionBackendConfig& backend_config = + *gpu_config.mutable_fusion_backend_config(); + if (backend_config.kind() != kTritonGemmFusionKind && + backend_config.kind() != kCuDnnFusionKind && + backend_config.kind() != kCustomFusionKind) { + return absl::OkStatus(); + } + + VLOG(4) << "Processing " << hlo->ToString(); + if (!backend_config.has_triton_gemm_config() && + !backend_config.has_cudnn_fusion_config() && + !backend_config.has_custom_fusion_config()) { + TF_ASSIGN_OR_RETURN( + AutotuneResult autotune_result, + AutotunerUtil::Autotune( + hlo, config_, [&]() -> absl::StatusOr { + if (config_.IsDeviceless()) { + return absl::InternalError(absl::StrCat( + "Expect autotune result cache hit for deviceless " + "compilation (HLO: ", + hlo->ToString(), ")")); + } + return absl::InternalError("Expect autotune result cache hit."); + })); + VLOG(4) << "Result: " << autotune_result.ShortDebugString(); + + if (autotune_result.has_triton()) { + *backend_config.mutable_triton_gemm_config() = autotune_result.triton(); + TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); + } else if (autotune_result.has_gemm()) { + // Falling back to cuBLAS: Converting the fusion to a Call, so that it + // can be inlined back again. + HloComputation* const computation = hlo->parent(); + HloInstruction* const call = computation->AddInstruction( + HloInstruction::CreateCall(hlo->shape(), hlo->operands(), + hlo->fused_instructions_computation())); + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, call)); + hlo = call; + } else if (autotune_result.has_custom_kernel_fusion()) { + HloComputation* const computation = hlo->parent(); + HloInstruction* const call = computation->AddInstruction( + HloInstruction::CreateCall(hlo->shape(), hlo->operands(), + hlo->fused_instructions_computation())); + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, call)); + hlo = call; + } else { + CHECK(autotune_result.has_algorithm()); + backend_config.set_kind(std::string(kCuDnnFusionKind)); + backend_config.mutable_cudnn_fusion_config()->set_plan_id( + autotune_result.algorithm().algo_id()); + TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); + } + } + + if (backend_config.has_triton_gemm_config()) { + TF_ASSIGN_OR_RETURN( + const TritonGemmConfig config, + TritonGemmConfig::FromProto(backend_config.triton_gemm_config())); + if (config.split_k > 1) { + TF_RETURN_IF_ERROR(MakeDotSplitKBatch(hlo, config)); + } + } + + MarkAsChanged(); + return absl::OkStatus(); +} + // Methods required for sorting the configs. bool GemmFusionAutotunerImpl::CuBlasConfig::operator<( const CuBlasConfig& other) const { @@ -582,6 +627,10 @@ bool GemmFusionAutotunerImpl::CuDnnConfig::operator<( const CuDnnConfig& other) const { return plan_id < other.plan_id; } +bool GemmFusionAutotunerImpl::CustomKernelFusionConfig::operator<( + const CustomKernelFusionConfig& other) const { + return false; +} bool GemmFusionAutotunerImpl::IsAutotuningEnabled() const { return debug_options_.xla_gpu_autotune_level() > 0 && @@ -639,6 +688,43 @@ GemmFusionAutotunerImpl::GenerateConfigs(const HloFusionInstruction& fusion) { return configs; } + // Add CustomKernelFusion (Cutlass) configs, if available. + // Go through all the instructions in the fusion body try to match them to + // a custom kernel fusion pattern. + if ((IsFusionKind(fusion, kCustomFusionKind) || + IsFusionKind(fusion, kTritonGemmFusionKind)) && + IsAutotuningEnabled() && !config_.IsDeviceless()) { + const CustomKernelFusionPatternRegistry* patterns = + CustomKernelFusionPatternRegistry::Default(); + HloComputation* computation = fusion.called_computation(); + // Get the first dot instruction in the fusion body. + HloInstruction* dot_instruction = + hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); + std::vector match = patterns->Match( + config_.GetExecutor()->GetDeviceDescription(), dot_instruction); + + // For Cutlass we expect only one match for a gemm fusion. + if (match.size() == 1) { + CustomKernelFusionRegistry* registry = + CustomKernelFusionRegistry::Default(); + auto* custom_kernel_fusion = registry->Lookup(match[0].config().name()); + + // If custom fusion is not found it means that some of the build targets + // might not be statically linked into the binary. + if (custom_kernel_fusion != nullptr) { + // Load custom kernels that can implement a fusion computation. + TF_ASSIGN_OR_RETURN(std::vector kernels, + custom_kernel_fusion->LoadKernels( + config_.GetExecutor()->GetDeviceDescription(), + fusion.fused_instructions_computation())); + for (int i = 0; i < kernels.size(); ++i) { + CustomKernelFusionConfig config{/*kernel_index=*/i}; + configs.push_back(config); + } + } + } + } + // Add triton configs. TF_ASSIGN_OR_RETURN(std::vector triton_configs, GenerateTritonConfigs(*dot)); @@ -802,6 +888,15 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, config_, config_.GetExecutor()->GetDeviceDescription(), toolkit_version_, fusion, opts); })); + } else if (std::holds_alternative(config)) { + CustomKernelFusionConfig cutlass_config = + std::get(config); + TF_ASSIGN_OR_RETURN( + executable, compile_util.Compile([&](const DebugOptions& opts) { + return CutlassGemmAutotuneExtractor(cutlass_config, config_, + toolkit_version_, fusion, opts); + })); + } else { LOG(FATAL) << "Unsupported config type: " << config.index(); } diff --git a/xla/service/gpu/autotuning/gemm_fusion_autotuner.h b/xla/service/gpu/autotuning/gemm_fusion_autotuner.h index 791535062ce3a..7d852b57b1386 100644 --- a/xla/service/gpu/autotuning/gemm_fusion_autotuner.h +++ b/xla/service/gpu/autotuning/gemm_fusion_autotuner.h @@ -29,7 +29,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/autotuning.pb.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/pjrt/distributed/key_value_store_interface.h" @@ -46,6 +48,17 @@ limitations under the License. namespace xla { namespace gpu { +class GemmFusionAutotunerVisitor : public DfsHloRewriteVisitor { + public: + explicit GemmFusionAutotunerVisitor(const AutotuneConfig& config) + : config_(config) {} + + absl::Status HandleFusion(HloInstruction* hlo) override; + + private: + AutotuneConfig config_; +}; + // Takes a gemm fusion and chooses between cuBLAS, cuDNN, and Triton backends. // In the case of Triton, it also chooses the best tiling configuration. // @@ -99,8 +112,13 @@ class GemmFusionAutotunerImpl { int64_t plan_id; bool operator<(const CuDnnConfig& other) const; }; + struct CustomKernelFusionConfig { + int64_t kernel_index; + bool operator<(const CustomKernelFusionConfig& other) const; + }; using BackendConfig = - std::variant; + std::variant; using BackendConfigs = std::vector< std::pair>>; diff --git a/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index 3fec7e03e7125..29f2c2133e9b4 100644 --- a/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -50,7 +50,9 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" +#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/semantic_version.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" @@ -195,6 +197,25 @@ class GemmFusionAutotunerTest : public StatelessAutotunerTest { .cuda_compute_capability(); } + absl::StatusOr> + GetPossibleMatmulAutotuneConfigs( + const HloFusionInstruction& fusion, + const se::CudaComputeCapability& compute_capability, + const se::SemanticVersion& toolkit_version, + const DebugOptions& debug_options) { + se::GpuDeviceInfoProto deviceless_proto; + auto ccc = deviceless_proto.mutable_cuda_compute_capability(); + ccc->set_major(compute_capability.major); + ccc->set_minor(compute_capability.minor); + + DeviceConfig test_config{backend().default_stream_executor(), + backend().memory_allocator()}; + AutotuneConfig autotune_config{test_config, debug_options}; + GemmFusionAutotunerImpl autotuner(autotune_config, toolkit_version, + debug_options, nullptr); + return autotuner.GenerateConfigs(fusion); + } + void CheckTritonAutotuning(absl::string_view hlo, absl::string_view expected) { HloPassPipeline pipeline("gemm_rewrite"); @@ -247,7 +268,8 @@ class GemmFusionAutotunerTestWithMorePreciseReduction } }; -absl::StatusOr> GetPossibleMatmulAutotuneConfigs( +absl::StatusOr> +GetPossibleMatmulAutotuneTritonConfigs( const HloDotInstruction& dot, const se::CudaComputeCapability& compute_capability, const se::SemanticVersion& toolkit_version, @@ -276,7 +298,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -298,7 +320,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -320,7 +342,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -874,7 +896,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -906,7 +928,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -937,7 +959,7 @@ ENTRY wais { TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), debug_options)); @@ -1001,6 +1023,100 @@ ENTRY entry { CHECK_OK(autotuner.CompileAll(*compile_util, configs)); } +TEST_F(GemmFusionAutotunerTest, CreatesCustomKernelFusionConfigs) { + const std::string kHlo = R"( + HloModule module, entry_computation_layout={(bf16[1024,1024]{1,0}, bf16[1024,1024]{1,0})->f32[1024,1024]{1,0}} + + %gemm_fusion_r_computation { + %parameter_0 = bf16[1024,1024]{1,0} parameter(0) + %convert.2 = f32[1024,1024]{1,0} convert(%parameter_0) + %parameter_1 = bf16[1024,1024]{1,0} parameter(1) + %convert.3 = f32[1024,1024]{1,0} convert(%parameter_1) + ROOT %r.1 = f32[1024,1024]{1,0} dot(%convert.2, %convert.3), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY main { + %p0 = bf16[1024,1024]{1,0} parameter(0) + %p1 = bf16[1024,1024]{1,0} parameter(1) + ROOT %gemm_fusion_r = f32[1024,1024]{1,0} fusion(%p0, %p1), kind=kCustom, calls=gemm_fusion_r_computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false} + })"; + + std::unique_ptr module = + ParseAndReturnVerifiedModule(kHlo).value(); + const se::CudaComputeCapability compute_capability{ + se::CudaComputeCapability::AMPERE, /*minor=*/0}; + + TF_ASSERT_OK_AND_ASSIGN( + const std::vector configs, + GetPossibleMatmulAutotuneConfigs( + *Cast( + module->entry_computation()->root_instruction()), + compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); + EXPECT_TRUE(std::any_of( + configs.begin(), configs.end(), + [](const GemmFusionAutotunerImpl::BackendConfig& config) { + return std::holds_alternative< + GemmFusionAutotunerImpl::CustomKernelFusionConfig>(config); + })); +} + +TEST_F(GemmFusionAutotunerTest, + InlinesTritonFusionIfCustomKernelFusionIsMorePerformant) { + const std::string kHlo = R"( + HloModule module, entry_computation_layout={(bf16[1024,1024]{1,0}, bf16[1024,1024]{1,0})->f32[1024,1024]{1,0}} + + %gemm_fusion_r_computation { + %parameter_0 = bf16[1024,1024]{1,0} parameter(0) + %convert.2 = f32[1024,1024]{1,0} convert(%parameter_0) + %parameter_1 = bf16[1024,1024]{1,0} parameter(1) + %convert.3 = f32[1024,1024]{1,0} convert(%parameter_1) + ROOT %r.1 = f32[1024,1024]{1,0} dot(%convert.2, %convert.3), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY main { + %p0 = bf16[1024,1024]{1,0} parameter(0) + %p1 = bf16[1024,1024]{1,0} parameter(1) + ROOT %gemm_fusion_r = f32[1024,1024]{1,0} fusion(%p0, %p1), kind=kCustom, calls=gemm_fusion_r_computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false} + } +)"; + + std::unique_ptr module = + ParseAndReturnVerifiedModule(kHlo).value(); + + DebugOptions opts; + AutotuneConfig autotune_config{ + DeviceConfig{backend().default_stream_executor(), + backend().memory_allocator()}, + opts}; + AutotuneCacheKey cache_key(autotune_config.GetModelStr(), + *module->entry_computation()->root_instruction()); + TF_ASSERT_OK_AND_ASSIGN(AutotuneResults autotune_results_override, + ParseTextProto(R"pb( + version: 3 + results { + device: "..." + hlo: "..." + result { + custom_kernel_fusion { kernel_index: 0 } + run_time { nanos: 14 } + } + })pb")); + autotune_results_override.mutable_results(0)->set_device( + std::string(cache_key.GetModelStr())); + autotune_results_override.mutable_results(0)->set_hlo( + std::string(cache_key.GetHlo())); + + GemmFusionAutotunerVisitor visitor(autotune_config); + + CHECK_OK(AutotunerUtil::LoadAutotuneResults(autotune_results_override)); + visitor.RunOnModule(module.get(), {}).value(); + std::string pattern = R"( + CHECK-NOT: ROOT %gemm_fusion_r = f32[1024,1024]{1,0} fusion + CHECK: ROOT %call = f32[1024,1024]{1,0} call + )"; + TF_ASSERT_OK(RunFileCheck(module->ToString(), pattern)); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 6cb0d4d37ecd1..0e2d5e477ba1d 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1539,8 +1539,17 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( &pipeline, hlo_module, autotune_config, thread_pool, options.key_value_store, gpu_target_config.device_description.runtime_version())); - // Inline back the calls which have better performance with cuBLAS. + // Inline back the calls which have better performance with cuBLAS or Custom + // Kernel Fusion. pipeline.AddPass(); + + // Greedily pattern match and replace with Custom Kernel Fusions (e.g. + // Cutlass kernels with upcasts). + pipeline.AddPass(); + pipeline.AddPass( + &gpu_target_config.device_description); + pipeline.AddPass(autotune_config); + // TODO(tdanyluk): Apply CublasPadForGemms to the cuBLAS GEMMs generated // here for possibly better cuBLAS performance. AddGemmRewriterPasses(pipeline, debug_options, gpu_version, diff --git a/xla/service/gpu/gpu_compiler_test.cc b/xla/service/gpu/gpu_compiler_test.cc index 51b459e8a81a0..2f423fd6d6e4e 100644 --- a/xla/service/gpu/gpu_compiler_test.cc +++ b/xla/service/gpu/gpu_compiler_test.cc @@ -55,6 +55,7 @@ limitations under the License. #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" +#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/casts.h" @@ -1013,8 +1014,9 @@ TEST_F(GpuCompilerPassTest, ->GetDeviceDescription() .cuda_compute_capability(); - bool expect_custom_kernel_fusion_rewriter_has_run = - cc.major == se::CudaComputeCapability::VOLTA; + if (cc.major != se::CudaComputeCapability::VOLTA) { + GTEST_SKIP(); + } constexpr absl::string_view constant_module = R"( HloModule noop @@ -1036,8 +1038,7 @@ ENTRY main { pass_metadata.pass_name() == "custom-kernel-fusion-rewriter"; } - EXPECT_EQ(custom_kernel_fusion_rewriter_has_run, - expect_custom_kernel_fusion_rewriter_has_run); + EXPECT_EQ(custom_kernel_fusion_rewriter_has_run, true); } } // namespace diff --git a/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc b/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc index af9591a16c67a..c0b77ffaedb50 100644 --- a/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc +++ b/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc @@ -186,6 +186,13 @@ absl::StatusOr CustomKernelFusionRewriter::Run( // Collect all potential custom fusion matches in the module. for (HloComputation* computation : module->computations()) { + if (computation->FusionInstruction() != nullptr && + computation->FusionInstruction()->fusion_kind() == + HloInstruction::FusionKind::kCustom) { + // Skip computations based on a custom fusion to avoid recursive fusion. + return false; + } + for (HloInstruction* instr : computation->instructions()) { auto matched = patterns_->Match(*device_, instr); matches.insert(matches.end(), matched.begin(), matched.end());