diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index 149ddc3563a832..c19a48a324d084 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -229,7 +229,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_cudnn_layer_norm(false); opts.set_xla_gpu_threshold_for_windowed_einsum_mib(100000); - opts.set_xla_gpu_enable_triton_hopper(false); opts.set_xla_gpu_experimental_enable_fusion_block_level_rewriter(false); opts.set_xla_gpu_enable_llvm_module_compilation_parallelism(false); @@ -1787,11 +1786,6 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "Einsums that have partitioned operand(can be either LHS or RHS) that's " "larger than this threshold will be transformed to use windowed einsums." "Default is 100000")); - flag_list->push_back(tsl::Flag( - "xla_gpu_enable_triton_hopper", - bool_setter_for(&DebugOptions::set_xla_gpu_enable_triton_hopper), - debug_options->xla_gpu_enable_triton_hopper(), - "Currently used to enable MMA_V3 for Hopper in Triton")); flag_list->push_back(tsl::Flag( "xla_gpu_experimental_enable_fusion_block_level_rewriter", bool_setter_for( diff --git a/xla/service/gpu/autotuning/BUILD b/xla/service/gpu/autotuning/BUILD index 25042e4d5da929..ea34a37200c218 100644 --- a/xla/service/gpu/autotuning/BUILD +++ b/xla/service/gpu/autotuning/BUILD @@ -148,6 +148,7 @@ xla_test( "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tools:hlo_decomposer_lib", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", diff --git a/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index 45d820ca0f3271..321b312e4c34e0 100644 --- a/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -1190,8 +1190,9 @@ std::vector GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const { std::vector configs; se::CudaComputeCapability cc = GetComputeCapability(); - bool tune_ctas = - debug_options_.xla_gpu_enable_triton_hopper() && cc.IsAtLeastHopper(); + // Clusters are only supported from Hopper. We don't autotune them by default. + bool should_tune_ctas = + debug_options_.xla_gpu_exhaustive_tiling_search() && cc.IsAtLeastHopper(); for (int num_stages : kNumStages) { // Volta doesn't support num_stages > 2. @@ -1214,18 +1215,19 @@ GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const { split_k > 1) { break; } - for (int num_ctas : kNumCtas) { - // Clusters are only supported on Hopper. - // Autotuning this parameter is enabled by a flag. - if (!tune_ctas && num_ctas > 1) { - break; - } - if (num_ctas > num_warps) { - break; + + if (should_tune_ctas) { + for (int num_ctas : kNumCtas) { + if (num_ctas <= num_warps) { + configs.push_back(TritonGemmConfig(tile_m, tile_n, tile_k, + split_k, num_stages, + num_warps, num_ctas)); + } } + } else { configs.push_back(TritonGemmConfig(tile_m, tile_n, tile_k, split_k, num_stages, - num_warps, num_ctas)); + num_warps, /*num_ctas=*/1)); } } } diff --git a/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index 01e6e630a99529..6d592b5589b6b5 100644 --- a/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/statusor.h" @@ -1062,6 +1063,61 @@ ENTRY wais { INSTANTIATE_TEST_SUITE_P(GemmFusionAutotunerConfigSweep, GemmFusionAutotunerConfigTest, ::testing::Bool()); +TEST_F(StatelessAutotunerTest, + ExhaustiveAutotuningTunesNumberOfCtasFromHopper) { + const std::string kHloText = R"( +HloModule test + +ENTRY main { + lhs = f32[5,1600] parameter(0) + rhs = f32[1600,10] parameter(1) + ROOT dot = f32[5,10] dot(lhs, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + DebugOptions debug_options_with_exhaustive_autotuning = + GetDebugOptionsForTest(); + debug_options_with_exhaustive_autotuning.set_xla_gpu_exhaustive_tiling_search( + true); + + auto get_configs = [&](const se::CudaComputeCapability& cc, + const DebugOptions& debug_options) { + return GetPossibleMatmulAutotuneTritonConfigs( + *Cast( + module->entry_computation()->root_instruction()), + cc, GetToolkitVersion(), debug_options) + .value(); + }; + + for (const auto& config : + get_configs(se::CudaComputeCapability::Ampere(), + debug_options_with_exhaustive_autotuning)) { + // We do not tune the number of CTAs on Ampere... + EXPECT_EQ(config.num_ctas, 1); + } + + absl::flat_hash_set config_num_ctas; + for (const auto& config : + get_configs(se::CudaComputeCapability::Hopper(), + debug_options_with_exhaustive_autotuning)) { + // ... but we do on Hopper... + config_num_ctas.insert(config.num_ctas); + } + EXPECT_GT(config_num_ctas.size(), 1); + + DebugOptions debug_options_without_exhaustive_autotuning = + GetDebugOptionsForTest(); + debug_options_without_exhaustive_autotuning + .set_xla_gpu_exhaustive_tiling_search(false); + + for (const auto& config : + get_configs(se::CudaComputeCapability::Hopper(), + debug_options_without_exhaustive_autotuning)) { + // ... except if exhaustive autotuning is disabled. + EXPECT_EQ(config.num_ctas, 1); + } +} + TEST_F(GemmFusionAutotunerTest, SplitKFLoatNormalization) { if (!GetCudaComputeCapability().IsAtLeastHopper()) { GTEST_SKIP() << "f8 types are only supported from Hopper onwards."; diff --git a/xla/xla.proto b/xla/xla.proto index fc161b0bfca2df..38b1626663775a 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -782,8 +782,7 @@ message DebugOptions { // Threshold to enable windowed einsum (collective matmul) in MB. int64 xla_gpu_threshold_for_windowed_einsum_mib = 265; - // Enables currently disabled features within Triton for Hopper. - bool xla_gpu_enable_triton_hopper = 266; + reserved 266; // was xla_gpu_enable_triton_hopper // Enable NCCL user buffers. bool xla_gpu_enable_nccl_user_buffers = 267;