Skip to content

Commit

Permalink
[Codegen][Tuner] Allow tuning specs in the LLVMGPU pipeline (#19359)
Browse files Browse the repository at this point in the history
This adds the `materialize-tuning-specs` pass to the LLVMGPU executable
configuration pipelines.

Add a test that shows that the tuning spec gets applied and picked up in
the ROCDL pipeline.

Also, replace the print-based checks in existing tests with op remarks
on transform strategy application in `materialize-user-configs`.
  • Loading branch information
kuhar authored Dec 4, 2024
1 parent 62bccc9 commit 9f8aad8
Show file tree
Hide file tree
Showing 13 changed files with 112 additions and 8 deletions.
1 change: 1 addition & 0 deletions compiler/plugins/target/CUDA/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ iree_compiler_cc_library(
],
deps = [
"//compiler/src/iree/compiler/Codegen",
"//compiler/src/iree/compiler/Codegen/Common",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils:KnownTargets",
"//compiler/src/iree/compiler/Codegen/LLVMGPU",
Expand Down
1 change: 1 addition & 0 deletions compiler/plugins/target/CUDA/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ iree_cc_library(
MLIRTransformDialect
iree::base::internal::flatcc::building
iree::compiler::Codegen
iree::compiler::Codegen::Common
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::Dialect::GPU::TargetUtils::KnownTargets
iree::compiler::Codegen::LLVMGPU
Expand Down
3 changes: 3 additions & 0 deletions compiler/plugins/target/CUDA/CUDATarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "./SetBlockIdsRangePass.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
Expand Down Expand Up @@ -448,6 +449,8 @@ class CUDATargetBackend final : public TargetBackend {
mlir::registerBuiltinDialectTranslation(registry);
mlir::registerLLVMDialectTranslation(registry);
mlir::registerNVVMDialectTranslation(registry);
// Configuration may load and manipulate transform dialect libraries.
registerTransformDialectTranslationDependentDialects(registry);
}

void
Expand Down
1 change: 1 addition & 0 deletions compiler/plugins/target/ROCM/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ iree_compiler_cc_library(
],
deps = [
"//compiler/plugins/target/ROCM/builtins/ukernel:iree_uk_amdgpu_bitcode",
"//compiler/src/iree/compiler/Codegen/Common",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils:KnownTargets",
Expand Down
1 change: 1 addition & 0 deletions compiler/plugins/target/ROCM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ iree_cc_library(
MLIRROCDLToLLVMIRTranslation
MLIRSupport
MLIRTargetLLVMIRExport
iree::compiler::Codegen::Common
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
iree::compiler::Codegen::Dialect::GPU::TargetUtils::KnownTargets
Expand Down
3 changes: 3 additions & 0 deletions compiler/plugins/target/ROCM/ROCMTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <cstdint>

#include "compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_bitcode.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
Expand Down Expand Up @@ -268,6 +269,8 @@ class ROCMTargetBackend final : public TargetBackend {
registry.insert<IREE::Codegen::IREECodegenDialect>();
registry.insert<IREE::VectorExt::IREEVectorExtDialect>();
registry.insert<IREE::GPU::IREEGPUDialect>();
// Configuration may load and manipulate transform dialect libraries.
registerTransformDialectTranslationDependentDialects(registry);
}

void
Expand Down
4 changes: 4 additions & 0 deletions compiler/plugins/target/ROCM/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@ iree_lit_test_suite(
name = "lit",
srcs = [
"gpu_lower_to_ukernels.mlir",
"lowering_strategy_from_tuning_spec.mlir",
"ukernel_pipeline_transform.mlir",
],
cfg = "//compiler:lit.cfg.py",
data = [
"tuning_spec_mmt_tile_and_fuse.mlir",
],
tools = [
"//tools:iree-opt",
"@llvm-project//llvm:FileCheck",
Expand Down
3 changes: 3 additions & 0 deletions compiler/plugins/target/ROCM/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ iree_lit_test_suite(
lit
SRCS
"gpu_lower_to_ukernels.mlir"
"lowering_strategy_from_tuning_spec.mlir"
"ukernel_pipeline_transform.mlir"
TOOLS
FileCheck
iree-opt
DATA
tuning_spec_mmt_tile_and_fuse.mlir
)

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 \
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-hal-configure-target-executable-variants{target=rocm})))" \
// RUN: --iree-codegen-tuning-spec-path=%p/tuning_spec_mmt_tile_and_fuse.mlir \
// RUN: --iree-codegen-notify-transform-strategy-application \
// RUN: --verify-diagnostics %s | FileCheck %s

// Make sure we can apply the lowering strategy from the specified tuning spec.

// CHECK: #translation = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [128, 1, 1] subgroup_size = 64>
// CHECK: func.func @matmul_transpose_b
// CHECK-SAME: translation_info = #translation
// CHECK: linalg.generic
// CHECK-SAME: __tuning_spec_applied__
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
hal.executable public @main {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
hal.executable.export public @matmul_transpose_b ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
// expected-remark@+1 {{Applied transform configuration strategy @iree_linked_tuning_spec::@__kernel_config}}
func.func @matmul_transpose_b() {
%cst = arith.constant 0.000000e+00 : f16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2048x1280xf16>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<10240x1280xf16>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2048x10240xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2048, 1280], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2048x1280xf16>> -> tensor<2048x1280xf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [10240, 1280], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<10240x1280xf16>> -> tensor<10240x1280xf16>
%5 = tensor.empty() : tensor<2048x10240xf32>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32>
%7 = linalg.matmul_transpose_b
ins(%3, %4 : tensor<2048x1280xf16>, tensor<10240x1280xf16>)
outs(%6 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2048, 10240], strides = [1, 1] : tensor<2048x10240xf32> -> !flow.dispatch.tensor<writeonly:tensor<2048x10240xf32>>
return
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: iree-opt %s

module @mmt_tile_and_fuse_spec attributes { transform.with_named_sequence } {
transform.named_sequence @main(%arg0: !transform.any_op {transform.readonly}) -> ()
attributes { iree_codegen.tuning_spec_entrypoint } {
%mmt = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
// transform.print %mmt {name="MMT"} : !transform.any_op
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0],
reduction = [0, 0, 4],
thread = [8, 4],
promote_operands = [0, 1]}>,
translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
workgroup_size = [128, 1, 1] subgroup_size = 64>
> -> !transform.any_param
transform.annotate %mmt "compilation_info" = %config : !transform.any_op, !transform.any_param
// Add a dummy unit attribute to be sure that the tuning spec applied.
// Otherwise it would be difficult to tell if the lowering config attribute
// comes from our tuning spec or if the compiler heuristic happened to produce
// the same config as this script.
transform.annotate %mmt "__tuning_spec_applied__" : !transform.any_op
transform.yield
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ llvm::cl::opt<std::string> clCodegenTransformDialectLibraryFileName(
"this will default to `__kernel_config`."),
llvm::cl::init(""));

llvm::cl::opt<bool> clCodegenNotifyTransformDialectLibraryApplication(
"iree-codegen-notify-transform-strategy-application",
llvm::cl::desc(
"Emit a remark when a transform configuration strategy successfully "
"applies on a function. This is intended for testing/debuging."),
llvm::cl::init(false));

#define GEN_PASS_DEF_MATERIALIZEUSERCONFIGSPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

Expand Down Expand Up @@ -194,6 +201,9 @@ struct MaterializeUserConfigsPass final
// ```
LDBG("MaterializeUserConfigsPass on function: " << funcOp);
if (succeeded(userTransformLibrary)) {
StringRef libraryModuleName =
userTransformLibrary->transformLibrary.getSymName().value_or(
"<unnamed>");
StringRef entrySequenceName = userTransformLibrary->entrypointName;
auto runResult = runTransformConfigurationStrategy(
funcOp, entrySequenceName, userTransformLibrary->transformLibrary);
Expand All @@ -207,6 +217,12 @@ struct MaterializeUserConfigsPass final
<< entrySequenceName << "` failed to apply";
return signalPassFailure();
}

if (clCodegenNotifyTransformDialectLibraryApplication) {
funcOp->emitRemark()
<< "Applied transform configuration strategy @"
<< libraryModuleName << "::@" << entrySequenceName;
}
}

/// Nothing to do if the export already has a config.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// RUN: iree-opt --pass-pipeline='builtin.module(builtin.module(iree-codegen-materialize-tuning-specs,iree-codegen-materialize-user-configs))' \
// RUN: --iree-codegen-tuning-spec-path=%p/tuning_spec.mlir \
// RUN: --mlir-disable-threading --no-implicit-module %s | FileCheck %s
// RUN: --iree-codegen-notify-transform-strategy-application \
// RUN: --no-implicit-module --verify-diagnostics %s | FileCheck %s

// RUN: iree-opt --pass-pipeline='builtin.module(iree-codegen-materialize-tuning-specs,builtin.module(iree-codegen-materialize-user-configs))' \
// RUN: --iree-codegen-tuning-spec-path=%p/tuning_spec.mlir \
// RUN: --mlir-disable-threading --no-implicit-module %s | FileCheck %s --check-prefix=PARENT
// RUN: --iree-codegen-notify-transform-strategy-application \
// RUN: --no-implicit-module --verify-diagnostics %s | FileCheck %s --check-prefix=PARENT

// (1) We start by running the `Materialize Tuning Specs` pass to embed the
// transform dialect library into the module. Doing it by hand hand is not
Expand All @@ -13,9 +15,6 @@
// Check that the transform spec gets executed and that it does not remain as
// a module attribute after `Materialize User Configs`.

// CHECK-LABEL: [ IR printer: Hello Tuning Spec top-level ]
// CHECK-NEXT: func.func @main_0
//
// CHECK-LABEL: module @parent {
// CHECK-LABEL: module @child {
// CHECK: func.func @main_0
Expand All @@ -25,16 +24,14 @@
// (conservatively) only remove tuning spec from the module passed
// to the `materialize-user-configs` pass.

// PARENT-LABEL: [ IR printer: Hello Tuning Spec top-level ]
// PARENT-NEXT: func.func @main_0
//
// PARENT-LABEL: module @parent attributes {
// PARENT-SAME: iree_codegen.tuning_spec_mlirbc = dense<
// PARENT-LABEL: module @child {
// PARENT: func.func @main_0

module @parent {
module @child {
// expected-remark@+1 {{Applied transform configuration strategy @iree_linked_tuning_spec::@__kernel_config}}
func.func @main_0() {
return
}
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1178,6 +1178,7 @@ static void buildLLVMGPUCodegenConfigurationPassPipelineImpl(
funcPassManager.addPass(createConfigTrackingCanonicalizerPass);
funcPassManager.addPass(createCSEPass);
}
modulePassManager.addPass(createMaterializeTuningSpecsPass());
modulePassManager.addPass(createMaterializeUserConfigsPass());
modulePassManager.addPass(createLLVMGPUSelectLoweringStrategyPass());
}
Expand Down Expand Up @@ -1245,6 +1246,7 @@ static void buildROCDLCodegenConfigurationPassPipelineImpl(
funcPassManager.addPass(createGPUGeneralizeNamedOpsPass);
addCommonTargetExecutablePreprocessingPasses(funcPassManager);
}
modulePassManager.addPass(createMaterializeTuningSpecsPass());
modulePassManager.addPass(createMaterializeUserConfigsPass());

modulePassManager.addPass(createROCDLSelectLoweringStrategyPass());
Expand Down

0 comments on commit 9f8aad8

Please sign in to comment.