Skip to content

Commit

Permalink
[AMDAIEFuseFillIntoForall] Handle case where fill output is not sliced (
Browse files Browse the repository at this point in the history
#976)

For 2x2 or 4x4 tiling the chain of ops after the fill looks like

```
 %9 = linalg.fill ins(%cst : f32) ...
 ... 
 %12 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %9)  ...
 ...
 %extracted_slice_19 = tensor.extract_slice %arg5 ... 
 ...
 %19 = linalg.generic  ... outs(%extracted_slice_19 ... )
 ```
 
 i.e. the filled value enters an extract_slice inside the scf.forall. But for 1x1 tiling, it looks like 
 
 
 ```
  %9 = linalg.fill ins(%cst : f32)
  ...
  %14 = scf.forall (%arg3, %arg4) in (1, 1) shared_outs(%arg5 = %9)
  ...
  %19 = linalg.generic  ... outs(%arg5... )
```


i.e. there is no intermediate extract_slice. 


Before this PR, the logic was hardcoded to look for an extrac_slice,
this PR relaxes this.
  • Loading branch information
newling authored Dec 11, 2024
1 parent 8ae709e commit db10c75
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "iree-amd-aie/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"

#define DEBUG_TYPE "iree-amdaie-fuse-fill-into-forall"
Expand All @@ -29,60 +30,87 @@ class AMDAIEFuseFillIntoForallPass

void AMDAIEFuseFillIntoForallPass::runOnOperation() {
MLIRContext *context = &getContext();
mlir::FunctionOpInterface funcOp = getOperation();
IRRewriter rewriter(context);

// Find the producer op, in this case is linalg.fill.
TilingInterface tileableProducer;
funcOp->walk([&](TilingInterface op) {
if (isa<linalg::FillOp>(op)) {
tileableProducer = op;
return WalkResult::interrupt();
}
return WalkResult::advance();
});

if (!tileableProducer) {
LLVM_DEBUG(llvm::dbgs() << "There is no producer op to be fused.\n");
// Find a unique FillOp with a single output, or return.
SmallVector<linalg::FillOp> fillOps;
getOperation()->walk(
[&](linalg::FillOp fillOp) { fillOps.push_back(fillOp); });
if (fillOps.size() != 1) {
LLVM_DEBUG(llvm::dbgs() << "Expected exactly 1 fill op, but found "
<< fillOps.size() << ".\n");
return;
}

// Search the first use by a scf::ForallOp user.
scf::ForallOp forallOp;
auto itProducerUses =
llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
return forallOp;
});
linalg::FillOp fillOp = fillOps[0];
if (fillOp.getResults().size() != 1) {
LLVM_DEBUG(llvm::dbgs() << "Expected fill op to have exactly 1 result, but "
<< "found " << fillOp.getResults().size() << ".\n");

return;
};

// Confirm that there is a unique user that is a forall, and match
// the block argument that is used by the fill op, or return.
if (!fillOp->hasOneUse()) {
LLVM_DEBUG(llvm::dbgs()
<< "Expected exactly 1 use of fill op, but found 0 or 2+.");
return;
}
OpOperand &fillUse = *fillOp->getUses().begin();
auto forallOp = dyn_cast<scf::ForallOp>(fillUse.getOwner());
if (!forallOp) {
LLVM_DEBUG(llvm::dbgs() << "There is no forall Op.\n");
LLVM_DEBUG(llvm::dbgs() << "Expected fill op to be used by a forall op, "
<< "but unique user is "
<< fillUse.getOwner()->getName() << ".\n");
return;
}

// Search the producer slices accessed within the Forall op.
OpOperand *pUse = &(*itProducerUses);
BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);

auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
return sliceOp;
});
if (itBBArgUsers == bbArg.getUsers().end()) {
funcOp->emitOpError("There is no extract tensor slice.");
return signalPassFailure();
BlockArgument bbArg = forallOp.getTiedBlockArgument(&fillUse);

// Find 0 or 1 ExtractSliceOps that use the fill result, or return.
tensor::ExtractSliceOp extractSliceOp;
for (Operation *user : bbArg.getUsers()) {
if (auto nxt = dyn_cast<tensor::ExtractSliceOp>(user)) {
if (extractSliceOp) {
LLVM_DEBUG(llvm::dbgs()
<< "Expected at most 1 extract_slice op, but found 2+.\n");
return;
}
extractSliceOp = nxt;
}
}
auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);

