Skip to content

Commit

Permalink
Merge branch 'main' into bump-iree-12102024
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhang93 authored Dec 11, 2024
2 parents 94da847 + db10c75 commit 573b750
Show file tree
Hide file tree
Showing 12 changed files with 275 additions and 215 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ class FoldDmaOpLinearDims
"expected a source and target memory space for hardware aware "
"linear dimension folding");
}
AMDAIE::DmaDimConfig dmaDimConfig(
deviceModel.value(), sourceMemSpace.value(), targetMemSpace.value());
maxSourceSizes = dmaDimConfig.getMaxSizes<CopyOpOperateOn::Source>();
maxTargetSizes = dmaDimConfig.getMaxSizes<CopyOpOperateOn::Target>();
DmaDimConfig sourceDmaDimConfig(deviceModel.value(),
sourceMemSpace.value());
maxSourceSizes = sourceDmaDimConfig.getMaxSizes();
DmaDimConfig targetDmaDimConfig(deviceModel.value(),
targetMemSpace.value());
maxTargetSizes = targetDmaDimConfig.getMaxSizes();
}
LogicalResult sourceRes = foldLinearDims(
op.getContext(), sourceOffsets, sourceSizes, sourceStrides,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ struct CombineStridedOps
return rewriter.notifyMatchFailure(
nextStridedOp, "expected a source and target memory space");
}
AMDAIE::DmaDimConfig dmaDimConfig(deviceModel, sourceMemspaceInt.value(),
targetMemspaceInt.value());
size_t sourceMaxNbDims = dmaDimConfig.sourceMaxNbDims;
size_t targetMaxNbDims = dmaDimConfig.targetMaxNbDims;
DmaDimConfig sourceDmaDimConfig(deviceModel, sourceMemspaceInt.value());
size_t sourceMaxNbDims = sourceDmaDimConfig.maxNbDims;
DmaDimConfig targetDmaDimConfig(deviceModel, targetMemspaceInt.value());
size_t targetMaxNbDims = targetDmaDimConfig.maxNbDims;

SmallVector<OpFoldResult> sourceOffsetsA = op.getSourceMixedOffsets();
SmallVector<OpFoldResult> sourceSizesA = op.getSourceMixedSizes();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ struct HalfDmaCpyNdToNpuConverter final
ArrayRef<OpFoldResult> strides) const {
uint8_t numIntraAddrDim = deviceModel.getDmaProp<uint8_t>(
tileType, AMDAIE::AMDAIEDmaProp::NumAddrDim);
uint8_t numAddrDim = numIntraAddrDim + kAMDAIEDmaNbInterDims;
uint8_t numAddrDim =
numIntraAddrDim + deviceModel.deviceConfig.dmaNbInterDims;
auto subspanOp = dyn_cast_if_present<IREE::HAL::InterfaceBindingSubspanOp>(
logicalObjFifo.getMemref().getDefiningOp());
if (!subspanOp) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ struct SubsumeLoopIntoDMA
/// operation.
LogicalResult rewriteWithLoopLikeOpParent(
AMDAIE::DoublyStridedOpInterface op, PatternRewriter &rewriter,
const AMDAIE::DmaDimConfig &dmaDimConfig,
const DmaDimConfig &sourceDmaDimConfig,
const DmaDimConfig &targetDmaDimConfig,
const SmallVector<int64_t> &lowerBounds,
const SmallVector<int64_t> &upperBounds,
const SmallVector<int64_t> &steps,
Expand Down Expand Up @@ -210,10 +211,10 @@ struct SubsumeLoopIntoDMA
if (nbIterations > 1) nbNonUnitIterations++;
}
if (newSourceOffsets.size() + nbNonUnitIterations >
dmaDimConfig.sourceMaxNbDims)
sourceDmaDimConfig.maxNbDims)
return failure();
if (newTargetOffsets.size() + nbNonUnitIterations >
dmaDimConfig.targetMaxNbDims)
targetDmaDimConfig.maxNbDims)
return failure();

