From 15172ac8bc90e495c4c0336f5ab48688f69c68c4 Mon Sep 17 00:00:00 2001 From: Ian Wood <ianwood2024@u.northwestern.edu> Date: Thu, 5 Dec 2024 06:20:54 -0800 Subject: [PATCH] [Dispatch] don't bubble reshapes through k2 dims Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu> --- .../DispatchCreation/BubbleUpExpandShapes.cpp | 45 ++++++++++++++++++- .../test/attention_fuse_by_expansion.mlir | 35 +++++++++++++++ 2 files changed, 78 insertions(+), 2 deletions(-) diff --git a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp index 7ce4bddd5731..4fa84db415fd 100644 --- a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp @@ -13,15 +13,17 @@ //===----------------------------------------------------------------------===// #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" +#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" +#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" #include "iree/compiler/DispatchCreation/FusionUtils.h" #include "iree/compiler/DispatchCreation/Passes.h" -#include "llvm/Support/Debug.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "iree-dispatch-creation-bubble-up-expand-shapes" @@ -156,8 +158,47 @@ void BubbleUpExpandShapesPass::runOnOperation() { }; linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns, bubbleUpExpansionControlFn); + + // TODO(#19263): Temporary fix to prevent compilation failures when the k2 + // dims get expanded to unit dimensions. This adds the constraint to + // `bubbleUpExpansionControlFn` that the k2 dimensions cannot be expanded by + // the reshape fusion. + linalg::ControlFusionFn linalgExtExpansionFn = [&](OpOperand *fusedOperand) { + if (!bubbleUpExpansionControlFn(fusedOperand)) { + return false; + } + + // There is no need to handle `expand_shape` ops because they would be the + // producer and therefore are unable to expand the k2 dims. + auto collapseOp = + dyn_cast<tensor::CollapseShapeOp>(fusedOperand->get().getDefiningOp()); + auto attentionOp = + dyn_cast<IREE::LinalgExt::AttentionOp>(fusedOperand->getOwner()); + if (!collapseOp || !attentionOp) { + return true; + } + + SmallVector<ReassociationIndices> reassoc = + collapseOp.getReassociationIndices(); + auto opDetail = IREE::LinalgExt::AttentionOpDetail::get( + attentionOp.getQueryMap(), attentionOp.getKeyMap(), + attentionOp.getValueMap(), attentionOp.getOutputMap()); + + // Don't sink the `collapse_shape` op if it is collapsing into any of the k2 + // dimensions. + AffineMap operandMap = attentionOp.getMatchingIndexingMap(fusedOperand); + for (auto dim : opDetail->getK2Dims()) { + auto dimExpr = getAffineDimExpr(dim, operandMap.getContext()); + if (std::optional<int64_t> maybeDim = + operandMap.getResultPosition(dimExpr); + maybeDim && !reassoc[maybeDim.value()].empty()) { + return false; + } + } + return true; + }; IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns( - bubbleExpandShapePatterns, bubbleUpExpansionControlFn); + bubbleExpandShapePatterns, linalgExtExpansionFn); // Add patterns to do some additional cleanup (on top of canonicalizations // that can be done later) of reshape ops. diff --git a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir index 79ea644c3176..25e15537a7f4 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir @@ -446,3 +446,38 @@ util.func public @sink_single_collapse_masked(%0 : tensor<4x32x64x128xf16>, %1 : // CHECK-SAME: ins(%[[ARG0]], %[[EXPANDED1]], %[[EXPANDED2]], %[[ARG3]], %[[EXPANDED3]] : // CHECK: %[[RET:.+]] = tensor.collapse_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> // CHECK: util.return %[[RET]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> + +util.func public @dont_sink_through_k2(%0 : tensor<128x64x128x1x1xf16>, %1 : tensor<128x64x128xf16>, %2 : tensor<128x64x128xf16>, %cst : f16) -> (tensor<128x64x128xf16>) { + %13 = tensor.empty() : tensor<4x32x64x128xf16> + %collapsed_12 = tensor.collapse_shape %0 [[0], [1], [2, 3, 4]] : tensor<128x64x128x1x1xf16> into tensor<128x64x128xf16> + %17 = tensor.empty() : tensor<128x64x128xf16> + %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%2, %1, %collapsed_12, %cst : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16) outs(%17 : tensor<128x64x128xf16>) { + ^bb0(%score: f16): + iree_linalg_ext.yield %score: f16 + } -> tensor<128x64x128xf16> + util.return %18 : tensor<128x64x128xf16> +} + +// CHECK-LABEL: util.func public @dont_sink_through_k2 +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG3:.+]]: f16 +// CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] +// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention +// CHECK-SAME: indexing_maps = +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> ()> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> +// CHECK-SAME: ins(%[[ARG2]], %[[ARG1]], %[[COLLAPSED]], %[[ARG3]] : +// CHECK: util.return %[[ATTENTION]]