From 9f8aad85b469b224bb9594d002f35bd8febebbf9 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Wed, 4 Dec 2024 13:09:28 -0500 Subject: [PATCH] [Codegen][Tuner] Allow tuning specs in the LLVMGPU pipeline (#19359) This adds the `materialize-tuning-specs` pass to the LLVMGPU executable configuration pipelines. Add a test that shows that the tuning spec gets applied and picked up in the ROCDL pipeline. Also, replace the print-based checks in existing tests with op remarks on transform strategy application in `materialize-user-configs`. --- compiler/plugins/target/CUDA/BUILD.bazel | 1 + compiler/plugins/target/CUDA/CMakeLists.txt | 1 + compiler/plugins/target/CUDA/CUDATarget.cpp | 3 ++ compiler/plugins/target/ROCM/BUILD.bazel | 1 + compiler/plugins/target/ROCM/CMakeLists.txt | 1 + compiler/plugins/target/ROCM/ROCMTarget.cpp | 3 ++ compiler/plugins/target/ROCM/test/BUILD.bazel | 4 ++ .../plugins/target/ROCM/test/CMakeLists.txt | 3 ++ .../lowering_strategy_from_tuning_spec.mlir | 48 +++++++++++++++++++ .../test/tuning_spec_mmt_tile_and_fuse.mlir | 24 ++++++++++ .../Codegen/Common/MaterializeUserConfigs.cpp | 16 +++++++ ...erialize_user_config_from_tuning_spec.mlir | 13 ++--- .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 2 + 13 files changed, 112 insertions(+), 8 deletions(-) create mode 100644 compiler/plugins/target/ROCM/test/lowering_strategy_from_tuning_spec.mlir create mode 100644 compiler/plugins/target/ROCM/test/tuning_spec_mmt_tile_and_fuse.mlir diff --git a/compiler/plugins/target/CUDA/BUILD.bazel b/compiler/plugins/target/CUDA/BUILD.bazel index b694187f7325..2af2c29883bc 100644 --- a/compiler/plugins/target/CUDA/BUILD.bazel +++ b/compiler/plugins/target/CUDA/BUILD.bazel @@ -28,6 +28,7 @@ iree_compiler_cc_library( ], deps = [ "//compiler/src/iree/compiler/Codegen", + "//compiler/src/iree/compiler/Codegen/Common", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils:KnownTargets", "//compiler/src/iree/compiler/Codegen/LLVMGPU", diff --git a/compiler/plugins/target/CUDA/CMakeLists.txt b/compiler/plugins/target/CUDA/CMakeLists.txt index 70c6dc6b8a5b..e3e86c00e54f 100644 --- a/compiler/plugins/target/CUDA/CMakeLists.txt +++ b/compiler/plugins/target/CUDA/CMakeLists.txt @@ -52,6 +52,7 @@ iree_cc_library( MLIRTransformDialect iree::base::internal::flatcc::building iree::compiler::Codegen + iree::compiler::Codegen::Common iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect iree::compiler::Codegen::Dialect::GPU::TargetUtils::KnownTargets iree::compiler::Codegen::LLVMGPU diff --git a/compiler/plugins/target/CUDA/CUDATarget.cpp b/compiler/plugins/target/CUDA/CUDATarget.cpp index ffc49b57fa7d..fe41cb44f8f4 100644 --- a/compiler/plugins/target/CUDA/CUDATarget.cpp +++ b/compiler/plugins/target/CUDA/CUDATarget.cpp @@ -5,6 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "./SetBlockIdsRangePass.h" +#include "iree/compiler/Codegen/Common/Passes.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h" #include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.h" #include "iree/compiler/Codegen/LLVMGPU/Passes.h" @@ -448,6 +449,8 @@ class CUDATargetBackend final : public TargetBackend { mlir::registerBuiltinDialectTranslation(registry); mlir::registerLLVMDialectTranslation(registry); mlir::registerNVVMDialectTranslation(registry); + // Configuration may load and manipulate transform dialect libraries. + registerTransformDialectTranslationDependentDialects(registry); } void diff --git a/compiler/plugins/target/ROCM/BUILD.bazel b/compiler/plugins/target/ROCM/BUILD.bazel index 48dfeb3ff401..682806e23539 100644 --- a/compiler/plugins/target/ROCM/BUILD.bazel +++ b/compiler/plugins/target/ROCM/BUILD.bazel @@ -28,6 +28,7 @@ iree_compiler_cc_library( ], deps = [ "//compiler/plugins/target/ROCM/builtins/ukernel:iree_uk_amdgpu_bitcode", + "//compiler/src/iree/compiler/Codegen/Common", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils:KnownTargets", diff --git a/compiler/plugins/target/ROCM/CMakeLists.txt b/compiler/plugins/target/ROCM/CMakeLists.txt index 96c3305d936d..69204abd3d13 100644 --- a/compiler/plugins/target/ROCM/CMakeLists.txt +++ b/compiler/plugins/target/ROCM/CMakeLists.txt @@ -52,6 +52,7 @@ iree_cc_library( MLIRROCDLToLLVMIRTranslation MLIRSupport MLIRTargetLLVMIRExport + iree::compiler::Codegen::Common iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect iree::compiler::Codegen::Dialect::GPU::TargetUtils::KnownTargets diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp index 48ef62e07220..c175ab02029a 100644 --- a/compiler/plugins/target/ROCM/ROCMTarget.cpp +++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp @@ -9,6 +9,7 @@ #include #include "compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_bitcode.h" +#include "iree/compiler/Codegen/Common/Passes.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h" @@ -268,6 +269,8 @@ class ROCMTargetBackend final : public TargetBackend { registry.insert(); registry.insert(); registry.insert(); + // Configuration may load and manipulate transform dialect libraries. + registerTransformDialectTranslationDependentDialects(registry); } void diff --git a/compiler/plugins/target/ROCM/test/BUILD.bazel b/compiler/plugins/target/ROCM/test/BUILD.bazel index bf9a18d582bd..ebf4dfd7463e 100644 --- a/compiler/plugins/target/ROCM/test/BUILD.bazel +++ b/compiler/plugins/target/ROCM/test/BUILD.bazel @@ -16,9 +16,13 @@ iree_lit_test_suite( name = "lit", srcs = [ "gpu_lower_to_ukernels.mlir", + "lowering_strategy_from_tuning_spec.mlir", "ukernel_pipeline_transform.mlir", ], cfg = "//compiler:lit.cfg.py", + data = [ + "tuning_spec_mmt_tile_and_fuse.mlir", + ], tools = [ "//tools:iree-opt", "@llvm-project//llvm:FileCheck", diff --git a/compiler/plugins/target/ROCM/test/CMakeLists.txt b/compiler/plugins/target/ROCM/test/CMakeLists.txt index 6d2199d8c4bb..38158aac8c5b 100644 --- a/compiler/plugins/target/ROCM/test/CMakeLists.txt +++ b/compiler/plugins/target/ROCM/test/CMakeLists.txt @@ -15,10 +15,13 @@ iree_lit_test_suite( lit SRCS "gpu_lower_to_ukernels.mlir" + "lowering_strategy_from_tuning_spec.mlir" "ukernel_pipeline_transform.mlir" TOOLS FileCheck iree-opt + DATA + tuning_spec_mmt_tile_and_fuse.mlir ) ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/plugins/target/ROCM/test/lowering_strategy_from_tuning_spec.mlir b/compiler/plugins/target/ROCM/test/lowering_strategy_from_tuning_spec.mlir new file mode 100644 index 000000000000..6f7cf092242e --- /dev/null +++ b/compiler/plugins/target/ROCM/test/lowering_strategy_from_tuning_spec.mlir @@ -0,0 +1,48 @@ +// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 \ +// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-hal-configure-target-executable-variants{target=rocm})))" \ +// RUN: --iree-codegen-tuning-spec-path=%p/tuning_spec_mmt_tile_and_fuse.mlir \ +// RUN: --iree-codegen-notify-transform-strategy-application \ +// RUN: --verify-diagnostics %s | FileCheck %s + +// Make sure we can apply the lowering strategy from the specified tuning spec. + +// CHECK: #translation = #iree_codegen.translation_info +// CHECK: func.func @matmul_transpose_b +// CHECK-SAME: translation_info = #translation +// CHECK: linalg.generic +// CHECK-SAME: __tuning_spec_applied__ +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config< + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> +hal.executable public @main { + hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) { + hal.executable.export public @matmul_transpose_b ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %x, %y, %z = flow.dispatch.workgroup_count_from_slice + hal.return %x, %y, %z : index, index, index + } + builtin.module { + // expected-remark@+1 {{Applied transform configuration strategy @iree_linked_tuning_spec::@__kernel_config}} + func.func @matmul_transpose_b() { + %cst = arith.constant 0.000000e+00 : f16 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2048, 1280], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<2048x1280xf16> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [10240, 1280], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<10240x1280xf16> + %5 = tensor.empty() : tensor<2048x10240xf32> + %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32> + %7 = linalg.matmul_transpose_b + ins(%3, %4 : tensor<2048x1280xf16>, tensor<10240x1280xf16>) + outs(%6 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2048, 10240], strides = [1, 1] : tensor<2048x10240xf32> -> !flow.dispatch.tensor> + return + } + } + } +} diff --git a/compiler/plugins/target/ROCM/test/tuning_spec_mmt_tile_and_fuse.mlir b/compiler/plugins/target/ROCM/test/tuning_spec_mmt_tile_and_fuse.mlir new file mode 100644 index 000000000000..24f0c3a200ad --- /dev/null +++ b/compiler/plugins/target/ROCM/test/tuning_spec_mmt_tile_and_fuse.mlir @@ -0,0 +1,24 @@ +// RUN: iree-opt %s + +module @mmt_tile_and_fuse_spec attributes { transform.with_named_sequence } { + transform.named_sequence @main(%arg0: !transform.any_op {transform.readonly}) -> () + attributes { iree_codegen.tuning_spec_entrypoint } { + %mmt = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // transform.print %mmt {name="MMT"} : !transform.any_op + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], + reduction = [0, 0, 4], + thread = [8, 4], + promote_operands = [0, 1]}>, + translation_info = #iree_codegen.translation_info + > -> !transform.any_param + transform.annotate %mmt "compilation_info" = %config : !transform.any_op, !transform.any_param + // Add a dummy unit attribute to be sure that the tuning spec applied. + // Otherwise it would be difficult to tell if the lowering config attribute + // comes from our tuning spec or if the compiler heuristic happened to produce + // the same config as this script. + transform.annotate %mmt "__tuning_spec_applied__" : !transform.any_op + transform.yield + } +} diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp index 21fee4a3f065..92e719b76dbd 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp @@ -32,6 +32,13 @@ llvm::cl::opt clCodegenTransformDialectLibraryFileName( "this will default to `__kernel_config`."), llvm::cl::init("")); +llvm::cl::opt clCodegenNotifyTransformDialectLibraryApplication( + "iree-codegen-notify-transform-strategy-application", + llvm::cl::desc( + "Emit a remark when a transform configuration strategy successfully " + "applies on a function. This is intended for testing/debuging."), + llvm::cl::init(false)); + #define GEN_PASS_DEF_MATERIALIZEUSERCONFIGSPASS #include "iree/compiler/Codegen/Common/Passes.h.inc" @@ -194,6 +201,9 @@ struct MaterializeUserConfigsPass final // ``` LDBG("MaterializeUserConfigsPass on function: " << funcOp); if (succeeded(userTransformLibrary)) { + StringRef libraryModuleName = + userTransformLibrary->transformLibrary.getSymName().value_or( + ""); StringRef entrySequenceName = userTransformLibrary->entrypointName; auto runResult = runTransformConfigurationStrategy( funcOp, entrySequenceName, userTransformLibrary->transformLibrary); @@ -207,6 +217,12 @@ struct MaterializeUserConfigsPass final << entrySequenceName << "` failed to apply"; return signalPassFailure(); } + + if (clCodegenNotifyTransformDialectLibraryApplication) { + funcOp->emitRemark() + << "Applied transform configuration strategy @" + << libraryModuleName << "::@" << entrySequenceName; + } } /// Nothing to do if the export already has a config. diff --git a/compiler/src/iree/compiler/Codegen/Common/test/materialize_user_config_from_tuning_spec.mlir b/compiler/src/iree/compiler/Codegen/Common/test/materialize_user_config_from_tuning_spec.mlir index 08f52791de3f..4e4bba81f056 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/materialize_user_config_from_tuning_spec.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/materialize_user_config_from_tuning_spec.mlir @@ -1,10 +1,12 @@ // RUN: iree-opt --pass-pipeline='builtin.module(builtin.module(iree-codegen-materialize-tuning-specs,iree-codegen-materialize-user-configs))' \ // RUN: --iree-codegen-tuning-spec-path=%p/tuning_spec.mlir \ -// RUN: --mlir-disable-threading --no-implicit-module %s | FileCheck %s +// RUN: --iree-codegen-notify-transform-strategy-application \ +// RUN: --no-implicit-module --verify-diagnostics %s | FileCheck %s // RUN: iree-opt --pass-pipeline='builtin.module(iree-codegen-materialize-tuning-specs,builtin.module(iree-codegen-materialize-user-configs))' \ // RUN: --iree-codegen-tuning-spec-path=%p/tuning_spec.mlir \ -// RUN: --mlir-disable-threading --no-implicit-module %s | FileCheck %s --check-prefix=PARENT +// RUN: --iree-codegen-notify-transform-strategy-application \ +// RUN: --no-implicit-module --verify-diagnostics %s | FileCheck %s --check-prefix=PARENT // (1) We start by running the `Materialize Tuning Specs` pass to embed the // transform dialect library into the module. Doing it by hand hand is not @@ -13,9 +15,6 @@ // Check that the transform spec gets executed and that it does not remain as // a module attribute after `Materialize User Configs`. -// CHECK-LABEL: [ IR printer: Hello Tuning Spec top-level ] -// CHECK-NEXT: func.func @main_0 -// // CHECK-LABEL: module @parent { // CHECK-LABEL: module @child { // CHECK: func.func @main_0 @@ -25,9 +24,6 @@ // (conservatively) only remove tuning spec from the module passed // to the `materialize-user-configs` pass. -// PARENT-LABEL: [ IR printer: Hello Tuning Spec top-level ] -// PARENT-NEXT: func.func @main_0 -// // PARENT-LABEL: module @parent attributes { // PARENT-SAME: iree_codegen.tuning_spec_mlirbc = dense< // PARENT-LABEL: module @child { @@ -35,6 +31,7 @@ module @parent { module @child { + // expected-remark@+1 {{Applied transform configuration strategy @iree_linked_tuning_spec::@__kernel_config}} func.func @main_0() { return } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index b79abdd0eb19..53e49efbf66a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -1178,6 +1178,7 @@ static void buildLLVMGPUCodegenConfigurationPassPipelineImpl( funcPassManager.addPass(createConfigTrackingCanonicalizerPass); funcPassManager.addPass(createCSEPass); } + modulePassManager.addPass(createMaterializeTuningSpecsPass()); modulePassManager.addPass(createMaterializeUserConfigsPass()); modulePassManager.addPass(createLLVMGPUSelectLoweringStrategyPass()); } @@ -1245,6 +1246,7 @@ static void buildROCDLCodegenConfigurationPassPipelineImpl( funcPassManager.addPass(createGPUGeneralizeNamedOpsPass); addCommonTargetExecutablePreprocessingPasses(funcPassManager); } + modulePassManager.addPass(createMaterializeTuningSpecsPass()); modulePassManager.addPass(createMaterializeUserConfigsPass()); modulePassManager.addPass(createROCDLSelectLoweringStrategyPass());