Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PreTask Conv2d_nhwc_hwcf Pull Request #217

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions examples/ConvOpt/conv2d_nhwc_hwcf.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
func.func @conv_2d_nhwc_hwcf(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
linalg.conv_2d_nhwc_hwcf ins (%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
outs (%arg2: memref<?x?x?x?xf32>)
return
}
27 changes: 27 additions & 0 deletions midend/include/Utils/TileSizeSelection.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef INCLUDE_UTILS_TILESIZESELECTION_H_
#define INCLUDE_UTILS_TILESIZESELECTION_H_

#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"

using namespace mlir;

namespace buddy {
// Map a range to a SmallVectot with element types deduced from the mapping.
template <unsigned Size, class ContainerTy, class FuncTy>
auto map_to_vec(ContainerTy &&C, FuncTy &&F) {
return llvm::to_vector<Size>(
llvm::map_range(std::forward<ContainerTy>(C), std::forward<FuncTy>(F)));
}

template <class ContainerTy, class FuncTy>
auto map_to_vec(ContainerTy &&C, FuncTy &&F) {
return to_vector(
map_range(std::forward<ContainerTy>(C), std::forward<FuncTy>(F)));
}

void setSCFTileSizes(scf::SCFTilingOptions &options, TilingInterface consumerOp,
SmallVector<int64_t> tileSizes,
SmallVector<bool> tileScalableFlags);
} // namespace buddy

#endif // INCLUDE_UTILS_TILESIZESELECTION_H_
1 change: 1 addition & 0 deletions midend/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ add_subdirectory(ConvOptimization)
add_subdirectory(LowerVectorExp)
add_subdirectory(LowerGemmini)
add_subdirectory(LowerLinalgToGemmini)
add_subdirectory(PolyhedralOptimization)
1 change: 1 addition & 0 deletions midend/lib/Conversion/ConvOptimization/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_mlir_library(ConvOptimization
ConvOptimize.cpp
ConvBroadcast.cpp
)
487 changes: 487 additions & 0 deletions midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions midend/lib/Conversion/PolyhedralOptimization/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
add_mlir_library(PolyhedralOptimization
Tiling.cpp
TileAndFuse.cpp
)
252 changes: 252 additions & 0 deletions midend/lib/Conversion/PolyhedralOptimization/TileAndFuse.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Iterators.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "Utils/TileSizeSelection.h"

#define DEBUG_TYPE "polyhedral-tile-and-fuse"

using namespace mlir;

namespace {

/// Starting from `op` walk all operands backwards to find all potentially
/// fuseable operations, i.e. operations that implement the `TilingInterface`
void collectTiledAndFusedOps(Operation *rootOp,
llvm::SmallDenseSet<Operation *> &result) {
SmallVector<Operation *> worklist;
worklist.push_back(rootOp);
result.insert(rootOp);
while (!worklist.empty())
{
Operation *current = worklist.pop_back_val();
for (OpOperand &operand : current->getOpOperands()) {
Operation *producer = operand.get().getDefiningOp();
if (!producer || !isa<TilingInterface>(producer) ||
result.count(producer))
continue;
worklist.push_back(producer);
result.insert(producer);
}
}
}

FailureOr<tensor::PadOp>
foldIfGeneratedFromPadding(RewriterBase &rewriter, tensor::PadOp untiledPadOp,
tensor::PadOp tiledPadOp) {
auto ifOp = dyn_cast<scf::IfOp>(tiledPadOp->getParentOp());
if (!ifOp)
return failure();
Block *block = tiledPadOp->getBlock();
Operation *terminator = block->getTerminator();
ValueRange results = terminator->getOperands();
rewriter.inlineBlockBefore(block, ifOp, {});
rewriter.replaceOp(ifOp, results);
rewriter.eraseOp(terminator);
return tiledPadOp;
}

struct TileAndFusePass : public PassWrapper<TileAndFusePass, OperationPass<func::FuncOp>> {
TileAndFusePass(int64_t tilingLevel) {
this->tilingLevel = tilingLevel;
}

MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileAndFusePass)
StringRef getArgument() const final { return "polyhedral-tile-and-fuse"; }
StringRef getDescription() const final { return "Tile and fuse"; }

void getDependentDialects(DialectRegistry &registry) const final {
registry.insert<arith::ArithDialect, affine::AffineDialect,
linalg::LinalgDialect, vector::VectorDialect, scf::SCFDialect>();
}
void runOnOperation() override;

int64_t tilingLevel;
};

LogicalResult applyTileAndFuse(RewriterBase &rewriter, Operation *rootOp,
DominanceInfo &dominanceInfo,
scf::SCFTilingOptions options) {
llvm::SmallDenseSet<Operation *> originTiledAndFuseOps;
collectTiledAndFusedOps(rootOp, originTiledAndFuseOps);
auto isIgnoredUser = [&](Operation *user, scf::ForOp outerMostTiledLoop) {
return originTiledAndFuseOps.count(user) || isa<tensor::DimOp>(user);
};

// 1. Tile the consumer.
SmallVector<OpResult> yieldedValuesToOrigValues;
SmallVector<Operation *> tiledOps;
FailureOr<scf::SCFTilingResult> tilingResult =
scf::tileUsingSCFForOp(rewriter, cast<TilingInterface>(rootOp), options);
if (failed(tilingResult)) {
return failure();
}
auto forLoops = llvm::to_vector(llvm::map_range(tilingResult->loops,
[](Operation *op) { return cast<scf::ForOp>(op); }));
yieldedValuesToOrigValues.append(rootOp->result_begin(),
rootOp->result_end());
// A map from untiled value to scf.for iter_arg. The iter_arg is used for DPS
// init operand of they use the same init operand
llvm::DenseMap<Value, Value> mapToIterArg;

if (auto rootPadOp = dyn_cast<tensor::PadOp>(rootOp)) {
assert(tilingResult->tiledOps.size() == 1 &&
"Expecting only one tiled op for tensor::PadOp");
FailureOr<Operation *> replacementTiledOp = foldIfGeneratedFromPadding(
rewriter, rootPadOp, cast<tensor::PadOp>(tilingResult->tiledOps[0]));
if (!failed(replacementTiledOp)) {
tilingResult->tiledOps[0] = replacementTiledOp.value();
}
} else if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(rootOp)) {
for (auto [init, iterArg] : llvm::zip_equal(
dpsOp.getDpsInitOperands(),
cast<scf::ForOp>(forLoops.back()).getRegionIterArgs())) {
mapToIterArg[init->get()] = iterArg;
}
}
tiledOps.append(tilingResult->tiledOps);

// 2. Tiling each operation results in generation of slices. The source of
// these slices could be producers that can be fused into the tiled loops by
// computing the slices of these producers in-place. This results in more
// slices created for operands of the "fused producer". This open up more
// opportunities for fusion. Use a worklist to fuse greedily.
auto addCandidateSlices =
[&](Operation *fusedOp, std::deque<tensor::ExtractSliceOp> &candidates) {
for (OpOperand &operand : fusedOp->getOpOperands()) {
auto sliceOp = operand.get().getDefiningOp<tensor::ExtractSliceOp>();
if (!sliceOp)
continue;
candidates.push_back(sliceOp);

auto dpsOp = dyn_cast<DestinationStyleOpInterface>(fusedOp);
if (!dpsOp)
continue;

if (dpsOp.isDpsInit(&operand) &&
mapToIterArg.contains(sliceOp.getSource())) {
rewriter.startRootUpdate(sliceOp);
sliceOp.getSourceMutable().assign(mapToIterArg[sliceOp.getSource()]);
rewriter.finalizeRootUpdate(sliceOp);
}
}
};

std::deque<tensor::ExtractSliceOp> candidates;
addCandidateSlices(tilingResult->tiledOps.back(), candidates);
OpBuilder::InsertionGuard g(rewriter);
while (!candidates.empty())
{
// Traverse the slices in BFS fashion.
tensor::ExtractSliceOp candidateSliceOp = candidates.front();
candidates.pop_front();

// Materialize the slice of the producer in place.
std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
scf::tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
if (!fusedProducer)
continue;

// Check if the fused producer has other uses that require the value
// to be yielded from within the tiled loop.
OpResult untiledProducer = fusedProducer->origProducer;
if (llvm::any_of(untiledProducer.getUsers(), [&](Operation *user) {
return !isIgnoredUser(user, forLoops.front()) &&
!forLoops.front()->isAncestor(user);
})) {
scf::yieldReplacementForFusedProducer(rewriter, candidateSliceOp,
fusedProducer.value(), forLoops);
yieldedValuesToOrigValues.push_back(untiledProducer);
}

// Add more fusion candidates to the worklist.
for (auto tiledOp : fusedProducer->tiledOps) {
addCandidateSlices(tiledOp, candidates);
tiledOps.push_back(tiledOp);
}
}

scf::ForOp outermostLoop = forLoops.front();
for (auto [index, origVal] : llvm::enumerate(yieldedValuesToOrigValues)) {
Value replacement = outermostLoop.getResult(index);
rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
return !isIgnoredUser(use.getOwner(), outermostLoop) &&
dominanceInfo.properlyDominates(outermostLoop, use.getOwner());
});
}

