Skip to content

Commit

Permalink
Add custom CF to SCF raising pass
Browse files Browse the repository at this point in the history
  • Loading branch information
BuildKite committed Jan 27, 2025
1 parent 6103fee commit 90d8cce
Show file tree
Hide file tree
Showing 4 changed files with 818 additions and 0 deletions.
80 changes: 80 additions & 0 deletions src/enzyme_ad/jax/Passes/ControlFlowToSCF.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
//===- ControlFlowToSCF.h - ControlFlow to SCF -------------*- C++ ------*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Define conversions from the ControlFlow dialect to the SCF dialect.
//
//===----------------------------------------------------------------------===//

#include "Passes.h"

#include "mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/CFGToSCF.h"

namespace mlir {
namespace enzyme {
#define GEN_PASS_DEF_ENZYMELIFTCONTROLFLOWTOSCFPASS
#include "src/enzyme_ad/jax/Passes/Passes.h.inc"
} // namespace enzyme
} // namespace mlir

using namespace mlir;

namespace {

struct EnzymeLiftControlFlowToSCF
: public enzyme::impl::EnzymeLiftControlFlowToSCFPassBase<
EnzymeLiftControlFlowToSCF> {

using EnzymeLiftControlFlowToSCFPassBase::EnzymeLiftControlFlowToSCFPassBase;

void runOnOperation() override {
ControlFlowToSCFTransformation transformation;

bool changed = false;
Operation *op = getOperation();
WalkResult result = op->walk([&](Region *region) {
if (region->empty())
return WalkResult::advance();

Operation *regionParent = region->getParentOp();
auto &domInfo = regionParent != op
? getChildAnalysis<DominanceInfo>(regionParent)
: getAnalysis<DominanceInfo>();

auto visitor = [&](Operation *innerOp) -> WalkResult {
for (Region &reg : innerOp->getRegions()) {
FailureOr<bool> changedFunc =
transformCFGToSCF(reg, transformation, domInfo);
if (failed(changedFunc))
return WalkResult::interrupt();

changed |= *changedFunc;
}
return WalkResult::advance();
};

if (region->walk<WalkOrder::PostOrder>(visitor).wasInterrupted())
return WalkResult::interrupt();

return WalkResult::advance();
});
if (result.wasInterrupted())
return signalPassFailure();

if (!changed)
markAllAnalysesPreserved();
}
};
} // namespace
31 changes: 31 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -320,4 +320,35 @@ def ConvertLLVMToControlFlowPass : Pass<"convert-llvm-to-cf"> {
let dependentDialects = ["cf::ControlFlowDialect"];
}

//===----------------------------------------------------------------------===//
// ControlFlowToSCF
//===----------------------------------------------------------------------===//

def EnzymeLiftControlFlowToSCFPass : Pass<"enzyme-lift-cf-to-scf"> {
let summary = "Lift ControlFlow dialect to SCF dialect";
let description = [{
Lifts ControlFlow operations to SCF dialect operations.

This pass is prefixed with "lift" instead of "convert" as it is not always
guaranteed to replace all ControlFlow ops.
If a region contains only a single kind of return-like operation, all
ControlFlow operations will be replaced successfully.
Otherwise a single ControlFlow switch branching to one block per return-like
operation kind remains.

This pass may need to create unreachable terminators in case of infinite
loops, which is only supported for 'func.func' for now. If you potentially
have infinite loops inside CFG regions not belonging to 'func.func',
consider using `transformCFGToSCF` function directly with corresponding
`CFGToSCFInterface::createUnreachableTerminator` implementation.
}];

let dependentDialects = ["scf::SCFDialect",
"arith::ArithDialect",
"ub::UBDialect",
// TODO: This is only necessary until we have a
// ub.unreachable op.
"func::FuncDialect"];
}

#endif
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/enzymexlamlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ int main(int argc, char **argv) {
mlir::affine::registerAffinePasses();
mlir::registerReconcileUnrealizedCasts();
mlir::enzyme::registerConvertLLVMToControlFlowPass();
mlir::enzyme::registerEnzymeLiftControlFlowToSCFPass();

registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
Expand Down
Loading

0 comments on commit 90d8cce

Please sign in to comment.