Skip to content

Commit

Permalink
Reland "[DispatchCreation] Run preprocessing before..." (#18939)
Browse files Browse the repository at this point in the history
Land #18934 after fixing mi250 regressions. Instead of
moving the whole pass (which "interchanges" the indexing maps and was
the root of the regressions), only move the single pattern
`GatherFusionPattern` into `ElementwiseOpFusion.cpp`.


closes: #19077

Signed-off-by: Ian Wood <[email protected]>
  • Loading branch information
IanWood1 authored Nov 18, 2024
1 parent 1ab3b49 commit 540cebf
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/DispatchCreation/FusionUtils.h"
#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/Support/Debug.h"
Expand All @@ -35,6 +36,69 @@ struct ElementwiseOpFusionPass final
void runOnOperation() override;
};

//===----------------------------------------------------------------------===//
// GatherFusionPattern
//===----------------------------------------------------------------------===//

// Specific case. The linalg generic implementation of "gather"
// cannot be fused because it there is no producer-consumer
// relationship between the two generics. This is because the indexing
// is not affine (index values come from a tensor).
struct GatherFusionPattern final : public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
// Check if extractOp is inside a generic op
auto consumerOp =
dyn_cast_or_null<linalg::GenericOp>(extractOp->getParentOp());
if (!consumerOp) {
return rewriter.notifyMatchFailure(
extractOp, "expected extract op to be inside a generic op");
}

auto producerOp = extractOp.getTensor().getDefiningOp<linalg::GenericOp>();
if (!producerOp) {
return rewriter.notifyMatchFailure(
consumerOp, "expected extract operand to be a generic op");
}

// Check if the producerOp is fusible
if (producerOp.getNumDpsInputs() != 1 || producerOp.getNumResults() != 1 ||
!isElementwise(producerOp) ||
!IREE::LinalgExt::isBitExtendOp(producerOp)) {
return rewriter.notifyMatchFailure(producerOp,
"producer op is not fusible");
}

OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(extractOp);

// Create a new extract op that extracts from the original tensor
// (after the original extract). Clone the producerOp's body into the
// consumerOp, inline the cloned block (erases the block) after the new
// extract, and clean up.
auto newExtractOp = rewriter.create<tensor::ExtractOp>(
extractOp.getLoc(), producerOp.getDpsInputOperand(0)->get(),
extractOp.getIndices());
rewriter.cloneRegionBefore(producerOp.getRegion(), consumerOp.getRegion(),
consumerOp.getRegion().begin());
Block &clonedBlock = consumerOp.getRegion().front();
auto producerTermOp = clonedBlock.getTerminator();

rewriter.inlineBlockBefore(
&clonedBlock, extractOp->getNextNode(),
{newExtractOp.getResult(), newExtractOp.getResult()});

// Replace the the all references to the original extract result with the
// result from the inlined producerOp.
extractOp.getResult().replaceAllUsesWith(producerTermOp->getOperand(0));
rewriter.eraseOp(producerTermOp);
rewriter.eraseOp(extractOp);

return success();
}
};

} // namespace

void ElementwiseOpFusionPass::runOnOperation() {
Expand Down Expand Up @@ -82,6 +146,7 @@ void ElementwiseOpFusionPass::runOnOperation() {
};
IREE::LinalgExt::populateFuseLinalgExtOpsWithTransposes(
fusionPatterns, foldTransposeControlFn);
fusionPatterns.insert<GatherFusionPattern>(context);

GreedyRewriteConfig rewriteConfig;
rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,76 +149,12 @@ struct FoldSuccessiveTensorInsertSliceOps final
}
};

//===----------------------------------------------------------------------===//
// GatherFusionPattern
//===----------------------------------------------------------------------===//

// Specific case. The linalg generic implementation of "gather"
// cannot be fused because it there is no producer-consumer
// relationship between the two generics. This is because the indexing
// is not affine (index values come from a tensor).
struct GatherFusionPattern final : public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
// Check if extractOp is inside a generic op
auto consumerOp =
dyn_cast_or_null<linalg::GenericOp>(extractOp->getParentOp());
if (!consumerOp) {
return rewriter.notifyMatchFailure(
extractOp, "expected extract op to be inside a generic op");
}

auto producerOp = extractOp.getTensor().getDefiningOp<linalg::GenericOp>();
if (!producerOp) {
return rewriter.notifyMatchFailure(
consumerOp, "expected extract operand to be a generic op");
}

// Check if the producerOp is fusible
if (producerOp.getNumDpsInputs() != 1 || producerOp.getNumResults() != 1 ||
!isElementwise(producerOp) ||
!IREE::LinalgExt::isBitExtendOp(producerOp)) {
return rewriter.notifyMatchFailure(producerOp,
"producer op is not fusible");
}

OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(extractOp);

// Create a new extract op that extracts from the original tensor
// (after the original extract). Clone the producerOp's body into the
// consumerOp, inline the cloned block (erases the block) after the new
// extract, and clean up.
auto newExtractOp = rewriter.create<tensor::ExtractOp>(
extractOp.getLoc(), producerOp.getDpsInputOperand(0)->get(),
extractOp.getIndices());
rewriter.cloneRegionBefore(producerOp.getRegion(), consumerOp.getRegion(),
consumerOp.getRegion().begin());
Block &clonedBlock = consumerOp.getRegion().front();
auto producerTermOp = clonedBlock.getTerminator();

