Skip to content

Commit

Permalink
[Codegen] Add pass to decompose pack unpack ops at dispatch boundaries (
Browse files Browse the repository at this point in the history
iree-org#18852)

This PR adds a wrapper pass around DecomposePackUnPackOps, which adds a
control function for decomposing only packs and unpacks whose reshapes
can be folded with dispatch tensor loads/stores.

This PR also removes the public pass constructor with a control
function, opting to use wrapper passes in place of constructing passes
with arbitrary control functions. This is better for creating simple bug
repros, since the control function is part of the pass.

---------

Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 authored Oct 24, 2024
1 parent 9c5b57a commit e1469b2
Show file tree
Hide file tree
Showing 8 changed files with 411 additions and 77 deletions.
233 changes: 178 additions & 55 deletions compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand All @@ -16,21 +17,31 @@
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-codegen-decompose-pack-unpack-ops"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_DECOMPOSEPACKUNPACKOPSPASS
#define GEN_PASS_DEF_DECOMPOSEBOUNDARYPACKUNPACKOPSPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

using PackUnPackControlFn = std::function<LogicalResult(Operation *)>;

namespace {

//===----------------------------------------------------------------------===//
// Shared rewrite patterns
//===----------------------------------------------------------------------===//

/// A wrapper pattern that calls linalg::lowerPack on tensor::PackOp. It lowers
/// a tensor.pack op to tensor.pad + tensor.expand_shape + linalg.transpose ops.
struct LowerPackPattern : public OpRewritePattern<tensor::PackOp> {
Expand Down Expand Up @@ -85,33 +96,14 @@ struct LowerUnPackPattern : public OpRewritePattern<tensor::UnPackOp> {
std::optional<PackUnPackControlFn> controlFn;
};

struct DecomposePackUnPackOpsPass final
: impl::DecomposePackUnPackOpsPassBase<DecomposePackUnPackOpsPass> {
using impl::DecomposePackUnPackOpsPassBase<
DecomposePackUnPackOpsPass>::DecomposePackUnPackOpsPassBase;
explicit DecomposePackUnPackOpsPass(
bool tileOuterToOne, bool useOnlyReshapes,
std::optional<PackUnPackControlFn> controlFn) {
this->tileOuterToOne = tileOuterToOne;
this->useOnlyReshapes = useOnlyReshapes;
this->controlFn = controlFn;
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, arith::ArithDialect, scf::SCFDialect,
tensor::TensorDialect>();
}
//===----------------------------------------------------------------------===//
// Shared pass implementation
//===----------------------------------------------------------------------===//

void runOnOperation() override;

private:
std::optional<PackUnPackControlFn> controlFn;
};

} // namespace

void DecomposePackUnPackOpsPass::runOnOperation() {
MLIRContext *ctx = &getContext();
auto funcOp = getOperation();
static LogicalResult commonRunOnOperation(
MLIRContext *ctx, FunctionOpInterface funcOp, bool useOnlyReshapes,
bool tileOuterToOne,
std::optional<PackUnPackControlFn> controlFn = std::nullopt) {
// Generalization patterns for outer unit dims have higher priority because
// they do not generate reshape ops.
if (!useOnlyReshapes) {
Expand All @@ -122,7 +114,7 @@ void DecomposePackUnPackOpsPass::runOnOperation() {
funcOp.emitError(
"failed to apply generalization patterns on pack/unpack ops for "
"outer unit dims cases");
return signalPassFailure();
return failure();
}
}

Expand All @@ -135,7 +127,7 @@ void DecomposePackUnPackOpsPass::runOnOperation() {
funcOp.emitError(
"failed to apply generalization patterns on pack/unpack ops for "
"general cases.");
return signalPassFailure();
return failure();
}
}

Expand Down Expand Up @@ -163,17 +155,24 @@ void DecomposePackUnPackOpsPass::runOnOperation() {
builder.getIndexAttr(1));
return tileSizes;
}));
funcOp->walk([&](tensor::PackOp op) {
if (controlFn && failed(controlFn.value()(op))) {
return;
{
WalkResult status = funcOp->walk([&](tensor::PackOp op) {
if (controlFn && failed(controlFn.value()(op))) {
return WalkResult::advance();
}
FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
scf::tileConsumerAndFuseProducersUsingSCF(
rewriter, cast<TilingInterface>(op.getOperation()),
packOptions);
if (failed(tileAndFuseResult))
return WalkResult::interrupt();
rewriter.replaceOp(op, tileAndFuseResult->replacements[op.getResult()]);
return WalkResult::advance();
});
if (status.wasInterrupted()) {
return failure();
}
FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
scf::tileConsumerAndFuseProducersUsingSCF(
rewriter, cast<TilingInterface>(op.getOperation()), packOptions);
if (failed(tileAndFuseResult))
return signalPassFailure();
rewriter.replaceOp(op, tileAndFuseResult->replacements[op.getResult()]);
});
}

auto unpackTilingOptions =
scf::SCFTilingOptions().setTileSizeComputationFunction(
Expand All @@ -191,17 +190,23 @@ void DecomposePackUnPackOpsPass::runOnOperation() {
}
return tileSizes;
});
funcOp->walk([&](tensor::UnPackOp op) {
if (controlFn && failed(controlFn.value()(op))) {
return;
{
WalkResult status = funcOp->walk([&](tensor::UnPackOp op) {
if (controlFn && failed(controlFn.value()(op))) {
return WalkResult::advance();
}
FailureOr<scf::SCFTilingResult> tilingResult = scf::tileUsingSCF(
rewriter, cast<TilingInterface>(op.getOperation()),
unpackTilingOptions);
if (failed(tilingResult))
return WalkResult::interrupt();
rewriter.replaceOp(op, tilingResult->replacements);
return WalkResult::advance();
});
if (status.wasInterrupted()) {
return failure();
}
FailureOr<scf::SCFTilingResult> tilingResult =
scf::tileUsingSCF(rewriter, cast<TilingInterface>(op.getOperation()),
unpackTilingOptions);
if (failed(tilingResult))
return signalPassFailure();
rewriter.replaceOp(op, tilingResult->replacements);
});
}

LLVM_DEBUG({
llvm::dbgs()
Expand All @@ -219,7 +224,7 @@ void DecomposePackUnPackOpsPass::runOnOperation() {
ctx->getOrLoadDialect<tensor::TensorDialect>()->getCanonicalizationPatterns(
patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
return failure();
}
}

Expand All @@ -238,16 +243,134 @@ void DecomposePackUnPackOpsPass::runOnOperation() {
linalg::GeneralizeOuterUnitDimsUnPackOpPattern>(ctx);
}
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
return failure();
}
}
return success();
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createDecomposePackUnPackOpsPass(bool tileOuterToOne, bool useOnlyReshapes,
std::optional<PackUnPackControlFn> controlFn) {
return std::make_unique<DecomposePackUnPackOpsPass>(
tileOuterToOne, useOnlyReshapes, controlFn);
//===----------------------------------------------------------------------===//
// DecomposePackUnPackOpsPass
//===----------------------------------------------------------------------===//

struct DecomposePackUnPackOpsPass final
: impl::DecomposePackUnPackOpsPassBase<DecomposePackUnPackOpsPass> {
using impl::DecomposePackUnPackOpsPassBase<
DecomposePackUnPackOpsPass>::DecomposePackUnPackOpsPassBase;

void runOnOperation() override;
};

} // namespace

void DecomposePackUnPackOpsPass::runOnOperation() {
if (failed(commonRunOnOperation(&getContext(), getOperation(),
useOnlyReshapes, tileOuterToOne))) {
return signalPassFailure();
}
}

//===----------------------------------------------------------------------===//
// DecomposeBoundaryPackUnPackOpsPass
//===----------------------------------------------------------------------===//

namespace {

struct DecomposeBoundaryPackUnPackOpsPass final
: impl::DecomposeBoundaryPackUnPackOpsPassBase<
DecomposeBoundaryPackUnPackOpsPass> {
using impl::DecomposeBoundaryPackUnPackOpsPassBase<
DecomposeBoundaryPackUnPackOpsPass>::
DecomposeBoundaryPackUnPackOpsPassBase;

void runOnOperation() override;
};

} // namespace

/// Check if the given `op` is a pack or unpack op with padding.
static bool hasPadding(Operation *op) {
auto needsPad = [](ShapedType unpackedType, ArrayRef<int64_t> innerDimPos,
ArrayRef<int64_t> staticInnerTiles) {
for (auto [dimPos, tile] : llvm::zip_equal(innerDimPos, staticInnerTiles)) {
if (unpackedType.isDynamicDim(dimPos) || ShapedType::isDynamic(tile) ||
unpackedType.getDimSize(dimPos) % tile != 0) {
return true;
}
}
return false;
};
auto packOp = dyn_cast<tensor::PackOp>(op);
if (packOp && needsPad(packOp.getSourceType(), packOp.getInnerDimsPos(),
packOp.getStaticInnerTiles())) {
return true;
}
auto unPackOp = dyn_cast<tensor::UnPackOp>(op);
if (unPackOp && needsPad(unPackOp.getDestType(), unPackOp.getInnerDimsPos(),
unPackOp.getStaticInnerTiles())) {
return true;
}
return false;
}

/// Control function for decomposing pack and unpack ops. Returns true if the
/// op is a pack or unpack op, and its reshapes can be folded with a producer
/// or consumer interface tensor op. To be foldable, the following conditions
/// must be met:
///
/// 1. The PackOp or UnPackOp must have no padding.
/// 2. If the op is a PackOp, then its producer must be a dispatch tensor load.
/// 3. If the op is an UnPackOp, then all of its consumers must be dispatch
/// tensor stores.
/// 4. Any dispatch tensor load producers or dispatch tensor store consumers
/// must be full slices.
static LogicalResult isFoldableIntoInterfaceTensor(Operation *op) {
// Full slice means zero offsets, unit strides, and sizes match full tensor
// shape.
auto isFullSlice =
[](ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides, ArrayRef<int64_t> fullTensorShape) {
return areAllConstantIntValue(offsets, 0) &&
areAllConstantIntValue(strides, 1) &&
areConstantIntValues(sizes, fullTensorShape);
};
if (!isa<tensor::PackOp, tensor::UnPackOp>(op)) {
return failure();
}
if (hasPadding(op)) {
return failure();
}

// If the producer is a full slice dispatch tensor load, then the `op` is
// foldable if it is a PackOp.
auto load = dyn_cast<IREE::Flow::DispatchTensorLoadOp>(
op->getOperand(0).getDefiningOp());
if (isa<tensor::PackOp>(op) && load &&
isFullSlice(load.getMixedOffsets(), load.getMixedSizes(),
load.getMixedStrides(), load.getSourceType().getShape())) {
return success();
}
// If all consumers are full slice dispatch tensor stores, then the `op` is
// foldable if it is an UnPackOp.
if (isa<tensor::UnPackOp>(op) &&
llvm::all_of(op->getUsers(), [&](Operation *user) {
auto store = dyn_cast<IREE::Flow::DispatchTensorStoreOp>(user);
return store &&
isFullSlice(store.getMixedOffsets(), store.getMixedSizes(),
store.getMixedStrides(),
store.getTargetType().getShape());
})) {
return success();
}
return failure();
}

void DecomposeBoundaryPackUnPackOpsPass::runOnOperation() {
if (failed(commonRunOnOperation(&getContext(), getOperation(),
/*useOnlyReshapes=*/true, tileOuterToOne,
isFoldableIntoInterfaceTensor))) {
return signalPassFailure();
}
}

} // namespace mlir::iree_compiler
11 changes: 0 additions & 11 deletions compiler/src/iree/compiler/Codegen/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,6 @@ using ConfigFn =
std::unique_ptr<InterfacePass<FunctionOpInterface>>
createConvolutionToIGEMMPass(ConfigFn configFn);

