Skip to content

Commit

Permalink
[Cherry-Pick] Move bottom up fuser declaration to header file (cruise…
Browse files Browse the repository at this point in the history
…-automation#7)

Move bottom up fuser declaration to header file

No testing required since it's a minor restructuring

Cherry-pick from
cruise-automation@c4c94fb
---------

Co-authored-by: Muhammad Abubakar <[email protected]>
  • Loading branch information
Muhammad Abubakar authored and GitHub Enterprise committed Dec 7, 2023
2 parents 052ba71 + eea9c34 commit 9c19ae3
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 78 deletions.
2 changes: 2 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,15 @@ cc_library(
name = "TcpDialectPasses",
srcs = [
"lib/Dialect/Transforms/FuseTcpOpsPass.cpp",
"lib/Dialect/Transforms/FusionPatterns.cpp",
"lib/Dialect/Transforms/IsolateGroupOpsPass.cpp",
"lib/Dialect/Transforms/PassDetail.h",
"lib/Dialect/Transforms/Passes.cpp",
"lib/Dialect/Transforms/VerifyTcpBackendContractPass.cpp",
],
hdrs = [
"include/mlir-tcp/Dialect/Transforms/FuseTcpOpsPass.h",
"include/mlir-tcp/Dialect/Transforms/FusionPatterns.h",
"include/mlir-tcp/Dialect/Transforms/IsolateGroupOpsPass.h",
"include/mlir-tcp/Dialect/Transforms/Passes.h",
"include/mlir-tcp/Dialect/Transforms/VerifyTcpBackendContractPass.h",
Expand Down
45 changes: 45 additions & 0 deletions include/mlir-tcp/Dialect/Transforms/FusionPatterns.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#pragma once

#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::tcp {

class GenericBottomUpFuser : public RewritePattern {
public:
using CanFuseFuncType = std::function<bool(Operation *, Operation *)>;
using PostProcessingFuncType =
std::function<void(Operation *, PatternRewriter &rewriter)>;

// A class for supporting generic bottom-up fusion
// All fused operations will be placed in a single TCP group
// canFuseCallback checks whether two operations can be fused
// postFuncCallback is called on the new TCP group for
// post-processing. It provides the group handle to the client pass
// and, e.g. can be used to add ad-hoc attributes to the group op
// auto addGroupAttr = [](Operation * groupOp,
// PatternRewriter &rewriter) -> void {
// groupOp->setAttr("group_type", rewriter.getStringAttr("xxx"));
// };

GenericBottomUpFuser(MLIRContext *context, CanFuseFuncType canFuseCallback,
PostProcessingFuncType postFuncCallback = nullptr)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
canFuse(canFuseCallback), postFunc(postFuncCallback) {}

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;

private:
CanFuseFuncType canFuse;
PostProcessingFuncType postFunc;
};
} // namespace mlir::tcp
79 changes: 1 addition & 78 deletions lib/Dialect/Transforms/FuseTcpOpsPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,95 +8,18 @@
//===----------------------------------------------------------------------===//

#include "mlir-tcp/Dialect/Transforms/FuseTcpOpsPass.h"
#include "mlir-tcp/Dialect/IR/TcpDialect.h"
#include "mlir-tcp/Dialect/IR/TcpOps.h"
#include "mlir-tcp/Dialect/Transforms/FusionPatterns.h"
#include "mlir-tcp/Dialect/Transforms/Passes.h"

#include "./PassDetail.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

namespace mlir::tcp {

namespace {

class GenericBottomUpFuser : public RewritePattern {
public:
using CanFuseFuncType = std::function<bool(Operation *, Operation *)>;

GenericBottomUpFuser(MLIRContext *context, CanFuseFuncType cf)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
canFuse(cf) {}

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
Operation *use = op;
for (auto operand : op->getOperands()) {
if (operand.getDefiningOp()) {
Operation *def = operand.getDefiningOp();
if (canFuse(def, use)) {

// Currently we are only fusing ops at the top-level.
// This is to avoid recursing inside a group and ending up with
// nested groups that contain the same ops.
// Since we are iterating bottom up in a block, we only need to check
// if the def op has a func parent.
//
// TODO: Remove this restriction to allow fusing in nested regions.
if (!isa<func::FuncOp>(def->getParentOp())) {
continue;
}

// We only support fusing def ops that have exactly one use, for now.
if (!def->hasOneUse()) {
continue;
}

// Fuse the def and use ops into a group.

// * If both the ops have the same parent region, they must be part
// of the top-level func. So, we need to create a new group.
// * The only other case is when the def op is part of the top-level
// func and the use is already inside a group.
if (def->getParentRegion() == use->getParentRegion()) {
auto groupOp =
rewriter.create<GroupOp>(use->getLoc(), use->getResultTypes());
Block *groupBlock = new Block();
groupOp.getBody().push_back(groupBlock);
for (unsigned num = 0; num < use->getNumResults(); ++num) {
rewriter.replaceAllUsesWith(use->getResult(num),
groupOp->getResult(num));
}
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(groupBlock);
auto yieldOp =
rewriter.create<YieldOp>(use->getLoc(), use->getResults());
use->moveBefore(yieldOp);
operand.getDefiningOp()->moveBefore(use);
}
} else if (auto groupOp = dyn_cast<GroupOp>(use->getParentOp())) {
def->moveBefore(use);
} else {
llvm_unreachable("Unhandled case during fusion");
}
}
}
}
return success();
}

private:
CanFuseFuncType canFuse;
};

class TcpFuseElementwiseOpsPass
: public TcpFuseElementwiseOpsBase<TcpFuseElementwiseOpsPass> {
void runOnOperation() override {
Expand Down
81 changes: 81 additions & 0 deletions lib/Dialect/Transforms/FusionPatterns.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "mlir-tcp/Dialect/Transforms/FusionPatterns.h"
#include "mlir-tcp/Dialect/IR/TcpDialect.h"
#include "mlir-tcp/Dialect/IR/TcpOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpDefinition.h"

namespace mlir::tcp {
LogicalResult
GenericBottomUpFuser::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
Operation *use = op;
bool isChanged = false;
for (auto operand : op->getOperands()) {
if (operand.getDefiningOp()) {
Operation *def = operand.getDefiningOp();
if (canFuse(def, use)) {

// Currently we are only fusing ops at the top-level.
// This is to avoid recursing inside a group and ending up with
// nested groups that contain the same ops.
// Since we are iterating bottom up in a block, we only need to check
// if the def op has a func parent.
//
// TODO: Remove this restriction to allow fusing in nested regions.
if (!isa<func::FuncOp>(def->getParentOp())) {
continue;
}

// We only support fusing def ops that have exactly one use, for now.
if (!def->hasOneUse()) {
continue;
}

// Fuse the def and use ops into a group.

// * If both the ops have the same parent region, they must be part
// of the top-level func. So, we need to create a new group.
// * The only other case is when the def op is part of the top-level
// func and the use is already inside a group.
isChanged = true;
if (def->getParentRegion() == use->getParentRegion()) {
auto groupOp =
rewriter.create<GroupOp>(use->getLoc(), use->getResultTypes());
if (postFunc) {
postFunc(groupOp, rewriter);
}
Block *groupBlock = new Block();
groupOp.getBody().push_back(groupBlock);
for (unsigned num = 0; num < use->getNumResults(); ++num) {
rewriter.replaceAllUsesWith(use->getResult(num),
groupOp->getResult(num));
}
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(groupBlock);
auto yieldOp =
rewriter.create<YieldOp>(use->getLoc(), use->getResults());
use->moveBefore(yieldOp);
operand.getDefiningOp()->moveBefore(use);
}
} else if (auto groupOp = dyn_cast<GroupOp>(use->getParentOp())) {
def->moveBefore(use);
} else {
llvm_unreachable("Unhandled case during fusion");
}
}
}
}
return isChanged ? success() : failure();
}
} // namespace mlir::tcp

0 comments on commit 9c19ae3

Please sign in to comment.