From db10c754062e614680fb23e28c7e251ec44a3ad0 Mon Sep 17 00:00:00 2001 From: James Newling Date: Wed, 11 Dec 2024 13:20:32 -0800 Subject: [PATCH] [AMDAIEFuseFillIntoForall] Handle case where fill output is not sliced (#976) For 2x2 or 4x4 tiling the chain of ops after the fill looks like ``` %9 = linalg.fill ins(%cst : f32) ... ... %12 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %9) ... ... %extracted_slice_19 = tensor.extract_slice %arg5 ... ... %19 = linalg.generic ... outs(%extracted_slice_19 ... ) ``` i.e. the filled value enters an extract_slice inside the scf.forall. But for 1x1 tiling, it looks like ``` %9 = linalg.fill ins(%cst : f32) ... %14 = scf.forall (%arg3, %arg4) in (1, 1) shared_outs(%arg5 = %9) ... %19 = linalg.generic ... outs(%arg5... ) ``` i.e. there is no intermediate extract_slice. Before this PR, the logic was hardcoded to look for an extrac_slice, this PR relaxes this. --- .../Transforms/AMDAIEFuseFillIntoForall.cpp | 118 +++++++++++------- .../test/fuse_fill_into_forall.mlir | 37 +++++- 2 files changed, 109 insertions(+), 46 deletions(-) diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFuseFillIntoForall.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFuseFillIntoForall.cpp index 1b771f3ef..30762e929 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFuseFillIntoForall.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFuseFillIntoForall.cpp @@ -7,6 +7,7 @@ #include "iree-amd-aie/Transforms/Passes.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Pass/Pass.h" #define DEBUG_TYPE "iree-amdaie-fuse-fill-into-forall" @@ -29,60 +30,87 @@ class AMDAIEFuseFillIntoForallPass void AMDAIEFuseFillIntoForallPass::runOnOperation() { MLIRContext *context = &getContext(); - mlir::FunctionOpInterface funcOp = getOperation(); IRRewriter rewriter(context); - // Find the producer op, in this case is linalg.fill. - TilingInterface tileableProducer; - funcOp->walk([&](TilingInterface op) { - if (isa(op)) { - tileableProducer = op; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - - if (!tileableProducer) { - LLVM_DEBUG(llvm::dbgs() << "There is no producer op to be fused.\n"); + // Find a unique FillOp with a single output, or return. + SmallVector fillOps; + getOperation()->walk( + [&](linalg::FillOp fillOp) { fillOps.push_back(fillOp); }); + if (fillOps.size() != 1) { + LLVM_DEBUG(llvm::dbgs() << "Expected exactly 1 fill op, but found " + << fillOps.size() << ".\n"); return; } - // Search the first use by a scf::ForallOp user. - scf::ForallOp forallOp; - auto itProducerUses = - llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) { - forallOp = dyn_cast(use.getOwner()); - return forallOp; - }); + linalg::FillOp fillOp = fillOps[0]; + if (fillOp.getResults().size() != 1) { + LLVM_DEBUG(llvm::dbgs() << "Expected fill op to have exactly 1 result, but " + << "found " << fillOp.getResults().size() << ".\n"); + + return; + }; + + // Confirm that there is a unique user that is a forall, and match + // the block argument that is used by the fill op, or return. + if (!fillOp->hasOneUse()) { + LLVM_DEBUG(llvm::dbgs() + << "Expected exactly 1 use of fill op, but found 0 or 2+."); + return; + } + OpOperand &fillUse = *fillOp->getUses().begin(); + auto forallOp = dyn_cast(fillUse.getOwner()); if (!forallOp) { - LLVM_DEBUG(llvm::dbgs() << "There is no forall Op.\n"); + LLVM_DEBUG(llvm::dbgs() << "Expected fill op to be used by a forall op, " + << "but unique user is " + << fillUse.getOwner()->getName() << ".\n"); return; } - - // Search the producer slices accessed within the Forall op. - OpOperand *pUse = &(*itProducerUses); - BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse); - - auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) { - auto sliceOp = dyn_cast(user); - return sliceOp; - }); - if (itBBArgUsers == bbArg.getUsers().end()) { - funcOp->emitOpError("There is no extract tensor slice."); - return signalPassFailure(); + BlockArgument bbArg = forallOp.getTiedBlockArgument(&fillUse); + + // Find 0 or 1 ExtractSliceOps that use the fill result, or return. + tensor::ExtractSliceOp extractSliceOp; + for (Operation *user : bbArg.getUsers()) { + if (auto nxt = dyn_cast(user)) { + if (extractSliceOp) { + LLVM_DEBUG(llvm::dbgs() + << "Expected at most 1 extract_slice op, but found 2+.\n"); + return; + } + extractSliceOp = nxt; + } } - auto sliceOpToTile = cast(*itBBArgUsers); - - LoopLikeOpInterface loops = - cast(forallOp.getOperation()); - - // Materialize the slice of the producer in place. - std::optional fusedProducer = - scf::tileAndFuseProducerOfSlice(rewriter, sliceOpToTile, - MutableArrayRef(&loops, 1)); - if (!fusedProducer) { - funcOp->emitOpError("Failed to fuse fill op into forall loop."); - return signalPassFailure(); + + if (extractSliceOp) { + LoopLikeOpInterface loops = + cast(forallOp.getOperation()); + + // Materialize the slice of the producer in place. + std::optional fusedProducer = + scf::tileAndFuseProducerOfSlice(rewriter, extractSliceOp, + MutableArrayRef(&loops, 1)); + if (!fusedProducer) { + fillOp->emitOpError("could not be fused into forall"); + return signalPassFailure(); + } + } else { + // In the case where there are no extract_slice ops, we manually create the + // fill at the beginning of the forall body. This situation might arise + // if the extract_slice has been folded, for example if the forall is + // over a grid if size 1. + rewriter.setInsertionPointToStart(forallOp.getBody()); + auto fusedFill = + rewriter.create(fillOp.getLoc(), fillOp.value(), bbArg); + rewriter.replaceUsesWithIf( + bbArg, fusedFill.getResult(0), [&](OpOperand &operand) { + Operation *owner = operand.getOwner(); + if (owner == fusedFill || isa(owner)) { + return false; + } + return true; + }); + + // Do not use the result of the old fill. + rewriter.replaceAllUsesWith(fillOp.getResults()[0], fillOp.getOutputs()[0]); } } diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/fuse_fill_into_forall.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/fuse_fill_into_forall.mlir index 6eda0882a..407f534ed 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/fuse_fill_into_forall.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/fuse_fill_into_forall.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-amdaie-fuse-fill-into-forall))' %s | FileCheck %s +// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-amdaie-fuse-fill-into-forall))' %s | FileCheck %s #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)> @@ -29,3 +29,38 @@ func.func @fuse_fill_into_forall(%arg0: tensor<1x4x16x64xi8>, %arg1 : tensor<4x1 // CHECK: linalg.fill // CHECK: linalg.generic // CHECK: } + +// ----- + +#map = affine_map<(d0) -> (d0)> +func.func @fuse_without_slice(%arg0: tensor<8xi8>) -> tensor<8xi8> { + %c7_i8 = arith.constant 7 : i8 + %c3_i8 = arith.constant 3 : i8 + %0 = linalg.fill ins(%c7_i8 : i8) outs(%arg0 : tensor<8xi8>) -> tensor<8xi8> + %1 = tensor.empty() : tensor<8xi8> + %2 = scf.forall (%arg1) in (1) shared_outs(%arg2 = %0) -> (tensor<8xi8>) { + %3 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<8xi8>) outs(%1 : tensor<8xi8>) { + ^bb0(%in: i8, %out: i8): + %4 = arith.addi %in, %c3_i8 : i8 + linalg.yield %4 : i8 + } -> tensor<8xi8> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg2[0] [8] [1] : tensor<8xi8> into tensor<8xi8> + } + } {mapping = [#gpu.thread]} + return %2 : tensor<8xi8> +} + +// CHECK: @fuse_without_slice(%[[FUNCARG:.*]]: tensor<8xi8>) -> tensor<8xi8> { +// check that the operand of scf.forall is not the filled tensor, because the +// fill will take place inside the scf.forall: +// CHECK: %[[FORALL:.*]] = scf.forall (%[[ARG1:.*]]) in (1) +// CHECK-SAME: shared_outs(%[[ARG2:.*]] = %[[FUNCARG]]) +// check for the new fill: +// CHECK: %[[NEWFILL:.*]] = linalg.fill +// CHECK-SAME: outs(%[[ARG2]] : tensor<8xi8>) -> tensor<8xi8> +// CHECK: linalg.generic +// check the the parallel_insert_slice still happens on arg2, not the filled +// tensor. This is because it must match the shared_outs of the scf.forall: +// CHECK: tensor.parallel_insert_slice +// CHECK-SAME: into %[[ARG2]]