// Fail if zero stride is only supported on the outer dimension and adding
Expand Down Expand Up @@ -309,10 +310,8 @@ struct SubsumeLoopIntoDMA
insertInFront(newSourceSizes, insertSourceSizes);
SmallVector<int64_t> newSourceStridesInt =
insertInFront(newSourceStrides, insertSourceStrides);
SmallVector<int64_t> maxSizes =
dmaDimConfig.getMaxSizes<CopyOpOperateOn::Source>();
SmallVector<int64_t> maxStrides =
dmaDimConfig.getMaxStrides<CopyOpOperateOn::Source>();
SmallVector<int64_t> maxSizes = sourceDmaDimConfig.getMaxSizes();
SmallVector<int64_t> maxStrides = sourceDmaDimConfig.getMaxStrides();
assert(maxSizes.size() >= newSourceSizesInt.size() &&
"Max number of dimensions exceeded");
size_t begin = maxSizes.size() - newSourceSizesInt.size();
Expand All @@ -335,10 +334,8 @@ struct SubsumeLoopIntoDMA
insertInFront(newTargetSizes, insertTargetSizes);
SmallVector<int64_t> newTargetStridesInt =
insertInFront(newTargetStrides, insertTargetStrides);
SmallVector<int64_t> maxSizes =
dmaDimConfig.getMaxSizes<CopyOpOperateOn::Target>();
SmallVector<int64_t> maxStrides =
dmaDimConfig.getMaxStrides<CopyOpOperateOn::Target>();
SmallVector<int64_t> maxSizes = targetDmaDimConfig.getMaxSizes();
SmallVector<int64_t> maxStrides = targetDmaDimConfig.getMaxStrides();
assert(maxSizes.size() >= newTargetSizesInt.size() &&
"Max number of dimensions exceeded");
size_t begin = maxSizes.size() - newTargetSizesInt.size();
Expand Down Expand Up @@ -413,7 +410,8 @@ struct SubsumeLoopIntoDMA
/// optional `affine.apply` user for now.
LogicalResult rewriteWithForOpParent(
AMDAIE::DoublyStridedOpInterface op, PatternRewriter &rewriter,
const AMDAIE::DmaDimConfig &dmaDimConfig) const {
const DmaDimConfig &sourceDmaDimConfig,
const DmaDimConfig &targetDmaDimConfig) const {
auto forOp = dyn_cast<scf::ForOp>(op->getParentOp());
if (!forOp) return failure();

Expand All @@ -440,17 +438,18 @@ struct SubsumeLoopIntoDMA
SmallVector<int64_t> upperBounds = {upperBound.value()};
SmallVector<int64_t> steps = {step.value()};
SmallVector<DenseSet<Value>> inductionValues = {curIvValues};
return rewriteWithLoopLikeOpParent(op, rewriter, dmaDimConfig, lowerBounds,
upperBounds, steps, inductionValues,
curIvValues);
return rewriteWithLoopLikeOpParent(
op, rewriter, sourceDmaDimConfig, targetDmaDimConfig, lowerBounds,
upperBounds, steps, inductionValues, curIvValues);
}

/// Main rewrite function for a doubly strided operation with a `scf.forall`
/// parent operation. Only handle loop induction variables with an
/// optional `affine.apply` user for now.
LogicalResult rewriteWithForallOpParent(
AMDAIE::DoublyStridedOpInterface op, PatternRewriter &rewriter,
const AMDAIE::DmaDimConfig &dmaDimConfig) const {
const DmaDimConfig &sourceDmaDimConfig,
const DmaDimConfig &targetDmaDimConfig) const {
auto forallOp = dyn_cast<scf::ForallOp>(op->getParentOp());
if (!forallOp) return failure();

Expand Down Expand Up @@ -481,9 +480,10 @@ struct SubsumeLoopIntoDMA
}
inductionValues.push_back(curIvValues);
}
return rewriteWithLoopLikeOpParent(
op, rewriter, dmaDimConfig, lowerBounds.value(), upperBounds.value(),
steps.value(), inductionValues, allInductionValues);
return rewriteWithLoopLikeOpParent(op, rewriter, sourceDmaDimConfig,
targetDmaDimConfig, lowerBounds.value(),
upperBounds.value(), steps.value(),
inductionValues, allInductionValues);
}

