Skip to content

Commit

Permalink
[XLA:GPU] Remove the now obsolete --xla_gpu_enable_triton_hopper flag.
Browse files Browse the repository at this point in the history
MMA_V3 has been enabled by default, and this only gated varying the number
of CTAs when autotuning matrix multiplications at this point.

This also fixes a bug where the number of CTAs was not being autotuned when
using exhaustive tiling.

PiperOrigin-RevId: 681535035
  • Loading branch information
bchetioui authored and Google-ML-Automation committed Oct 3, 2024
1 parent 8af9774 commit d84df59
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 19 deletions.
6 changes: 0 additions & 6 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_enable_cudnn_layer_norm(false);
opts.set_xla_gpu_threshold_for_windowed_einsum_mib(100000);

opts.set_xla_gpu_enable_triton_hopper(false);
opts.set_xla_gpu_experimental_enable_fusion_block_level_rewriter(false);

opts.set_xla_gpu_enable_llvm_module_compilation_parallelism(false);
Expand Down Expand Up @@ -1787,11 +1786,6 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"Einsums that have partitioned operand(can be either LHS or RHS) that's "
"larger than this threshold will be transformed to use windowed einsums."
"Default is 100000"));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_triton_hopper",
bool_setter_for(&DebugOptions::set_xla_gpu_enable_triton_hopper),
debug_options->xla_gpu_enable_triton_hopper(),
"Currently used to enable MMA_V3 for Hopper in Triton"));
flag_list->push_back(tsl::Flag(
"xla_gpu_experimental_enable_fusion_block_level_rewriter",
bool_setter_for(
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/autotuning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ xla_test(
"//xla/tests:xla_internal_test_main", # fixdeps: keep
"//xla/tools:hlo_decomposer_lib",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
Expand Down
24 changes: 13 additions & 11 deletions xla/service/gpu/autotuning/gemm_fusion_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1190,8 +1190,9 @@ std::vector<TritonGemmConfig>
GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const {
std::vector<TritonGemmConfig> configs;
se::CudaComputeCapability cc = GetComputeCapability();
bool tune_ctas =
debug_options_.xla_gpu_enable_triton_hopper() && cc.IsAtLeastHopper();
// Clusters are only supported from Hopper. We don't autotune them by default.
bool should_tune_ctas =
debug_options_.xla_gpu_exhaustive_tiling_search() && cc.IsAtLeastHopper();

for (int num_stages : kNumStages) {
// Volta doesn't support num_stages > 2.
Expand All @@ -1214,18 +1215,19 @@ GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const {
split_k > 1) {
break;
}
for (int num_ctas : kNumCtas) {
// Clusters are only supported on Hopper.
// Autotuning this parameter is enabled by a flag.
if (!tune_ctas && num_ctas > 1) {
break;
}
if (num_ctas > num_warps) {
break;

if (should_tune_ctas) {
for (int num_ctas : kNumCtas) {
if (num_ctas <= num_warps) {
configs.push_back(TritonGemmConfig(tile_m, tile_n, tile_k,
split_k, num_stages,
num_warps, num_ctas));
}
}
} else {
configs.push_back(TritonGemmConfig(tile_m, tile_n, tile_k,
split_k, num_stages,
num_warps, num_ctas));
num_warps, /*num_ctas=*/1));
}
}
}
Expand Down
56 changes: 56 additions & 0 deletions xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -1062,6 +1063,61 @@ ENTRY wais {
INSTANTIATE_TEST_SUITE_P(GemmFusionAutotunerConfigSweep,
GemmFusionAutotunerConfigTest, ::testing::Bool());

TEST_F(StatelessAutotunerTest,
ExhaustiveAutotuningTunesNumberOfCtasFromHopper) {
const std::string kHloText = R"(
HloModule test
ENTRY main {
lhs = f32[5,1600] parameter(0)
rhs = f32[1600,10] parameter(1)
ROOT dot = f32[5,10] dot(lhs, rhs),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
DebugOptions debug_options_with_exhaustive_autotuning =
GetDebugOptionsForTest();
debug_options_with_exhaustive_autotuning.set_xla_gpu_exhaustive_tiling_search(
true);

auto get_configs = [&](const se::CudaComputeCapability& cc,
const DebugOptions& debug_options) {
return GetPossibleMatmulAutotuneTritonConfigs(
*Cast<HloDotInstruction>(
module->entry_computation()->root_instruction()),
cc, GetToolkitVersion(), debug_options)
.value();
};

for (const auto& config :
get_configs(se::CudaComputeCapability::Ampere(),
debug_options_with_exhaustive_autotuning)) {
// We do not tune the number of CTAs on Ampere...
EXPECT_EQ(config.num_ctas, 1);
}

absl::flat_hash_set<int> config_num_ctas;
for (const auto& config :
get_configs(se::CudaComputeCapability::Hopper(),
debug_options_with_exhaustive_autotuning)) {
// ... but we do on Hopper...
config_num_ctas.insert(config.num_ctas);
}
EXPECT_GT(config_num_ctas.size(), 1);

DebugOptions debug_options_without_exhaustive_autotuning =
GetDebugOptionsForTest();
debug_options_without_exhaustive_autotuning
.set_xla_gpu_exhaustive_tiling_search(false);

for (const auto& config :
get_configs(se::CudaComputeCapability::Hopper(),
debug_options_without_exhaustive_autotuning)) {
// ... except if exhaustive autotuning is disabled.
EXPECT_EQ(config.num_ctas, 1);
}
}

TEST_F(GemmFusionAutotunerTest, SplitKFLoatNormalization) {
if (!GetCudaComputeCapability().IsAtLeastHopper()) {
GTEST_SKIP() << "f8 types are only supported from Hopper onwards.";
Expand Down
3 changes: 1 addition & 2 deletions xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -782,8 +782,7 @@ message DebugOptions {
// Threshold to enable windowed einsum (collective matmul) in MB.
int64 xla_gpu_threshold_for_windowed_einsum_mib = 265;

// Enables currently disabled features within Triton for Hopper.
bool xla_gpu_enable_triton_hopper = 266;
reserved 266; // was xla_gpu_enable_triton_hopper

// Enable NCCL user buffers.
bool xla_gpu_enable_nccl_user_buffers = 267;
Expand Down

0 comments on commit d84df59

Please sign in to comment.