Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a command line flag to disable XLA GPU passes based on binary libraries. #17162

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {

opts.set_xla_gpu_executable_warn_stuck_timeout_seconds(10);
opts.set_xla_gpu_executable_terminate_timeout_seconds(30);
opts.set_xla_gpu_experimental_disable_binary_libraries(false);
return opts;
}

Expand Down Expand Up @@ -1936,6 +1937,13 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
&DebugOptions::set_xla_gpu_executable_terminate_timeout_seconds),
debug_options->xla_gpu_executable_terminate_timeout_seconds(),
"Set timeout for RendezvousSingle termination"));
flag_list->push_back(tsl::Flag(
"xla_gpu_experimental_disable_binary_libraries",
bool_setter_for(
&DebugOptions::set_xla_gpu_experimental_disable_binary_libraries),
debug_options->xla_gpu_experimental_disable_binary_libraries(),
"Disable XLA GPU passes that depend on non-open source binary "
"libraries"));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Expand Down
66 changes: 45 additions & 21 deletions xla/service/gpu/nvptx_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,17 @@ absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization(
pipeline.AddPass<FloatNormalization>(&matmul_bf16_support);

pipeline.AddPass<GpusolverRewriter>();
pipeline.AddPass<ConvRewriter>(cuda_compute_capability);
pipeline.AddPass<CudnnFusedConvRewriter>(cuda_compute_capability, dnn_version,
toolkit_version);
pipeline.AddPass<ConvPaddingLegalization>();
pipeline.AddPass<CudnnPadForConvolutions>(cuda_compute_capability);
pipeline.AddPass<CudnnVectorizeConvolutions>(cuda_compute_capability,
dnn_version);
if (!hlo_module->config()
.debug_options()
.xla_gpu_experimental_disable_binary_libraries()) {
pipeline.AddPass<ConvRewriter>(cuda_compute_capability);
pipeline.AddPass<CudnnFusedConvRewriter>(cuda_compute_capability,
dnn_version, toolkit_version);
pipeline.AddPass<ConvPaddingLegalization>();
pipeline.AddPass<CudnnPadForConvolutions>(cuda_compute_capability);
pipeline.AddPass<CudnnVectorizeConvolutions>(cuda_compute_capability,
dnn_version);
}
// The conv padding/vectorization passes which we need to get rid of. They
// also leave behind unnecessary tuple/get-tuple-element pairs that
// TupleSimplifier fixes.
Expand All @@ -228,12 +232,16 @@ absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization(
pipeline.AddPass<HloPassFix<GpuAlgebraicSimplifier>>(algsimp_options,
gpu_version);

// CudnnSimplifyPadding gets rid of some padding introduced by
// CudnnPadForConvolutions and used by CudnnVectorizeConvolutions. The
// pattern-matches in this pass need to be run after inlining and simplifying
// tuples from CudnnVectorizeConvolutions. We also need to run algsimp to
// e.g. clean up unnecessary nop `convert`s.
pipeline.AddPass<CudnnSimplifyPadding>();
if (!hlo_module->config()
.debug_options()
.xla_gpu_experimental_disable_binary_libraries()) {
// CudnnSimplifyPadding gets rid of some padding introduced by
// CudnnPadForConvolutions and used by CudnnVectorizeConvolutions. The
// pattern-matches in this pass need to be run after inlining and
// simplifying tuples from CudnnVectorizeConvolutions. We also need to run
// algsimp to e.g. clean up unnecessary nop `convert`s.
pipeline.AddPass<CudnnSimplifyPadding>();
}

// tf2xla bridge, DepthwiseConvolutionConverter, ConvRewriter, and
// CudnnSimplifyPadding introduce reshapes and transposes. Run ReshapeMover
Expand Down Expand Up @@ -275,7 +283,10 @@ absl::Status NVPTXCompiler::OptimizeHloPostLayoutAssignment(
auto cuda_compute_capability = std::get<se::CudaComputeCapability>(
gpu_target_config.device_description.gpu_compute_capability());

if (hlo_module->config().debug_options().xla_gpu_enable_cudnn_fmha()) {
if (hlo_module->config().debug_options().xla_gpu_enable_cudnn_fmha() &&
!hlo_module->config()
.debug_options()
.xla_gpu_experimental_disable_binary_libraries()) {
HloPassPipeline mha_fusion_pipeline(
"nvptx cudnn multi-headed attention fusion");
// The LayoutAssignment pass may leave behind kCopy instructions which are
Expand Down Expand Up @@ -314,20 +325,28 @@ absl::Status NVPTXCompiler::OptimizeHloPostLayoutAssignment(
}

HloPassPipeline pre_pipeline("nvptx post-layout_assignment part 1");
if (hlo_module->config().debug_options().xla_gpu_enable_cudnn_layer_norm()) {
if (hlo_module->config().debug_options().xla_gpu_enable_cudnn_layer_norm() &&
!hlo_module->config()
.debug_options()
.xla_gpu_experimental_disable_binary_libraries()) {
// Rewrite normalization patterns into cuDNN Custom Calls.
pre_pipeline.AddPass<CudnnNormRewriter>(cuda_compute_capability);
}

pre_pipeline.AddPass<DotDimensionMerger>();
pre_pipeline.AddPass<DotSparsityRewriter>();

for (const CublasPaddingRequirement& requirement :
CublasPaddingRequirements) {
if (cuda_compute_capability.IsAtLeast(requirement.min_compute_capability)) {
pre_pipeline.AddPass<CublasPadForGemms>(cuda_compute_capability,
requirement.data_type,
requirement.multiple_of);
if (!hlo_module->config()
.debug_options()
.xla_gpu_experimental_disable_binary_libraries()) {
for (const CublasPaddingRequirement& requirement :
CublasPaddingRequirements) {
if (cuda_compute_capability.IsAtLeast(
requirement.min_compute_capability)) {
pre_pipeline.AddPass<CublasPadForGemms>(cuda_compute_capability,
requirement.data_type,
requirement.multiple_of);
}
}
}
// Padding a gemm operand that's a constant results in pad(constant). Run
Expand Down Expand Up @@ -397,6 +416,11 @@ absl::Status NVPTXCompiler::AddCustomKernelReplacementPasses(
absl::Status NVPTXCompiler::RunCudnnCompilerPasses(
HloModule* module, se::StreamExecutor* stream_exec,
BinaryMap* dnn_compiled_graphs) {
if (module->config()
.debug_options()
.xla_gpu_experimental_disable_binary_libraries()) {
return absl::OkStatus();
}
tsl::profiler::ScopedAnnotation annotation([&] {
return absl::StrFormat("XlaCompileCudnnFusion:#module=%s,program_id=%d#",
module->name(), module->unique_id());
Expand Down
5 changes: 4 additions & 1 deletion xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ message DebugOptions {
// Specifies the behavior of per kernel autotuning cache.
AutotuneCacheMode xla_gpu_experimental_autotune_cache_mode = 324;

// Experimentally disables binary libraries in GPU compiler passes.
bool xla_gpu_experimental_disable_binary_libraries = 329;

// Gates the experimental feature coupling the Triton Softmax pattern matcher
// with priority fusion.
bool xla_gpu_experimental_enable_triton_softmax_priority_fusion = 325;
Expand Down Expand Up @@ -971,7 +974,7 @@ message DebugOptions {
int32 xla_gpu_executable_warn_stuck_timeout_seconds = 327;
int32 xla_gpu_executable_terminate_timeout_seconds = 328;

// Next id: 329
// Next id: 330

// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
Expand Down
Loading