Skip to content

Commit

Permalink
[Codegen][Tuner] Make default and user-provided specs work (iree-org#…
Browse files Browse the repository at this point in the history
…19449)

This is a fixup to the tuning spec materialization that makes default
and user-provided specs work e2e. As an example, a working spec for
`linalg.matmul_transpose_b` is provided for gfx942.

* Allow for tuning spec entry points that consume their argument op.
This is so that tuning specs can use `transform.foreach_match`.
* Require all tuning spec entry points to return `any_op`, so that we
can chain includes. This works for both consumed and readonly args.
* Add a test to show that user-provided tuning specs take precedence
over default ones.
* Work around a transform interpreter bug when multiple named sequenced
across different modules share the same symbol name.

Issue: iree-org#19214

---------

Signed-off-by: Jakub Kuderski <[email protected]>
  • Loading branch information
ScottTodd authored Dec 11, 2024
2 parents 274977c + e3b56d7 commit 6b7ca46
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,68 @@

module @iree_default_tuning_spec_gfx942 attributes { transform.with_named_sequence } {

transform.named_sequence @__kernel_config(%variant_op: !transform.any_op {transform.readonly}) -> ()
attributes { iree_codegen.tuning_spec_entrypoint } {
transform.named_sequence @apply_op_config(%op: !transform.any_op {transform.readonly},
%config: !transform.any_param {transform.readonly}) {
// transform.print %op {name="Apply on"} : !transform.any_op
transform.annotate %op "compilation_info" = %config : !transform.any_op, !transform.any_param
// Add a dummy unit attribute to be sure that the tuning spec applied.
// Otherwise it would be difficult to tell if the lowering config attribute
// comes from our tuning spec or if the compiler heuristic happened to produce
// the same config as this script.
transform.annotate %op "__tuning_spec_applied__" : !transform.any_op
transform.yield
}

transform.named_sequence @match_mmt_f16_f16_f32(%root: !transform.any_op {transform.readonly}) -> !transform.any_op {
transform.match.operation_name %root ["linalg.generic"] : !transform.any_op
// transform.print %root {name = "Generic"} : !transform.any_op
%ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root {
^bb0(%lhs: tensor<?x?xf16>, %rhs: tensor<?x?xf16>, %out: tensor<?x?xf32>):
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%lhs, %rhs : tensor<?x?xf16>, tensor<?x?xf16>) outs(%out : tensor<?x?xf32>) {
^bb0(%in: f16, %in_0: f16, %acc: f32):
%8 = arith.extf %in : f16 to f32
%9 = arith.extf %in_0 : f16 to f32
%10 = arith.mulf %8, %9 : f32
%11 = arith.addf %acc, %10 : f32
linalg.yield %11 : f32
} -> tensor<?x?xf32>
} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
transform.yield %root : !transform.any_op
}

transform.named_sequence
@match_mmt_2048x1280x5120_f16_f16_f32(%matmul: !transform.any_op {transform.readonly})
-> (!transform.any_op, !transform.any_param) {
%mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul)
: (!transform.any_op) -> !transform.any_op
%lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value
%rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value
transform.iree.match.cast_compatible_type %lhs = tensor<2048x5120xf16> : !transform.any_value
transform.iree.match.cast_compatible_type %rhs = tensor<1280x5120xf16> : !transform.any_value
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1],
mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
subgroup_m_count = 2, subgroup_n_count = 2,
reduction = [0, 0, 64],
workgroup = [64, 128, 0]}>,
translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
workgroup_size = [256, 1, 1] subgroup_size = 64,
{gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true>}>
> -> !transform.any_param
transform.yield %matmul, %config : !transform.any_op, !transform.any_param
}

transform.named_sequence
@__kernel_config(%variant_op: !transform.any_op {transform.consumed}) -> !transform.any_op
attributes { iree_codegen.tuning_spec_entrypoint } {
%res = transform.foreach_match in %variant_op
@match_mmt_2048x1280x5120_f16_f16_f32 -> @apply_op_config
: (!transform.any_op) -> !transform.any_op
transform.yield %res : !transform.any_op
}

}
24 changes: 21 additions & 3 deletions compiler/plugins/target/ROCM/builtins/tuning/test/spec_gfx942.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
// RUN: --iree-codegen-notify-transform-strategy-application \
// RUN: --verify-diagnostics %s | FileCheck %s

// CHECK-LABEL: func.func @placeholder
// Check that the default configuration for mmt_2048x1280x5120_f16_f16_f32
// applies to the `linalg.matmul_transpose_b` below.