LogicalResult matchAndRewrite(AMDAIE::DoublyStridedOpInterface op,
Expand Down Expand Up @@ -562,13 +562,15 @@ struct SubsumeLoopIntoDMA
return rewriter.notifyMatchFailure(
op, "expected a source and target memory space");
}
AMDAIE::DmaDimConfig dmaDimConfig(deviceModel, sourceMemspaceInt.value(),
targetMemspaceInt.value());
DmaDimConfig sourceDmaDimConfig(deviceModel, sourceMemspaceInt.value());
DmaDimConfig targetDmaDimConfig(deviceModel, targetMemspaceInt.value());

if (isa<scf::ForOp>(parentOp)) {
return rewriteWithForOpParent(op, rewriter, dmaDimConfig);
return rewriteWithForOpParent(op, rewriter, sourceDmaDimConfig,
targetDmaDimConfig);
} else if (isa<scf::ForallOp>(parentOp)) {
return rewriteWithForallOpParent(op, rewriter, dmaDimConfig);
return rewriteWithForallOpParent(op, rewriter, sourceDmaDimConfig,
targetDmaDimConfig);
} else {
return rewriter.notifyMatchFailure(
op, "Has parent operation of currently unsupported type");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "iree-amd-aie/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"

#define DEBUG_TYPE "iree-amdaie-fuse-fill-into-forall"
Expand All @@ -29,60 +30,87 @@ class AMDAIEFuseFillIntoForallPass

void AMDAIEFuseFillIntoForallPass::runOnOperation() {
MLIRContext *context = &getContext();
mlir::FunctionOpInterface funcOp = getOperation();
IRRewriter rewriter(context);

// Find the producer op, in this case is linalg.fill.
TilingInterface tileableProducer;
funcOp->walk([&](TilingInterface op) {
if (isa<linalg::FillOp>(op)) {
tileableProducer = op;
return WalkResult::interrupt();
}
return WalkResult::advance();
});

if (!tileableProducer) {
LLVM_DEBUG(llvm::dbgs() << "There is no producer op to be fused.\n");
// Find a unique FillOp with a single output, or return.
SmallVector<linalg::FillOp> fillOps;
getOperation()->walk(
[&](linalg::FillOp fillOp) { fillOps.push_back(fillOp); });
if (fillOps.size() != 1) {
LLVM_DEBUG(llvm::dbgs() << "Expected exactly 1 fill op, but found "
<< fillOps.size() << ".\n");
return;
}

// Search the first use by a scf::ForallOp user.
scf::ForallOp forallOp;
auto itProducerUses =
llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
return forallOp;
});
linalg::FillOp fillOp = fillOps[0];
if (fillOp.getResults().size() != 1) {
LLVM_DEBUG(llvm::dbgs() << "Expected fill op to have exactly 1 result, but "
<< "found " << fillOp.getResults().size() << ".\n");

return;
};

// Confirm that there is a unique user that is a forall, and match
// the block argument that is used by the fill op, or return.
if (!fillOp->hasOneUse()) {
LLVM_DEBUG(llvm::dbgs()
<< "Expected exactly 1 use of fill op, but found 0 or 2+.");
return;
}
OpOperand &fillUse = *fillOp->getUses().begin();
auto forallOp = dyn_cast<scf::ForallOp>(fillUse.getOwner());
if (!forallOp) {
LLVM_DEBUG(llvm::dbgs() << "There is no forall Op.\n");
LLVM_DEBUG(llvm::dbgs() << "Expected fill op to be used by a forall op, "
<< "but unique user is "
<< fillUse.getOwner()->getName() << ".\n");
return;
}

// Search the producer slices accessed within the Forall op.
OpOperand *pUse = &(*itProducerUses);
BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);

auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
return sliceOp;
});
if (itBBArgUsers == bbArg.getUsers().end()) {
funcOp->emitOpError("There is no extract tensor slice.");
return signalPassFailure();
BlockArgument bbArg = forallOp.getTiedBlockArgument(&fillUse);

// Find 0 or 1 ExtractSliceOps that use the fill result, or return.
tensor::ExtractSliceOp extractSliceOp;
for (Operation *user : bbArg.getUsers()) {
if (auto nxt = dyn_cast<tensor::ExtractSliceOp>(user)) {
if (extractSliceOp) {
LLVM_DEBUG(llvm::dbgs()
<< "Expected at most 1 extract_slice op, but found 2+.\n");
return;
}
extractSliceOp = nxt;
}
}
auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);

LoopLikeOpInterface loops =
cast<LoopLikeOpInterface>(forallOp.getOperation());