rewriter.inlineBlockBefore(
&clonedBlock, extractOp->getNextNode(),
{newExtractOp.getResult(), newExtractOp.getResult()});

// Replace the the all references to the original extract result with the
// result from the inlined producerOp.
extractOp.getResult().replaceAllUsesWith(producerTermOp->getOperand(0));
rewriter.eraseOp(producerTermOp);
rewriter.eraseOp(extractOp);

return success();
}
};

struct FusionPreprocessingPass final
: public impl::FusionPreprocessingPassBase<FusionPreprocessingPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<ElementwiseOpInterchangePattern,
FoldSuccessiveTensorInsertSliceOps, GatherFusionPattern>(
&getContext());
FoldSuccessiveTensorInsertSliceOps>(&getContext());

// Fold away `tensor.dim` operations that can be resolved in terms of its
// operand shapes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ iree_lit_test_suite(
"collapse_linalg_generic_on_tensors.mlir",
"collapse_reduction.mlir",
"attention_fuse_by_expansion.mlir",
"fold_transpose.mlir",
"elementwise_op_fusion.mlir",
"dispatch_linalg_transform_dialect.mlir",
"dispatch_region_formation_preprocessing.mlir",
"fold_unit_dims.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ iree_lit_test_suite(
"dispatch_linalg_on_tensors_fusion_with_transpose.mlir"
"dispatch_linalg_transform_dialect.mlir"
"dispatch_region_formation_preprocessing.mlir"
"fold_transpose.mlir"
"elementwise_op_fusion.mlir"
"fold_unit_dims.mlir"
"form_dispatch_regions.mlir"
"form_dispatch_workgroups.mlir"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,86 @@ util.func public @transpose_matmul(%arg0 : tensor<100x100xf16>, %arg1 : tensor<1
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)>
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]

// -----

