Skip to content

Commit

Permalink
Improving linking support for ROCM and ukernels. (#19211)
Browse files Browse the repository at this point in the history
To support externally-defined ukernels on ROCM the ROCMTarget has been
brought in-line with LLVMCPU/CUDA by calling `linkBitcodeObjects`. To
make authoring passes that include object references
`#hal.executable.object` now allows any data type to be associated so
long as it is serializable allowing for external resource attrs and
other custom attributes that may serialize based on other information.
To allow patterns to attach object references all ops within an
executable variant can now declare a `hal.executable.objects` array that
will be hoisted and merged into the top-level variant objects after our
executable linking pass (before serialization where they are used).
  • Loading branch information
benvanik authored Nov 19, 2024
2 parents 4396bf1 + f510664 commit 82a89e3
Show file tree
Hide file tree
Showing 13 changed files with 191 additions and 14 deletions.
20 changes: 20 additions & 0 deletions compiler/plugins/target/ROCM/ROCMTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,26 @@ class ROCMTargetBackend final : public TargetBackend {
}
}

// Link bitcode (*.bc) object attrs specified by the input program.
// Note that this happens after the command-line files so that the command
// line ones override the symbols coming from the embedded files.
auto specializationCallback = [&](llvm::Module &userModule) {
// TODO: inject __nvvm_reflect-style functions/globals for
// bitcode specialization based on the targetMachine and configuration.
// These could use any information we have on the IREE side as well as
// the TargetMachine.
};
unsigned linkerFlags =
llvm::Linker::LinkOnlyNeeded | llvm::Linker::OverrideFromSrc;
if (failed(linkBitcodeObjects(variantOp.getLoc(), linker, linkerFlags,
*targetMachine, variantOp.getObjectsAttr(),
llvmModule->getContext(),
specializationCallback))) {
return mlir::emitError(variantOp.getLoc())
<< "failed linking in user objects for target triple '"
<< targetArch.str() << "'";
}

// Link module to HIP device library.
if (bitcodeDirectory.empty()) {
return variantOp.emitError()
Expand Down
17 changes: 10 additions & 7 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ Attribute ExecutableObjectAttr::parse(AsmParser &p, Type type) {
}
auto pathAttr = llvm::dyn_cast_if_present<StringAttr>(dict.get("path"));
auto dataAttr =
llvm::dyn_cast_if_present<DenseIntElementsAttr>(dict.get("data"));
llvm::dyn_cast_if_present<IREE::Util::SerializableAttrInterface>(
dict.get("data"));
return get(p.getContext(), pathAttr, dataAttr);
}

