Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
newling committed Dec 12, 2024
1 parent 3ce1b5c commit 7aacd04
Showing 1 changed file with 21 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,40 +19,30 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

FailureOr<Operation *> getParentBeforeLoop(BlockArgument blkArg) {
Operation *parent = blkArg.getOwner()->getParentOp();
LoopLikeOpInterface loop = dyn_cast<LoopLikeOpInterface>(parent);
if (!loop) return failure();
Operation *operandParent = loop.getTiedLoopInit(blkArg)->getOwner();
return operandParent;
}

/// A utility function specific to this pass which, given a value `operand`,
/// traverses the def-chain till it finds a tensor.extract_slice. The 2 cases
/// where it successfully finds and returns an extract_slice (SLICE) are:
///
/// Case 1)
/// pack -> block arg -> SLICE -> pack -> pack -> pack -> operand
/// ^^^^^^^^^^^^^^^^^^^^
/// any number (>= 0) of trailing packs
///
/// Case 2)
/// pack -> SLICE -> pack -> pack -> pack -> operand
/// ^^^^^^^^^^^^^^^^^^^^
/// any number (>= 0) of trailing packs
///
/// Case 2)
/// pack -> block arg -> SLICE -> pack -> pack -> pack -> operand
/// ^^^^^^^^^^^^^^^^^^^^
/// any number (>= 0) of trailing packs
///
/// Case 2 only matches where `block arg` is for a loop operation.
static FailureOr<tensor::ExtractSliceOp> getTensorExtractSliceDefiningOp(
Value operand) {
// roll back through all the packs immediately preceding `operand`.
while (operand.getDefiningOp() &&
isa<tensor::PackOp>(operand.getDefiningOp())) {
while (isa_and_present<tensor::PackOp>(operand.getDefiningOp())) {
operand = operand.getDefiningOp()->getOperand(0);
}

// If the parent of `operand` is not an extract_slice, return failure.
Operation *defOp = operand.getDefiningOp();
if (!defOp) return failure();
tensor::ExtractSliceOp sliceOp = dyn_cast<tensor::ExtractSliceOp>(defOp);
tensor::ExtractSliceOp sliceOp =
dyn_cast_if_present<tensor::ExtractSliceOp>(operand.getDefiningOp());
if (!sliceOp) return failure();

// Case 1 outlined above.
Expand All @@ -62,9 +52,11 @@ static FailureOr<tensor::ExtractSliceOp> getTensorExtractSliceDefiningOp(

// Case 2 outlined above.
else if (auto blkArg = dyn_cast<BlockArgument>(sliceOp.getSource())) {
FailureOr<Operation *> operandParent = getParentBeforeLoop(blkArg);
if (failed(operandParent)) return failure();
if (isa_and_present<tensor::PackOp>(operandParent.value())) return sliceOp;
Operation *parent = blkArg.getOwner()->getParentOp();
LoopLikeOpInterface loop = dyn_cast<LoopLikeOpInterface>(parent);
if (!loop) return failure();
Operation *operandParent = loop.getTiedLoopInit(blkArg)->getOwner();
if (isa_and_present<tensor::PackOp>(operandParent)) return sliceOp;
}

return failure();
Expand Down Expand Up @@ -139,6 +131,7 @@ void AMDAIEFusePackIntoLoopPass::runOnOperation() {
for (Value operand : genericOp.getOperands()) {
FailureOr<tensor::ExtractSliceOp> maybeSliceOp =
getTensorExtractSliceDefiningOp(operand);

if (succeeded(maybeSliceOp)) {
tensor::ExtractSliceOp sliceOp = maybeSliceOp.value();
std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
Expand All @@ -152,14 +145,12 @@ void AMDAIEFusePackIntoLoopPass::runOnOperation() {

// Case where operand of generic is a pack op which is in a different
// block than the generic's block.
else {
if (auto parent =
dyn_cast_if_present<tensor::PackOp>(operand.getDefiningOp())) {
Block *genericBlock = genericOp->getBlock();
if (parent->getBlock() != genericBlock && parent->hasOneUse()) {
Operation *firstOpInBlock = &genericBlock->front();
rewriter.moveOpBefore(parent, firstOpInBlock);
}
else if (auto parent = dyn_cast_if_present<tensor::PackOp>(
operand.getDefiningOp())) {
Block *genericBlock = genericOp->getBlock();
if (parent->getBlock() != genericBlock && parent->hasOneUse()) {
Operation *firstOpInBlock = &genericBlock->front();
rewriter.moveOpBefore(parent, firstOpInBlock);
}
}
}
Expand Down

0 comments on commit 7aacd04

Please sign in to comment.