using PackUnPackControlFn = std::function<LogicalResult(Operation *)>;
/// Pass to decompose pack and unpack ops into pad/extract_slice and reshape
/// ops. If specified, `controlFn` controls which ops get decomposed. The
/// `controlFn` should be used with `useOnlyReshapes` set to true.
/// TODO(Max191): Add a controlFn upstream for `GeneralizeOuterUnitDim*`
/// patterns and remove the need to have `useOnlyReshapes = true` when using
/// `controlFn`.
std::unique_ptr<InterfacePass<FunctionOpInterface>>
createDecomposePackUnPackOpsPass(bool tileOuterToOne, bool useOnlyReshapes,
std::optional<PackUnPackControlFn> controlFn);

std::unique_ptr<Pass> createDecomposeSoftmaxPass(bool useFusion);

/// Pass to perform linalg on tensor bufferization. The function passed into
Expand Down
21 changes: 21 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,27 @@ def DecomposePackUnPackOpsPass :
Option<"useOnlyReshapes", "use-only-reshapes", "bool", "false",
"Use decomposition into reshape ops, even when packing unit dimensions.">
];
let dependentDialects = [
"arith::ArithDialect",
"linalg::LinalgDialect",
"scf::SCFDialect",
"tensor::TensorDialect"
];
}

