From a145d84fa5bf3367552182d96e88792ac6329f48 Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Wed, 31 Jul 2024 11:41:56 -0400 Subject: [PATCH] Add support for customizing const inlining and ops with regions to IsolateGroupOps pass (#87) A few small improvements to the `IsolateGroupOps` pass: - Provide a hook to customize whether a const like op used in the `tcp.group` will get copied into the group or whether it will be passed in as an input argument. The main pass does always returns true so this is a non-functional change in `mlir-tcp`. - The previous version did not handle ops with contained regions and block arguments (such as an `scf.forall` inside the `tcp.group`). We now handle this (see updated lit test). --------- Co-authored-by: Srinath Avadhanula --- .../Dialect/Transforms/IsolateGroupOpsPass.h | 11 +++ .../Transforms/IsolateGroupOpsPass.cpp | 60 ++++++++++--- test/Dialect/tcp_isolate_groups.mlir | 87 +++++++++++++++++++ 3 files changed, 147 insertions(+), 11 deletions(-) diff --git a/include/mlir-tcp/Dialect/Transforms/IsolateGroupOpsPass.h b/include/mlir-tcp/Dialect/Transforms/IsolateGroupOpsPass.h index 5e444092..bdd98f0a 100644 --- a/include/mlir-tcp/Dialect/Transforms/IsolateGroupOpsPass.h +++ b/include/mlir-tcp/Dialect/Transforms/IsolateGroupOpsPass.h @@ -9,6 +9,7 @@ #pragma once +#include "mlir-tcp/Dialect/IR/TcpOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include @@ -18,4 +19,14 @@ namespace mlir::tcp { std::unique_ptr> createTcpIsolateGroupOpsPass(); +// `createTcpIsolateGroupOpsPass` will clone all const operations used +// inside a `tcp.group` into the new `tcp.isolated_group` it creates. If +// you want to customize this behavior, you can use this instead to +// pass a predicate function to control when a `const-like` operation +// should be cloned into the isolated group or whether it should be added +// as an argument to the isolated group. +void populateIsolateGroupPatterns( + RewritePatternSet &patterns, + std::function shouldCopyConstPredicate); + } // namespace mlir::tcp diff --git a/lib/Dialect/Transforms/IsolateGroupOpsPass.cpp b/lib/Dialect/Transforms/IsolateGroupOpsPass.cpp index df242812..44e5377c 100644 --- a/lib/Dialect/Transforms/IsolateGroupOpsPass.cpp +++ b/lib/Dialect/Transforms/IsolateGroupOpsPass.cpp @@ -29,9 +29,12 @@ namespace mlir::tcp { namespace { -class IsolateGroups : public OpRewritePattern { +class IsolateGroups : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + IsolateGroups(MLIRContext *context, + std::function shouldInlineConst) + : OpRewritePattern(context), + shouldInlineConst_(shouldInlineConst) {} LogicalResult matchAndRewrite(tcp::GroupOp groupOp, PatternRewriter &rewriter) const override { @@ -41,12 +44,21 @@ class IsolateGroups : public OpRewritePattern { llvm::SmallVector inputs; llvm::SmallDenseSet addedInputs; llvm::SmallDenseSet consts; - llvm::SmallDenseSet defs; - for (auto &op : groupOp.getBody().front()) { - for (auto operand : op.getOperands()) { - if (defs.find(operand) == defs.end()) { + + groupOp->walk([&](Operation *op) { + for (auto operand : op->getOperands()) { + // Find the operation defining this Value, or whose block argument + // this Value is. + auto operandDefiningOp = operand.getDefiningOp(); + if (!operandDefiningOp) { + operandDefiningOp = operand.getParentBlock()->getParentOp(); + } + // If that operation lives outside the group, we need to add it as + // an input to the newly created isolated group. + if (!groupOp->isProperAncestor(operandDefiningOp)) { if (operand.getDefiningOp() && - operand.getDefiningOp()->hasTrait()) { + operand.getDefiningOp()->hasTrait() && + shouldInlineConst_(groupOp, operand)) { consts.insert(operand); } else if (!addedInputs.contains(operand)) { inputs.push_back(operand); @@ -54,20 +66,20 @@ class IsolateGroups : public OpRewritePattern { } } } - defs.insert(op.getResults().begin(), op.getResults().end()); - } + }); auto isolatedGroupOp = rewriter.create( groupOp.getLoc(), groupOp.getResultTypes(), inputs); isolatedGroupOp->setAttrs(groupOp->getAttrs()); isolatedGroupOp.getBody().takeBody(groupOp.getBody()); + auto &isolatedGroupBlock = isolatedGroupOp.getBody().front(); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&isolatedGroupBlock); auto belongsToIsolatedGroup = [&](OpOperand &opOperand) { - return (opOperand.getOwner()->getParentOp() == isolatedGroupOp); + return (isolatedGroupOp->isProperAncestor(opOperand.getOwner())); }; // Clone the constants at the start of the isolated group block. @@ -91,6 +103,23 @@ class IsolateGroups : public OpRewritePattern { rewriter.eraseOp(groupOp); return success(); } + +private: + std::function shouldInlineConst_; +}; + +class DropSymbolicShapesInsideGroups + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tcp::BindSymbolicShapeOp shapeOp, + PatternRewriter &rewriter) const override { + if (isa(shapeOp->getParentOp())) { + rewriter.eraseOp(shapeOp); + return success(); + } + return failure(); + } }; class TcpIsolateGroupOpsPass @@ -100,7 +129,8 @@ class TcpIsolateGroupOpsPass MLIRContext *context = op->getContext(); RewritePatternSet patterns(context); - patterns.add(context); + auto shouldCopyConstPredicate = [&](tcp::GroupOp, Value) { return true; }; + populateIsolateGroupPatterns(patterns, shouldCopyConstPredicate); if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) return signalPassFailure(); } @@ -112,4 +142,12 @@ std::unique_ptr> createTcpIsolateGroupOpsPass() { return std::make_unique(); } +void populateIsolateGroupPatterns( + RewritePatternSet &patterns, + std::function shouldCopyConstPredicate) { + + patterns.add(patterns.getContext(), shouldCopyConstPredicate); + patterns.add(patterns.getContext()); +} + } // namespace mlir::tcp diff --git a/test/Dialect/tcp_isolate_groups.mlir b/test/Dialect/tcp_isolate_groups.mlir index 221407f8..d11c72a6 100644 --- a/test/Dialect/tcp_isolate_groups.mlir +++ b/test/Dialect/tcp_isolate_groups.mlir @@ -120,3 +120,90 @@ func.func @test_inputs_with_multiple_uses(%arg0 : tensor<5xi32>) -> tensor<5xi32 }) : () -> tensor<5xi32> return %10 : tensor<5xi32> } + + +// ----- + +// isolate tcp.group ops in the presence of nested regions. + +// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: module { +// CHECK: func.func @forward(%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor, %[[ARG2:.+]]: tensor) -> tensor { +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[V0:.+]] = tcp.isolated_group %[[DIM]], %[[ARG0]], %[[ARG1]] attributes {group_type = "codegen_group"} { +// CHECK: ^bb0(%[[ARG3:.+]]: index, %[[ARG4:.+]]: tensor, %[[ARG5:.+]]: tensor): +// CHECK: %[[V1:.+]] = tensor.empty(%[[ARG3]]) : tensor +// CHECK: %[[V2:.+]] = scf.forall (%[[ARG6:.+]], %[[ARG7:.+]]) in (%[[ARG3]], 4096) shared_outs(%[[ARG8:.+]] = %[[V1]]) -> (tensor) { +// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG4]][%[[ARG6]], %[[ARG7]]] [1, 1] [1, 1] : tensor to tensor<1x1xf32> +// CHECK: %[[EXTRACTED_SLICE_0:.+]] = tensor.extract_slice %[[ARG5]][%[[ARG6]], %[[ARG7]]] [1, 1] [1, 1] : tensor to tensor<1x1xf32> +// CHECK: %[[V3:.+]] = tensor.empty() : tensor<1x1xf32> +// CHECK: %[[V4:.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE_0]] : tensor<1x1xf32>, tensor<1x1xf32>) outs(%[[V3]] : tensor<1x1xf32>) { +// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32): +// CHECK: %[[V5:.+]] = arith.mulf %[[IN]], %[[IN_1]] : f32 +// CHECK: linalg.yield %[[V5]] : f32 +// CHECK: } -> tensor<1x1xf32> +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[V4]] into %[[ARG8]][%[[ARG6]], %[[ARG7]]] [1, 1] [1, 1] : tensor<1x1xf32> into tensor +// CHECK: } +// CHECK: } +// CHECK: tcp.yield %[[V2]] : tensor +// CHECK: } : index, tensor, tensor -> tensor +// CHECK: return %[[V0]] : tensor +// CHECK: } +// CHECK: } +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @forward(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %c0 = arith.constant 0 : index + %dim = tensor.dim %arg0, %c0 : tensor + %0 = tcp.group attributes {group_type = "codegen_group"} { + %1 = tensor.empty(%dim) : tensor + %2 = scf.forall (%arg3, %arg4) in (%dim, 4096) shared_outs(%arg5 = %1) -> (tensor) { + %extracted_slice = tensor.extract_slice %arg0[%arg3, %arg4] [1, 1] [1, 1] : tensor to tensor<1x1xf32> + %extracted_slice_0 = tensor.extract_slice %arg1[%arg3, %arg4] [1, 1] [1, 1] : tensor to tensor<1x1xf32> + %3 = tensor.empty() : tensor<1x1xf32> + %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice, %extracted_slice_0 : tensor<1x1xf32>, tensor<1x1xf32>) outs(%3 : tensor<1x1xf32>) { + ^bb0(%in: f32, %in_4: f32, %out: f32): + %8 = arith.mulf %in, %in_4 : f32 + linalg.yield %8 : f32 + } -> tensor<1x1xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %4 into %arg5[%arg3, %arg4] [1, 1] [1, 1] : tensor<1x1xf32> into tensor + } + } + tcp.yield %2 : tensor + } : tensor + return %0 : tensor +} + +// ----- + +// Ensure that we correctly drop `tcp.bind_symbolic_shape` ops within the +// newly created tcp.isolated_group region. + +// CHECK: func.func @test_symbolic_shape_ops(%[[ARG0:.+]]: tensor) -> tensor { +// CHECK: %[[V0:.+]] = tcp.symbolic_int "s0" {min_val = 2, max_val = 9223372036854775806} : i64 +// CHECK: tcp.bind_symbolic_shape %[[ARG0]], [%[[V0]]], affine_map<()[s0] -> (s0, 3)> : tensor +// CHECK: %[[V1:.+]] = tcp.isolated_group %[[ARG0]] { +// CHECK: ^bb0(%[[ARG1:.+]]: tensor): +// CHECK: %[[V2:.+]] = tcp.add %[[ARG1]], %[[ARG1]] : tensor, tensor -> tensor +// CHECK-NOT: tcp.bind_symbolic_shape +// CHECK: %[[V3:.+]] = tcp.mul %[[V2]], %[[V2]] : tensor, tensor -> tensor +// CHECK: tcp.yield %[[V3]] : tensor +// CHECK: } : tensor -> tensor +// CHECK: tcp.bind_symbolic_shape %[[V1]], [%[[V0]]], affine_map<()[s0] -> (s0, 3)> : tensor +// CHECK: return %[[V1]] : tensor +// CHECK: } +func.func @test_symbolic_shape_ops(%arg0 : tensor) -> tensor { + %0 = tcp.symbolic_int "s0" {min_val = 2, max_val = 9223372036854775806} : i64 + tcp.bind_symbolic_shape %arg0, [%0], affine_map<()[s0] -> (s0, 3)> : tensor + %10 = "tcp.group" () ({ + ^bb0() : + %2 = tcp.add %arg0, %arg0 : tensor, tensor -> tensor + tcp.bind_symbolic_shape %2, [%0], affine_map<()[s0] -> (s0, 3)> : tensor + %3 = tcp.mul %2, %2 : tensor, tensor -> tensor + tcp.yield %3 : tensor + }) : () -> tensor + tcp.bind_symbolic_shape %10, [%0], affine_map<()[s0] -> (s0, 3)> : tensor + return %10 : tensor +}