// CHECK-LABEL: func.func @mmt_2048x1280x5120_f16_f16_f32
// CHECK: linalg.generic
// CHECK-SAME: __tuning_spec_applied__

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
Expand All @@ -13,14 +18,27 @@
]>
hal.executable public @main {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
hal.executable.export public @placeholder ordinal(0) layout(#pipeline_layout) {
hal.executable.export public @matmul_transpose_b ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
// expected-remark@+1 {{Applied transform configuration strategy @iree_default_tuning_spec_gfx942::@__kernel_config}}
func.func @placeholder() {
func.func @mmt_2048x1280x5120_f16_f16_f32() {
%cst = arith.constant 0.000000e+00 : f16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2048x5120xf16>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1280x5120xf16>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2048x1280xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2048, 5120], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2048x5120xf16>> -> tensor<2048x5120xf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1280, 5120], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1280x5120xf16>> -> tensor<1280x5120xf16>
%5 = tensor.empty() : tensor<2048x1280xf32>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>
%7 = linalg.matmul_transpose_b
ins(%3, %4 : tensor<2048x5120xf16>, tensor<1280x5120xf16>)
outs(%6 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2048, 1280], strides = [1, 1] : tensor<2048x1280xf32> -> !flow.dispatch.tensor<writeonly:tensor<2048x1280xf32>>
return
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,20 @@
// RUN: --iree-codegen-notify-transform-strategy-application \
// RUN: --verify-diagnostics %s | FileCheck %s

// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 \
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-hal-configure-target-executable-variants{target=rocm})))" \
// RUN: --iree-codegen-tuning-spec-path=%p/tuning_spec_mmt_tile_and_fuse.mlir \
// RUN: --iree-codegen-enable-default-tuning-specs \
// RUN: --iree-codegen-notify-transform-strategy-application \
// RUN: --verify-diagnostics %s | FileCheck %s

// Make sure we can apply the lowering strategy from the specified tuning spec.

// CHECK: #translation = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [128, 1, 1] subgroup_size = 64>
// CHECK: func.func @matmul_transpose_b
// CHECK-SAME: translation_info = #translation
// CHECK: linalg.generic
// CHECK-SAME: __tuning_spec_applied__
// CHECK-SAME: __custom_tuning_spec_applied__
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<

#pipeline_layout = #hal.pipeline.layout<bindings = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,33 @@
// RUN: iree-opt %s

module @mmt_tile_and_fuse_spec attributes { transform.with_named_sequence } {
transform.named_sequence @main(%arg0: !transform.any_op {transform.readonly}) -> ()
attributes { iree_codegen.tuning_spec_entrypoint } {
%mmt = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
// transform.print %mmt {name="MMT"} : !transform.any_op
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0],
reduction = [0, 0, 4],
thread = [8, 4],
promote_operands = [0, 1]}>,
translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
workgroup_size = [128, 1, 1] subgroup_size = 64>
> -> !transform.any_param
transform.annotate %mmt "compilation_info" = %config : !transform.any_op, !transform.any_param
// Add a dummy unit attribute to be sure that the tuning spec applied.
// Otherwise it would be difficult to tell if the lowering config attribute
// comes from our tuning spec or if the compiler heuristic happened to produce
// the same config as this script.
transform.annotate %mmt "__tuning_spec_applied__" : !transform.any_op
transform.yield
}
transform.named_sequence @apply_op_config(%op: !transform.any_op {transform.readonly},
%config: !transform.any_param {transform.readonly}) {
transform.annotate %op "compilation_info" = %config : !transform.any_op, !transform.any_param
transform.annotate %op "__custom_tuning_spec_applied__" : !transform.any_op
transform.yield
}

transform.named_sequence @match_mmt(%matmul: !transform.any_op {transform.readonly})
-> (!transform.any_op, !transform.any_param) {
transform.match.operation_name %matmul ["linalg.generic"] : !transform.any_op
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0],
reduction = [0, 0, 4],
thread = [8, 4],
promote_operands = [0, 1]}>,
translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
workgroup_size = [128, 1, 1] subgroup_size = 64>
> -> !transform.any_param
transform.yield %matmul, %config : !transform.any_op, !transform.any_param
}