LoopLikeOpInterface loops =
cast<LoopLikeOpInterface>(forallOp.getOperation());

// Materialize the slice of the producer in place.
std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
scf::tileAndFuseProducerOfSlice(rewriter, sliceOpToTile,
MutableArrayRef(&loops, 1));
if (!fusedProducer) {
funcOp->emitOpError("Failed to fuse fill op into forall loop.");
return signalPassFailure();

if (extractSliceOp) {
LoopLikeOpInterface loops =
cast<LoopLikeOpInterface>(forallOp.getOperation());

// Materialize the slice of the producer in place.
std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
scf::tileAndFuseProducerOfSlice(rewriter, extractSliceOp,
MutableArrayRef(&loops, 1));
if (!fusedProducer) {
fillOp->emitOpError("could not be fused into forall");
return signalPassFailure();
}
} else {
// In the case where there are no extract_slice ops, we manually create the
// fill at the beginning of the forall body. This situation might arise
// if the extract_slice has been folded, for example if the forall is
// over a grid if size 1.
rewriter.setInsertionPointToStart(forallOp.getBody());
auto fusedFill =
rewriter.create<linalg::FillOp>(fillOp.getLoc(), fillOp.value(), bbArg);
rewriter.replaceUsesWithIf(
bbArg, fusedFill.getResult(0), [&](OpOperand &operand) {
Operation *owner = operand.getOwner();
if (owner == fusedFill || isa<tensor::ParallelInsertSliceOp>(owner)) {
return false;
}
return true;
});

// Do not use the result of the old fill.
rewriter.replaceAllUsesWith(fillOp.getResults()[0], fillOp.getOutputs()[0]);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-amdaie-fuse-fill-into-forall))' %s | FileCheck %s
// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-amdaie-fuse-fill-into-forall))' %s | FileCheck %s

#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
Expand Down Expand Up @@ -29,3 +29,38 @@ func.func @fuse_fill_into_forall(%arg0: tensor<1x4x16x64xi8>, %arg1 : tensor<4x1
// CHECK: linalg.fill
// CHECK: linalg.generic
// CHECK: }

// -----

#map = affine_map<(d0) -> (d0)>
func.func @fuse_without_slice(%arg0: tensor<8xi8>) -> tensor<8xi8> {
%c7_i8 = arith.constant 7 : i8
%c3_i8 = arith.constant 3 : i8
%0 = linalg.fill ins(%c7_i8 : i8) outs(%arg0 : tensor<8xi8>) -> tensor<8xi8>
%1 = tensor.empty() : tensor<8xi8>
%2 = scf.forall (%arg1) in (1) shared_outs(%arg2 = %0) -> (tensor<8xi8>) {
%3 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<8xi8>) outs(%1 : tensor<8xi8>) {
^bb0(%in: i8, %out: i8):
%4 = arith.addi %in, %c3_i8 : i8
linalg.yield %4 : i8
} -> tensor<8xi8>
scf.forall.in_parallel {
tensor.parallel_insert_slice %3 into %arg2[0] [8] [1] : tensor<8xi8> into tensor<8xi8>
}
} {mapping = [#gpu.thread<y>]}
return %2 : tensor<8xi8>
}

// CHECK: @fuse_without_slice(%[[FUNCARG:.*]]: tensor<8xi8>) -> tensor<8xi8> {
// check that the operand of scf.forall is not the filled tensor, because the
// fill will take place inside the scf.forall:
// CHECK: %[[FORALL:.*]] = scf.forall (%[[ARG1:.*]]) in (1)
// CHECK-SAME: shared_outs(%[[ARG2:.*]] = %[[FUNCARG]])
// check for the new fill:
// CHECK: %[[NEWFILL:.*]] = linalg.fill
// CHECK-SAME: outs(%[[ARG2]] : tensor<8xi8>) -> tensor<8xi8>
// CHECK: linalg.generic
// check the the parallel_insert_slice still happens on arg2, not the filled
// tensor. This is because it must match the shared_outs of the scf.forall:
// CHECK: tensor.parallel_insert_slice
// CHECK-SAME: into %[[ARG2]]

0 comments on commit db10c75

Please sign in to comment.