From 15172ac8bc90e495c4c0336f5ab48688f69c68c4 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024@u.northwestern.edu>
Date: Thu, 5 Dec 2024 06:20:54 -0800
Subject: [PATCH] [Dispatch] don't bubble reshapes through k2 dims

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
---
 .../DispatchCreation/BubbleUpExpandShapes.cpp | 45 ++++++++++++++++++-
 .../test/attention_fuse_by_expansion.mlir     | 35 +++++++++++++++
 2 files changed, 78 insertions(+), 2 deletions(-)

diff --git a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp
index 7ce4bddd5731..4fa84db415fd 100644
--- a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp
@@ -13,15 +13,17 @@
 //===----------------------------------------------------------------------===//
 
 #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
 #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
 #include "iree/compiler/DispatchCreation/FusionUtils.h"
 #include "iree/compiler/DispatchCreation/Passes.h"
-#include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #define DEBUG_TYPE "iree-dispatch-creation-bubble-up-expand-shapes"
@@ -156,8 +158,47 @@ void BubbleUpExpandShapesPass::runOnOperation() {
       };
   linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns,
                                                     bubbleUpExpansionControlFn);
+
+  // TODO(#19263): Temporary fix to prevent compilation failures when the k2
+  // dims get expanded to unit dimensions. This adds the constraint to
+  // `bubbleUpExpansionControlFn` that the k2 dimensions cannot be expanded by
+  // the reshape fusion.
+  linalg::ControlFusionFn linalgExtExpansionFn = [&](OpOperand *fusedOperand) {
+    if (!bubbleUpExpansionControlFn(fusedOperand)) {
+      return false;
+    }
+
+    // There is no need to handle `expand_shape` ops because they would be the
+    // producer and therefore are unable to expand the k2 dims.
+    auto collapseOp =
+        dyn_cast<tensor::CollapseShapeOp>(fusedOperand->get().getDefiningOp());
+    auto attentionOp =
+        dyn_cast<IREE::LinalgExt::AttentionOp>(fusedOperand->getOwner());
+    if (!collapseOp || !attentionOp) {
+      return true;
+    }
+
+    SmallVector<ReassociationIndices> reassoc =
+        collapseOp.getReassociationIndices();
+    auto opDetail = IREE::LinalgExt::AttentionOpDetail::get(
+        attentionOp.getQueryMap(), attentionOp.getKeyMap(),
+        attentionOp.getValueMap(), attentionOp.getOutputMap());
+
+    // Don't sink the `collapse_shape` op if it is collapsing into any of the k2
+    // dimensions.
+    AffineMap operandMap = attentionOp.getMatchingIndexingMap(fusedOperand);
+    for (auto dim : opDetail->getK2Dims()) {
+      auto dimExpr = getAffineDimExpr(dim, operandMap.getContext());
+      if (std::optional<int64_t> maybeDim =
+              operandMap.getResultPosition(dimExpr);
+          maybeDim && !reassoc[maybeDim.value()].empty()) {
+        return false;
+      }
+    }
+    return true;
+  };
   IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns(
-      bubbleExpandShapePatterns, bubbleUpExpansionControlFn);
+      bubbleExpandShapePatterns, linalgExtExpansionFn);
 
   // Add patterns to do some additional cleanup (on top of canonicalizations
   // that can be done later) of reshape ops.
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
index 79ea644c3176..25e15537a7f4 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
@@ -446,3 +446,38 @@ util.func public @sink_single_collapse_masked(%0 : tensor<4x32x64x128xf16>, %1 :
 //  CHECK-SAME:       ins(%[[ARG0]], %[[EXPANDED1]], %[[EXPANDED2]], %[[ARG3]], %[[EXPANDED3]] :
 //       CHECK:   %[[RET:.+]] = tensor.collapse_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
 //       CHECK:   util.return %[[RET]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+
+util.func public @dont_sink_through_k2(%0 : tensor<128x64x128x1x1xf16>, %1 : tensor<128x64x128xf16>, %2 : tensor<128x64x128xf16>, %cst : f16) -> (tensor<128x64x128xf16>) {
+  %13 = tensor.empty() : tensor<4x32x64x128xf16>
+  %collapsed_12 = tensor.collapse_shape %0 [[0], [1], [2, 3, 4]] : tensor<128x64x128x1x1xf16> into tensor<128x64x128xf16>
+  %17 = tensor.empty() : tensor<128x64x128xf16>
+  %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%2, %1, %collapsed_12, %cst : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16) outs(%17 : tensor<128x64x128xf16>) {
+    ^bb0(%score: f16):
+      iree_linalg_ext.yield %score: f16
+  } -> tensor<128x64x128xf16>
+  util.return %18 : tensor<128x64x128xf16>
+}
+
+// CHECK-LABEL: util.func public @dont_sink_through_k2
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG3:.+]]: f16
+//   CHECK-DAG:   %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]]
+//       CHECK:   %[[ATTENTION:.+]] = iree_linalg_ext.attention
+//  CHECK-SAME:       indexing_maps =
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4) -> ()>
+//  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+//  CHECK-SAME:       ins(%[[ARG2]], %[[ARG1]], %[[COLLAPSED]], %[[ARG3]] :
+//       CHECK:   util.return %[[ATTENTION]]