Skip to content

Commit

Permalink
fix: account for dynamic sizes while tiling
Browse files Browse the repository at this point in the history
This is not super safe, when upstreaming we should get feedback here.
Also not sure how to test?
  • Loading branch information
maxbartel authored and DavidGinten committed Jan 22, 2025
1 parent 0bdbc1c commit 3ba205e
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1188,10 +1188,32 @@ mlir::scf::tileAndFuseProducerOfSlice(
clonedProducerOp->getResult(resultNumber));
if (failed(tileAndFuseResult))
return std::nullopt;
// Note: Do not delete the candidateSliceOp, since its passed in from the
// caller.
rewriter.replaceAllUsesWith(candidateSliceOp,
tileAndFuseResult->tiledValues[0]);

// Check if the types are the same. If possible insert a cast. Fail otherwise.
if (tileAndFuseResult->tiledValues[0].getType() !=
candidateSliceOp.getResult().getType()) {
auto tileAndFuseResultType =
cast<RankedTensorType>(tileAndFuseResult->tiledValues[0].getType());
auto candidateSliceOpType =
cast<RankedTensorType>(candidateSliceOp.getResult().getType());
// We can only cast if the tileAndFuseResultType has a static shape and
// canidateSliceOp has a dynamic shape. Might be expanded in the future.
if (!tileAndFuseResultType.hasStaticShape() ||
candidateSliceOpType.hasStaticShape()) {
return std::nullopt;
}

auto castOp = rewriter.create<tensor::CastOp>(
candidateSliceOp->getLoc(), candidateSliceOpType, tileAndFuseResult->tiledValues[0]);
// Note: Do not delete the candidateSliceOp, since its passed in from the
// caller.
rewriter.replaceAllUsesWith(candidateSliceOp, castOp);
} else {
// Note: Do not delete the candidateSliceOp, since its passed in from the
// caller.
rewriter.replaceAllUsesWith(candidateSliceOp,
tileAndFuseResult->tiledValues[0]);
}
rewriter.eraseOp(clonedCandidateSliceOp);
rewriter.eraseOp(clonedProducerOp);

Expand Down

0 comments on commit 3ba205e

Please sign in to comment.