From 496c4c08b03f60d318ed1850349700bb16d78871 Mon Sep 17 00:00:00 2001 From: Hanchen Ye Date: Sun, 14 Jan 2024 16:37:16 -0600 Subject: [PATCH] Add scf-to-affine pass --- include/scalehls/Transforms/Passes.h | 1 + include/scalehls/Transforms/Passes.td | 5 + lib/Transforms/CMakeLists.txt | 1 + lib/Transforms/RaiseSCFToAffine.cpp | 1030 +++++++++++++++++++++++++ 4 files changed, 1037 insertions(+) create mode 100644 lib/Transforms/RaiseSCFToAffine.cpp diff --git a/include/scalehls/Transforms/Passes.h b/include/scalehls/Transforms/Passes.h index 82dcc29b..141023a7 100644 --- a/include/scalehls/Transforms/Passes.h +++ b/include/scalehls/Transforms/Passes.h @@ -28,6 +28,7 @@ namespace scalehls { std::unique_ptr createConvertLinalgToDataflowPass(); std::unique_ptr createConvertDataflowToFuncPass(); +std::unique_ptr createRaiseSCFToAffinePass(); //===----------------------------------------------------------------------===// // Transform Passes diff --git a/include/scalehls/Transforms/Passes.td b/include/scalehls/Transforms/Passes.td index fa9837b4..3159030c 100644 --- a/include/scalehls/Transforms/Passes.td +++ b/include/scalehls/Transforms/Passes.td @@ -25,6 +25,11 @@ def ConvertDataflowToFunc : let constructor = "mlir::scalehls::createConvertDataflowToFuncPass()"; } +def RaiseSCFToAffine : Pass<"scalehls-raise-scf-to-affine"> { + let summary = "Raise SCF to affine"; + let constructor = "mlir::scalehls::createRaiseSCFToAffinePass()"; +} + //===----------------------------------------------------------------------===// // Transform Passes //===----------------------------------------------------------------------===// diff --git a/lib/Transforms/CMakeLists.txt b/lib/Transforms/CMakeLists.txt index 0995e04f..3000df56 100644 --- a/lib/Transforms/CMakeLists.txt +++ b/lib/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_library(MLIRScaleHLSTransforms GenerateRuntimeFunc.cpp LowerCopyToAffineLoops.cpp Pipelines.cpp + RaiseSCFToAffine.cpp DEPENDS MLIRScaleHLSTransformsIncGen diff --git a/lib/Transforms/RaiseSCFToAffine.cpp b/lib/Transforms/RaiseSCFToAffine.cpp new file mode 100644 index 00000000..2f69e93e --- /dev/null +++ b/lib/Transforms/RaiseSCFToAffine.cpp @@ -0,0 +1,1030 @@ +//===----------------------------------------------------------------------===// +// +// Copyright 2020-2021 The ScaleHLS Authors. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Passes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "scalehls/Transforms/Passes.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "raise-scf-to-affine" + +using namespace mlir; +using namespace mlir::arith; +using namespace scalehls; +using namespace affine; + +static bool isValidSymbolInt(Value value, bool recur = true); +static bool isValidSymbolInt(Operation *defOp, bool recur) { + Attribute operandCst; + if (matchPattern(defOp, m_Constant(&operandCst))) + return true; + + if (recur) { + if (isa(defOp)) + if (llvm::all_of(defOp->getOperands(), [&](Value v) { + bool b = isValidSymbolInt(v, recur); + // if (!b) + // LLVM_DEBUG(llvm::dbgs() << "illegal isValidSymbolInt: " + //<< value << " due to " << v << "\n"); + return b; + })) + return true; + if (auto ifOp = dyn_cast(defOp)) { + if (isValidSymbolInt(ifOp.getCondition(), recur)) { + if (llvm::all_of( + ifOp.thenBlock()->without_terminator(), + [&](Operation &o) { return isValidSymbolInt(&o, recur); }) && + llvm::all_of( + ifOp.elseBlock()->without_terminator(), + [&](Operation &o) { return isValidSymbolInt(&o, recur); })) + return true; + } + } + if (auto ifOp = dyn_cast(defOp)) { + if (llvm::all_of(ifOp.getOperands(), + [&](Value o) { return isValidSymbolInt(o, recur); })) + if (llvm::all_of( + ifOp.getThenBlock()->without_terminator(), + [&](Operation &o) { return isValidSymbolInt(&o, recur); }) && + llvm::all_of( + ifOp.getElseBlock()->without_terminator(), + [&](Operation &o) { return isValidSymbolInt(&o, recur); })) + return true; + } + } + return false; +} + +// isValidSymbol, even if not index +static bool isValidSymbolInt(Value value, bool recur) { + // Check that the value is a top level value. + if (affine::isTopLevelValue(value)) + return true; + + if (auto *defOp = value.getDefiningOp()) { + if (isValidSymbolInt(defOp, recur)) + return true; + return affine::isValidSymbol(value, affine::getAffineScope(defOp)); + } + + return false; +} + +static bool isValidIndex(Value val) { + if (isValidSymbolInt(val)) + return true; + + if (auto cast = val.getDefiningOp()) + return isValidIndex(cast.getOperand()); + + if (auto cast = val.getDefiningOp()) + return isValidIndex(cast.getOperand()); + + if (auto cast = val.getDefiningOp()) + return isValidIndex(cast.getOperand()); + + if (auto bop = val.getDefiningOp()) + return isValidIndex(bop.getOperand(0)) && isValidIndex(bop.getOperand(1)); + + if (auto bop = val.getDefiningOp()) + return (isValidIndex(bop.getOperand(0)) && + isValidSymbolInt(bop.getOperand(1))) || + (isValidIndex(bop.getOperand(1)) && + isValidSymbolInt(bop.getOperand(0))); + + if (auto bop = val.getDefiningOp()) + return (isValidIndex(bop.getOperand(0)) && + isValidSymbolInt(bop.getOperand(1))); + + if (auto bop = val.getDefiningOp()) + return (isValidIndex(bop.getOperand(0)) && + isValidSymbolInt(bop.getOperand(1))); + + if (auto bop = val.getDefiningOp()) { + return (isValidIndex(bop.getOperand(0)) && + bop.getOperand(1).getDefiningOp()); + } + + if (auto bop = val.getDefiningOp()) + return (isValidIndex(bop.getOperand(0)) && + bop.getOperand(1).getDefiningOp()); + + if (auto bop = val.getDefiningOp()) + return isValidIndex(bop.getOperand(0)) && isValidIndex(bop.getOperand(1)); + + if (val.getDefiningOp()) + return true; + + if (val.getDefiningOp()) + return true; + + if (auto ba = val.dyn_cast()) { + auto *owner = ba.getOwner(); + assert(owner); + + auto *parentOp = owner->getParentOp(); + if (!parentOp) { + owner->dump(); + llvm::errs() << " ba: " << ba << "\n"; + } + assert(parentOp); + if (isa(parentOp)) + return true; + if (auto af = dyn_cast(parentOp)) + return af.getInductionVar() == ba; + + // TODO ensure not a reduced var + if (isa(parentOp)) + return true; + + if (isa(parentOp)) + return true; + } + + LLVM_DEBUG(llvm::dbgs() << "illegal isValidIndex: " << val << "\n"); + return false; +} + +static bool isReadOnly(Operation *op) { + bool hasRecursiveEffects = op->hasTrait(); + if (hasRecursiveEffects) { + for (Region ®ion : op->getRegions()) { + for (auto &block : region) { + for (auto &nestedOp : block) + if (!isReadOnly(&nestedOp)) + return false; + } + } + return true; + } + + // If the op has memory effects, try to characterize them to see if the op + // is trivially dead here. + if (auto effectInterface = dyn_cast(op)) { + // Check to see if this op either has no effects, or only allocates/reads + // memory. + SmallVector effects; + effectInterface.getEffects(effects); + if (!llvm::all_of(effects, [](const MemoryEffects::EffectInstance &it) { + return isa(it.getEffect()); + })) { + return false; + } + return true; + } + return false; +} + +namespace { +struct AffineApplyNormalizer { + AffineApplyNormalizer(AffineMap map, ArrayRef operands, + PatternRewriter &rewriter, DominanceInfo &DI); + + /// Returns the AffineMap resulting from normalization. + AffineMap getAffineMap() { return affineMap; } + + SmallVector getOperands() { + SmallVector res(reorderedDims); + res.append(concatenatedSymbols.begin(), concatenatedSymbols.end()); + return res; + } + +private: + /// Helper function to insert `v` into the coordinate system of the current + /// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding + /// renumbered position. + AffineDimExpr renumberOneDim(Value v); + + /// Maps of Value to position in `affineMap`. + DenseMap dimValueToPosition; + + /// Ordered dims and symbols matching positional dims and symbols in + /// `affineMap`. + SmallVector reorderedDims; + SmallVector concatenatedSymbols; + + AffineMap affineMap; +}; +} // namespace + +static bool isAffineForArg(Value val) { + if (!val.isa()) + return false; + Operation *parentOp = val.cast().getOwner()->getParentOp(); + return ( + isa_and_nonnull(parentOp)); +} + +static bool legalCondition(Value en, bool dim = false) { + if (en.getDefiningOp()) + return true; + + if (!dim && !isValidSymbolInt(en, /*recur*/ false)) { + if (isValidIndex(en) || isValidSymbolInt(en, /*recur*/ true)) { + return true; + } + } + + while (auto ic = en.getDefiningOp()) + en = ic.getIn(); + + if ((en.getDefiningOp() || en.getDefiningOp() || + en.getDefiningOp() || en.getDefiningOp() || + en.getDefiningOp()) && + (en.getDefiningOp()->getOperand(1).getDefiningOp() || + en.getDefiningOp()->getOperand(1).getDefiningOp())) + return true; + // if (auto IC = dyn_cast_or_null(en.getDefiningOp())) { + // if (!outer || legalCondition(IC.getOperand(), false)) return true; + //} + if (!dim) + if (auto BA = en.dyn_cast()) { + if (isa( + BA.getOwner()->getParentOp())) + return true; + } + return false; +} + +/// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to +/// keep a correspondence between the mathematical `map` and the `operands` of +/// a given affine::AffineApplyOp. This correspondence is maintained by +/// iterating over the operands and forming an `auxiliaryMap` that can be +/// composed mathematically with `map`. To keep this correspondence in cases +/// where symbols are produced by affine.apply operations, we perform a local +/// rewrite of symbols as dims. +/// +/// Rationale for locally rewriting symbols as dims: +/// ================================================ +/// The mathematical composition of AffineMap must always concatenate symbols +/// because it does not have enough information to do otherwise. For example, +/// composing `(d0)[s0] -> (d0 + s0)` with itself must produce +/// `(d0)[s0, s1] -> (d0 + s0 + s1)`. +/// +/// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when +/// applied to the same mlir::Value for both s0 and s1. +/// As a consequence mathematical composition of AffineMap always concatenates +/// symbols. +/// +/// When AffineMaps are used in affine::AffineApplyOp however, they may specify +/// composition via symbols, which is ambiguous mathematically. This corner case +/// is handled by locally rewriting such symbols that come from +/// affine::AffineApplyOp into dims and composing through dims. +/// TODO: Composition via symbols comes at a significant code +/// complexity. Alternatively we should investigate whether we want to +/// explicitly disallow symbols coming from affine.apply and instead force the +/// user to compose symbols beforehand. The annoyances may be small (i.e. 1 or 2 +/// extra API calls for such uses, which haven't popped up until now) and the +/// benefit potentially big: simpler and more maintainable code for a +/// non-trivial, recursive, procedure. +AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, + ArrayRef operands, + PatternRewriter &rewriter, + DominanceInfo &DI) { + assert(map.getNumInputs() == operands.size() && + "number of operands does not match the number of map inputs"); + + LLVM_DEBUG(map.print(llvm::dbgs() << "\nInput map: ")); + + SmallVector addedValues; + + llvm::SmallSet symbolsToPromote; + + unsigned numDims = map.getNumDims(); + // unsigned numSymbols = map.getNumSymbols(); + + SmallVector dimReplacements; + SmallVector symReplacements; + + SmallVector *> opsTodos; + auto replaceOp = [&](Operation *oldOp, Operation *newOp) { + for (auto [oldV, newV] : + llvm::zip(oldOp->getResults(), newOp->getResults())) + for (auto ops : opsTodos) + for (auto &op : *ops) + if (op == oldV) + op = newV; + }; + + std::function fix = [&](Value v, + bool index) -> Value /*legal*/ { + if (isValidSymbolInt(v, /*recur*/ false)) + return v; + if (index && isAffineForArg(v)) + return v; + auto *op = v.getDefiningOp(); + if (!op) + return nullptr; + if (!op) + llvm::errs() << v << "\n"; + assert(op); + if (isa(op) || isa(op)) + return v; + if (!isReadOnly(op)) { + return nullptr; + } + Operation *front = nullptr; + SmallVector ops; + opsTodos.push_back(&ops); + std::function getAllOps = [&](Operation *todo) { + for (auto v : todo->getOperands()) { + if (llvm::all_of(op->getRegions(), [&](Region &r) { + return !r.isAncestor(v.getParentRegion()); + })) + ops.push_back(v); + } + for (auto &r : todo->getRegions()) { + for (auto &b : r.getBlocks()) + for (auto &o2 : b.without_terminator()) + getAllOps(&o2); + } + }; + getAllOps(op); + for (auto o : ops) { + Operation *next; + if (auto *op = o.getDefiningOp()) { + if (Value nv = fix(o, index)) { + op = nv.getDefiningOp(); + } else { + return nullptr; + } + next = op->getNextNode(); + } else { + auto BA = o.cast(); + if (index && isAffineForArg(BA)) { + } else if (!isValidSymbolInt(o, /*recur*/ false)) { + return nullptr; + } + next = &BA.getOwner()->front(); + } + if (front == nullptr) + front = next; + else if (DI.dominates(front, next)) + front = next; + } + opsTodos.pop_back(); + if (!front) + op->dump(); + assert(front); + PatternRewriter::InsertionGuard B(rewriter); + rewriter.setInsertionPoint(front); + auto cloned = rewriter.clone(*op); + replaceOp(op, cloned); + rewriter.replaceOp(op, cloned->getResults()); + return cloned->getResult(0); + }; + auto renumberOneSymbol = [&](Value v) { + for (auto i : llvm::enumerate(addedValues)) { + if (i.value() == v) + return getAffineSymbolExpr(i.index(), map.getContext()); + } + auto expr = getAffineSymbolExpr(addedValues.size(), map.getContext()); + addedValues.push_back(v); + return expr; + }; + + // 2. Compose affine::AffineApplyOps and dispatch dims or symbols. + for (unsigned i = 0, e = operands.size(); i < e; ++i) { + auto t = operands[i]; + auto decast = t; + while (true) { + if (auto idx = decast.getDefiningOp()) { + decast = idx.getIn(); + continue; + } + if (auto idx = decast.getDefiningOp()) { + decast = idx.getIn(); + continue; + } + if (auto idx = decast.getDefiningOp()) { + decast = idx.getIn(); + continue; + } + break; + } + + if (!isValidSymbolInt(t, /*recur*/ false)) { + t = decast; + } + + // Only promote one at a time, lest we end up with two dimensions + // multiplying each other. + + if (((!isValidSymbolInt(t, /*recur*/ false) && + (t.getDefiningOp() || t.getDefiningOp() || + (t.getDefiningOp() && + ((isValidIndex(t.getDefiningOp()->getOperand(0)) && + isValidSymbolInt(t.getDefiningOp()->getOperand(1))) || + (isValidIndex(t.getDefiningOp()->getOperand(1)) && + isValidSymbolInt(t.getDefiningOp()->getOperand(0)))) && + !(fix(t.getDefiningOp()->getOperand(0), false) && + fix(t.getDefiningOp()->getOperand(1), false)) + + ) || + ((t.getDefiningOp() || t.getDefiningOp()) && + (isValidIndex(t.getDefiningOp()->getOperand(0)) && + isValidSymbolInt(t.getDefiningOp()->getOperand(1))) && + (!(fix(t.getDefiningOp()->getOperand(0), false) && + fix(t.getDefiningOp()->getOperand(1), false)))) || + (t.getDefiningOp() && + (isValidIndex(t.getDefiningOp()->getOperand(0)) && + isValidSymbolInt(t.getDefiningOp()->getOperand(1)))) || + (t.getDefiningOp() && + (isValidIndex(t.getDefiningOp()->getOperand(0)) && + isValidSymbolInt(t.getDefiningOp()->getOperand(1)))) || + (t.getDefiningOp() && + (isValidIndex(t.getDefiningOp()->getOperand(0)) && + isValidSymbolInt(t.getDefiningOp()->getOperand(1)))) || + t.getDefiningOp() || + t.getDefiningOp())) || + ((decast.getDefiningOp() || decast.getDefiningOp() || + decast.getDefiningOp() || decast.getDefiningOp() || + decast.getDefiningOp()) && + (decast.getDefiningOp() + ->getOperand(1) + .getDefiningOp() || + decast.getDefiningOp() + ->getOperand(1) + .getDefiningOp())))) { + t = decast; + LLVM_DEBUG(llvm::dbgs() << " Replacing: " << t << "\n"); + + AffineMap affineApplyMap; + SmallVector affineApplyOperands; + + // llvm::dbgs() << "\nop to start: " << t << "\n"; + + if (auto op = t.getDefiningOp()) { + affineApplyMap = + AffineMap::get(0, 2, + getAffineSymbolExpr(0, op.getContext()) + + getAffineSymbolExpr(1, op.getContext())); + affineApplyOperands.push_back(op.getLhs()); + affineApplyOperands.push_back(op.getRhs()); + } else if (auto op = t.getDefiningOp()) { + affineApplyMap = + AffineMap::get(0, 2, + getAffineSymbolExpr(0, op.getContext()) - + getAffineSymbolExpr(1, op.getContext())); + affineApplyOperands.push_back(op.getLhs()); + affineApplyOperands.push_back(op.getRhs()); + } else if (auto op = t.getDefiningOp()) { + if (auto ci = op.getRhs().getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 1, getAffineSymbolExpr(0, op.getContext()) * ci.value()); + affineApplyOperands.push_back(op.getLhs()); + } else if (auto ci = op.getRhs().getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 1, getAffineSymbolExpr(0, op.getContext()) * ci.value()); + affineApplyOperands.push_back(op.getLhs()); + } else { + affineApplyMap = + AffineMap::get(0, 2, + getAffineSymbolExpr(0, op.getContext()) * + getAffineSymbolExpr(1, op.getContext())); + affineApplyOperands.push_back(op.getLhs()); + affineApplyOperands.push_back(op.getRhs()); + } + } else if (auto op = t.getDefiningOp()) { + if (auto ci = op.getRhs().getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 1, + getAffineSymbolExpr(0, op.getContext()).floorDiv(ci.value())); + affineApplyOperands.push_back(op.getLhs()); + } else if (auto ci = op.getRhs().getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 1, + getAffineSymbolExpr(0, op.getContext()).floorDiv(ci.value())); + affineApplyOperands.push_back(op.getLhs()); + } else { + affineApplyMap = AffineMap::get( + 0, 2, + getAffineSymbolExpr(0, op.getContext()) + .floorDiv(getAffineSymbolExpr(1, op.getContext()))); + affineApplyOperands.push_back(op.getLhs()); + affineApplyOperands.push_back(op.getRhs()); + } + } else if (auto op = t.getDefiningOp()) { + if (auto ci = op.getRhs().getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 1, + getAffineSymbolExpr(0, op.getContext()).floorDiv(ci.value())); + affineApplyOperands.push_back(op.getLhs()); + } else if (auto ci = op.getRhs().getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 1, + getAffineSymbolExpr(0, op.getContext()).floorDiv(ci.value())); + affineApplyOperands.push_back(op.getLhs()); + } else { + affineApplyMap = AffineMap::get( + 0, 2, + getAffineSymbolExpr(0, op.getContext()) + .floorDiv(getAffineSymbolExpr(1, op.getContext()))); + affineApplyOperands.push_back(op.getLhs()); + affineApplyOperands.push_back(op.getRhs()); + } + } else if (auto op = t.getDefiningOp()) { + if (auto ci = op.getRhs().getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 1, getAffineSymbolExpr(0, op.getContext()) % ci.value()); + affineApplyOperands.push_back(op.getLhs()); + } else if (auto ci = op.getRhs().getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 1, getAffineSymbolExpr(0, op.getContext()) % ci.value()); + affineApplyOperands.push_back(op.getLhs()); + } else { + affineApplyMap = + AffineMap::get(0, 2, + getAffineSymbolExpr(0, op.getContext()) % + getAffineSymbolExpr(1, op.getContext())); + affineApplyOperands.push_back(op.getLhs()); + affineApplyOperands.push_back(op.getRhs()); + } + } else if (auto op = t.getDefiningOp()) { + if (auto ci = op.getRhs().getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 1, getAffineSymbolExpr(0, op.getContext()) % ci.value()); + affineApplyOperands.push_back(op.getLhs()); + } else if (auto ci = op.getRhs().getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 1, getAffineSymbolExpr(0, op.getContext()) % ci.value()); + affineApplyOperands.push_back(op.getLhs()); + } else { + affineApplyMap = + AffineMap::get(0, 2, + getAffineSymbolExpr(0, op.getContext()) % + getAffineSymbolExpr(1, op.getContext())); + affineApplyOperands.push_back(op.getLhs()); + affineApplyOperands.push_back(op.getRhs()); + } + } else if (auto op = t.getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 0, getAffineConstantExpr(op.value(), op.getContext())); + } else if (auto op = t.getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 0, getAffineConstantExpr(op.value(), op.getContext())); + } else { + llvm_unreachable(""); + } + + SmallVector dimRemapping; + unsigned numOtherSymbols = affineApplyOperands.size(); + SmallVector symRemapping(numOtherSymbols); + for (unsigned idx = 0; idx < numOtherSymbols; ++idx) { + symRemapping[idx] = renumberOneSymbol(affineApplyOperands[idx]); + } + affineApplyMap = affineApplyMap.replaceDimsAndSymbols( + dimRemapping, symRemapping, reorderedDims.size(), addedValues.size()); + + LLVM_DEBUG(affineApplyMap.print( + llvm::dbgs() << "\nRenumber into current normalizer: ")); + + if (i >= numDims) + symReplacements.push_back(affineApplyMap.getResult(0)); + else + dimReplacements.push_back(affineApplyMap.getResult(0)); + + } else if (isAffineForArg(t)) { + if (i >= numDims) + symReplacements.push_back(renumberOneDim(t)); + else + dimReplacements.push_back(renumberOneDim(t)); + } else if (t.getDefiningOp()) { + auto affineApply = t.getDefiningOp(); + // a. Compose affine.apply operations. + LLVM_DEBUG(affineApply->print( + llvm::dbgs() << "\nCompose affine::AffineApplyOp recursively: ")); + AffineMap affineApplyMap = affineApply.getAffineMap(); + SmallVector affineApplyOperands( + affineApply.getOperands().begin(), affineApply.getOperands().end()); + + SmallVector dimRemapping(affineApplyMap.getNumDims()); + + for (size_t i = 0; i < affineApplyMap.getNumDims(); ++i) { + assert(i < affineApplyOperands.size()); + dimRemapping[i] = renumberOneDim(affineApplyOperands[i]); + } + unsigned numOtherSymbols = affineApplyOperands.size(); + SmallVector symRemapping(numOtherSymbols - + affineApplyMap.getNumDims()); + for (unsigned idx = 0; idx < symRemapping.size(); ++idx) { + symRemapping[idx] = renumberOneSymbol( + affineApplyOperands[idx + affineApplyMap.getNumDims()]); + } + affineApplyMap = affineApplyMap.replaceDimsAndSymbols( + dimRemapping, symRemapping, reorderedDims.size(), addedValues.size()); + + LLVM_DEBUG( + affineApplyMap.print(llvm::dbgs() << "\nAffine apply fixup map: ")); + + if (i >= numDims) + symReplacements.push_back(affineApplyMap.getResult(0)); + else + dimReplacements.push_back(affineApplyMap.getResult(0)); + } else { + if (!isValidSymbolInt(t, /*recur*/ false)) { + if (t.getDefiningOp()) { + if ((t = fix(t, false))) { + assert(isValidSymbolInt(t, /*recur*/ false)); + } else + assert(0 && "cannot move"); + } else + assert(0 && "cannot move2"); + } + if (i < numDims) { + // b. The mathematical composition of AffineMap composes dims. + dimReplacements.push_back(renumberOneDim(t)); + } else { + // c. The mathematical composition of AffineMap concatenates symbols. + // Note that the map composition will put symbols already present + // in the map before any symbols coming from the auxiliary map, so + // we insert them before any symbols that are due to renumbering, + // and after the proper symbols we have seen already. + symReplacements.push_back(renumberOneSymbol(t)); + } + } + } + for (auto v : addedValues) + concatenatedSymbols.push_back(v); + + // Create the new map by replacing each symbol at pos by the next new dim. + unsigned numNewDims = reorderedDims.size(); + unsigned numNewSymbols = addedValues.size(); + assert(dimReplacements.size() == map.getNumDims()); + assert(symReplacements.size() == map.getNumSymbols()); + auto auxillaryMap = map.replaceDimsAndSymbols( + dimReplacements, symReplacements, numNewDims, numNewSymbols); + LLVM_DEBUG(auxillaryMap.print(llvm::dbgs() << "\nRewritten map: ")); + + affineMap = auxillaryMap; // simplifyAffineMap(auxillaryMap); + + LLVM_DEBUG(affineMap.print(llvm::dbgs() << "\nSimplified result: ")); + LLVM_DEBUG(llvm::dbgs() << "\n"); +} + +AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value v) { + DenseMap::iterator iterPos; + bool inserted = false; + std::tie(iterPos, inserted) = + dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size())); + if (inserted) { + reorderedDims.push_back(v); + } + return cast(getAffineDimExpr(iterPos->second, v.getContext())); +} + +static void composeAffineMapAndOperands(AffineMap *map, + SmallVectorImpl *operands, + PatternRewriter &rewriter, + DominanceInfo &DI) { + AffineApplyNormalizer normalizer(*map, *operands, rewriter, DI); + auto normalizedMap = normalizer.getAffineMap(); + auto normalizedOperands = normalizer.getOperands(); + affine::canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands); + *map = normalizedMap; + *operands = normalizedOperands; + assert(*map); +} + +static bool need(AffineMap *map, SmallVectorImpl *operands) { + assert(map->getNumInputs() == operands->size()); + for (size_t i = 0; i < map->getNumInputs(); ++i) { + auto v = (*operands)[i]; + if (legalCondition(v, i < map->getNumDims())) + return true; + } + return false; +} + +static void fully2ComposeAffineMapAndOperands(PatternRewriter &builder, + AffineMap *map, + SmallVectorImpl *operands, + DominanceInfo &DI) { + IRMapping indexMap; + for (auto op : *operands) { + SmallVector attempt; + auto idx0 = op.getDefiningOp(); + attempt.push_back(idx0); + if (!idx0) + continue; + + for (auto &u : idx0.getIn().getUses()) { + if (auto idx = dyn_cast(u.getOwner())) + if (DI.dominates((Operation *)idx, &*builder.getInsertionPoint())) + attempt.push_back(idx); + } + + for (auto idx : attempt) { + if (affine::isValidSymbol(idx)) { + indexMap.map(idx.getIn(), idx); + break; + } + } + } + assert(map->getNumInputs() == operands->size()); + while (need(map, operands)) { + composeAffineMapAndOperands(map, operands, builder, DI); + assert(map->getNumInputs() == operands->size()); + } + *map = simplifyAffineMap(*map); + for (auto &op : *operands) { + if (!op.getType().isIndex()) { + Operation *toInsert; + if (auto *o = op.getDefiningOp()) + toInsert = o->getNextNode(); + else { + auto BA = op.cast(); + toInsert = &BA.getOwner()->front(); + } + + if (auto v = indexMap.lookupOrNull(op)) + op = v; + else { + PatternRewriter::InsertionGuard B(builder); + builder.setInsertionPoint(toInsert); + op = builder.create(op.getLoc(), builder.getIndexType(), + op); + } + } + } +} + +struct ForOpRaising : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + // TODO: remove me or rename me. + bool isAffine(scf::ForOp loop) const { + // return true; + // enforce step to be a ConstantIndexOp (maybe too restrictive). + return affine::isValidSymbol(loop.getStep()); + } + + int64_t getStep(mlir::Value value) const { + ConstantIndexOp cstOp = value.getDefiningOp(); + if (cstOp) + return cstOp.value(); + else + return 1; + } + + AffineMap getMultiSymbolIdentity(Builder &B, unsigned rank) const { + SmallVector dimExprs; + dimExprs.reserve(rank); + for (unsigned i = 0; i < rank; ++i) + dimExprs.push_back(B.getAffineSymbolExpr(i)); + return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/rank, dimExprs, + B.getContext()); + } + LogicalResult matchAndRewrite(scf::ForOp loop, + PatternRewriter &rewriter) const final { + if (isAffine(loop)) { + OpBuilder builder(loop); + + SmallVector lbs; + { + SmallVector todo = {loop.getLowerBound()}; + while (todo.size()) { + auto cur = todo.back(); + todo.pop_back(); + if (isValidIndex(cur)) { + lbs.push_back(cur); + continue; + } else if (auto selOp = cur.getDefiningOp()) { + // LB only has max of operands + if (auto cmp = selOp.getCondition().getDefiningOp()) { + if (cmp.getLhs() == selOp.getTrueValue() && + cmp.getRhs() == selOp.getFalseValue() && + cmp.getPredicate() == CmpIPredicate::sge) { + todo.push_back(cmp.getLhs()); + todo.push_back(cmp.getRhs()); + continue; + } + } + } + return failure(); + } + } + + SmallVector ubs; + { + SmallVector todo = {loop.getUpperBound()}; + while (todo.size()) { + auto cur = todo.back(); + todo.pop_back(); + if (isValidIndex(cur)) { + ubs.push_back(cur); + continue; + } else if (auto selOp = cur.getDefiningOp()) { + // UB only has min of operands + if (auto cmp = selOp.getCondition().getDefiningOp()) { + if (cmp.getLhs() == selOp.getTrueValue() && + cmp.getRhs() == selOp.getFalseValue() && + cmp.getPredicate() == CmpIPredicate::sle) { + todo.push_back(cmp.getLhs()); + todo.push_back(cmp.getRhs()); + continue; + } + } + } + return failure(); + } + } + + bool rewrittenStep = false; + if (!loop.getStep().getDefiningOp()) { + if (ubs.size() != 1 || lbs.size() != 1) + return failure(); + ubs[0] = rewriter.create( + loop.getLoc(), + rewriter.create( + loop.getLoc(), + rewriter.create( + loop.getLoc(), loop.getStep(), + rewriter.create(loop.getLoc(), 1)), + rewriter.create(loop.getLoc(), loop.getUpperBound(), + loop.getLowerBound())), + loop.getStep()); + lbs[0] = rewriter.create(loop.getLoc(), 0); + rewrittenStep = true; + } + + auto *scope = affine::getAffineScope(loop)->getParentOp(); + DominanceInfo DI(scope); + + AffineMap lbMap = getMultiSymbolIdentity(builder, lbs.size()); + { + fully2ComposeAffineMapAndOperands(rewriter, &lbMap, &lbs, DI); + affine::canonicalizeMapAndOperands(&lbMap, &lbs); + lbMap = removeDuplicateExprs(lbMap); + } + AffineMap ubMap = getMultiSymbolIdentity(builder, ubs.size()); + { + fully2ComposeAffineMapAndOperands(rewriter, &ubMap, &ubs, DI); + affine::canonicalizeMapAndOperands(&ubMap, &ubs); + ubMap = removeDuplicateExprs(ubMap); + } + + affine::AffineForOp affineLoop = rewriter.create( + loop.getLoc(), lbs, lbMap, ubs, ubMap, getStep(loop.getStep()), + loop.getInits()); + + auto mergedYieldOp = + cast(loop.getRegion().front().getTerminator()); + + Block &newBlock = affineLoop.getRegion().front(); + + // The terminator is added if the iterator args are not provided. + // see the ::build method. + if (affineLoop.getNumIterOperands() == 0) { + auto *affineYieldOp = newBlock.getTerminator(); + rewriter.eraseOp(affineYieldOp); + } + + SmallVector vals; + rewriter.setInsertionPointToStart(&affineLoop.getRegion().front()); + for (Value arg : affineLoop.getRegion().front().getArguments()) { + if (rewrittenStep && arg == affineLoop.getInductionVar()) { + arg = rewriter.create( + loop.getLoc(), loop.getLowerBound(), + rewriter.create(loop.getLoc(), arg, loop.getStep())); + } + vals.push_back(arg); + } + assert(vals.size() == loop.getRegion().front().getNumArguments()); + rewriter.mergeBlocks(&loop.getRegion().front(), + &affineLoop.getRegion().front(), vals); + + rewriter.setInsertionPoint(mergedYieldOp); + rewriter.create(mergedYieldOp.getLoc(), + mergedYieldOp.getOperands()); + rewriter.eraseOp(mergedYieldOp); + + rewriter.replaceOp(loop, affineLoop.getResults()); + + return success(); + } + return failure(); + } +}; + +struct ParallelOpRaising : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + void canonicalizeLoopBounds(PatternRewriter &rewriter, + affine::AffineParallelOp forOp) const { + SmallVector lbOperands(forOp.getLowerBoundsOperands()); + SmallVector ubOperands(forOp.getUpperBoundsOperands()); + + auto lbMap = forOp.getLowerBoundsMap(); + auto ubMap = forOp.getUpperBoundsMap(); + + auto *scope = affine::getAffineScope(forOp)->getParentOp(); + DominanceInfo DI(scope); + + fully2ComposeAffineMapAndOperands(rewriter, &lbMap, &lbOperands, DI); + affine::canonicalizeMapAndOperands(&lbMap, &lbOperands); + + fully2ComposeAffineMapAndOperands(rewriter, &ubMap, &ubOperands, DI); + affine::canonicalizeMapAndOperands(&ubMap, &ubOperands); + + forOp.setLowerBounds(lbOperands, lbMap); + forOp.setUpperBounds(ubOperands, ubMap); + } + + LogicalResult matchAndRewrite(scf::ParallelOp loop, + PatternRewriter &rewriter) const final { + OpBuilder builder(loop); + + if (loop.getResults().size()) + return failure(); + + if (!llvm::all_of(loop.getLowerBound(), isValidIndex)) { + return failure(); + } + + if (!llvm::all_of(loop.getUpperBound(), isValidIndex)) { + return failure(); + } + + SmallVector steps; + for (auto step : loop.getStep()) + if (auto cst = step.getDefiningOp()) + steps.push_back(cst.value()); + else + return failure(); + + ArrayRef reductions; + SmallVector bounds; + for (size_t i = 0; i < loop.getLowerBound().size(); i++) + bounds.push_back(AffineMap::get( + /*dimCount=*/0, /*symbolCount=*/loop.getLowerBound().size(), + builder.getAffineSymbolExpr(i))); + affine::AffineParallelOp affineLoop = + rewriter.create( + loop.getLoc(), loop.getResultTypes(), reductions, bounds, + loop.getLowerBound(), bounds, loop.getUpperBound(), + steps); //, loop.getInitVals()); + + canonicalizeLoopBounds(rewriter, affineLoop); + + auto mergedYieldOp = + cast(loop.getRegion().front().getTerminator()); + + Block &newBlock = affineLoop.getRegion().front(); + + // The terminator is added if the iterator args are not provided. + // see the ::build method. + if (affineLoop.getResults().size() == 0) { + auto *affineYieldOp = newBlock.getTerminator(); + rewriter.eraseOp(affineYieldOp); + } + + SmallVector vals; + for (Value arg : affineLoop.getRegion().front().getArguments()) { + vals.push_back(arg); + } + rewriter.mergeBlocks(&loop.getRegion().front(), + &affineLoop.getRegion().front(), vals); + + rewriter.setInsertionPoint(mergedYieldOp); + rewriter.create(mergedYieldOp.getLoc(), + mergedYieldOp.getOperands()); + rewriter.eraseOp(mergedYieldOp); + + rewriter.replaceOp(loop, affineLoop.getResults()); + + return success(); + } +}; + +namespace { +struct RaiseSCFToAffine : public RaiseSCFToAffineBase { + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); + } +}; +} // namespace + +std::unique_ptr scalehls::createRaiseSCFToAffinePass() { + return std::make_unique(); +}