forked from cruise-automation/mlir-tcp
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Cherry-Pick] Move bottom up fuser declaration to header file (cruise…
…-automation#7) Move bottom up fuser declaration to header file No testing required since it's a minor restructuring Cherry-pick from cruise-automation@c4c94fb --------- Co-authored-by: Muhammad Abubakar <[email protected]>
- Loading branch information
Showing
4 changed files
with
129 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
//===------------------------------------------------------------*- C++ -*-===// | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#pragma once | ||
|
||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
namespace mlir::tcp { | ||
|
||
class GenericBottomUpFuser : public RewritePattern { | ||
public: | ||
using CanFuseFuncType = std::function<bool(Operation *, Operation *)>; | ||
using PostProcessingFuncType = | ||
std::function<void(Operation *, PatternRewriter &rewriter)>; | ||
|
||
// A class for supporting generic bottom-up fusion | ||
// All fused operations will be placed in a single TCP group | ||
// canFuseCallback checks whether two operations can be fused | ||
// postFuncCallback is called on the new TCP group for | ||
// post-processing. It provides the group handle to the client pass | ||
// and, e.g. can be used to add ad-hoc attributes to the group op | ||
// auto addGroupAttr = [](Operation * groupOp, | ||
// PatternRewriter &rewriter) -> void { | ||
// groupOp->setAttr("group_type", rewriter.getStringAttr("xxx")); | ||
// }; | ||
|
||
GenericBottomUpFuser(MLIRContext *context, CanFuseFuncType canFuseCallback, | ||
PostProcessingFuncType postFuncCallback = nullptr) | ||
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), | ||
canFuse(canFuseCallback), postFunc(postFuncCallback) {} | ||
|
||
LogicalResult matchAndRewrite(Operation *op, | ||
PatternRewriter &rewriter) const override; | ||
|
||
private: | ||
CanFuseFuncType canFuse; | ||
PostProcessingFuncType postFunc; | ||
}; | ||
} // namespace mlir::tcp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
//===------------------------------------------------------------*- C++ -*-===// | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir-tcp/Dialect/Transforms/FusionPatterns.h" | ||
#include "mlir-tcp/Dialect/IR/TcpDialect.h" | ||
#include "mlir-tcp/Dialect/IR/TcpOps.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/OpDefinition.h" | ||
|
||
namespace mlir::tcp { | ||
LogicalResult | ||
GenericBottomUpFuser::matchAndRewrite(Operation *op, | ||
PatternRewriter &rewriter) const { | ||
Operation *use = op; | ||
bool isChanged = false; | ||
for (auto operand : op->getOperands()) { | ||
if (operand.getDefiningOp()) { | ||
Operation *def = operand.getDefiningOp(); | ||
if (canFuse(def, use)) { | ||
|
||
// Currently we are only fusing ops at the top-level. | ||
// This is to avoid recursing inside a group and ending up with | ||
// nested groups that contain the same ops. | ||
// Since we are iterating bottom up in a block, we only need to check | ||
// if the def op has a func parent. | ||
// | ||
// TODO: Remove this restriction to allow fusing in nested regions. | ||
if (!isa<func::FuncOp>(def->getParentOp())) { | ||
continue; | ||
} | ||
|
||
// We only support fusing def ops that have exactly one use, for now. | ||
if (!def->hasOneUse()) { | ||
continue; | ||
} | ||
|
||
// Fuse the def and use ops into a group. | ||
|
||
// * If both the ops have the same parent region, they must be part | ||
// of the top-level func. So, we need to create a new group. | ||
// * The only other case is when the def op is part of the top-level | ||
// func and the use is already inside a group. | ||
isChanged = true; | ||
if (def->getParentRegion() == use->getParentRegion()) { | ||
auto groupOp = | ||
rewriter.create<GroupOp>(use->getLoc(), use->getResultTypes()); | ||
if (postFunc) { | ||
postFunc(groupOp, rewriter); | ||
} | ||
Block *groupBlock = new Block(); | ||
groupOp.getBody().push_back(groupBlock); | ||
for (unsigned num = 0; num < use->getNumResults(); ++num) { | ||
rewriter.replaceAllUsesWith(use->getResult(num), | ||
groupOp->getResult(num)); | ||
} | ||
{ | ||
OpBuilder::InsertionGuard guard(rewriter); | ||
rewriter.setInsertionPointToStart(groupBlock); | ||
auto yieldOp = | ||
rewriter.create<YieldOp>(use->getLoc(), use->getResults()); | ||
use->moveBefore(yieldOp); | ||
operand.getDefiningOp()->moveBefore(use); | ||
} | ||
} else if (auto groupOp = dyn_cast<GroupOp>(use->getParentOp())) { | ||
def->moveBefore(use); | ||
} else { | ||
llvm_unreachable("Unhandled case during fusion"); | ||
} | ||
} | ||
} | ||
} | ||
return isChanged ? success() : failure(); | ||
} | ||
} // namespace mlir::tcp |