Skip to content

Commit

Permalink
Clean up debug printouts as op attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
erwei-xilinx committed Feb 5, 2025
1 parent 17a5d25 commit 06cdb35
Showing 1 changed file with 7 additions and 108 deletions.
115 changes: 7 additions & 108 deletions mlir/lib/Transform/AIRMiscPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1047,56 +1047,16 @@ FailureOr<Value> tileChannelOpByFactor(
auto newGetOp = rewriter.create<air::ChannelGetOp>(
loc, tys, deps, newChanOp.getSymName(), newIndices,
originalChanOp.getMemref(), newOffsets, newWraps, newStrides);
newGetOp->setAttr("id",
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
originalChanOp.getId()));
newGetOp->setAttrs(originalChanOp->getDiscardableAttrDictionary());
tokens.push_back(newGetOp.getAsyncToken());
opToSplitInfoMap[newGetOp] = splitInfoVec[i];
newGetOp->setAttr(
"split_dim",
mlir::IntegerAttr::get(IntegerType::get(ctx, 32), splitDimOnOffsets));
if (splitInfoAffineMap)
newGetOp->setAttr("affine_map",
mlir::AffineMapAttr::get(splitInfoAffineMap));
if (splitInfoSplitOffset)
newGetOp->setAttr("split_offset",
mlir::IntegerAttr::get(IntegerType::get(ctx, 32),
*splitInfoSplitOffset));
if (splitInfoSplitSize)
newGetOp->setAttr("split_size",
mlir::IntegerAttr::get(IntegerType::get(ctx, 32),
*splitInfoSplitSize));
if (splitInfoSplitStrideFactor)
newGetOp->setAttr("split_stride_factor",
mlir::IntegerAttr::get(IntegerType::get(ctx, 32),
*splitInfoSplitStrideFactor));
} else {
auto newPutOp = rewriter.create<air::ChannelPutOp>(
loc, tys, deps, newChanOp.getSymName(), newIndices,
originalChanOp.getMemref(), newOffsets, newWraps, newStrides);
newPutOp->setAttr("id",
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
originalChanOp.getId()));
newPutOp->setAttrs(originalChanOp->getDiscardableAttrDictionary());
tokens.push_back(newPutOp.getAsyncToken());
opToSplitInfoMap[newPutOp] = splitInfoVec[i];
newPutOp->setAttr(
"split_dim",
mlir::IntegerAttr::get(IntegerType::get(ctx, 32), splitDimOnOffsets));
if (splitInfoAffineMap)
newPutOp->setAttr("affine_map",
mlir::AffineMapAttr::get(splitInfoAffineMap));
if (splitInfoSplitOffset)
newPutOp->setAttr("split_offset",
mlir::IntegerAttr::get(IntegerType::get(ctx, 32),
*splitInfoSplitOffset));
if (splitInfoSplitSize)
newPutOp->setAttr("split_size",
mlir::IntegerAttr::get(IntegerType::get(ctx, 32),
*splitInfoSplitSize));
if (splitInfoSplitStrideFactor)
newPutOp->setAttr("split_stride_factor",
mlir::IntegerAttr::get(IntegerType::get(ctx, 32),
*splitInfoSplitStrideFactor));
}
}
auto newWaitAll = rewriter.create<air::WaitAllOp>(
Expand Down Expand Up @@ -1207,26 +1167,18 @@ void AIRSplitL2MemrefForBufferConstraintPass::partitionMemref(
return;
push_back_if_unique<int>(keys, offset_key);
chanOpPartitions[offset_key].push_back(op);
op->setAttr("partition_key",
IntegerAttr::get(IntegerType::get(ctx, 32), offset_key));
};

for (auto op : puts) {
op->setAttr("partitioning", BoolAttr::get(ctx, true));
if (!opToSplitInfoMap.count(op)) {
op->setAttr("opNotInOpToSplitInfoMap", BoolAttr::get(ctx, true));
if (!opToSplitInfoMap.count(op))
continue;
}
auto &[splitInfoDimOnOffsets, splitAffineMap, splitOffset, splitSize,
splitStride] = opToSplitInfoMap[op];
getChanOpPartitionsMap(chanOpPartitions, keys, splitInfoDimOnOffsets, op);
}
for (auto op : gets) {
op->setAttr("partitioning", BoolAttr::get(ctx, true));
if (!opToSplitInfoMap.count(op)) {
op->setAttr("opNotInOpToSplitInfoMap", BoolAttr::get(ctx, true));
if (!opToSplitInfoMap.count(op))
continue;
}
auto &[splitInfoDimOnOffsets, splitAffineMap, splitOffset, splitSize,
splitStride] = opToSplitInfoMap[op];
getChanOpPartitionsMap(chanOpPartitions, keys, splitInfoDimOnOffsets, op);
Expand Down Expand Up @@ -1495,15 +1447,6 @@ AIRSplitL2MemrefForBufferConstraintPass::getTargetMemrefAllocs(
int tilingFactor =
std::max(getChanCount(MM2SChannels), getChanCount(S2MMChannels));

// Keep debug log in alloc op's attributes. TODO: clean up.
allocOp->setAttr("split", BoolAttr::get(func.getContext(), true));
allocOp->setAttr("tilingFactor",
IntegerAttr::get(IntegerType::get(ctx, 32), tilingFactor));
if (getChanCount(MM2SChannels) > 1) {
allocOp->setAttr("split_type", StringAttr::get(ctx, "MM2SChannels"));
} else {
allocOp->setAttr("split_type", StringAttr::get(ctx, "S2MMChannels"));
}
llvm::MapVector<int, SmallVector<infoEntryTy>> infoEntryMap;
std::optional<int> splitDimOffset = std::nullopt;
std::optional<int> splitDimSize = std::nullopt;
Expand All @@ -1529,18 +1472,6 @@ AIRSplitL2MemrefForBufferConstraintPass::getTargetMemrefAllocs(
"memref splitting analysis failed to get the split dimension.");
return failure();
}
if (allocOp->hasAttr("split_dim"))
assert(allocOp->getAttrOfType<IntegerAttr>("split_dim").getInt() ==
*splitDim &&
"L2 memref tiled inconsistently across multiple data access "
"patterns. Cannot infer L2 memref tiling strategy.");
else {
if (*splitDim >= 0)
allocOp->setAttr(
"split_dim",
IntegerAttr::get(IntegerType::get(ctx, 32), *splitDim));
assert(*splitDim >= 0 && "failed to get split dimension");
}

// Methods to get root offset/size/stride from air.channel's operands, where
// root is either a constant, or a loop's induction variable.
Expand All @@ -1554,13 +1485,6 @@ AIRSplitL2MemrefForBufferConstraintPass::getTargetMemrefAllocs(
};
auto getRootSize = [&](Value offsetVal, Value sizeVal) {
std::optional<int> rootSize = std::nullopt;
// if (auto constSize = getConstantIntValue(sizeVal)){
// // splitDimSize = *constSize;
// // putgets[i]->setAttr(
// // "split_dim_size",
// // IntegerAttr::get(IntegerType::get(ctx, 32), *splitDimSize));
// }
// else
if (auto forOp = getScfForFromVal(offsetVal)) {
if (auto trip_count = air::getStaticScfForTripCountAsInt(forOp))
rootSize = *getConstantIntValue(sizeVal) * (*trip_count);
Expand All @@ -1582,39 +1506,21 @@ AIRSplitL2MemrefForBufferConstraintPass::getTargetMemrefAllocs(
auto offsetDimOpt =
air::getOffsetDimFromMemrefDim(*splitDim, putgets[i].getStrides(),
air::getTensorShape(memref.getType()));
if (offsetDimOpt)
putgets[i]->setAttr(
"split_dim",
IntegerAttr::get(IntegerType::get(ctx, 32), *offsetDimOpt));
// Infer offset at splitDim.
if (auto rootOffset =
getRootOffset(putgets[i].getOffsets()[*offsetDimOpt])) {
getRootOffset(putgets[i].getOffsets()[*offsetDimOpt]))
splitDimOffset = *rootOffset;
putgets[i]->setAttr(
"split_dim_offset",
IntegerAttr::get(IntegerType::get(ctx, 32), *splitDimOffset));
}
// Infer size at splitDim.
if (auto rootSize = getRootSize(putgets[i].getOffsets()[*offsetDimOpt],
putgets[i].getSizes()[*offsetDimOpt])) {
putgets[i].getSizes()[*offsetDimOpt]))
splitDimSize = *rootSize;
allocOp->setAttr(
"split_dim_size",
IntegerAttr::get(IntegerType::get(ctx, 32), *splitDimSize));
putgets[i]->setAttr(
"split_dim_size",
IntegerAttr::get(IntegerType::get(ctx, 32), *splitDimSize));
}
// Infer stride (factor) at splitDim. If the root comes from an scf.for
// loop, and if the loop has non-unit step size, then that multiplier
// should be applied to other split channe put/get ops.
if (auto rootStrideFactor =
getRootStrideFactor(putgets[i].getOffsets()[*offsetDimOpt],
putgets[i].getStrides()[*offsetDimOpt])) {
splitDimStrideFactor = *rootStrideFactor;
putgets[i]->setAttr(
"split_dim_stride_factor",
IntegerAttr::get(IntegerType::get(ctx, 32), *splitDimStrideFactor));
// Cancel out the non-unit step size on the for loop, to get contiguous
// access pattern on memrefs after split.
if (auto forOp =
Expand All @@ -1625,11 +1531,8 @@ AIRSplitL2MemrefForBufferConstraintPass::getTargetMemrefAllocs(
}
AffineMap applyMap;
auto apply = getAffineMapOnMemrefSplitDim(putgets[i], *offsetDimOpt);
if (apply) {
if (apply)
applyMap = apply.getAffineMap();
allocOp->setAttr("affine_map", AffineMapAttr::get(applyMap));
putgets[i]->setAttr("affine_map", AffineMapAttr::get(applyMap));
}

infoEntryTy newEntry = {*offsetDimOpt, applyMap, splitDimOffset,
splitDimSize, splitDimStrideFactor};
Expand Down Expand Up @@ -1942,10 +1845,6 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() {
signalPassFailure();
erased.insert(par);
}
for (auto &[old, news] : parUnrollMap) {
for (auto newOp : news)
newOp->setAttr("unrolled", BoolAttr::get(ctx, true));
}
// Update map after loop unrolling.
for (auto &[oldOp, splitInfo] : opToSplitInfoMap) {
Operation *o = oldOp;
Expand Down

0 comments on commit 06cdb35

Please sign in to comment.