Skip to content

Commit

Permalink
[AMDAIEFusePackIntoLoop] Handle case where there is no extract_slice (#…
Browse files Browse the repository at this point in the history
…983)

This is very similar to #976 

When the AIE grid we're using is `m x n` where one or both of `m` and
`n` is `1`, for a matmul we get pack ops that produce operands for
matmuls directly. i.e. as opposed to `pack->extract_slice->matmul` we
have `pack->matmul`.
  • Loading branch information
newling authored Dec 12, 2024
1 parent ee61bdb commit a55ee84
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Iterators.h"
#include "mlir/Pass/Pass.h"

Expand All @@ -18,37 +19,46 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

/// A utility function specific to this pass which, given a value, would
/// traverse the def-chain till it either finds a tensor.extract_slice op or a
/// BlockArgument.
/// A utility function specific to this pass which, given a value `operand`,
/// traverses the def-chain till it finds a tensor.extract_slice. The 2 cases
/// where it successfully finds and returns an extract_slice (SLICE) are:
///
/// Case 1)
/// pack -> SLICE -> pack -> pack -> pack -> operand
/// ^^^^^^^^^^^^^^^^^^^^
/// any number (>= 0) of trailing packs
///
/// Case 2)
/// pack -> block arg -> SLICE -> pack -> pack -> pack -> operand
/// ^^^^^^^^^^^^^^^^^^^^
/// any number (>= 0) of trailing packs
///
/// Case 2 only matches where `block arg` is for a loop operation.
static FailureOr<tensor::ExtractSliceOp> getTensorExtractSliceDefiningOp(
Value operand) {
while (Operation *defOp = operand.getDefiningOp()) {
auto sliceOp = dyn_cast_if_present<tensor::ExtractSliceOp>(defOp);
if (sliceOp) {
// The producer of sliceOp should be a pack op.
if (isa_and_present<tensor::PackOp>(
sliceOp.getSource().getDefiningOp())) {
return sliceOp;
}
if (isa<BlockArgument>(sliceOp.getSource())) {
auto blkArg = dyn_cast<BlockArgument>(sliceOp.getSource());
for (Value blkOperand :
blkArg.getOwner()->getParentOp()->getOperands()) {
if (isa_and_present<tensor::PackOp>(blkOperand.getDefiningOp())) {
return sliceOp;
}
}
}
break;
}
// We perform further traversal only if we have tensor.pack op in the
// def-chain.
if (!isa<tensor::PackOp>(defOp)) {
break;
}
operand = defOp->getOperand(0);
// roll back through all the packs immediately preceding `operand`.
while (isa_and_present<tensor::PackOp>(operand.getDefiningOp())) {
operand = operand.getDefiningOp()->getOperand(0);
}

tensor::ExtractSliceOp sliceOp =
dyn_cast_if_present<tensor::ExtractSliceOp>(operand.getDefiningOp());
if (!sliceOp) return failure();

// Case 1 outlined above.
if (isa_and_present<tensor::PackOp>(sliceOp.getSource().getDefiningOp())) {
return sliceOp;
}

// Case 2 outlined above.
else if (auto blkArg = dyn_cast<BlockArgument>(sliceOp.getSource())) {
Operation *parent = blkArg.getOwner()->getParentOp();
LoopLikeOpInterface loop = dyn_cast<LoopLikeOpInterface>(parent);
if (!loop) return failure();
Operation *operandParent = loop.getTiedLoopInit(blkArg)->getOwner();
if (isa_and_present<tensor::PackOp>(operandParent)) return sliceOp;
}

return failure();
}

Expand Down Expand Up @@ -90,11 +100,6 @@ void AMDAIEFusePackIntoLoopPass::runOnOperation() {
return;
}

if (fusePackDepth < 1) {
funcOp->emitOpError("Invalid depth of pack ops for fusion.");
return signalPassFailure();
}

LoopLikeOpInterface loops = cast<LoopLikeOpInterface>(scfLoopOp);

// Based on the `fusePackDepth`, we would greedily fuse the producer
Expand Down Expand Up @@ -122,29 +127,31 @@ void AMDAIEFusePackIntoLoopPass::runOnOperation() {
return;
}

SmallVector<tensor::ExtractSliceOp> sliceOps;
for (auto [index, operand] : llvm::enumerate(genericOp.getOperands())) {
FailureOr<tensor::ExtractSliceOp> sliceOp =
// Materialize each slice of the producer in place.
for (Value operand : genericOp.getOperands()) {
FailureOr<tensor::ExtractSliceOp> maybeSliceOp =
getTensorExtractSliceDefiningOp(operand);
if (!failed(sliceOp)) {
sliceOps.push_back(sliceOp.value());
}
}

if (sliceOps.empty()) {
LLVM_DEBUG(llvm::dbgs() << "----- Pack ops are already fused or no slice "
"ops were found.-----\n");
return;
}
if (succeeded(maybeSliceOp)) {
tensor::ExtractSliceOp sliceOp = maybeSliceOp.value();
std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
scf::tileAndFuseProducerOfSlice(rewriter, sliceOp,
MutableArrayRef(&loops, 1));
if (!fusedProducer) {
funcOp->emitOpError("Failed to fuse pack ops into for loop.");
return signalPassFailure();
}
}

// Materialize each slice of the producer in place.
for (auto sliceOp : sliceOps) {
std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
scf::tileAndFuseProducerOfSlice(rewriter, sliceOp,
MutableArrayRef(&loops, 1));
if (!fusedProducer) {
funcOp->emitOpError("Failed to fuse pack ops into for loop.");
return signalPassFailure();
// Case where operand of generic is a pack op which is in a different
// block than the generic's block.
else if (auto parent = dyn_cast_if_present<tensor::PackOp>(
operand.getDefiningOp())) {
Block *genericBlock = genericOp->getBlock();
if (parent->getBlock() != genericBlock && parent->hasOneUse()) {
Operation *firstOpInBlock = &genericBlock->front();
rewriter.moveOpBefore(parent, firstOpInBlock);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def AMDAIEFusePackIntoLoop :
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEFusePackIntoLoopPass()";
let options = [
Option<"fusePackDepth", "fuse-pack-depth", "int64_t", /*default=*/"1",
"Set the depth until which we would keep fusing producer tensor.pack chain">,
"Set the depth until which we would keep fusing producer tensor.pack chain.">,
Option<"useSCFFor", "use-scf-for", "bool", /*default=*/"true",
"Set the innermost scf loop type to fuse tensor.pack ops into">,
Option<"targetElementwise", "target-elementwise", "bool", /*default=*/"false",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-amdaie-fuse-pack-into-loop{fuse-pack-depth=2}))' %s | FileCheck %s --check-prefix=DEPTH-2
// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-amdaie-fuse-pack-into-loop{fuse-pack-depth=2 use-scf-for=false}))' %s | FileCheck %s --check-prefix=FORALL-DEPTH-2

// -----

#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>
#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>
Expand Down Expand Up @@ -219,3 +221,46 @@ func.func @fuse_multilevel_pack_into_forall(%arg0: tensor<2048x2048xi32>, %arg1:
// FORALL-DEPTH-2: linalg.generic {{.*}} ins(%[[PACK_1_DEPTH_1]], %[[PACK_2_DEPTH_1]] :
// FORALL-DEPTH-2: }
// FORALL-DEPTH-2: }


// -----

// A test with a linalg.generic which has a pack result as an operand.

#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>
#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>
func.func @pack_without_slice(%arg0: tensor<1x1x32x512xi32>, %arg1: tensor<1x1x32x32xi32>) -> tensor<1x1x4x8x4x8xi32> {
%c4 = arith.constant 4 : index
%c64 = arith.constant 64 : index
%c0_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
%15 = tensor.empty() : tensor<1x1x64x8x4x8xi32>
%pack_8 = tensor.pack %arg0 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %15 : tensor<1x1x32x512xi32> -> tensor<1x1x64x8x4x8xi32>
%16 = tensor.empty() : tensor<1x1x4x4x8x8xi32>
%pack_10 = tensor.pack %arg1 outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [2, 3] inner_tiles = [8, 8] into %16 : tensor<1x1x32x32xi32> -> tensor<1x1x4x4x8x8xi32>

%17 = tensor.empty() : tensor<1x1x4x8x4x8xi32>
%18 = linalg.fill ins(%c0_i32 : i32) outs(%17 : tensor<1x1x4x8x4x8xi32>) -> tensor<1x1x4x8x4x8xi32>
%19 = scf.for %arg6 = %c0 to %c64 step %c4 iter_args(%arg7 = %18) -> (tensor<1x1x4x8x4x8xi32>) {
%extracted_slice_12 = tensor.extract_slice %pack_8[0, 0, %arg6, 0, 0, 0] [1, 1, 4, 8, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x64x8x4x8xi32> to tensor<1x1x4x8x4x8xi32>
%20 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%extracted_slice_12, %pack_10 : tensor<1x1x4x8x4x8xi32>, tensor<1x1x4x4x8x8xi32>) outs(%arg7 : tensor<1x1x4x8x4x8xi32>) {
^bb0(%in: i32, %in_14: i32, %out: i32):
%21 = arith.muli %in, %in_14 : i32
%22 = arith.addi %out, %21 : i32
linalg.yield %22 : i32
} -> tensor<1x1x4x8x4x8xi32>
scf.yield %20 : tensor<1x1x4x8x4x8xi32>
}
return %19 : tensor<1x1x4x8x4x8xi32>
}

// DEPTH-1-LABEL: pack_without_slice
// DEPTH-1: scf.for
// DEPTH-1-DAG: %[[PACK_1:.*]] = tensor.pack %{{.*}} into %{{.*}} : tensor<1x1x32x32xi32> -> tensor<1x1x4x4x8x8xi32>
// DEPTH-1-DAG: %[[PACK_2:.*]] = tensor.pack %{{.*}} into %{{.*}} : tensor<1x1x32x32xi32> -> tensor<1x1x4x8x4x8xi32>
// DEPTH-1: linalg.generic
// DEPTH-1-SAME: ins(%[[PACK_2]], %[[PACK_1]]



0 comments on commit a55ee84

Please sign in to comment.