Skip to content

Commit

Permalink
Add a command line flag to disable XLA GPU passes based on binary lib…
Browse files Browse the repository at this point in the history
…raries.

Start using the new pass to optionally disable many cuDNN-specific passes.

PiperOrigin-RevId: 674371842
  • Loading branch information
klucke authored and Google-ML-Automation committed Sep 17, 2024
1 parent 8ace4ee commit 3e6c18b
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 22 deletions.
8 changes: 8 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {

opts.set_xla_gpu_executable_warn_stuck_timeout_seconds(10);
opts.set_xla_gpu_executable_terminate_timeout_seconds(30);
opts.set_xla_gpu_experimental_disable_binary_libraries(false);
return opts;
}

Expand Down Expand Up @@ -1936,6 +1937,13 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
&DebugOptions::set_xla_gpu_executable_terminate_timeout_seconds),
debug_options->xla_gpu_executable_terminate_timeout_seconds(),
"Set timeout for RendezvousSingle termination"));
flag_list->push_back(tsl::Flag(
"xla_gpu_experimental_disable_binary_libraries",
bool_setter_for(
&DebugOptions::set_xla_gpu_experimental_disable_binary_libraries),
debug_options->xla_gpu_experimental_disable_binary_libraries(),
"Disable XLA GPU passes that depend on non-open source binary "
"libraries"));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Expand Down
66 changes: 45 additions & 21 deletions xla/service/gpu/nvptx_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,17 @@ absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization(
pipeline.AddPass<FloatNormalization>(&matmul_bf16_support);

pipeline.AddPass<GpusolverRewriter>();
pipeline.AddPass<ConvRewriter>(cuda_compute_capability);
pipeline.AddPass<CudnnFusedConvRewriter>(cuda_compute_capability, dnn_version,
toolkit_version);
pipeline.AddPass<ConvPaddingLegalization>();
pipeline.AddPass<CudnnPadForConvolutions>(cuda_compute_capability);
pipeline.AddPass<CudnnVectorizeConvolutions>(cuda_compute_capability,
dnn_version);
if (!hlo_module->config()
.debug_options()
.xla_gpu_experimental_disable_binary_libraries()) {
pipeline.AddPass<ConvRewriter>(cuda_compute_capability);
pipeline.AddPass<CudnnFusedConvRewriter>(cuda_compute_capability,
dnn_version, toolkit_version);
pipeline.AddPass<ConvPaddingLegalization>();
pipeline.AddPass<CudnnPadForConvolutions>(cuda_compute_capability);
pipeline.AddPass<CudnnVectorizeConvolutions>(cuda_compute_capability,
dnn_version);
}
// The conv padding/vectorization passes which we need to get rid of. They
// also leave behind unnecessary tuple/get-tuple-element pairs that
// TupleSimplifier fixes.
Expand All @@ -228,12 +232,16 @@ absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization(
pipeline.AddPass<HloPassFix<GpuAlgebraicSimplifier>>(algsimp_options,
gpu_version);

// CudnnSimplifyPadding gets rid of some padding introduced by
// CudnnPadForConvolutions and used by CudnnVectorizeConvolutions. The
// pattern-matches in this pass need to be run after inlining and simplifying
// tuples from CudnnVectorizeConvolutions. We also need to run algsimp to
// e.g. clean up unnecessary nop `convert`s.
pipeline.AddPass<CudnnSimplifyPadding>();
if (!hlo_module->config()
.debug_options()
.xla_gpu_experimental_disable_binary_libraries()) {
// CudnnSimplifyPadding gets rid of some padding introduced by
// CudnnPadForConvolutions and used by CudnnVectorizeConvolutions. The
// pattern-matches in this pass need to be run after inlining and
// simplifying tuples from CudnnVectorizeConvolutions. We also need to run
// algsimp to e.g. clean up unnecessary nop `convert`s.
pipeline.AddPass<CudnnSimplifyPadding>();
}

// tf2xla bridge, DepthwiseConvolutionConverter, ConvRewriter, and
// CudnnSimplifyPadding introduce reshapes and transposes. Run ReshapeMover
Expand Down Expand Up @@ -275,7 +283,10 @@ absl::Status NVPTXCompiler::OptimizeHloPostLayoutAssignment(
auto cuda_compute_capability = std::get<se::CudaComputeCapability>(
gpu_target_config.device_description.gpu_compute_capability());

if (hlo_module->config().debug_options().xla_gpu_enable_cudnn_fmha()) {
if (hlo_module->config().debug_options().xla_gpu_enable_cudnn_fmha() &&
!hlo_module->config()
.debug_options()
.xla_gpu_experimental_disable_binary_libraries()) {
HloPassPipeline mha_fusion_pipeline(
"nvptx cudnn multi-headed attention fusion");
// The LayoutAssignment pass may leave behind kCopy instructions which are
Expand Down Expand Up @@ -314,20 +325,28 @@ absl::Status NVPTXCompiler::OptimizeHloPostLayoutAssignment(
}

HloPassPipeline pre_pipeline("nvptx post-layout_assignment part 1");
if (hlo_module->config().debug_options().xla_gpu_enable_cudnn_layer_norm()) {
if (hlo_module->config().debug_options().xla_gpu_enable_cudnn_layer_norm() &&
!hlo_module->config()
.debug_options()
.xla_gpu_experimental_disable_binary_libraries()) {
// Rewrite normalization patterns into cuDNN Custom Calls.
pre_pipeline.AddPass<CudnnNormRewriter>(cuda_compute_capability);
}

pre_pipeline.AddPass<DotDimensionMerger>();
pre_pipeline.AddPass<DotSparsityRewriter>();

for (const CublasPaddingRequirement& requirement :
CublasPaddingRequirements) {
if (cuda_compute_capability.IsAtLeast(requirement.min_compute_capability)) {
pre_pipeline.AddPass<CublasPadForGemms>(cuda_compute_capability,
requirement.data_type,
requirement.multiple_of);
if (!hlo_module->config()
.debug_options()
.xla_gpu_experimental_disable_binary_libraries()) {
for (const CublasPaddingRequirement& requirement :
CublasPaddingRequirements) {
if (cuda_compute_capability.IsAtLeast(
requirement.min_compute_capability)) {
pre_pipeline.AddPass<CublasPadForGemms>(cuda_compute_capability,
requirement.data_type,
requirement.multiple_of);
}
}
}
// Padding a gemm operand that's a constant results in pad(constant). Run
Expand Down Expand Up @@ -397,6 +416,11 @@ absl::Status NVPTXCompiler::AddCustomKernelReplacementPasses(
absl::Status NVPTXCompiler::RunCudnnCompilerPasses(
HloModule* module, se::StreamExecutor* stream_exec,
BinaryMap* dnn_compiled_graphs) {
if (module->config()
.debug_options()
.xla_gpu_experimental_disable_binary_libraries()) {
return absl::OkStatus();
}
tsl::profiler::ScopedAnnotation annotation([&] {
return absl::StrFormat("XlaCompileCudnnFusion:#module=%s,program_id=%d#",
module->name(), module->unique_id());
Expand Down
5 changes: 4 additions & 1 deletion xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ message DebugOptions {
// Specifies the behavior of per kernel autotuning cache.
AutotuneCacheMode xla_gpu_experimental_autotune_cache_mode = 324;

// Experimentally disables binary libraries in GPU compiler passes.
bool xla_gpu_experimental_disable_binary_libraries = 329;

// Gates the experimental feature coupling the Triton Softmax pattern matcher
// with priority fusion.
bool xla_gpu_experimental_enable_triton_softmax_priority_fusion = 325;
Expand Down Expand Up @@ -971,7 +974,7 @@ message DebugOptions {
int32 xla_gpu_executable_warn_stuck_timeout_seconds = 327;
int32 xla_gpu_executable_terminate_timeout_seconds = 328;

// Next id: 329
// Next id: 330

// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
Expand Down

0 comments on commit 3e6c18b

Please sign in to comment.