return success();
}

void TileAndFusePass::runOnOperation() {
MLIRContext *context = &getContext();
auto funcOp = getOperation();

TilingInterface consumerOp;
funcOp.walk<WalkOrder::PostOrder, ReverseIterator>([&](TilingInterface op) {
// Find the next consumer op if it does not have loops.
if (op.getLoopIteratorTypes().empty())
return WalkResult::advance();
consumerOp = op;
return WalkResult::interrupt();
});
if (!consumerOp) {
LLVM_DEBUG(llvm::dbgs() << "No consumer op found, skip tiling\n");
return;
}

SmallVector<int64_t> tileSizes;
SmallVector<bool> tileScalableFlags;

// todo: configure tile sizes and tile scalable flags

if (llvm::all_of(tileSizes, [&](int64_t size) { return size == 0; })) {
LLVM_DEBUG(llvm::dbgs() << "All tile sizes are 0, skip tiling\n");
return;
}

scf::SCFTilingOptions options{};
buddy::setSCFTileSizes(options, consumerOp, std::move(tileSizes),
std::move(tileScalableFlags));

IRRewriter rewriter(context);
DominanceInfo domainInfo(funcOp);
if (failed(applyTileAndFuse(rewriter, consumerOp, domainInfo, options))) {
LLVM_DEBUG(llvm::dbgs() << "Failed to tile and fuse\n");
return signalPassFailure();
}

RewritePatternSet patterns =
linalg::getLinalgTilingCanonicalizationPatterns(context);
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
tensor::populateFoldTensorEmptyPatterns(patterns);
memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
context->getLoadedDialect<tensor::TensorDialect>()
->getCanonicalizationPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
LLVM_DEBUG(llvm::dbgs() << "Failed to canonicalize\n");
return signalPassFailure();
}
}

} // namespace

namespace mlir {
namespace buddy {
void registerPolyhedralTileAndFusePass() {
PassRegistration<TileAndFusePass>();
}
} // namespace buddy
} // namespace mlir
Loading