-
Notifications
You must be signed in to change notification settings - Fork 618
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Codegen] Add pass to verify workgroup distribution (#19186)
While general verification is not possible, when using `scf.forall` for workgroup distribution we have the opportunity for basic verification that all writes are located within the distributed loop. In particular, if we have any workgroup level loops, any write to global memory outside is assumed to be illegal. This happens after bufferization because it is impossible to do this verification before determining the memory space of every tensor in the dispatch. This pass is relatively lightweight (two walks, both of which should be short) and so is on by default for every CPU and GPU pipeline.
- Loading branch information
Showing
10 changed files
with
152 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
89 changes: 89 additions & 0 deletions
89
compiler/src/iree/compiler/Codegen/Common/VerifyWorkgroupDistribution.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree/compiler/Codegen/Common/Passes.h" | ||
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" | ||
#include "iree/compiler/Codegen/Utils/GPUUtils.h" | ||
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h" | ||
#include "mlir/IR/Visitors.h" | ||
#include "mlir/Interfaces/FunctionInterfaces.h" | ||
#include "mlir/Interfaces/SideEffectInterfaces.h" | ||
|
||
namespace mlir::iree_compiler { | ||
|
||
#define GEN_PASS_DEF_VERIFYWORKGROUPDISTRIBUTIONPASS | ||
#include "iree/compiler/Codegen/Common/Passes.h.inc" | ||
|
||
namespace { | ||
|
||
struct VerifyWorkgroupDistributionPass final | ||
: impl::VerifyWorkgroupDistributionPassBase< | ||
VerifyWorkgroupDistributionPass> { | ||
|
||
void runOnOperation() override { | ||
FunctionOpInterface funcOp = getOperation(); | ||
|
||
WalkResult hasForall = funcOp.walk([&](scf::ForallOp forallOp) { | ||
if (forallOpHasMappingType<IREE::Codegen::WorkgroupMappingAttr>( | ||
forallOp)) { | ||
return WalkResult::interrupt(); | ||
} | ||
return WalkResult::advance(); | ||
}); | ||
|
||
// Without a workgroup level forall, either this is a single workgroup | ||
// dispatch, in which case no verification is needed, or this is already | ||
// distributed in which case verification is no longer possible. | ||
if (!hasForall.wasInterrupted()) { | ||
return; | ||
} | ||
|
||
auto globalAddressSpace = IREE::HAL::DescriptorTypeAttr::get( | ||
&getContext(), IREE::HAL::DescriptorType::StorageBuffer); | ||
|
||
// Walk in PreOrder so that parent operations are visited before children, | ||
// thus allowing all operations contained within workgroup foralls to be | ||
// skipped. | ||
WalkResult res = funcOp.walk<WalkOrder::PreOrder>([&](Operation *op) { | ||
if (auto forallOp = dyn_cast<scf::ForallOp>(op)) { | ||
// Skip ops contained within forall ops with workgroup mappings. | ||
if (forallOpHasMappingType<IREE::Codegen::WorkgroupMappingAttr>( | ||
forallOp)) { | ||
return WalkResult::skip(); | ||
} | ||
} | ||
if (auto memoryEffectOp = dyn_cast<MemoryEffectOpInterface>(op)) { | ||
for (Value operand : memoryEffectOp->getOperands()) { | ||
auto type = dyn_cast<MemRefType>(operand.getType()); | ||
if (!type || | ||
!memoryEffectOp.getEffectOnValue<MemoryEffects::Write>(operand)) { | ||
continue; | ||
} | ||
|
||
// Writes to non-global memory are fine. | ||
if (type.getMemorySpace() != globalAddressSpace) { | ||
continue; | ||
} | ||
|
||
op->emitOpError( | ||
"write affecting operations on global resources are restricted " | ||
"to workgroup distributed contexts."); | ||
return WalkResult::interrupt(); | ||
} | ||
} | ||
return WalkResult::advance(); | ||
}); | ||
|
||
if (res.wasInterrupted()) { | ||
funcOp.emitOpError("failed on workgroup distribution verification"); | ||
return signalPassFailure(); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
} // namespace mlir::iree_compiler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
40 changes: 40 additions & 0 deletions
40
compiler/src/iree/compiler/Codegen/Common/test/verify_workgroup_distribution.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
// RUN: iree-opt %s --split-input-file --verify-diagnostics \ | ||
// RUN: --pass-pipeline="builtin.module(func.func(iree-codegen-verify-workgroup-distribution))" \ | ||
// RUN: | FileCheck %s | ||
|
||
// expected-error@+1 {{op failed on workgroup distribution verification}} | ||
func.func @write_outside_workgroup_forall(%i: i32, %out: memref<32xi32, #hal.descriptor_type<storage_buffer>>) { | ||
scf.forall (%arg0) in (32) { | ||
} {mapping = [#iree_codegen.workgroup_mapping<x>]} | ||
%c0 = arith.constant 0 : index | ||
// expected-error@+1 {{write affecting operations on global resources are restricted to workgroup distributed contexts.}} | ||
memref.store %i, %out[%c0] : memref<32xi32, #hal.descriptor_type<storage_buffer>> | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK: func @non_workgroup_write_outside_workgroup_forall | ||
func.func @non_workgroup_write_outside_workgroup_forall( | ||
%i: i32, %out: memref<32xi32, #hal.descriptor_type<storage_buffer>>, %out2: memref<32xi32>) { | ||
scf.forall (%arg0) in (32) { | ||
memref.store %i, %out[%arg0] : memref<32xi32, #hal.descriptor_type<storage_buffer>> | ||
} {mapping = [#iree_codegen.workgroup_mapping<x>]} | ||
%c0 = arith.constant 0 : index | ||
memref.store %i, %out2[%c0] : memref<32xi32> | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
// expected-error@+1 {{op failed on workgroup distribution verification}} | ||
func.func @write_nested_in_other_forall(%i: i32, %out: memref<32xi32, #hal.descriptor_type<storage_buffer>>) { | ||
scf.forall (%arg0) in (32) { | ||
} {mapping = [#iree_codegen.workgroup_mapping<x>]} | ||
%c0 = arith.constant 0 : index | ||
scf.forall (%arg1) in (32) { | ||
// expected-error@+1 {{write affecting operations on global resources are restricted to workgroup distributed contexts.}} | ||
memref.store %i, %out[%arg1] : memref<32xi32, #hal.descriptor_type<storage_buffer>> | ||
} | ||
return | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters