diff --git a/xla/autotuning.proto b/xla/autotuning.proto index 4cadf6dbb250e..a7ffcbb57ae6e 100644 --- a/xla/autotuning.proto +++ b/xla/autotuning.proto @@ -83,10 +83,6 @@ message AutotuneResult { int64 num_ctas = 7; } - message CustomKernelFusionKey { - int64 kernel_index = 1; - } - int64 scratch_bytes = 8; google.protobuf.Duration run_time = 9; @@ -97,11 +93,10 @@ message AutotuneResult { GemmKey gemm = 6; TritonGemmKey triton = 17; CudaConvPlanKey cuda_conv_plan = 15; - CustomKernelFusionKey custom_kernel_fusion = 18; stream_executor.dnn.AlgorithmProto algorithm = 16; } - // Next ID: 19 + // Next ID: 17 } message AutotuningLog { diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index fbe509210668e..151417710d85c 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -1627,7 +1627,6 @@ xla_test( "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", diff --git a/xla/service/gpu/autotuning/BUILD b/xla/service/gpu/autotuning/BUILD index cae9e277097f9..310b957aca83a 100644 --- a/xla/service/gpu/autotuning/BUILD +++ b/xla/service/gpu/autotuning/BUILD @@ -58,15 +58,12 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:buffer_comparator", "//xla/service/gpu:gpu_float_support", + "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:split_k_gemm_rewriter", "//xla/service/gpu:stream_executor_util", - "//xla/service/gpu/kernels:custom_kernel", - "//xla/service/gpu/kernels:custom_kernel_fusion", - "//xla/service/gpu/kernels:custom_kernel_fusion_pattern", "//xla/service/gpu/transforms:cudnn_fusion_compiler", - "//xla/service/gpu/transforms:custom_kernel_fusion_rewriter", "//xla/service/gpu/transforms:fusion_wrapper", "//xla/service/gpu/transforms:gemm_rewriter", "//xla/service/gpu/transforms:priority_fusion", @@ -75,9 +72,11 @@ cc_library( "//xla/stream_executor:device_memory", "//xla/stream_executor:semantic_version", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/stream_executor/gpu:redzone_allocator", "//xla/tools:hlo_decomposer_lib", "//xla/tsl/lib/core:bits", "//xla/tsl/util/proto:proto_utils", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -138,8 +137,6 @@ xla_test( "//xla/stream_executor:device_description", "//xla/stream_executor:device_description_proto_cc", "//xla/stream_executor:semantic_version", - "//xla/stream_executor:stream_executor_h", - "//xla/stream_executor/gpu:gpu_executor_header", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:test_utils", diff --git a/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc b/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc index 4a02c599cb987..164252eb83312 100644 --- a/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc +++ b/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc @@ -192,7 +192,7 @@ absl::StatusOr AutotuneCustomKernelFusion( return previous_kernel_index != fastest_kernel_index; } -bool IsCutlassCustomFusion(const HloComputation* computation) { +bool IsCustomFusion(const HloComputation* computation) { if (!computation->IsFusionComputation()) { return false; } @@ -212,18 +212,8 @@ bool IsCutlassCustomFusion(const HloComputation* computation) { return false; } - if (gpu_backend_config->fusion_backend_config().kind() != kCustomFusionKind) { - return false; - } - - if (gpu_backend_config->fusion_backend_config() - .custom_fusion_config() - .name() - .rfind("cutlass", 0) != 0) { - return false; - } - - return true; + return gpu_backend_config->fusion_backend_config().kind() == + kCustomFusionKind; } } // namespace @@ -241,7 +231,7 @@ absl::StatusOr CustomKernelFusionAutotuner::Run( bool hlo_changed = false; for (const HloComputation* computation : module->computations()) { - if (IsCutlassCustomFusion(computation)) { + if (IsCustomFusion(computation)) { TF_ASSIGN_OR_RETURN( bool instruction_changed, AutotuneCustomKernelFusion(computation->FusionInstruction(), config_, diff --git a/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index a0d98975bf440..c12c45a4abfd8 100644 --- a/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -26,6 +27,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" @@ -54,21 +56,19 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/service/algorithm_util.h" #include "xla/service/dump.h" +#include "xla/service/executable.h" #include "xla/service/float_normalization.h" #include "xla/service/gpu/autotuning/autotuner_compile_util.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/buffer_comparator.h" #include "xla/service/gpu/gpu_float_support.h" +#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/kernels/custom_kernel.h" -#include "xla/service/gpu/kernels/custom_kernel_fusion.h" -#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/split_k_gemm_rewriter.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/transforms/cudnn_fusion_compiler.h" -#include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h" #include "xla/service/gpu/transforms/fusion_wrapper.h" #include "xla/service/gpu/transforms/gemm_rewriter.h" #include "xla/service/gpu/transforms/priority_fusion.h" @@ -82,6 +82,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/gpu/redzone_allocator.h" #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" @@ -139,6 +140,76 @@ constexpr std::array kNumCtas = {1, 2, 4, 8, 16}; using AutoTuneCacheKeyCount = absl::flat_hash_map; +class GemmFusionAutotunerVisitor : public DfsHloRewriteVisitor { + public: + explicit GemmFusionAutotunerVisitor(const AutotuneConfig& config) + : config_(config) {} + + absl::Status HandleFusion(HloInstruction* hlo) override { + TF_ASSIGN_OR_RETURN(auto gpu_config, + hlo->backend_config()); + FusionBackendConfig& backend_config = + *gpu_config.mutable_fusion_backend_config(); + if (backend_config.kind() != kTritonGemmFusionKind && + backend_config.kind() != kCuDnnFusionKind) { + return absl::OkStatus(); + } + + VLOG(4) << "Processing " << hlo->ToString(); + if (!backend_config.has_triton_gemm_config() && + !backend_config.has_cudnn_fusion_config()) { + TF_ASSIGN_OR_RETURN( + AutotuneResult autotune_result, + AutotunerUtil::Autotune( + hlo, config_, [&]() -> absl::StatusOr { + if (config_.IsDeviceless()) { + return absl::InternalError(absl::StrCat( + "Expect autotune result cache hit for deviceless " + "compilation (HLO: ", + hlo->ToString(), ")")); + } + return absl::InternalError("Expect autotune result cache hit."); + })); + VLOG(4) << "Result: " << autotune_result.ShortDebugString(); + + if (autotune_result.has_triton()) { + *backend_config.mutable_triton_gemm_config() = autotune_result.triton(); + TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); + } else if (autotune_result.has_gemm()) { + // Falling back to cuBLAS: Converting the fusion to a Call, so that it + // can be inlined back again. + HloComputation* const computation = hlo->parent(); + HloInstruction* const call = computation->AddInstruction( + HloInstruction::CreateCall(hlo->shape(), hlo->operands(), + hlo->fused_instructions_computation())); + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, call)); + hlo = call; + } else { + CHECK(autotune_result.has_algorithm()); + backend_config.set_kind(std::string(kCuDnnFusionKind)); + backend_config.mutable_cudnn_fusion_config()->set_plan_id( + autotune_result.algorithm().algo_id()); + TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); + } + } + + if (backend_config.has_triton_gemm_config()) { + TF_ASSIGN_OR_RETURN( + const TritonGemmConfig config, + TritonGemmConfig::FromProto(backend_config.triton_gemm_config())); + if (config.split_k > 1) { + TF_RETURN_IF_ERROR(MakeDotSplitKBatch(hlo, config)); + } + } + + MarkAsChanged(); + return absl::OkStatus(); + } + + private: + AutotuneConfig config_; +}; + class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault { public: explicit GemmConfigSetCollector(GemmFusionAutotunerImpl* impl) @@ -188,9 +259,7 @@ class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault { bool missing_config = (backend_config.kind() == kTritonGemmFusionKind && !backend_config.has_triton_gemm_config()) || (backend_config.kind() == kCuDnnFusionKind && - !backend_config.has_cudnn_fusion_config()) || - (backend_config.kind() == kCustomFusionKind && - !backend_config.has_custom_fusion_config()); + !backend_config.has_cudnn_fusion_config()); if (missing_config) { if (error_out_on_cache_miss_) { return absl::NotFoundError(absl::StrCat( @@ -357,46 +426,6 @@ absl::StatusOr> CublasGemmAutotuneExtractor( return new_module; } -absl::Status UpdateFusionInstructionKernelIndex( - HloInstruction* fusion_instruction, int kernel_index) { - GpuBackendConfig gpu_config = - fusion_instruction->backend_config().value(); - gpu_config.mutable_fusion_backend_config() - ->mutable_custom_fusion_config() - ->set_kernel_index(kernel_index); - TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(gpu_config)); - - return absl::OkStatus(); -} - -absl::StatusOr> CutlassGemmAutotuneExtractor( - const GemmFusionAutotunerImpl::CustomKernelFusionConfig& cutlass_config, - const AutotuneConfig& config, const se::SemanticVersion& toolkit_version, - const HloFusionInstruction* fusion, const DebugOptions& debug_opts) { - const HloComputation* fusion_computation = fusion->called_computation(); - std::unique_ptr new_module = - ExtractComputationIntoNewModule(*fusion_computation); - new_module->mutable_config().set_debug_options(debug_opts); - - CustomKernelFusionRewriter rewriter( - &config.GetExecutor()->GetDeviceDescription()); - PriorityFusion fusion_pass( - /*thread_pool=*/nullptr, config.GetExecutor()->GetDeviceDescription(), - PriorityFusionOptions()); - TF_RETURN_IF_ERROR(rewriter.Run(new_module.get()).status()); - TF_RETURN_IF_ERROR(fusion_pass.Run(new_module.get()).status()); - - // Select custom kernel fusion kernel. - HloInstruction* custom_kernel_fusion = - hlo_query::GetFirstInstructionWithOpcode(*new_module->entry_computation(), - HloOpcode::kFusion); - int64_t kernel_index = cutlass_config.kernel_index; - TF_RETURN_IF_ERROR( - UpdateFusionInstructionKernelIndex(custom_kernel_fusion, kernel_index)); - - return new_module; -} - absl::StatusOr> FusionExtractor( const HloFusionInstruction& fusion, const DebugOptions& debug_opts) { std::unique_ptr module = ExtractInstructionIntoNewModule(fusion); @@ -445,11 +474,6 @@ AutotuneResult FromConfig(const BackendConfig& config) { AutotuneResult res; if (std::holds_alternative(config)) { res.mutable_gemm()->set_algorithm(CUBLAS_GEMM_DEFAULT); - } else if (std::holds_alternative< - GemmFusionAutotunerImpl::CustomKernelFusionConfig>(config)) { - res.mutable_custom_kernel_fusion()->set_kernel_index( - std::get(config) - .kernel_index); } else if (std::holds_alternative( config)) { res.mutable_algorithm()->set_algo_id( @@ -549,75 +573,6 @@ std::string Serialize(const BackendConfig& config) { } // anonymous namespace -absl::Status GemmFusionAutotunerVisitor::HandleFusion(HloInstruction* hlo) { - TF_ASSIGN_OR_RETURN(auto gpu_config, hlo->backend_config()); - FusionBackendConfig& backend_config = - *gpu_config.mutable_fusion_backend_config(); - if (backend_config.kind() != kTritonGemmFusionKind && - backend_config.kind() != kCuDnnFusionKind && - backend_config.kind() != kCustomFusionKind) { - return absl::OkStatus(); - } - - VLOG(4) << "Processing " << hlo->ToString(); - if (!backend_config.has_triton_gemm_config() && - !backend_config.has_cudnn_fusion_config() && - !backend_config.has_custom_fusion_config()) { - TF_ASSIGN_OR_RETURN( - AutotuneResult autotune_result, - AutotunerUtil::Autotune( - hlo, config_, [&]() -> absl::StatusOr { - if (config_.IsDeviceless()) { - return absl::InternalError(absl::StrCat( - "Expect autotune result cache hit for deviceless " - "compilation (HLO: ", - hlo->ToString(), ")")); - } - return absl::InternalError("Expect autotune result cache hit."); - })); - VLOG(4) << "Result: " << autotune_result.ShortDebugString(); - - if (autotune_result.has_triton()) { - *backend_config.mutable_triton_gemm_config() = autotune_result.triton(); - TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); - } else if (autotune_result.has_gemm()) { - // Falling back to cuBLAS: Converting the fusion to a Call, so that it - // can be inlined back again. - HloComputation* const computation = hlo->parent(); - HloInstruction* const call = computation->AddInstruction( - HloInstruction::CreateCall(hlo->shape(), hlo->operands(), - hlo->fused_instructions_computation())); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, call)); - hlo = call; - } else if (autotune_result.has_custom_kernel_fusion()) { - HloComputation* const computation = hlo->parent(); - HloInstruction* const call = computation->AddInstruction( - HloInstruction::CreateCall(hlo->shape(), hlo->operands(), - hlo->fused_instructions_computation())); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, call)); - hlo = call; - } else { - CHECK(autotune_result.has_algorithm()); - backend_config.set_kind(std::string(kCuDnnFusionKind)); - backend_config.mutable_cudnn_fusion_config()->set_plan_id( - autotune_result.algorithm().algo_id()); - TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); - } - } - - if (backend_config.has_triton_gemm_config()) { - TF_ASSIGN_OR_RETURN( - const TritonGemmConfig config, - TritonGemmConfig::FromProto(backend_config.triton_gemm_config())); - if (config.split_k > 1) { - TF_RETURN_IF_ERROR(MakeDotSplitKBatch(hlo, config)); - } - } - - MarkAsChanged(); - return absl::OkStatus(); -} - // Methods required for sorting the configs. bool GemmFusionAutotunerImpl::CuBlasConfig::operator<( const CuBlasConfig& other) const { @@ -627,10 +582,6 @@ bool GemmFusionAutotunerImpl::CuDnnConfig::operator<( const CuDnnConfig& other) const { return plan_id < other.plan_id; } -bool GemmFusionAutotunerImpl::CustomKernelFusionConfig::operator<( - const CustomKernelFusionConfig& other) const { - return false; -} bool GemmFusionAutotunerImpl::IsAutotuningEnabled() const { return debug_options_.xla_gpu_autotune_level() > 0 && @@ -688,43 +639,6 @@ GemmFusionAutotunerImpl::GenerateConfigs(const HloFusionInstruction& fusion) { return configs; } - // Add CustomKernelFusion (Cutlass) configs, if available. - // Go through all the instructions in the fusion body try to match them to - // a custom kernel fusion pattern. - if ((IsFusionKind(fusion, kCustomFusionKind) || - IsFusionKind(fusion, kTritonGemmFusionKind)) && - IsAutotuningEnabled() && !config_.IsDeviceless()) { - const CustomKernelFusionPatternRegistry* patterns = - CustomKernelFusionPatternRegistry::Default(); - HloComputation* computation = fusion.called_computation(); - // Get the first dot instruction in the fusion body. - HloInstruction* dot_instruction = - hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); - std::vector match = patterns->Match( - config_.GetExecutor()->GetDeviceDescription(), dot_instruction); - - // For Cutlass we expect only one match for a gemm fusion. - if (match.size() == 1) { - CustomKernelFusionRegistry* registry = - CustomKernelFusionRegistry::Default(); - auto* custom_kernel_fusion = registry->Lookup(match[0].config().name()); - - // If custom fusion is not found it means that some of the build targets - // might not be statically linked into the binary. - if (custom_kernel_fusion != nullptr) { - // Load custom kernels that can implement a fusion computation. - TF_ASSIGN_OR_RETURN(std::vector kernels, - custom_kernel_fusion->LoadKernels( - config_.GetExecutor()->GetDeviceDescription(), - fusion.fused_instructions_computation())); - for (int i = 0; i < kernels.size(); ++i) { - CustomKernelFusionConfig config{/*kernel_index=*/i}; - configs.push_back(config); - } - } - } - } - // Add triton configs. TF_ASSIGN_OR_RETURN(std::vector triton_configs, GenerateTritonConfigs(*dot)); @@ -888,15 +802,6 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, config_, config_.GetExecutor()->GetDeviceDescription(), toolkit_version_, fusion, opts); })); - } else if (std::holds_alternative(config)) { - CustomKernelFusionConfig cutlass_config = - std::get(config); - TF_ASSIGN_OR_RETURN( - executable, compile_util.Compile([&](const DebugOptions& opts) { - return CutlassGemmAutotuneExtractor(cutlass_config, config_, - toolkit_version_, fusion, opts); - })); - } else { LOG(FATAL) << "Unsupported config type: " << config.index(); } diff --git a/xla/service/gpu/autotuning/gemm_fusion_autotuner.h b/xla/service/gpu/autotuning/gemm_fusion_autotuner.h index b49eefb0fabd8..7c262ffc8c613 100644 --- a/xla/service/gpu/autotuning/gemm_fusion_autotuner.h +++ b/xla/service/gpu/autotuning/gemm_fusion_autotuner.h @@ -29,9 +29,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/autotuning.pb.h" -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" @@ -48,17 +46,6 @@ limitations under the License. namespace xla { namespace gpu { -class GemmFusionAutotunerVisitor : public DfsHloRewriteVisitor { - public: - explicit GemmFusionAutotunerVisitor(const AutotuneConfig& config) - : config_(config) {} - - absl::Status HandleFusion(HloInstruction* hlo) override; - - private: - AutotuneConfig config_; -}; - // Takes a gemm fusion and chooses between cuBLAS, cuDNN, and Triton backends. // In the case of Triton, it also chooses the best tiling configuration. // @@ -112,13 +99,8 @@ class GemmFusionAutotunerImpl { int64_t plan_id; bool operator<(const CuDnnConfig& other) const; }; - struct CustomKernelFusionConfig { - int64_t kernel_index; - bool operator<(const CustomKernelFusionConfig& other) const; - }; using BackendConfig = - std::variant; + std::variant; using BackendConfigs = std::vector< std::pair>>; diff --git a/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index 4a57c241f6336..92a47f4313e55 100644 --- a/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -50,9 +50,7 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" -#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/semantic_version.h" -#include "xla/stream_executor/stream_executor.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" @@ -197,25 +195,6 @@ class GemmFusionAutotunerTest : public StatelessAutotunerTest { .cuda_compute_capability(); } - absl::StatusOr> - GetPossibleMatmulAutotuneConfigs( - const HloFusionInstruction& fusion, - const se::CudaComputeCapability& compute_capability, - const se::SemanticVersion& toolkit_version, - const DebugOptions& debug_options) { - se::GpuDeviceInfoProto deviceless_proto; - auto ccc = deviceless_proto.mutable_cuda_compute_capability(); - ccc->set_major(compute_capability.major); - ccc->set_minor(compute_capability.minor); - - DeviceConfig test_config{backend().default_stream_executor(), - backend().memory_allocator()}; - AutotuneConfig autotune_config{test_config, debug_options}; - GemmFusionAutotunerImpl autotuner(autotune_config, toolkit_version, - debug_options, nullptr); - return autotuner.GenerateConfigs(fusion); - } - void CheckTritonAutotuning(absl::string_view hlo, absl::string_view expected) { HloPassPipeline pipeline("gemm_rewrite"); @@ -268,8 +247,7 @@ class GemmFusionAutotunerTestWithMorePreciseReduction } }; -absl::StatusOr> -GetPossibleMatmulAutotuneTritonConfigs( +absl::StatusOr> GetPossibleMatmulAutotuneConfigs( const HloDotInstruction& dot, const se::CudaComputeCapability& compute_capability, const se::SemanticVersion& toolkit_version, @@ -298,7 +276,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneTritonConfigs( + GetPossibleMatmulAutotuneConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -320,7 +298,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneTritonConfigs( + GetPossibleMatmulAutotuneConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -342,7 +320,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneTritonConfigs( + GetPossibleMatmulAutotuneConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -896,7 +874,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneTritonConfigs( + GetPossibleMatmulAutotuneConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -928,7 +906,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneTritonConfigs( + GetPossibleMatmulAutotuneConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -959,7 +937,7 @@ ENTRY wais { TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneTritonConfigs( + GetPossibleMatmulAutotuneConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), debug_options)); @@ -1023,100 +1001,6 @@ ENTRY entry { CHECK_OK(autotuner.CompileAll(*compile_util, configs)); } -TEST_F(GemmFusionAutotunerTest, CreatesCustomKernelFusionConfigs) { - const std::string kHlo = R"( - HloModule module, entry_computation_layout={(bf16[1024,1024]{1,0}, bf16[1024,1024]{1,0})->f32[1024,1024]{1,0}} - - %gemm_fusion_r_computation { - %parameter_0 = bf16[1024,1024]{1,0} parameter(0) - %convert.2 = f32[1024,1024]{1,0} convert(%parameter_0) - %parameter_1 = bf16[1024,1024]{1,0} parameter(1) - %convert.3 = f32[1024,1024]{1,0} convert(%parameter_1) - ROOT %r.1 = f32[1024,1024]{1,0} dot(%convert.2, %convert.3), lhs_contracting_dims={1}, rhs_contracting_dims={0} - } - - ENTRY main { - %p0 = bf16[1024,1024]{1,0} parameter(0) - %p1 = bf16[1024,1024]{1,0} parameter(1) - ROOT %gemm_fusion_r = f32[1024,1024]{1,0} fusion(%p0, %p1), kind=kCustom, calls=gemm_fusion_r_computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false} - })"; - - std::unique_ptr module = - ParseAndReturnVerifiedModule(kHlo).value(); - const se::CudaComputeCapability compute_capability{ - se::CudaComputeCapability::AMPERE, /*minor=*/0}; - - TF_ASSERT_OK_AND_ASSIGN( - const std::vector configs, - GetPossibleMatmulAutotuneConfigs( - *Cast( - module->entry_computation()->root_instruction()), - compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); - EXPECT_TRUE(std::any_of( - configs.begin(), configs.end(), - [](const GemmFusionAutotunerImpl::BackendConfig& config) { - return std::holds_alternative< - GemmFusionAutotunerImpl::CustomKernelFusionConfig>(config); - })); -} - -TEST_F(GemmFusionAutotunerTest, - InlinesTritonFusionIfCustomKernelFusionIsMorePerformant) { - const std::string kHlo = R"( - HloModule module, entry_computation_layout={(bf16[1024,1024]{1,0}, bf16[1024,1024]{1,0})->f32[1024,1024]{1,0}} - - %gemm_fusion_r_computation { - %parameter_0 = bf16[1024,1024]{1,0} parameter(0) - %convert.2 = f32[1024,1024]{1,0} convert(%parameter_0) - %parameter_1 = bf16[1024,1024]{1,0} parameter(1) - %convert.3 = f32[1024,1024]{1,0} convert(%parameter_1) - ROOT %r.1 = f32[1024,1024]{1,0} dot(%convert.2, %convert.3), lhs_contracting_dims={1}, rhs_contracting_dims={0} - } - - ENTRY main { - %p0 = bf16[1024,1024]{1,0} parameter(0) - %p1 = bf16[1024,1024]{1,0} parameter(1) - ROOT %gemm_fusion_r = f32[1024,1024]{1,0} fusion(%p0, %p1), kind=kCustom, calls=gemm_fusion_r_computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false} - } -)"; - - std::unique_ptr module = - ParseAndReturnVerifiedModule(kHlo).value(); - - DebugOptions opts; - AutotuneConfig autotune_config{ - DeviceConfig{backend().default_stream_executor(), - backend().memory_allocator()}, - opts}; - AutotuneCacheKey cache_key(autotune_config.GetModelStr(), - *module->entry_computation()->root_instruction()); - TF_ASSERT_OK_AND_ASSIGN(AutotuneResults autotune_results_override, - ParseTextProto(R"pb( - version: 3 - results { - device: "..." - hlo: "..." - result { - custom_kernel_fusion { kernel_index: 0 } - run_time { nanos: 14 } - } - })pb")); - autotune_results_override.mutable_results(0)->set_device( - std::string(cache_key.GetModelStr())); - autotune_results_override.mutable_results(0)->set_hlo( - std::string(cache_key.GetHlo())); - - GemmFusionAutotunerVisitor visitor(autotune_config); - - CHECK_OK(AutotunerUtil::LoadAutotuneResults(autotune_results_override)); - visitor.RunOnModule(module.get(), {}).value(); - std::string pattern = R"( - CHECK-NOT: ROOT %gemm_fusion_r = f32[1024,1024]{1,0} fusion - CHECK: ROOT %call = f32[1024,1024]{1,0} call - )"; - TF_ASSERT_OK(RunFileCheck(module->ToString(), pattern)); -} - } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 0e2d5e477ba1d..6cb0d4d37ecd1 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1539,17 +1539,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( &pipeline, hlo_module, autotune_config, thread_pool, options.key_value_store, gpu_target_config.device_description.runtime_version())); - // Inline back the calls which have better performance with cuBLAS or Custom - // Kernel Fusion. + // Inline back the calls which have better performance with cuBLAS. pipeline.AddPass(); - - // Greedily pattern match and replace with Custom Kernel Fusions (e.g. - // Cutlass kernels with upcasts). - pipeline.AddPass(); - pipeline.AddPass( - &gpu_target_config.device_description); - pipeline.AddPass(autotune_config); - // TODO(tdanyluk): Apply CublasPadForGemms to the cuBLAS GEMMs generated // here for possibly better cuBLAS performance. AddGemmRewriterPasses(pipeline, debug_options, gpu_version, diff --git a/xla/service/gpu/gpu_compiler_test.cc b/xla/service/gpu/gpu_compiler_test.cc index 2f423fd6d6e4e..51b459e8a81a0 100644 --- a/xla/service/gpu/gpu_compiler_test.cc +++ b/xla/service/gpu/gpu_compiler_test.cc @@ -55,7 +55,6 @@ limitations under the License. #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/casts.h" @@ -1014,9 +1013,8 @@ TEST_F(GpuCompilerPassTest, ->GetDeviceDescription() .cuda_compute_capability(); - if (cc.major != se::CudaComputeCapability::VOLTA) { - GTEST_SKIP(); - } + bool expect_custom_kernel_fusion_rewriter_has_run = + cc.major == se::CudaComputeCapability::VOLTA; constexpr absl::string_view constant_module = R"( HloModule noop @@ -1038,7 +1036,8 @@ ENTRY main { pass_metadata.pass_name() == "custom-kernel-fusion-rewriter"; } - EXPECT_EQ(custom_kernel_fusion_rewriter_has_run, true); + EXPECT_EQ(custom_kernel_fusion_rewriter_has_run, + expect_custom_kernel_fusion_rewriter_has_run); } } // namespace diff --git a/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc b/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc index c0b77ffaedb50..af9591a16c67a 100644 --- a/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc +++ b/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc @@ -186,13 +186,6 @@ absl::StatusOr CustomKernelFusionRewriter::Run( // Collect all potential custom fusion matches in the module. for (HloComputation* computation : module->computations()) { - if (computation->FusionInstruction() != nullptr && - computation->FusionInstruction()->fusion_kind() == - HloInstruction::FusionKind::kCustom) { - // Skip computations based on a custom fusion to avoid recursive fusion. - return false; - } - for (HloInstruction* instr : computation->instructions()) { auto matched = patterns_->Match(*device_, instr); matches.insert(matches.end(), matched.begin(), matched.end());