def DecomposeBoundaryPackUnPackOpsPass :
InterfacePass<"iree-codegen-decompose-boundary-pack-unpack-ops", "mlir::FunctionOpInterface"> {
let summary = "Wrapper for DecomposePackUnPackOpsPass to decompose ops at function boundaries";
let options = [
Option<"tileOuterToOne", "tile-outer-to-one", "bool", "false",
"Always apply tiling to make outer dimension be ones">
];
let dependentDialects = [
"arith::ArithDialect",
"linalg::LinalgDialect",
"scf::SCFDialect",
"tensor::TensorDialect"
];
}

def DecomposeSoftmaxPass :
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ iree_lit_test_suite(
"convolutions.mlir",
"erase_dead_alloc_and_stores.mlir",
"decompose_affine_ops.mlir",
"decompose_boundary_pack_unpack_ops.mlir",
"decompose_conv2d.mlir",
"decompose_linalg_generic.mlir",
"decompose_pack_unpack_ops.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ iree_lit_test_suite(
"convolution_to_igemm.mlir"
"convolutions.mlir"
"decompose_affine_ops.mlir"
"decompose_boundary_pack_unpack_ops.mlir"
"decompose_conv2d.mlir"
"decompose_linalg_generic.mlir"
"decompose_pack_unpack_ops.mlir"
Expand Down
Loading

0 comments on commit e1469b2

Please sign in to comment.