Skip to content

Commit

Permalink
Review comments 2.0 + add another lit test + Handle edge cases via li…
Browse files Browse the repository at this point in the history
…t test
  • Loading branch information
Abhishek-Varma committed Aug 28, 2024
1 parent 0a5664c commit 3fb3fb5
Show file tree
Hide file tree
Showing 2 changed files with 797 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,25 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

/// Utility to help fetch those input DmaCpyNd Ops which needs to be split.
/// TODO(avarma): Move this as a reusable component for other transformations ?
static SmallVector<AMDAIE::DmaCpyNdOp> fetchDmaCpyNdOpsToSplit(
ModuleOp moduleOp) {
SmallVector<AMDAIE::DmaCpyNdOp> l2ToL1DmaOps;
// We are currently walking through CoreOps gathering 3rd Input DmaOp (if
// applicable) from them.
// TODO: We will generalize this later.
moduleOp.walk([&](AMDAIE::CoreOp coreOp) {
SmallVector<Value> inputDmas = coreOp.getInputDmas();
if (inputDmas.size() != 3) return WalkResult::skip();
auto dmaCpyNdOp = inputDmas[2].getDefiningOp<AMDAIE::DmaCpyNdOp>();
assert(dmaCpyNdOp && "expected an amdaie.dma_cpy_nd op");
l2ToL1DmaOps.push_back(dmaCpyNdOp);
return WalkResult::advance();
});
return l2ToL1DmaOps;
}

