Skip to content

Commit

Permalink
Add support for customizing const inlining and ops with regions to Is…
Browse files Browse the repository at this point in the history
…olateGroupOps pass (#87)

A few small improvements to the `IsolateGroupOps` pass:
- Provide a hook to customize whether a const like op used in the
`tcp.group` will get copied into the group or whether it will be passed
in as an input argument. The main pass does always returns true so this
is a non-functional change in `mlir-tcp`.
- The previous version did not handle ops with contained regions and
block arguments (such as an `scf.forall` inside the `tcp.group`). We now
handle this (see updated lit test).

---------

Co-authored-by: Srinath Avadhanula <[email protected]>
  • Loading branch information
srinathava and Srinath Avadhanula authored Jul 31, 2024
1 parent 7c50225 commit a145d84
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 11 deletions.
11 changes: 11 additions & 0 deletions include/mlir-tcp/Dialect/Transforms/IsolateGroupOpsPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#pragma once

#include "mlir-tcp/Dialect/IR/TcpOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>
Expand All @@ -18,4 +19,14 @@ namespace mlir::tcp {
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createTcpIsolateGroupOpsPass();

// `createTcpIsolateGroupOpsPass` will clone all const operations used
// inside a `tcp.group` into the new `tcp.isolated_group` it creates. If
// you want to customize this behavior, you can use this instead to
// pass a predicate function to control when a `const-like` operation
// should be cloned into the isolated group or whether it should be added
// as an argument to the isolated group.
void populateIsolateGroupPatterns(
RewritePatternSet &patterns,
std::function<bool(GroupOp, Value)> shouldCopyConstPredicate);

} // namespace mlir::tcp
60 changes: 49 additions & 11 deletions lib/Dialect/Transforms/IsolateGroupOpsPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ namespace mlir::tcp {

namespace {

class IsolateGroups : public OpRewritePattern<tcp::GroupOp> {
class IsolateGroups : public OpRewritePattern<GroupOp> {
public:
using OpRewritePattern::OpRewritePattern;
IsolateGroups(MLIRContext *context,
std::function<bool(tcp::GroupOp, Value)> shouldInlineConst)
: OpRewritePattern<GroupOp>(context),
shouldInlineConst_(shouldInlineConst) {}

LogicalResult matchAndRewrite(tcp::GroupOp groupOp,
PatternRewriter &rewriter) const override {
Expand All @@ -41,33 +44,42 @@ class IsolateGroups : public OpRewritePattern<tcp::GroupOp> {
llvm::SmallVector<Value> inputs;
llvm::SmallDenseSet<Value> addedInputs;
llvm::SmallDenseSet<Value> consts;
llvm::SmallDenseSet<Value> defs;
for (auto &op : groupOp.getBody().front()) {
for (auto operand : op.getOperands()) {
if (defs.find(operand) == defs.end()) {

groupOp->walk([&](Operation *op) {
for (auto operand : op->getOperands()) {
// Find the operation defining this Value, or whose block argument
// this Value is.
auto operandDefiningOp = operand.getDefiningOp();
if (!operandDefiningOp) {
operandDefiningOp = operand.getParentBlock()->getParentOp();
}
// If that operation lives outside the group, we need to add it as
// an input to the newly created isolated group.
if (!groupOp->isProperAncestor(operandDefiningOp)) {
if (operand.getDefiningOp() &&
operand.getDefiningOp()->hasTrait<OpTrait::ConstantLike>()) {
operand.getDefiningOp()->hasTrait<OpTrait::ConstantLike>() &&
shouldInlineConst_(groupOp, operand)) {
consts.insert(operand);
} else if (!addedInputs.contains(operand)) {
inputs.push_back(operand);
addedInputs.insert(operand);
}
}
}
defs.insert(op.getResults().begin(), op.getResults().end());
}
});

auto isolatedGroupOp = rewriter.create<tcp::IsolatedGroupOp>(
groupOp.getLoc(), groupOp.getResultTypes(), inputs);
isolatedGroupOp->setAttrs(groupOp->getAttrs());

isolatedGroupOp.getBody().takeBody(groupOp.getBody());

auto &isolatedGroupBlock = isolatedGroupOp.getBody().front();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&isolatedGroupBlock);
auto belongsToIsolatedGroup = [&](OpOperand &opOperand) {
return (opOperand.getOwner()->getParentOp() == isolatedGroupOp);
return (isolatedGroupOp->isProperAncestor(opOperand.getOwner()));
};

// Clone the constants at the start of the isolated group block.
Expand All @@ -91,6 +103,23 @@ class IsolateGroups : public OpRewritePattern<tcp::GroupOp> {
rewriter.eraseOp(groupOp);
return success();
}

private:
std::function<bool(tcp::GroupOp, Value)> shouldInlineConst_;
};

class DropSymbolicShapesInsideGroups
: public OpRewritePattern<tcp::BindSymbolicShapeOp> {
using OpRewritePattern<tcp::BindSymbolicShapeOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tcp::BindSymbolicShapeOp shapeOp,
PatternRewriter &rewriter) const override {
if (isa<tcp::GroupOp>(shapeOp->getParentOp())) {
rewriter.eraseOp(shapeOp);
return success();
}
return failure();
}
};

class TcpIsolateGroupOpsPass
Expand All @@ -100,7 +129,8 @@ class TcpIsolateGroupOpsPass
MLIRContext *context = op->getContext();
RewritePatternSet patterns(context);

patterns.add<IsolateGroups>(context);
auto shouldCopyConstPredicate = [&](tcp::GroupOp, Value) { return true; };
populateIsolateGroupPatterns(patterns, shouldCopyConstPredicate);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
return signalPassFailure();
}
Expand All @@ -112,4 +142,12 @@ std::unique_ptr<OperationPass<ModuleOp>> createTcpIsolateGroupOpsPass() {
return std::make_unique<TcpIsolateGroupOpsPass>();
}

void populateIsolateGroupPatterns(
RewritePatternSet &patterns,
std::function<bool(tcp::GroupOp, Value)> shouldCopyConstPredicate) {

patterns.add<IsolateGroups>(patterns.getContext(), shouldCopyConstPredicate);
patterns.add<DropSymbolicShapesInsideGroups>(patterns.getContext());
}

} // namespace mlir::tcp
87 changes: 87 additions & 0 deletions test/Dialect/tcp_isolate_groups.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,90 @@ func.func @test_inputs_with_multiple_uses(%arg0 : tensor<5xi32>) -> tensor<5xi32
}) : () -> tensor<5xi32>
return %10 : tensor<5xi32>
}


// -----

// isolate tcp.group ops in the presence of nested regions.

// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: module {
// CHECK: func.func @forward(%[[ARG0:.+]]: tensor<?x4096xf32>, %[[ARG1:.+]]: tensor<?x4096xf32>, %[[ARG2:.+]]: tensor<?x4096xf32>) -> tensor<?x4096xf32> {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x4096xf32>
// CHECK: %[[V0:.+]] = tcp.isolated_group %[[DIM]], %[[ARG0]], %[[ARG1]] attributes {group_type = "codegen_group"} {
// CHECK: ^bb0(%[[ARG3:.+]]: index, %[[ARG4:.+]]: tensor<?x4096xf32>, %[[ARG5:.+]]: tensor<?x4096xf32>):
// CHECK: %[[V1:.+]] = tensor.empty(%[[ARG3]]) : tensor<?x4096xf32>
// CHECK: %[[V2:.+]] = scf.forall (%[[ARG6:.+]], %[[ARG7:.+]]) in (%[[ARG3]], 4096) shared_outs(%[[ARG8:.+]] = %[[V1]]) -> (tensor<?x4096xf32>) {
// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG4]][%[[ARG6]], %[[ARG7]]] [1, 1] [1, 1] : tensor<?x4096xf32> to tensor<1x1xf32>
// CHECK: %[[EXTRACTED_SLICE_0:.+]] = tensor.extract_slice %[[ARG5]][%[[ARG6]], %[[ARG7]]] [1, 1] [1, 1] : tensor<?x4096xf32> to tensor<1x1xf32>
// CHECK: %[[V3:.+]] = tensor.empty() : tensor<1x1xf32>
// CHECK: %[[V4:.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE_0]] : tensor<1x1xf32>, tensor<1x1xf32>) outs(%[[V3]] : tensor<1x1xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[V5:.+]] = arith.mulf %[[IN]], %[[IN_1]] : f32
// CHECK: linalg.yield %[[V5]] : f32
// CHECK: } -> tensor<1x1xf32>
// CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[V4]] into %[[ARG8]][%[[ARG6]], %[[ARG7]]] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<?x4096xf32>
// CHECK: }
// CHECK: }
// CHECK: tcp.yield %[[V2]] : tensor<?x4096xf32>
// CHECK: } : index, tensor<?x4096xf32>, tensor<?x4096xf32> -> tensor<?x4096xf32>
// CHECK: return %[[V0]] : tensor<?x4096xf32>
// CHECK: }
// CHECK: }
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @forward(%arg0: tensor<?x4096xf32>, %arg1: tensor<?x4096xf32>, %arg2: tensor<?x4096xf32>) -> tensor<?x4096xf32> {
%c0 = arith.constant 0 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x4096xf32>
%0 = tcp.group attributes {group_type = "codegen_group"} {
%1 = tensor.empty(%dim) : tensor<?x4096xf32>
%2 = scf.forall (%arg3, %arg4) in (%dim, 4096) shared_outs(%arg5 = %1) -> (tensor<?x4096xf32>) {
%extracted_slice = tensor.extract_slice %arg0[%arg3, %arg4] [1, 1] [1, 1] : tensor<?x4096xf32> to tensor<1x1xf32>
%extracted_slice_0 = tensor.extract_slice %arg1[%arg3, %arg4] [1, 1] [1, 1] : tensor<?x4096xf32> to tensor<1x1xf32>
%3 = tensor.empty() : tensor<1x1xf32>
%4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice, %extracted_slice_0 : tensor<1x1xf32>, tensor<1x1xf32>) outs(%3 : tensor<1x1xf32>) {
^bb0(%in: f32, %in_4: f32, %out: f32):
%8 = arith.mulf %in, %in_4 : f32
linalg.yield %8 : f32
} -> tensor<1x1xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %4 into %arg5[%arg3, %arg4] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<?x4096xf32>
}
}
tcp.yield %2 : tensor<?x4096xf32>
} : tensor<?x4096xf32>
return %0 : tensor<?x4096xf32>
}