transform.named_sequence @main(%variant_op: !transform.any_op {transform.consumed}) -> (!transform.any_op)
attributes { iree_codegen.tuning_spec_entrypoint } {
transform.print %variant_op {name="Custom spec"} : !transform.any_op
%res = transform.foreach_match in %variant_op
@match_mmt -> @apply_op_config
: (!transform.any_op) -> !transform.any_op
transform.yield %res : !transform.any_op
}
}
79 changes: 59 additions & 20 deletions compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
Expand All @@ -31,6 +31,10 @@ namespace mlir::iree_compiler {
namespace {

using mlir::transform::NamedSequenceOp;
constexpr StringLiteral kArgConsumedAttrName =
mlir::transform::TransformDialect::kArgConsumedAttrName;
constexpr StringLiteral kArgReadOnlyAttrName =
mlir::transform::TransformDialect::kArgReadOnlyAttrName;

static SmallVector<ModuleOp>
findNestedModulesWithNamedSequences(ModuleOp module) {
Expand All @@ -49,50 +53,67 @@ static SmallVector<NamedSequenceOp> findTuningSpecs(ModuleOp module) {
});
}

// Returns true iff the entrypoint has the following signature:
// ```
// transform.named_sequence @name(%arg0: !transform.any_op) ->
// (!transform.any_op)
// ```
static LogicalResult validateTuningSpec(NamedSequenceOp op) {
if (!op.getResultTypes().empty()) {
op->emitWarning() << "Tuning spec expected to have no results";
return failure();
ArrayRef<Type> resTypes = op.getFunctionType().getResults();
if (resTypes.size() != 1 || !isa<transform::AnyOpType>(resTypes[0])) {
return op.emitWarning()
<< "Tuning spec entry point expected to return any_op";
}

ArrayRef<Type> argTypes = op.getArgumentTypes();
if (argTypes.size() != 1 || !isa<transform::AnyOpType>(argTypes[0])) {
op->emitWarning() << "Tuning spec expected to have one argument of type "
"'!transform.any_op'";
return failure();
}

if (!op.getArgAttr(0, transform::TransformDialect::kArgReadOnlyAttrName)) {
op->emitWarning() << "Tuning spec expected to have one readonly argument";
return failure();
return op.emitWarning() << "Tuning spec entry point expected to have a "
"single any_op argument";
}

return success();
}

static bool consumesInputOp(NamedSequenceOp op) {
if (op.getArgAttr(0, kArgConsumedAttrName)) {
return true;
}
return false;
}

static NamedSequenceOp
emitLinkedTuningSpec(ModuleOp module, ArrayRef<NamedSequenceOp> specsToLink) {
OpBuilder builder(module->getContext());
builder.setInsertionPointToEnd(module.getBody());

const bool hasConsumedSequences = llvm::any_of(specsToLink, consumesInputOp);
Location loc = builder.getFusedLoc(llvm::map_to_vector(
specsToLink, [](NamedSequenceOp op) { return op->getLoc(); }));
FunctionType specType = builder.getFunctionType(
TypeRange{builder.getType<transform::AnyOpType>()}, TypeRange{});
Type anyOpType = builder.getType<transform::AnyOpType>();
FunctionType specType =
builder.getFunctionType(TypeRange{anyOpType}, TypeRange{anyOpType});
auto newSpec = builder.create<NamedSequenceOp>(
loc, kKernelConfigSpecName, TypeAttr::get(specType),
/*sym_visibility=*/StringAttr{},
/*arg_attrs=*/ArrayAttr{},
/*res_attrs*/ ArrayAttr{});
newSpec.setArgAttr(0, transform::TransformDialect::kArgReadOnlyAttrName,
builder.getUnitAttr());
newSpec.setArgAttr(
0, hasConsumedSequences ? kArgConsumedAttrName : kArgReadOnlyAttrName,
builder.getUnitAttr());
newSpec->setAttr(kTuningSpecEntrypointAttrName, builder.getUnitAttr());

Region &region = newSpec.getRegion();
Block *body = builder.createBlock(&region, region.begin(),
newSpec.getArgumentTypes(), loc);
builder.setInsertionPointToStart(body);

// Make sure spec names are unique to work around a transform dialect
// interpreter bug (`transform.include` does not handle name collisions
// correctly): https://github.com/llvm/llvm-project/issues/119578.
llvm::StringMap<unsigned> specNameCounts;
// Reserve the name for the outermost entrypoint.
specNameCounts[kKernelConfigSpecName] = 1;

// Emit one `transform.include` op per child tuning spec. In the future,
// we may want to switch to a custom transform op for this to perform
// 'short-circuring' and apply at most one tuning spec.
Expand All @@ -102,17 +123,27 @@ emitLinkedTuningSpec(ModuleOp module, ArrayRef<NamedSequenceOp> specsToLink) {
assert(parentModule);
StringAttr parentSymbol = parentModule.getSymNameAttr();
assert(parentSymbol);
StringRef specName = spec.getSymName();
unsigned specNameSeenCount = specNameCounts[specName]++;
if (specNameSeenCount > 0) {
spec.setSymName(
llvm::formatv("{}_{}", specName, specNameSeenCount).str());
}

auto symbol = SymbolRefAttr::get(
parentSymbol, FlatSymbolRefAttr::get(spec.getSymNameAttr()));

// Surpress silenceable errors so that failures to match in child tuning
// specs can be ignored.
builder.create<transform::IncludeOp>(
loc, TypeRange{}, symbol, transform::FailurePropagationMode::Suppress,
operand);
operand = builder
.create<transform::IncludeOp>(
loc, anyOpType, symbol,
transform::FailurePropagationMode::Suppress, operand)
.getResults()
.front();
}

builder.create<transform::YieldOp>(loc);
builder.create<transform::YieldOp>(loc, operand);
return newSpec;
}

Expand Down Expand Up @@ -145,6 +176,14 @@ FailureOr<NamedSequenceOp> linkTuningSpecs(ModuleOp module) {
}
}

size_t numConsumedSpecs = llvm::count_if(tuningSpecs, consumesInputOp);
if (numConsumedSpecs > 0 && numConsumedSpecs != tuningSpecs.size()) {
LDBG("Only " << numConsumedSpecs << " tuning specs out of "
<< tuningSpecs.size() << " total consume the input op");
return module.emitWarning() << "Expected the argument in all tuning specs "
"to be consistently readonly or consumed";
}

if (tuningSpecs.empty()) {
LDBG("No tuning specs found, exiting without linking");
return NamedSequenceOp{};
Expand Down
Loading

0 comments on commit 6b7ca46

Please sign in to comment.