Skip to content

Commit

Permalink
Reverts 67325a0
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675043690
  • Loading branch information
derdrdirk authored and Google-ML-Automation committed Sep 16, 2024
1 parent e80c23e commit 1765b13
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 362 deletions.
7 changes: 1 addition & 6 deletions xla/autotuning.proto
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,6 @@ message AutotuneResult {
int64 num_ctas = 7;
}

message CustomKernelFusionKey {
int64 kernel_index = 1;
}

int64 scratch_bytes = 8;
google.protobuf.Duration run_time = 9;

Expand All @@ -97,11 +93,10 @@ message AutotuneResult {
GemmKey gemm = 6;
TritonGemmKey triton = 17;
CudaConvPlanKey cuda_conv_plan = 15;
CustomKernelFusionKey custom_kernel_fusion = 18;
stream_executor.dnn.AlgorithmProto algorithm = 16;
}

// Next ID: 19
// Next ID: 17
}

message AutotuningLog {
Expand Down
1 change: 0 additions & 1 deletion xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1627,7 +1627,6 @@ xla_test(
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"//xla/tests:literal_test_util",
"//xla/tests:verified_hlo_module",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/log",
Expand Down
9 changes: 3 additions & 6 deletions xla/service/gpu/autotuning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,12 @@ cc_library(
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:buffer_comparator",
"//xla/service/gpu:gpu_float_support",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:matmul_utils",
"//xla/service/gpu:split_k_gemm_rewriter",
"//xla/service/gpu:stream_executor_util",
"//xla/service/gpu/kernels:custom_kernel",
"//xla/service/gpu/kernels:custom_kernel_fusion",
"//xla/service/gpu/kernels:custom_kernel_fusion_pattern",
"//xla/service/gpu/transforms:cudnn_fusion_compiler",
"//xla/service/gpu/transforms:custom_kernel_fusion_rewriter",
"//xla/service/gpu/transforms:fusion_wrapper",
"//xla/service/gpu/transforms:gemm_rewriter",
"//xla/service/gpu/transforms:priority_fusion",
Expand All @@ -75,9 +72,11 @@ cc_library(
"//xla/stream_executor:device_memory",
"//xla/stream_executor:semantic_version",
"//xla/stream_executor:stream_executor_memory_allocator",
"//xla/stream_executor/gpu:redzone_allocator",
"//xla/tools:hlo_decomposer_lib",
"//xla/tsl/lib/core:bits",
"//xla/tsl/util/proto:proto_utils",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
Expand Down Expand Up @@ -138,8 +137,6 @@ xla_test(
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_description_proto_cc",
"//xla/stream_executor:semantic_version",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor/gpu:gpu_executor_header",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"//xla/tests:test_utils",
Expand Down
18 changes: 4 additions & 14 deletions xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ absl::StatusOr<bool> AutotuneCustomKernelFusion(
return previous_kernel_index != fastest_kernel_index;
}

bool IsCutlassCustomFusion(const HloComputation* computation) {
bool IsCustomFusion(const HloComputation* computation) {
if (!computation->IsFusionComputation()) {
return false;
}
Expand All @@ -212,18 +212,8 @@ bool IsCutlassCustomFusion(const HloComputation* computation) {
return false;
}

if (gpu_backend_config->fusion_backend_config().kind() != kCustomFusionKind) {
return false;
}

if (gpu_backend_config->fusion_backend_config()
.custom_fusion_config()
.name()
.rfind("cutlass", 0) != 0) {
return false;
}

return true;
return gpu_backend_config->fusion_backend_config().kind() ==
kCustomFusionKind;
}
} // namespace

Expand All @@ -241,7 +231,7 @@ absl::StatusOr<bool> CustomKernelFusionAutotuner::Run(

bool hlo_changed = false;
for (const HloComputation* computation : module->computations()) {
if (IsCutlassCustomFusion(computation)) {
if (IsCustomFusion(computation)) {
TF_ASSIGN_OR_RETURN(
bool instruction_changed,
AutotuneCustomKernelFusion(computation->FusionInstruction(), config_,
Expand Down
Loading

0 comments on commit 1765b13

Please sign in to comment.