From 42471262ab390def128e6cb7b461d79d0ab7af68 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Tue, 17 Sep 2024 07:56:49 -0700 Subject: [PATCH] Disable more binary libraries if the disable flag is true. PiperOrigin-RevId: 675567404 --- .../gpu/autotuning/gemm_fusion_autotuner.cc | 54 ++++++++++--------- xla/service/gpu/nvptx_compiler.cc | 5 ++ 2 files changed, 33 insertions(+), 26 deletions(-) diff --git a/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index c12c45a4abfd82..0f8f06ebdf15d2 100644 --- a/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -607,36 +607,38 @@ GemmFusionAutotunerImpl::GenerateConfigs(const HloFusionInstruction& fusion) { const HloDotInstruction* dot = Cast(hlo_query::GetFirstInstructionWithOpcode( *fusion.called_computations().at(0), HloOpcode::kDot)); - - // Add cuBLAS reference config, if available. std::vector configs; - if (algorithm_util::IsSupportedByCublasOrCublasLt( - dot->precision_config().algorithm()) && - !dot->sparse_operands() && IsAutotuningEnabled()) { - configs.push_back(CuBlasConfig{}); - } - // Add cuDNN plans, if available. - bool is_hopper = - !config_.IsDeviceless() && GetComputeCapability().IsAtLeastHopper(); - bool is_cudnn_enabled = - debug_options_.xla_gpu_cudnn_gemm_fusion_level() > 0 && is_hopper && - GetDnnVersionInfoOrDefault(config_.GetExecutor()).major_version() >= 9; - if ((IsFusionKind(fusion, kCuDnnFusionKind) && IsAutotuningEnabled()) || - (IsFusionKind(fusion, kTritonGemmFusionKind) && is_cudnn_enabled && - algorithm_util::IsSupportedByCudnn( - dot->precision_config().algorithm()) && - !dot->sparse_operands() && IsAutotuningEnabled())) { - const int plan_count = GetCuDnnPlanCount(fusion, config_); - for (int plan_id = 0; plan_id < plan_count; ++plan_id) { - configs.push_back(CuDnnConfig{plan_id}); + if (!debug_options_.xla_gpu_experimental_disable_binary_libraries()) { + // Add cuBLAS reference config, if available. + if (algorithm_util::IsSupportedByCublasOrCublasLt( + dot->precision_config().algorithm()) && + !dot->sparse_operands() && IsAutotuningEnabled()) { + configs.push_back(CuBlasConfig{}); } - } - if (IsFusionKind(fusion, kCuDnnFusionKind)) { - if (!IsAutotuningEnabled()) { - configs.push_back(CuDnnConfig{-1}); + + // Add cuDNN plans, if available. + bool is_hopper = + !config_.IsDeviceless() && GetComputeCapability().IsAtLeastHopper(); + bool is_cudnn_enabled = + debug_options_.xla_gpu_cudnn_gemm_fusion_level() > 0 && is_hopper && + GetDnnVersionInfoOrDefault(config_.GetExecutor()).major_version() >= 9; + if ((IsFusionKind(fusion, kCuDnnFusionKind) && IsAutotuningEnabled()) || + (IsFusionKind(fusion, kTritonGemmFusionKind) && is_cudnn_enabled && + algorithm_util::IsSupportedByCudnn( + dot->precision_config().algorithm()) && + !dot->sparse_operands() && IsAutotuningEnabled())) { + const int plan_count = GetCuDnnPlanCount(fusion, config_); + for (int plan_id = 0; plan_id < plan_count; ++plan_id) { + configs.push_back(CuDnnConfig{plan_id}); + } + } + if (IsFusionKind(fusion, kCuDnnFusionKind)) { + if (!IsAutotuningEnabled()) { + configs.push_back(CuDnnConfig{-1}); + } + return configs; } - return configs; } // Add triton configs. diff --git a/xla/service/gpu/nvptx_compiler.cc b/xla/service/gpu/nvptx_compiler.cc index 17833f1f444f20..6a691cea969e25 100644 --- a/xla/service/gpu/nvptx_compiler.cc +++ b/xla/service/gpu/nvptx_compiler.cc @@ -388,6 +388,11 @@ bool NVPTXCompiler::RequiresCollectiveScheduleLinearizer( absl::Status NVPTXCompiler::AddConvAndGemmAutotuningPasses( HloPassPipeline* pipeline, HloModule* hlo_module, AutotuneConfig& autotune_config, tsl::thread::ThreadPool* thread_pool) { + if (hlo_module->config() + .debug_options() + .xla_gpu_experimental_disable_binary_libraries()) { + return absl::OkStatus(); + } if (GpuConvAlgorithmPicker::IsEnabled(hlo_module)) { pipeline->AddPass(autotune_config); }