Skip to content

Commit

Permalink
[Codegen] Add pass to verify workgroup distribution (#19186)
Browse files Browse the repository at this point in the history
While general verification is not possible, when using `scf.forall` for
workgroup distribution we have the opportunity for basic verification
that all writes are located within the distributed loop. In particular,
if we have any workgroup level loops, any write to global memory outside
is assumed to be illegal.

This happens after bufferization because it is impossible to do this
verification before determining the memory space of every tensor in the
dispatch.

This pass is relatively lightweight (two walks, both of which should be
short) and so is on by default for every CPU and GPU pipeline.
  • Loading branch information
qedawkins authored Nov 19, 2024
1 parent 35b495b commit 47432c6
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 4 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ iree_compiler_cc_library(
"UnrollAnnotatedLoops.cpp",
"UserConfig.cpp",
"VectorizeMemrefCopy.cpp",
"VerifyWorkgroupDistribution.cpp",
],
hdrs = [
"BufferizationAnalysis.h",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ iree_cc_library(
"UnrollAnnotatedLoops.cpp"
"UserConfig.cpp"
"VectorizeMemrefCopy.cpp"
"VerifyWorkgroupDistribution.cpp"
DEPS
::PassHeaders
::PassesIncGen
Expand Down
11 changes: 11 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -633,4 +633,15 @@ def VectorizeMemrefCopyPass :
let summary = "Vectorizes memref copy operations.";
}

def VerifyWorkgroupDistributionPass :
InterfacePass<"iree-codegen-verify-workgroup-distribution", "mlir::FunctionOpInterface"> {
let summary = "Pass to verify proper distribution to workgroups.";
let description = [{
Pass to verify that all writes to global memory are explicitly mapped to
workgroups. This means that in cases where we use loops (scf.forall) to
manage distribution to workgroups, we require that all ops with write
side effects are contained within a workgroup distributed loop.
}];
}

#endif // IREE_CODEGEN_COMMON_PASSES
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_VERIFYWORKGROUPDISTRIBUTIONPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

namespace {

struct VerifyWorkgroupDistributionPass final
: impl::VerifyWorkgroupDistributionPassBase<
VerifyWorkgroupDistributionPass> {

void runOnOperation() override {
FunctionOpInterface funcOp = getOperation();

WalkResult hasForall = funcOp.walk([&](scf::ForallOp forallOp) {
if (forallOpHasMappingType<IREE::Codegen::WorkgroupMappingAttr>(
forallOp)) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});

// Without a workgroup level forall, either this is a single workgroup
// dispatch, in which case no verification is needed, or this is already
// distributed in which case verification is no longer possible.
if (!hasForall.wasInterrupted()) {
return;
}

auto globalAddressSpace = IREE::HAL::DescriptorTypeAttr::get(
&getContext(), IREE::HAL::DescriptorType::StorageBuffer);

// Walk in PreOrder so that parent operations are visited before children,
// thus allowing all operations contained within workgroup foralls to be
// skipped.
WalkResult res = funcOp.walk<WalkOrder::PreOrder>([&](Operation *op) {
if (auto forallOp = dyn_cast<scf::ForallOp>(op)) {
// Skip ops contained within forall ops with workgroup mappings.
if (forallOpHasMappingType<IREE::Codegen::WorkgroupMappingAttr>(
forallOp)) {
return WalkResult::skip();
}
}
if (auto memoryEffectOp = dyn_cast<MemoryEffectOpInterface>(op)) {
for (Value operand : memoryEffectOp->getOperands()) {
auto type = dyn_cast<MemRefType>(operand.getType());
if (!type ||
!memoryEffectOp.getEffectOnValue<MemoryEffects::Write>(operand)) {
continue;
}

// Writes to non-global memory are fine.
if (type.getMemorySpace() != globalAddressSpace) {
continue;
}

op->emitOpError(
"write affecting operations on global resources are restricted "
"to workgroup distributed contexts.");
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});

if (res.wasInterrupted()) {
funcOp.emitOpError("failed on workgroup distribution verification");
return signalPassFailure();
}
}
};

} // namespace

} // namespace mlir::iree_compiler
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ iree_lit_test_suite(
"vectorize_memref_copy.mlir",
"vectorize_tensor_pad.mlir",
"vector_layout_analysis.mlir",
"verify_workgroup_distribution.mlir",
],
include = ["*.mlir"],
exclude = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ iree_lit_test_suite(
"vector_layout_analysis.mlir"
"vectorize_memref_copy.mlir"
"vectorize_tensor_pad.mlir"
"verify_workgroup_distribution.mlir"
TOOLS
FileCheck
iree-opt
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// RUN: iree-opt %s --split-input-file --verify-diagnostics \
// RUN: --pass-pipeline="builtin.module(func.func(iree-codegen-verify-workgroup-distribution))" \
// RUN: | FileCheck %s

// expected-error@+1 {{op failed on workgroup distribution verification}}
func.func @write_outside_workgroup_forall(%i: i32, %out: memref<32xi32, #hal.descriptor_type<storage_buffer>>) {
scf.forall (%arg0) in (32) {
} {mapping = [#iree_codegen.workgroup_mapping<x>]}
%c0 = arith.constant 0 : index
// expected-error@+1 {{write affecting operations on global resources are restricted to workgroup distributed contexts.}}
memref.store %i, %out[%c0] : memref<32xi32, #hal.descriptor_type<storage_buffer>>
return
}

// -----

// CHECK: func @non_workgroup_write_outside_workgroup_forall
func.func @non_workgroup_write_outside_workgroup_forall(
%i: i32, %out: memref<32xi32, #hal.descriptor_type<storage_buffer>>, %out2: memref<32xi32>) {
scf.forall (%arg0) in (32) {
memref.store %i, %out[%arg0] : memref<32xi32, #hal.descriptor_type<storage_buffer>>
} {mapping = [#iree_codegen.workgroup_mapping<x>]}
%c0 = arith.constant 0 : index
memref.store %i, %out2[%c0] : memref<32xi32>
return
}

// -----

// expected-error@+1 {{op failed on workgroup distribution verification}}
func.func @write_nested_in_other_forall(%i: i32, %out: memref<32xi32, #hal.descriptor_type<storage_buffer>>) {
scf.forall (%arg0) in (32) {
} {mapping = [#iree_codegen.workgroup_mapping<x>]}
%c0 = arith.constant 0 : index
scf.forall (%arg1) in (32) {
// expected-error@+1 {{write affecting operations on global resources are restricted to workgroup distributed contexts.}}
memref.store %i, %out[%arg1] : memref<32xi32, #hal.descriptor_type<storage_buffer>>
}
return
}
3 changes: 2 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,8 @@ void buildLLVMCPUCodegenPassPipeline(OpPassManager &variantPassManager,
OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass());
FunctionLikeNest(modulePassManager)
.addPass(createLLVMCPULowerExecutableTargetPass);
.addPass(createLLVMCPULowerExecutableTargetPass)
.addPass(createVerifyWorkgroupDistributionPass);
}

variantPassManager.addPass(createReconcileTranslationInfoPass());
Expand Down
6 changes: 4 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1184,7 +1184,8 @@ void buildLLVMGPUCodegenPassPipeline(OpPassManager &variantPassManager,
OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass());
FunctionLikeNest(modulePassManager)
.addPass(createLLVMGPULowerExecutableTargetPass);
.addPass(createLLVMGPULowerExecutableTargetPass)
.addPass(createVerifyWorkgroupDistributionPass);
}
variantPassManager.addPass(createReconcileTranslationInfoPass());

Expand Down Expand Up @@ -1250,7 +1251,8 @@ void buildROCDLCodegenPassPipeline(OpPassManager &variantPassManager) {
OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass());
FunctionLikeNest(modulePassManager)
.addPass(createROCDLLowerExecutableTargetPass);
.addPass(createROCDLLowerExecutableTargetPass)
.addPass(createVerifyWorkgroupDistributionPass);
}
variantPassManager.addPass(createReconcileTranslationInfoPass());
variantPassManager.addPass(IREE::Util::createDropCompilerHintsPass());
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,8 @@ void buildSPIRVCodegenPassPipeline(OpPassManager &variantPassManager) {
modulePassManager.addPass(
createSPIRVLowerExecutableUsingTransformDialectPass());
FunctionLikeNest(modulePassManager)
.addPass(createSPIRVLowerExecutableTargetPass);
.addPass(createSPIRVLowerExecutableTargetPass)
.addPass(createVerifyWorkgroupDistributionPass);
addMemRefLoweringPasses(modulePassManager);
}
variantPassManager.addPass(createReconcileTranslationInfoPass());
Expand Down

0 comments on commit 47432c6

Please sign in to comment.