diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFusePackIntoLoop.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFusePackIntoLoop.cpp index cead3226e..a5f314004 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFusePackIntoLoop.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFusePackIntoLoop.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Iterators.h" #include "mlir/Pass/Pass.h" @@ -18,37 +19,46 @@ namespace mlir::iree_compiler::AMDAIE { namespace { -/// A utility function specific to this pass which, given a value, would -/// traverse the def-chain till it either finds a tensor.extract_slice op or a -/// BlockArgument. +/// A utility function specific to this pass which, given a value `operand`, +/// traverses the def-chain till it finds a tensor.extract_slice. The 2 cases +/// where it successfully finds and returns an extract_slice (SLICE) are: +/// +/// Case 1) +/// pack -> SLICE -> pack -> pack -> pack -> operand +/// ^^^^^^^^^^^^^^^^^^^^ +/// any number (>= 0) of trailing packs +/// +/// Case 2) +/// pack -> block arg -> SLICE -> pack -> pack -> pack -> operand +/// ^^^^^^^^^^^^^^^^^^^^ +/// any number (>= 0) of trailing packs +/// +/// Case 2 only matches where `block arg` is for a loop operation. static FailureOr getTensorExtractSliceDefiningOp( Value operand) { - while (Operation *defOp = operand.getDefiningOp()) { - auto sliceOp = dyn_cast_if_present(defOp); - if (sliceOp) { - // The producer of sliceOp should be a pack op. - if (isa_and_present( - sliceOp.getSource().getDefiningOp())) { - return sliceOp; - } - if (isa(sliceOp.getSource())) { - auto blkArg = dyn_cast(sliceOp.getSource()); - for (Value blkOperand : - blkArg.getOwner()->getParentOp()->getOperands()) { - if (isa_and_present(blkOperand.getDefiningOp())) { - return sliceOp; - } - } - } - break; - } - // We perform further traversal only if we have tensor.pack op in the - // def-chain. - if (!isa(defOp)) { - break; - } - operand = defOp->getOperand(0); + // roll back through all the packs immediately preceding `operand`. + while (isa_and_present(operand.getDefiningOp())) { + operand = operand.getDefiningOp()->getOperand(0); } + + tensor::ExtractSliceOp sliceOp = + dyn_cast_if_present(operand.getDefiningOp()); + if (!sliceOp) return failure(); + + // Case 1 outlined above. + if (isa_and_present(sliceOp.getSource().getDefiningOp())) { + return sliceOp; + } + + // Case 2 outlined above. + else if (auto blkArg = dyn_cast(sliceOp.getSource())) { + Operation *parent = blkArg.getOwner()->getParentOp(); + LoopLikeOpInterface loop = dyn_cast(parent); + if (!loop) return failure(); + Operation *operandParent = loop.getTiedLoopInit(blkArg)->getOwner(); + if (isa_and_present(operandParent)) return sliceOp; + } + return failure(); } @@ -90,11 +100,6 @@ void AMDAIEFusePackIntoLoopPass::runOnOperation() { return; } - if (fusePackDepth < 1) { - funcOp->emitOpError("Invalid depth of pack ops for fusion."); - return signalPassFailure(); - } - LoopLikeOpInterface loops = cast(scfLoopOp); // Based on the `fusePackDepth`, we would greedily fuse the producer @@ -122,29 +127,31 @@ void AMDAIEFusePackIntoLoopPass::runOnOperation() { return; } - SmallVector sliceOps; - for (auto [index, operand] : llvm::enumerate(genericOp.getOperands())) { - FailureOr sliceOp = + // Materialize each slice of the producer in place. + for (Value operand : genericOp.getOperands()) { + FailureOr maybeSliceOp = getTensorExtractSliceDefiningOp(operand); - if (!failed(sliceOp)) { - sliceOps.push_back(sliceOp.value()); - } - } - if (sliceOps.empty()) { - LLVM_DEBUG(llvm::dbgs() << "----- Pack ops are already fused or no slice " - "ops were found.-----\n"); - return; - } + if (succeeded(maybeSliceOp)) { + tensor::ExtractSliceOp sliceOp = maybeSliceOp.value(); + std::optional fusedProducer = + scf::tileAndFuseProducerOfSlice(rewriter, sliceOp, + MutableArrayRef(&loops, 1)); + if (!fusedProducer) { + funcOp->emitOpError("Failed to fuse pack ops into for loop."); + return signalPassFailure(); + } + } - // Materialize each slice of the producer in place. - for (auto sliceOp : sliceOps) { - std::optional fusedProducer = - scf::tileAndFuseProducerOfSlice(rewriter, sliceOp, - MutableArrayRef(&loops, 1)); - if (!fusedProducer) { - funcOp->emitOpError("Failed to fuse pack ops into for loop."); - return signalPassFailure(); + // Case where operand of generic is a pack op which is in a different + // block than the generic's block. + else if (auto parent = dyn_cast_if_present( + operand.getDefiningOp())) { + Block *genericBlock = genericOp->getBlock(); + if (parent->getBlock() != genericBlock && parent->hasOneUse()) { + Operation *firstOpInBlock = &genericBlock->front(); + rewriter.moveOpBefore(parent, firstOpInBlock); + } } } } diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td index 96d80a184..09e2c43a4 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td @@ -310,7 +310,7 @@ def AMDAIEFusePackIntoLoop : let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEFusePackIntoLoopPass()"; let options = [ Option<"fusePackDepth", "fuse-pack-depth", "int64_t", /*default=*/"1", - "Set the depth until which we would keep fusing producer tensor.pack chain">, + "Set the depth until which we would keep fusing producer tensor.pack chain.">, Option<"useSCFFor", "use-scf-for", "bool", /*default=*/"true", "Set the innermost scf loop type to fuse tensor.pack ops into">, Option<"targetElementwise", "target-elementwise", "bool", /*default=*/"false", diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/fuse_pack_into_loop.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/fuse_pack_into_loop.mlir index 3a8cfe5e8..3872aaf55 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/fuse_pack_into_loop.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/fuse_pack_into_loop.mlir @@ -3,6 +3,8 @@ // RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-amdaie-fuse-pack-into-loop{fuse-pack-depth=2}))' %s | FileCheck %s --check-prefix=DEPTH-2 // RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-amdaie-fuse-pack-into-loop{fuse-pack-depth=2 use-scf-for=false}))' %s | FileCheck %s --check-prefix=FORALL-DEPTH-2 +// ----- + #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)> #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)> #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)> @@ -219,3 +221,46 @@ func.func @fuse_multilevel_pack_into_forall(%arg0: tensor<2048x2048xi32>, %arg1: // FORALL-DEPTH-2: linalg.generic {{.*}} ins(%[[PACK_1_DEPTH_1]], %[[PACK_2_DEPTH_1]] : // FORALL-DEPTH-2: } // FORALL-DEPTH-2: } + + +// ----- + +// A test with a linalg.generic which has a pack result as an operand. + +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)> +#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)> +func.func @pack_without_slice(%arg0: tensor<1x1x32x512xi32>, %arg1: tensor<1x1x32x32xi32>) -> tensor<1x1x4x8x4x8xi32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0_i32 = arith.constant 0 : i32 + %c0 = arith.constant 0 : index + %15 = tensor.empty() : tensor<1x1x64x8x4x8xi32> + %pack_8 = tensor.pack %arg0 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %15 : tensor<1x1x32x512xi32> -> tensor<1x1x64x8x4x8xi32> + %16 = tensor.empty() : tensor<1x1x4x4x8x8xi32> + %pack_10 = tensor.pack %arg1 outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [2, 3] inner_tiles = [8, 8] into %16 : tensor<1x1x32x32xi32> -> tensor<1x1x4x4x8x8xi32> + + %17 = tensor.empty() : tensor<1x1x4x8x4x8xi32> + %18 = linalg.fill ins(%c0_i32 : i32) outs(%17 : tensor<1x1x4x8x4x8xi32>) -> tensor<1x1x4x8x4x8xi32> + %19 = scf.for %arg6 = %c0 to %c64 step %c4 iter_args(%arg7 = %18) -> (tensor<1x1x4x8x4x8xi32>) { + %extracted_slice_12 = tensor.extract_slice %pack_8[0, 0, %arg6, 0, 0, 0] [1, 1, 4, 8, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x64x8x4x8xi32> to tensor<1x1x4x8x4x8xi32> + %20 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%extracted_slice_12, %pack_10 : tensor<1x1x4x8x4x8xi32>, tensor<1x1x4x4x8x8xi32>) outs(%arg7 : tensor<1x1x4x8x4x8xi32>) { + ^bb0(%in: i32, %in_14: i32, %out: i32): + %21 = arith.muli %in, %in_14 : i32 + %22 = arith.addi %out, %21 : i32 + linalg.yield %22 : i32 + } -> tensor<1x1x4x8x4x8xi32> + scf.yield %20 : tensor<1x1x4x8x4x8xi32> + } + return %19 : tensor<1x1x4x8x4x8xi32> +} + +// DEPTH-1-LABEL: pack_without_slice +// DEPTH-1: scf.for +// DEPTH-1-DAG: %[[PACK_1:.*]] = tensor.pack %{{.*}} into %{{.*}} : tensor<1x1x32x32xi32> -> tensor<1x1x4x4x8x8xi32> +// DEPTH-1-DAG: %[[PACK_2:.*]] = tensor.pack %{{.*}} into %{{.*}} : tensor<1x1x32x32xi32> -> tensor<1x1x4x8x4x8xi32> +// DEPTH-1: linalg.generic +// DEPTH-1-SAME: ins(%[[PACK_2]], %[[PACK_1]] + + +