// Materialize the slice of the producer in place.
std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
scf::tileAndFuseProducerOfSlice(rewriter, sliceOpToTile,
MutableArrayRef(&loops, 1));
if (!fusedProducer) {
funcOp->emitOpError("Failed to fuse fill op into forall loop.");
return signalPassFailure();

if (extractSliceOp) {
LoopLikeOpInterface loops =
cast<LoopLikeOpInterface>(forallOp.getOperation());

// Materialize the slice of the producer in place.
std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
scf::tileAndFuseProducerOfSlice(rewriter, extractSliceOp,
MutableArrayRef(&loops, 1));
if (!fusedProducer) {
fillOp->emitOpError("could not be fused into forall");
return signalPassFailure();
}
} else {
// In the case where there are no extract_slice ops, we manually create the
// fill at the beginning of the forall body. This situation might arise
// if the extract_slice has been folded, for example if the forall is
// over a grid if size 1.
rewriter.setInsertionPointToStart(forallOp.getBody());
auto fusedFill =
rewriter.create<linalg::FillOp>(fillOp.getLoc(), fillOp.value(), bbArg);
rewriter.replaceUsesWithIf(
bbArg, fusedFill.getResult(0), [&](OpOperand &operand) {
Operation *owner = operand.getOwner();
if (owner == fusedFill || isa<tensor::ParallelInsertSliceOp>(owner)) {
return false;
}
return true;
});

// Do not use the result of the old fill.
rewriter.replaceAllUsesWith(fillOp.getResults()[0], fillOp.getOutputs()[0]);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,8 @@ void AIEDeviceBuilder::foldDims(const SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> tmpStrides;
(void)foldUnitDims(rewriter.getContext(), offsets, sizes, strides, tmpOffsets,
tmpSizes, tmpStrides);
AMDAIE::DmaDimConfig dmaDimConfig(deviceModel, memSpace, memSpace);
SmallVector<int64_t> maxSizes =
dmaDimConfig.getMaxSizes<CopyOpOperateOn::Source>();
DmaDimConfig dmaDimConfig(deviceModel, memSpace);
SmallVector<int64_t> maxSizes = dmaDimConfig.getMaxSizes();
(void)foldLinearDims(rewriter.getContext(), tmpOffsets, tmpSizes, tmpStrides,
newOffsets, newSizes, newStrides, maxSizes);
(void)foldSingleDim(newOffsets, newSizes, newStrides);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,4 +483,32 @@ LogicalResult moveNpuDmaSyncUsersAfterAncestorInSameBlock(
return success();
}

//===----------------------------------------------------------------------===//
// DmaDimConfig
//===----------------------------------------------------------------------===//

SmallVector<int64_t> DmaDimConfig::getMaxSizes() const {
uint32_t maxIntraSize = deviceModel.getDmaBdProp<uint16_t>(
tileType, 0, AMDAIE::AMDAIEDmaBdProp::WrapMax);
uint32_t maxInterSize = deviceModel.getDmaBdProp<uint8_t>(
tileType, 0, AMDAIE::AMDAIEDmaBdProp::IterWrapMax);
SmallVector<int64_t> maxSizes(maxNbDims, maxIntraSize);
std::fill_n(maxSizes.begin(), nbInterDims, maxInterSize);
// The outermost intra size doesn't have limit in HW.
maxSizes[nbInterDims] = std::numeric_limits<int64_t>::max();
return maxSizes;
}

SmallVector<int64_t> DmaDimConfig::getMaxStrides() const {
uint32_t maxIntraStride = deviceModel.getDmaBdProp<uint32_t>(
tileType, 0, AMDAIE::AMDAIEDmaBdProp::StepSizeMax);
uint32_t maxInterStride = deviceModel.getDmaBdProp<uint32_t>(
tileType, 0, AMDAIE::AMDAIEDmaBdProp::IterStepSizeMax);
// +1 because values are encoded in HW BDs as (value - 1), so the range is
// [1:2^x].
SmallVector<int64_t> stepSizes(maxNbDims, maxIntraStride + 1);
std::fill_n(stepSizes.begin(), nbInterDims, maxInterStride + 1);
return stepSizes;
}

} // namespace mlir::iree_compiler::AMDAIE
Loading

0 comments on commit 573b750

Please sign in to comment.