util.func public @fuse_generic_gather(
%11 :tensor<128256x4096xf16>, %12 : tensor<4x?xi64>,
%13 : tensor<4x?x4096xf32>, %14 : tensor<128256x4096xf32>)
-> tensor<4x?x4096xf32>{

%15 = linalg.generic {
indexing_maps = [ affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%11 : tensor<128256x4096xf16>)
outs(%14 : tensor<128256x4096xf32>) {
^bb0(%in: f16, %out: f32):
%17 = arith.extf %in : f16 to f32
linalg.yield %17 : f32
} -> tensor<128256x4096xf32>
%16 = linalg.generic {
indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%12 : tensor<4x?xi64>)
outs(%13 : tensor<4x?x4096xf32>) {
^bb0(%in: i64, %out: f32):
%17 = arith.index_cast %in : i64 to index
%18 = linalg.index 2 : index
%extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32>
linalg.yield %extracted : f32
} -> tensor<4x?x4096xf32>
util.return %16 : tensor<4x?x4096xf32>
}

// CHECK: %[[INDEX0:[a-zA-Z0-9]+]] = arith.index_cast %in : i64 to index
// CHECK: %[[INDEX1:[a-zA-Z0-9]+]] = linalg.index 2 : index
// CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract %[[TENSOR0:.+]][%[[INDEX0]], %[[INDEX1]]] : tensor<128256x4096xf16>
// CHECK-NEXT: %[[RES:[a-zA-Z0-9]+]] = arith.extf %[[EXTRACTED]] : f16 to f32
// CHECK-NEXT: linalg.yield %[[RES]] : f32


// -----

util.func public @fuse_generic_gather2(
%11 :tensor<128256x4096xf16>, %12 : tensor<4x?xi64>,
%13 : tensor<4x?x4096xf32>, %14 : tensor<128256x4096xf32>)
-> tensor<4x?x4096xf32>{

%15 = linalg.generic {
indexing_maps = [ affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%11 : tensor<128256x4096xf16>)
outs(%14 : tensor<128256x4096xf32>) {
^bb0(%in: f16, %out: f32):
%17 = arith.extf %in : f16 to f32
linalg.yield %17 : f32
} -> tensor<128256x4096xf32>
%16 = linalg.generic {
indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%12 : tensor<4x?xi64>)
outs(%13 : tensor<4x?x4096xf32>) {
^bb0(%in: i64, %out: f32):
%17 = arith.index_cast %in : i64 to index
%18 = linalg.index 2 : index
%extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32>
%result = arith.addf %extracted, %extracted : f32
%result2 = arith.mulf %extracted, %extracted : f32
%final = arith.addf %result, %result2 : f32
linalg.yield %final: f32
} -> tensor<4x?x4096xf32>
util.return %16 : tensor<4x?x4096xf32>
}

// CHECK: %[[INDEX0:[a-zA-Z0-9]+]] = arith.index_cast %in : i64 to index
// CHECK: %[[INDEX1:[a-zA-Z0-9]+]] = linalg.index 2 : index
// CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract %[[TENSOR0:.+]][%[[INDEX0]], %[[INDEX1]]] : tensor<128256x4096xf16>
// CHECK-NEXT: %[[RES:[a-zA-Z0-9]+]] = arith.extf %[[EXTRACTED]] : f16 to f32
// CHECK-NEXT: %[[RES2:[a-zA-Z0-9]+]] = arith.addf %[[RES]], %[[RES]] : f32
// CHECK-NEXT: %[[RES3:[a-zA-Z0-9]+]] = arith.mulf %[[RES]], %[[RES]] : f32
// CHECK-NEXT: %[[RES4:[a-zA-Z0-9]+]] = arith.addf %[[RES2]], %[[RES3]] : f32
// CHECK-NEXT: linalg.yield %[[RES4]] : f32
Original file line number Diff line number Diff line change
Expand Up @@ -31,90 +31,6 @@ util.func public @fold_insert_slices(%source : tensor<?x?xf32>,
// CHECK-SAME: [%[[NEW_OFFSET0]], %[[NEW_OFFSET1]]] [%[[SIZE0]], %[[SIZE1]]]
// CHECK: util.return %[[RETURN]]


// -----

util.func public @fuse_generic_gather(
%11 :tensor<128256x4096xf16>, %12 : tensor<4x?xi64>,
%13 : tensor<4x?x4096xf32>, %14 : tensor<128256x4096xf32>)
-> tensor<4x?x4096xf32>{

%15 = linalg.generic {
indexing_maps = [ affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%11 : tensor<128256x4096xf16>)
outs(%14 : tensor<128256x4096xf32>) {
^bb0(%in: f16, %out: f32):
%17 = arith.extf %in : f16 to f32
linalg.yield %17 : f32
} -> tensor<128256x4096xf32>
%16 = linalg.generic {
indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%12 : tensor<4x?xi64>)
outs(%13 : tensor<4x?x4096xf32>) {
^bb0(%in: i64, %out: f32):
%17 = arith.index_cast %in : i64 to index
%18 = linalg.index 2 : index
%extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32>
linalg.yield %extracted : f32
} -> tensor<4x?x4096xf32>
util.return %16 : tensor<4x?x4096xf32>
}

// CHECK: %[[INDEX0:[a-zA-Z0-9]+]] = arith.index_cast %in : i64 to index
// CHECK: %[[INDEX1:[a-zA-Z0-9]+]] = linalg.index 2 : index
// CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract %[[TENSOR0:.+]][%[[INDEX0]], %[[INDEX1]]] : tensor<128256x4096xf16>
// CHECK-NEXT: %[[RES:[a-zA-Z0-9]+]] = arith.extf %[[EXTRACTED]] : f16 to f32
// CHECK-NEXT: linalg.yield %[[RES]] : f32


// -----

util.func public @fuse_generic_gather2(
%11 :tensor<128256x4096xf16>, %12 : tensor<4x?xi64>,
%13 : tensor<4x?x4096xf32>, %14 : tensor<128256x4096xf32>)
-> tensor<4x?x4096xf32>{

%15 = linalg.generic {
indexing_maps = [ affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%11 : tensor<128256x4096xf16>)
outs(%14 : tensor<128256x4096xf32>) {
^bb0(%in: f16, %out: f32):
%17 = arith.extf %in : f16 to f32
linalg.yield %17 : f32
} -> tensor<128256x4096xf32>
%16 = linalg.generic {
indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%12 : tensor<4x?xi64>)
outs(%13 : tensor<4x?x4096xf32>) {
^bb0(%in: i64, %out: f32):
%17 = arith.index_cast %in : i64 to index
%18 = linalg.index 2 : index
%extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32>
%result = arith.addf %extracted, %extracted : f32
%result2 = arith.mulf %extracted, %extracted : f32
%final = arith.addf %result, %result2 : f32
linalg.yield %final: f32
} -> tensor<4x?x4096xf32>
util.return %16 : tensor<4x?x4096xf32>
}

// CHECK: %[[INDEX0:[a-zA-Z0-9]+]] = arith.index_cast %in : i64 to index
// CHECK: %[[INDEX1:[a-zA-Z0-9]+]] = linalg.index 2 : index
// CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract %[[TENSOR0:.+]][%[[INDEX0]], %[[INDEX1]]] : tensor<128256x4096xf16>
// CHECK-NEXT: %[[RES:[a-zA-Z0-9]+]] = arith.extf %[[EXTRACTED]] : f16 to f32
// CHECK-NEXT: %[[RES2:[a-zA-Z0-9]+]] = arith.addf %[[RES]], %[[RES]] : f32
// CHECK-NEXT: %[[RES3:[a-zA-Z0-9]+]] = arith.mulf %[[RES]], %[[RES]] : f32
// CHECK-NEXT: %[[RES4:[a-zA-Z0-9]+]] = arith.addf %[[RES2]], %[[RES3]] : f32
// CHECK-NEXT: linalg.yield %[[RES4]] : f32

// -----

#ident = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
Expand Down

0 comments on commit 540cebf

Please sign in to comment.