class AMDAIESplitLogicalObjectFifosPass
: public impl::AMDAIESplitLogicalObjectFifosBase<
AMDAIESplitLogicalObjectFifosPass> {
Expand All @@ -33,28 +52,33 @@ void AMDAIESplitLogicalObjectFifosPass::runOnOperation() {
MLIRContext *context = &getContext();
IRRewriter rewriter(context);

SmallVector<AMDAIE::DmaCpyNdOp> l2ToL1DmaOps;
// We are currently walking through CoreOps gathering 3rd Input DmaOp (if
// applicable) from them.
// TODO: We will generalize this later.
moduleOp.walk([&](AMDAIE::CoreOp coreOp) {
SmallVector<Value> inputDmas = coreOp.getInputDmas();
if (inputDmas.size() != 3) return WalkResult::skip();
auto dmaCpyNdOp = inputDmas[2].getDefiningOp<AMDAIE::DmaCpyNdOp>();
assert(dmaCpyNdOp && "expected an amdaie.dma_cpy_nd op");
l2ToL1DmaOps.push_back(dmaCpyNdOp);
return WalkResult::advance();
});
SmallVector<AMDAIE::DmaCpyNdOp> l2ToL1DmaOps =
fetchDmaCpyNdOpsToSplit(moduleOp);

if (l2ToL1DmaOps.size() == 0) return;

SmallVector<OpFoldResult> baseSourceOffsets =
l2ToL1DmaOps[0].getSourceMixedOffsets();
LogicalObjectFifoFromMemrefOp sourceObjectFifo =
l2ToL1DmaOps[0].getSourceObjectFifo();
auto sourceAllocOp =
sourceObjectFifo.getMemref().getDefiningOp<memref::AllocOp>();
if (!sourceAllocOp) {
sourceObjectFifo->emitOpError()
<< "expected alloc op as the defining op of source "
"logicalobjectfifo.from_memref";
return signalPassFailure();
}
// We will now capture those dimensions where L2 memory was split. The way we
// do this is by checking all L2->L1 DmaOps' source offset and marking those
// dimensions which are not equal to at least one of the source offsets.
DenseSet<unsigned> splitDimensionsSet;
for (unsigned i = 1, n = l2ToL1DmaOps.size(); i < n; i++) {
if (l2ToL1DmaOps[i].getSourceObjectFifo() != sourceObjectFifo) {
l2ToL1DmaOps[i]->emitRemark() << "has different source objectfifo";
sourceObjectFifo->emitRemark() << "is the expected source objectfifo";
return signalPassFailure();
}
SmallVector<OpFoldResult> sourceOffsets =
l2ToL1DmaOps[i].getSourceMixedOffsets();
for (unsigned j = 0, m = baseSourceOffsets.size(); j < m; j++) {
Expand All @@ -63,34 +87,28 @@ void AMDAIESplitLogicalObjectFifosPass::runOnOperation() {
}
}
}
// Fetch the L3 -> L2 Dma Op corresponding to the L2 buffer as target.
AMDAIE::DmaCpyNdOp l3ToL2DmaOp;
DenseSet<Operation *> toBeErased;
for (Operation *objFifoUserOp : sourceObjectFifo->getUsers()) {
if (auto dmaOp = dyn_cast<AMDAIE::DmaCpyNdOp>(objFifoUserOp);
dmaOp.getTargetObjectFifo() == sourceObjectFifo) {
l3ToL2DmaOp = dmaOp;
toBeErased.insert(dmaOp);
break;
}
}
toBeErased.insert(sourceAllocOp);
toBeErased.insert(sourceObjectFifo);

OpFoldResult zeroVal = getAsIndexOpFoldResult(context, 0);
OpFoldResult oneVal = getAsIndexOpFoldResult(context, 1);
DenseSet<Operation *> toBeErased;
// Traverse each L2->L1 DmaCpyNd op and split them.
for (AMDAIE::DmaCpyNdOp l2ToL1DmaOp : l2ToL1DmaOps) {
LogicalObjectFifoFromMemrefOp sourceObjectFifo =
l2ToL1DmaOp.getSourceObjectFifo();
auto sourceAllocOp =
sourceObjectFifo.getMemref().getDefiningOp<memref::AllocOp>();
uint64_t sourceMemrefSpace = sourceObjectFifo.getMemorySpaceAsUInt();
if (!sourceAllocOp || sourceMemrefSpace != 1) continue;
LogicalObjectFifoFromMemrefOp targetObjectFifo =
l2ToL1DmaOp.getTargetObjectFifo();
Value targetAllocOp = targetObjectFifo.getMemref();

// Fetch the L3 -> L2 Dma Op corresponding to the L2 buffer as target.
AMDAIE::DmaCpyNdOp l3ToL2DmaOp;
for (Operation *objFifoUserOp : sourceObjectFifo->getUsers()) {
if (auto dmaOp = dyn_cast<AMDAIE::DmaCpyNdOp>(objFifoUserOp);
dmaOp.getTargetObjectFifo() == sourceObjectFifo) {
l3ToL2DmaOp = dmaOp;
toBeErased.insert(dmaOp);
break;
}
}
toBeErased.insert(sourceAllocOp);
toBeErased.insert(sourceObjectFifo);

SmallVector<OpFoldResult, 6> staticL2AsSourceOffsets =
l2ToL1DmaOp.getSourceMixedOffsets();
SmallVector<OpFoldResult, 6> staticL2AsSourceSizes =
Expand Down Expand Up @@ -125,13 +143,20 @@ void AMDAIESplitLogicalObjectFifosPass::runOnOperation() {
std::optional<int64_t> constantOffset =
getConstantIntValue(staticL2AsSourceOffsets[dim]);
if (!constantOffset) {
l2ToL1DmaOp->emitOpError("found a non-constant value for offset");
l2ToL1DmaOp->emitRemark()
<< "found a non-constant value for source offset at dim " << dim;
return signalPassFailure();
}
std::optional<int64_t> constantSize =
getConstantIntValue(staticL2AsTargetSizes[dim + 2]);
if (!constantSize) {
l3ToL2DmaOp->emitRemark()
<< "found a non-constant value for target size at dim "
<< (dim + 2);
return signalPassFailure();
}
dimToOffsetMapForL3AsSource.insert(
{dim + 2,
constantOffset.value() *
(getConstantIntValue(staticL2AsTargetSizes[dim + 2]).value())});
{dim + 2, constantOffset.value() * constantSize.value()});
staticL2AsSourceOffsets[dim] = zeroVal;
staticL2AsSourceSizes[dim] = oneVal;
staticL2AsTargetOffsets[dim] = zeroVal;
Expand Down Expand Up @@ -161,22 +186,83 @@ void AMDAIESplitLogicalObjectFifosPass::runOnOperation() {

SmallVector<OpFoldResult, 4> staticL3AsSourceOffsets =
l3ToL2DmaOp.getSourceMixedOffsets();
// We now traverse the map : DIM -> CONST_OFFSET_TO_ADD we created earlier
// to update extraction offsets while splitting L3->L2.
/*
For L3 -> L2 DmaCpyNd :-
From offset (0,0) we are extracting one 4x4 memref.
_______
|. . . .|
|. . . .|
|. . . .|
|. . . .|
---------
After split we will extract four 2x2 memrefs.
So, the corresponding offsets will be :-
1. Offset (0,0) - extract 2x2 memref
___
|. .|. .
|. .|. .
-----
. . . .
. . . .
2. Offset (0,2) - extract 2x2 memref
___
. .|. .|
. .|. .|
-----
. . . .
. . . .
3. Offset (2,0) - extract 2x2 memref
. . . .
. . . .
___
|. .|. .
|. .|. .
-----
4. Offset (2,2) - extract 2x2 memref
. . . .
. . . .
___
. .|. .|
. .|. .|
-----
The following logic performs this computation of offsets for L3 source.
*/
for (auto [dim, offsetToAdd] : dimToOffsetMapForL3AsSource) {
auto applyOp = cast<affine::AffineApplyOp>(
cast<Value>(staticL3AsSourceOffsets[dim]).getDefiningOp());
AffineMap affineMap = applyOp.getAffineMap();
AffineExpr affineExpr = affineMap.getResult(0);
AffineExpr newAffineExpr = affineExpr + offsetToAdd;
auto newAffineMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
{newAffineExpr}, context);
IRRewriter::InsertPoint oldInsertionPoint = rewriter.saveInsertionPoint();
rewriter.setInsertionPoint(applyOp);
auto newAffineApplyOp = rewriter.create<affine::AffineApplyOp>(
applyOp.getLoc(), newAffineMap, applyOp.getMapOperands());
rewriter.restoreInsertionPoint(oldInsertionPoint);
staticL3AsSourceOffsets[dim] = newAffineApplyOp.getResult();
auto l3SourceOffsetVal = cast<Value>(staticL3AsSourceOffsets[dim]);
Operation *defOpOfL3SourceOffset = l3SourceOffsetVal.getDefiningOp();
if (!defOpOfL3SourceOffset) {
// TODO: Handle this case better later.
l3ToL2DmaOp->emitRemark()
<< "source offset at dim " << dim << " is a block argument";
return signalPassFailure();
}
Location loc = defOpOfL3SourceOffset->getLoc();
rewriter.setInsertionPoint(defOpOfL3SourceOffset);
OpBuilder::InsertionGuard guard(rewriter);
Value newL3AsSourceOffsetVal;
if (auto applyOp =
dyn_cast<affine::AffineApplyOp>(defOpOfL3SourceOffset)) {
AffineExpr affineExpr = applyOp.getAffineMap().getResult(0);
AffineExpr newAffineExpr = affineExpr + offsetToAdd;
auto newAffineMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
{newAffineExpr}, context);
newL3AsSourceOffsetVal = rewriter.create<affine::AffineApplyOp>(
loc, newAffineMap, applyOp.getMapOperands());
} else if (auto constantOffset = getConstantIntValue(l3SourceOffsetVal)) {
int64_t newOffset = *constantOffset + offsetToAdd;
newL3AsSourceOffsetVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(newOffset));
} else {
// TODO: Ideally we should be able to handle even +, -, *, /, etc.
// But handle this later (if at all!) as such cases aren't going
// to arise.
l3ToL2DmaOp->emitRemark()
<< "Unhandled expression for source offset at dim " << dim;
return signalPassFailure();
}
staticL3AsSourceOffsets[dim] = newL3AsSourceOffsetVal;
}
// Create new L3 -> L2 Dma Op.
rewriter.setInsertionPoint(l3ToL2DmaOp);
Expand Down
Loading

0 comments on commit 3fb3fb5

Please sign in to comment.