Skip to content

Commit

Permalink
Disable more binary libraries if the disable flag is true.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675686593
  • Loading branch information
klucke authored and Google-ML-Automation committed Sep 17, 2024
1 parent 5128760 commit 89f818c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 26 deletions.
54 changes: 28 additions & 26 deletions xla/service/gpu/autotuning/gemm_fusion_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -607,36 +607,38 @@ GemmFusionAutotunerImpl::GenerateConfigs(const HloFusionInstruction& fusion) {
const HloDotInstruction* dot =
Cast<HloDotInstruction>(hlo_query::GetFirstInstructionWithOpcode(
*fusion.called_computations().at(0), HloOpcode::kDot));

// Add cuBLAS reference config, if available.
std::vector<BackendConfig> configs;
if (algorithm_util::IsSupportedByCublasOrCublasLt(
dot->precision_config().algorithm()) &&
!dot->sparse_operands() && IsAutotuningEnabled()) {
configs.push_back(CuBlasConfig{});
}

// Add cuDNN plans, if available.
bool is_hopper =
!config_.IsDeviceless() && GetComputeCapability().IsAtLeastHopper();
bool is_cudnn_enabled =
debug_options_.xla_gpu_cudnn_gemm_fusion_level() > 0 && is_hopper &&
GetDnnVersionInfoOrDefault(config_.GetExecutor()).major_version() >= 9;
if ((IsFusionKind(fusion, kCuDnnFusionKind) && IsAutotuningEnabled()) ||
(IsFusionKind(fusion, kTritonGemmFusionKind) && is_cudnn_enabled &&
algorithm_util::IsSupportedByCudnn(
dot->precision_config().algorithm()) &&
!dot->sparse_operands() && IsAutotuningEnabled())) {
const int plan_count = GetCuDnnPlanCount(fusion, config_);
for (int plan_id = 0; plan_id < plan_count; ++plan_id) {
configs.push_back(CuDnnConfig{plan_id});
if (!debug_options_.xla_gpu_experimental_disable_binary_libraries()) {
// Add cuBLAS reference config, if available.
if (algorithm_util::IsSupportedByCublasOrCublasLt(
dot->precision_config().algorithm()) &&
!dot->sparse_operands() && IsAutotuningEnabled()) {
configs.push_back(CuBlasConfig{});
}
}
if (IsFusionKind(fusion, kCuDnnFusionKind)) {
if (!IsAutotuningEnabled()) {
configs.push_back(CuDnnConfig{-1});

// Add cuDNN plans, if available.
bool is_hopper =
!config_.IsDeviceless() && GetComputeCapability().IsAtLeastHopper();
bool is_cudnn_enabled =
debug_options_.xla_gpu_cudnn_gemm_fusion_level() > 0 && is_hopper &&
GetDnnVersionInfoOrDefault(config_.GetExecutor()).major_version() >= 9;
if ((IsFusionKind(fusion, kCuDnnFusionKind) && IsAutotuningEnabled()) ||
(IsFusionKind(fusion, kTritonGemmFusionKind) && is_cudnn_enabled &&
algorithm_util::IsSupportedByCudnn(
dot->precision_config().algorithm()) &&
!dot->sparse_operands() && IsAutotuningEnabled())) {
const int plan_count = GetCuDnnPlanCount(fusion, config_);
for (int plan_id = 0; plan_id < plan_count; ++plan_id) {
configs.push_back(CuDnnConfig{plan_id});
}
}
if (IsFusionKind(fusion, kCuDnnFusionKind)) {
if (!IsAutotuningEnabled()) {
configs.push_back(CuDnnConfig{-1});
}
return configs;
}
return configs;
}

// Add triton configs.
Expand Down
5 changes: 5 additions & 0 deletions xla/service/gpu/nvptx_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,11 @@ bool NVPTXCompiler::RequiresCollectiveScheduleLinearizer(
absl::Status NVPTXCompiler::AddConvAndGemmAutotuningPasses(
HloPassPipeline* pipeline, HloModule* hlo_module,
AutotuneConfig& autotune_config, tsl::thread::ThreadPool* thread_pool) {
if (hlo_module->config()
.debug_options()
.xla_gpu_experimental_disable_binary_libraries()) {
return absl::OkStatus();
}
if (GpuConvAlgorithmPicker::IsEnabled(hlo_module)) {
pipeline->AddPass<GpuConvAlgorithmPicker>(autotune_config);
}
Expand Down

0 comments on commit 89f818c

Please sign in to comment.