Expand Down Expand Up @@ -312,12 +313,14 @@ FailureOr<std::string> ExecutableObjectAttr::getAbsolutePath() {

std::optional<std::string> ExecutableObjectAttr::loadData() {
if (auto dataAttr = getData()) {
// This is shady but so is using this feature.
// TODO(benvanik): figure out a way to limit the attribute to signless int8.
// We could share the attribute -> byte array code with the VM constant
// serialization if we wanted.
auto rawData = dataAttr.getRawData();
return std::string(rawData.data(), rawData.size());
std::string buffer;
buffer.resize(dataAttr.getStorageSize());
if (failed(dataAttr.serializeToBuffer(
UnknownLoc::get(dataAttr.getContext()), llvm::endianness::native,
ArrayRef(buffer.data(), buffer.size())))) {
return std::nullopt;
}
return buffer;
} else if (auto pathAttr = getPath()) {
// Search for file and try to load it if found.
auto filePath =
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def HAL_ExecutableObjectAttr : AttrDef<HAL_Dialect, "ExecutableObject"> {

let parameters = (ins
AttrParameter<"StringAttr", "">:$path,
OptionalParameter<"DenseIntElementsAttr", "">:$data
OptionalParameter<"IREE::Util::SerializableAttrInterface", "">:$data
);

let hasCustomAssemblyFormat = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ iree_compiler_cc_library(
"DumpExecutableSources.cpp",
"ElideRedundantCommands.cpp",
"FixupLegacySync.cpp",
"HoistExecutableObjects.cpp",
"InitializeDevices.cpp",
"InlineMemoizeRegions.cpp",
"LinkExecutables.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_cc_library(
"DumpExecutableSources.cpp"
"ElideRedundantCommands.cpp"
"FixupLegacySync.cpp"
"HoistExecutableObjects.cpp"
"InitializeDevices.cpp"
"InlineMemoizeRegions.cpp"
"LinkExecutables.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"

namespace mlir::iree_compiler::IREE::HAL {

#define GEN_PASS_DEF_HOISTEXECUTABLEOBJECTSPASS
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h.inc"

namespace {

//===----------------------------------------------------------------------===//
// --iree-hal-hoist-executable-objects
//===----------------------------------------------------------------------===//

struct HoistExecutableObjectsPass
: public IREE::HAL::impl::HoistExecutableObjectsPassBase<
HoistExecutableObjectsPass> {
void runOnOperation() override {
// Note that some executables may be external and not have any contents.
if (getOperation().isExternal()) {
return;
}

auto objectsAttrName =
StringAttr::get(&getContext(), "hal.executable.objects");

// Seed with existing variant-level object attrs, if any present.
SetVector<Attribute> allObjectAttrs;
if (auto existingAttr = getOperation().getObjectsAttr()) {
allObjectAttrs.insert(existingAttr.begin(), existingAttr.end());
}

// Move all op-level attributes into a unique set. Note that order can be
// important so we use an ordered set.
//
// We could do this first as a gather step in parallel if this walk gets too
// expensive.
bool foundAnyAttrs = false;
getOperation().getInnerModule().walk([&](Operation *op) {
auto objectsAttr = op->getAttrOfType<ArrayAttr>(objectsAttrName);
if (objectsAttr) {
allObjectAttrs.insert(objectsAttr.begin(), objectsAttr.end());
op->removeAttr(objectsAttrName);
foundAnyAttrs = true;
}
});

// Update the variant if any changes were made.
if (foundAnyAttrs) {
getOperation().setObjectsAttr(
ArrayAttr::get(&getContext(), allObjectAttrs.getArrayRef()));
}
}
};

} // namespace

} // namespace mlir::iree_compiler::IREE::HAL
10 changes: 9 additions & 1 deletion compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,14 +471,22 @@ void buildHALTransformPassPipeline(OpPassManager &passManager,
// TODO(benvanik): move translation down to here.

// After all executables are translated and before resolving export
// ordinals, we allow the backends to link executables together. For
// ordinals we allow the backends to link executables together. For
// example, the LLVM AOT backend may combine all executable targets for the
// same architecture into a single executable and link it as a shared
// library.
if (transformOptions.linkExecutables) {
passManager.addPass(IREE::HAL::createLinkExecutablesPass({targetRegistry}));
}

// If any executable variants have external objects referenced within them
// we hoist them up to the top-level variant. This is done after linking so
// that we have the greatest chance of combining executables without different
// object attrs preventing the merging.
passManager.nest<IREE::HAL::ExecutableOp>()
.addNestedPass<IREE::HAL::ExecutableVariantOp>(
IREE::HAL::createHoistExecutableObjectsPass());

// Resolve export ordinals from nested symbol references prior to
// serialization. As this pass creates lookup ops it should run before
// MaterializeResourceCachesPass.
Expand Down
9 changes: 9 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,15 @@ def TranslateTargetExecutableVariantsPass :
];
}

def HoistExecutableObjectsPass :
Pass<"iree-hal-hoist-executable-objects", "IREE::HAL::ExecutableVariantOp"> {
let summary = "Hoists local executable object annotations to the parent `hal.executable.variant`.";
let description = [{
Finds all `hal.executable.objects` attrs on all ops within an executable
inner module and moves them to the parent `hal.executable.variant` op.
}];
}

def PruneExecutablesPass :
Pass<"iree-hal-prune-executables", "mlir::ModuleOp"> {
let summary = "Prunes executable variants and exports that are not referenced.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,13 @@ externalizeExecutableOp(IREE::HAL::ExecutableOp executableOp,
// objects in case there were any as this does entire executable replacement -
// there may have been microkernel libraries or something referenced by the
// existing module.
auto dataObjectAttr = builder.getAttr<IREE::HAL::ExecutableObjectAttr>(
builder.getStringAttr(llvm::sys::path::filename(filePath)),
DenseIntElementsAttr::get(
auto dataAttr =
cast<IREE::Util::SerializableAttrInterface>(DenseIntElementsAttr::get(
VectorType::get({static_cast<int64_t>(fileContents->size())},
builder.getI8Type()),
ArrayRef(fileContents->data(), fileContents->size())));
auto dataObjectAttr = builder.getAttr<IREE::HAL::ExecutableObjectAttr>(
builder.getStringAttr(llvm::sys::path::filename(filePath)), dataAttr);
variantOp.setObjectsAttr(builder.getArrayAttr({dataObjectAttr}));

// Drop the inner module if present (may already be external).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ iree_lit_test_suite(
"dump_executable_sources.mlir",
"elide_redundant_commands.mlir",
"fixup_legacy_sync.mlir",
"hoist_executable_objects.mlir",
"initialize_devices.mlir",
"inline_memoize_regions.mlir",
"materialize_dispatch_instrumentation.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ iree_lit_test_suite(
"dump_executable_sources.mlir"
"elide_redundant_commands.mlir"
"fixup_legacy_sync.mlir"
"hoist_executable_objects.mlir"
"initialize_devices.mlir"
"inline_memoize_regions.mlir"
"materialize_dispatch_instrumentation.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-hal-hoist-executable-objects)))" %s | FileCheck %s

// Tests that attributes on top-level ops and nested ops are all detected,
// deduplicated, and moved to the variant.

// CHECK: hal.executable public @executable
hal.executable public @executable {
// CHECK: hal.executable.variant public @backend
// CHECK-SAME: objects([
// CHECK-SAME: #hal.executable.object<{path = "existing_variant.obj"}>,
// CHECK-SAME: #hal.executable.object<{path = "extern_fn_common.obj"}>,
// CHECK-SAME: #hal.executable.object<{path = "extern_fn_a.obj"}>,
// CHECK-SAME: #hal.executable.object<{path = "extern_fn_b.obj"}>,
// CHECK-SAME: #hal.executable.object<{path = "nested_common.obj"}>,
// CHECK-SAME: #hal.executable.object<{path = "nested_a.obj"}>,
// CHECK-SAME: #hal.executable.object<{path = "nested_b.obj"}>
hal.executable.variant public @backend target(#hal.executable.target<"backend", "format">) objects([
#hal.executable.object<{path = "existing_variant.obj"}>
]) {
hal.executable.export @entry0 ordinal(0) layout(#hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>
]>)
builtin.module {
// CHECK: func.func private @extern_fn_a
// CHECK-NOT: hal.executable.objects
func.func private @extern_fn_a() attributes {
hal.executable.objects = [
#hal.executable.object<{path = "extern_fn_common.obj"}>,
#hal.executable.object<{path = "extern_fn_a.obj"}>
]
}
// CHECK: func.func private @extern_fn_b
// CHECK-NOT: hal.executable.objects
func.func private @extern_fn_b() attributes {
hal.executable.objects = [
#hal.executable.object<{path = "extern_fn_common.obj"}>,
#hal.executable.object<{path = "extern_fn_b.obj"}>
]
}
func.func @entry0() {
// CHECK: call @extern_fn_a
// CHECK-NOT: hal.executable.objects
call @extern_fn_a() {
hal.executable.objects = [
#hal.executable.object<{path = "nested_common.obj"}>,
#hal.executable.object<{path = "nested_a.obj"}>
]
} : () -> ()
call @extern_fn_b() {
// CHECK: call @extern_fn_b
// CHECK-NOT: hal.executable.objects
hal.executable.objects = [
#hal.executable.object<{path = "nested_common.obj"}>,
#hal.executable.object<{path = "nested_b.obj"}>
]
} : () -> ()
return
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,10 @@ convertPipelineLayout(IREE::Input::PipelineLayoutAttr src) {

static IREE::HAL::ExecutableObjectAttr
convertExecutableObject(IREE::Input::ExecutableObjectAttr src) {
return IREE::HAL::ExecutableObjectAttr::get(src.getContext(), src.getPath(),
src.getData());
return IREE::HAL::ExecutableObjectAttr::get(
src.getContext(), src.getPath(),
dyn_cast_if_present<IREE::Util::SerializableAttrInterface>(
src.getData()));
}

static IREE::HAL::ExecutableTargetAttr
Expand Down

0 comments on commit 82a89e3

Please sign in to comment.