Skip to content

Commit

Permalink
AIRSplitL2Memref: Fixup a bug when trying to align dims (#670)
Browse files Browse the repository at this point in the history
* Fixup a bug when trying to align dims

* Clang format

* Fixup a bug in logic of unrolling scf.parallel, where invariant operands of each chan op should not be touched
  • Loading branch information
erwei-xilinx authored Jul 18, 2024
1 parent fbdcad3 commit e9a41bc
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
26 changes: 22 additions & 4 deletions mlir/lib/Transform/AIRMiscPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1360,13 +1360,31 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() {
// different rank, then we check if ranks can be matched after leading
// singleton dimensions are removed. If the ranks still do not match,
// then the behaviour is unstable.
int numLeadingSingletonDims = 0;
for (auto memrefDim : memrefShape)
int numLeadingSingletonDimDiff = 0;
for (auto memrefDim : memrefShape) {
if (memrefDim == 1)
numLeadingSingletonDims++;
numLeadingSingletonDimDiff++;
else
break;
}
for (auto memrefDim :
air::getTensorShape(theOtherChanOp[0].getMemref().getType())) {
if (memrefDim == 1)
numLeadingSingletonDimDiff--;
else
break;
}
if (dim - numLeadingSingletonDimDiff < 0) {
chanUserOp->emitOpError(
"Failed to split data access pattern along dimension ")
<< std::to_string(dim)
<< " due to dimension misalignment with channel op at the other "
"side.";
return;
}
Value newWaitAll1 = tileChannelOpByFactor(
theOtherChanOp[0], targetColTilingFactor, memrefShape[dim],
dim - numLeadingSingletonDims, new_chan, loc, ctx);
dim - numLeadingSingletonDimDiff, new_chan, loc, ctx);

// Update dependency.
auto oldToken =
Expand Down
8 changes: 6 additions & 2 deletions mlir/lib/Util/Dependency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,20 +783,24 @@ LogicalResult unrollAIRChannelPutGetInScfParallel(OpBuilder builder,
dyn_cast<mlir::affine::AffineApplyOp>(oper.getDefiningOp()))
position_apply = apply_op;
else if (auto exec = dyn_cast<air::ExecuteOp>(oper.getDefiningOp())) {
auto child_op = &exec.getBody().front().getOperations().front();
if (auto apply_op = dyn_cast<mlir::affine::AffineApplyOp>(child_op))
if (auto apply_op =
dyn_cast<mlir::affine::AffineApplyOp>(exec.getChildOp()))
position_apply = apply_op;
}
if (position_apply) {
bool positionApplyIsVariantWrtPar = false;
SmallVector<AffineExpr> const_syms;
for (unsigned i = 0; i < par.getInductionVars().size(); i++) {
for (auto map_o : position_apply.getMapOperands()) {
if (par.getInductionVars()[i] == map_o) {
const_syms.push_back(
getAffineConstantExpr(position[i], builder.getContext()));
positionApplyIsVariantWrtPar = true;
}
}
}
if (!positionApplyIsVariantWrtPar)
continue;
AffineExpr newC = position_apply.getAffineMap().getResult(0);
newC = newC.replaceSymbols(const_syms);
auto expr = dyn_cast<AffineConstantExpr>(simplifyAffineExpr(
Expand Down

0 comments on commit e9a41bc

Please sign in to comment.