Skip to content

Commit

Permalink
Add custom kernel fusion to gemm fusion autotuner.
Browse files Browse the repository at this point in the history
The GemmFusionAutotuner currently takes a fusion and compares its runtime on different backends (Triton, CuBLAS and CuDNN). We add CustomKernelFusions (mostly Cutlass kernels) to the autotuner.

PiperOrigin-RevId: 674266098
  • Loading branch information
derdrdirk authored and Google-ML-Automation committed Sep 13, 2024
1 parent c1ef7f8 commit 67325a0
Show file tree
Hide file tree
Showing 10 changed files with 362 additions and 97 deletions.
7 changes: 6 additions & 1 deletion xla/autotuning.proto
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ message AutotuneResult {
int64 num_ctas = 7;
}

message CustomKernelFusionKey {
int64 kernel_index = 1;
}

int64 scratch_bytes = 8;
google.protobuf.Duration run_time = 9;

Expand All @@ -93,10 +97,11 @@ message AutotuneResult {
GemmKey gemm = 6;
TritonGemmKey triton = 17;
CudaConvPlanKey cuda_conv_plan = 15;
CustomKernelFusionKey custom_kernel_fusion = 18;
stream_executor.dnn.AlgorithmProto algorithm = 16;
}

// Next ID: 17
// Next ID: 19
}

message AutotuningLog {
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1627,6 +1627,7 @@ xla_test(
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"//xla/tests:literal_test_util",
"//xla/tests:verified_hlo_module",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/log",
Expand Down
9 changes: 6 additions & 3 deletions xla/service/gpu/autotuning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,15 @@ cc_library(
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:buffer_comparator",
"//xla/service/gpu:gpu_float_support",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:matmul_utils",
"//xla/service/gpu:split_k_gemm_rewriter",
"//xla/service/gpu:stream_executor_util",
"//xla/service/gpu/kernels:custom_kernel",
"//xla/service/gpu/kernels:custom_kernel_fusion",
"//xla/service/gpu/kernels:custom_kernel_fusion_pattern",
"//xla/service/gpu/transforms:cudnn_fusion_compiler",
"//xla/service/gpu/transforms:custom_kernel_fusion_rewriter",
"//xla/service/gpu/transforms:fusion_wrapper",
"//xla/service/gpu/transforms:gemm_rewriter",
"//xla/service/gpu/transforms:priority_fusion",
Expand All @@ -72,11 +75,9 @@ cc_library(
"//xla/stream_executor:device_memory",
"//xla/stream_executor:semantic_version",
"//xla/stream_executor:stream_executor_memory_allocator",
"//xla/stream_executor/gpu:redzone_allocator",
"//xla/tools:hlo_decomposer_lib",
"//xla/tsl/lib/core:bits",
"//xla/tsl/util/proto:proto_utils",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
Expand Down Expand Up @@ -137,6 +138,8 @@ xla_test(
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_description_proto_cc",
"//xla/stream_executor:semantic_version",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor/gpu:gpu_executor_header",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"//xla/tests:test_utils",
Expand Down
18 changes: 14 additions & 4 deletions xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ absl::StatusOr<bool> AutotuneCustomKernelFusion(
return previous_kernel_index != fastest_kernel_index;
}

bool IsCustomFusion(const HloComputation* computation) {
bool IsCutlassCustomFusion(const HloComputation* computation) {
if (!computation->IsFusionComputation()) {
return false;
}
Expand All @@ -212,8 +212,18 @@ bool IsCustomFusion(const HloComputation* computation) {
return false;
}

return gpu_backend_config->fusion_backend_config().kind() ==
kCustomFusionKind;
if (gpu_backend_config->fusion_backend_config().kind() != kCustomFusionKind) {
return false;
}

if (gpu_backend_config->fusion_backend_config()
.custom_fusion_config()
.name()
.rfind("cutlass", 0) != 0) {
return false;
}

return true;
}
} // namespace

Expand All @@ -231,7 +241,7 @@ absl::StatusOr<bool> CustomKernelFusionAutotuner::Run(

bool hlo_changed = false;
for (const HloComputation* computation : module->computations()) {
if (IsCustomFusion(computation)) {
if (IsCutlassCustomFusion(computation)) {
TF_ASSIGN_OR_RETURN(
bool instruction_changed,
AutotuneCustomKernelFusion(computation->FusionInstruction(), config_,
Expand Down
Loading

0 comments on commit 67325a0

Please sign in to comment.