// -----

// Ensure that we correctly drop `tcp.bind_symbolic_shape` ops within the
// newly created tcp.isolated_group region.

// CHECK: func.func @test_symbolic_shape_ops(%[[ARG0:.+]]: tensor<?x3xf32>) -> tensor<?x3xf32> {
// CHECK: %[[V0:.+]] = tcp.symbolic_int "s0" {min_val = 2, max_val = 9223372036854775806} : i64
// CHECK: tcp.bind_symbolic_shape %[[ARG0]], [%[[V0]]], affine_map<()[s0] -> (s0, 3)> : tensor<?x3xf32>
// CHECK: %[[V1:.+]] = tcp.isolated_group %[[ARG0]] {
// CHECK: ^bb0(%[[ARG1:.+]]: tensor<?x3xf32>):
// CHECK: %[[V2:.+]] = tcp.add %[[ARG1]], %[[ARG1]] : tensor<?x3xf32>, tensor<?x3xf32> -> tensor<?x3xf32>
// CHECK-NOT: tcp.bind_symbolic_shape
// CHECK: %[[V3:.+]] = tcp.mul %[[V2]], %[[V2]] : tensor<?x3xf32>, tensor<?x3xf32> -> tensor<?x3xf32>
// CHECK: tcp.yield %[[V3]] : tensor<?x3xf32>
// CHECK: } : tensor<?x3xf32> -> tensor<?x3xf32>
// CHECK: tcp.bind_symbolic_shape %[[V1]], [%[[V0]]], affine_map<()[s0] -> (s0, 3)> : tensor<?x3xf32>
// CHECK: return %[[V1]] : tensor<?x3xf32>
// CHECK: }
func.func @test_symbolic_shape_ops(%arg0 : tensor<?x3xf32>) -> tensor<?x3xf32> {
%0 = tcp.symbolic_int "s0" {min_val = 2, max_val = 9223372036854775806} : i64
tcp.bind_symbolic_shape %arg0, [%0], affine_map<()[s0] -> (s0, 3)> : tensor<?x3xf32>
%10 = "tcp.group" () ({
^bb0() :
%2 = tcp.add %arg0, %arg0 : tensor<?x3xf32>, tensor<?x3xf32> -> tensor<?x3xf32>
tcp.bind_symbolic_shape %2, [%0], affine_map<()[s0] -> (s0, 3)> : tensor<?x3xf32>
%3 = tcp.mul %2, %2 : tensor<?x3xf32>, tensor<?x3xf32> -> tensor<?x3xf32>
tcp.yield %3 : tensor<?x3xf32>
}) : () -> tensor<?x3xf32>
tcp.bind_symbolic_shape %10, [%0], affine_map<()[s0] -> (s0, 3)> : tensor<?x3xf32>
return %10 : tensor<?x3xf32>
}

0 comments on commit a145d84

Please sign in to comment.