Skip to content

Commit

Permalink
Add LLVM to CF raising pass
Browse files Browse the repository at this point in the history
  • Loading branch information
BuildKite committed Jan 27, 2025
1 parent 30afdff commit 41cc6a5
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 0 deletions.
107 changes: 107 additions & 0 deletions src/enzyme_ad/jax/Passes/LLVMToControlFlow.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
//===- LLVMToControlFlow.cpp - ControlFlow to LLVM dialect conversion -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert MLIR standard and builtin dialects
// into the LLVM IR dialect.
//
//===----------------------------------------------------------------------===//

#include "Passes.h"

#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <functional>

namespace mlir {
namespace enzyme {
#define GEN_PASS_DEF_CONVERTLLVMTOCONTROLFLOWPASS
#include "src/enzyme_ad/jax/Passes/Passes.h.inc"
} // namespace enzyme
} // namespace mlir

using namespace mlir;
using namespace mlir::enzyme;

#define PASS_NAME "convert-llvm-to-cf"

namespace {

struct BranchOpLifting : public OpRewritePattern<LLVM::BrOp> {
using OpRewritePattern<LLVM::BrOp>::OpRewritePattern;
LogicalResult matchAndRewrite(LLVM::BrOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<cf::BranchOp>(op, op.getDest(),
op.getDestOperands());
return success();
}
};

struct CondBranchOpLifting : public OpRewritePattern<LLVM::CondBrOp> {
using OpRewritePattern<LLVM::CondBrOp>::OpRewritePattern;
LogicalResult matchAndRewrite(LLVM::CondBrOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
op, op.getCondition(), op.getTrueDest(), op.getTrueDestOperands(),
op.getFalseDest(), op.getFalseDestOperands());
return success();
}
};

struct SwitchOpLifting : public OpRewritePattern<LLVM::SwitchOp> {
using OpRewritePattern<LLVM::SwitchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(LLVM::SwitchOp op,
PatternRewriter &rewriter) const override {
SmallVector<APInt> caseValues;
SmallVector<ValueRange> caseOperands;
if (auto cvs = op.getCaseValues())
for (auto val : *cvs)
caseValues.push_back(val);
for (auto val : op.getCaseOperands())
caseOperands.push_back(val);
rewriter.replaceOpWithNewOp<cf::SwitchOp>(
op, op.getValue(), op.getDefaultDestination(), op.getDefaultOperands(),
caseValues, op.getCaseDestinations(), caseOperands);
return success();
}
};

} // namespace

void mlir::cf::populateLLVMToControlFlowConversionPatterns(
RewritePatternSet &patterns) {
// clang-format off
patterns.add<
BranchOpLifting,
CondBranchOpLifting,
SwitchOpLifting>(patterns.getContext());
// clang-format on
}

//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//

namespace {
/// A pass converting MLIR operations into the LLVM IR dialect.
struct ConvertLLVMToControlFlow
: public enzyme::impl::ConvertLLVMToControlFlowPassBase<
ConvertLLVMToControlFlow> {
using ConvertLLVMToControlFlowPassBase::ConvertLLVMToControlFlowPassBase;
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
mlir::cf::populateLLVMToControlFlowConversionPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};
} // namespace
5 changes: 5 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ void populateLibDeviceFuncsToOpsPatterns(MLIRContext *context,
RewritePatternSet &patterns);

} // namespace enzyme

namespace cf {
void populateLLVMToControlFlowConversionPatterns(RewritePatternSet &patterns);
} // namespace cf

} // end namespace mlir

#endif // ENZYMEXLA_PASSES_H
9 changes: 9 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -295,4 +295,13 @@ def LowerKernelPass : Pass<"lower-kernel"> {
];
}

//===----------------------------------------------------------------------===//
// LLVMToControlFlow
//===----------------------------------------------------------------------===//

def ConvertLLVMToControlFlowPass : Pass<"convert-llvm-to-cf"> {
let summary = "Convert LLVM cf operations to the ControlFlow dialect";
let dependentDialects = ["cf::ControlFlowDialect"];
}

#endif
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/enzymexlamlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ int main(int argc, char **argv) {
mlir::registerConvertSCFToOpenMPPass();
mlir::affine::registerAffinePasses();
mlir::registerReconcileUnrealizedCasts();
mlir::enzyme::registerConvertLLVMToControlFlowPass();

registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
Expand Down
41 changes: 41 additions & 0 deletions test/lit_tests/raising/llvm_to_cf.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(convert-llvm-to-cf)" | FileCheck %s

// CHECK-LABEL: func @test_br
func.func @test_br() {
// CHECK: cf.br ^bb1
llvm.br ^bb1
^bb1:
return
}

// CHECK-LABEL: func @test_cond_br
func.func @test_cond_br(%cond: i1) {
// CHECK: cf.cond_br %arg0, ^bb1, ^bb2
llvm.cond_br %cond, ^bb1, ^bb2
^bb1:
return
^bb2:
llvm.unreachable
}

llvm.func internal unnamed_addr fastcc @throw_boundserror_2676() attributes {dso_local, no_inline, sym_visibility = "private"} {
llvm.unreachable
}
// CHECK-LABEL: func @test_switch
func.func @test_switch(%val: i32) {
// CHECK: cf.switch %arg0 : i32, [
// CHECK-NEXT: default: ^bb3,
// CHECK-NEXT: 0: ^bb1,
// CHECK-NEXT: 1: ^bb2
llvm.switch %val : i32, ^bb3 [
0: ^bb1,
1: ^bb2
]
^bb1:
return
^bb2:
llvm.unreachable
^bb3:
llvm.call fastcc @throw_boundserror_2676() : () -> ()
llvm.unreachable
}

0 comments on commit 41cc6a5

Please sign in to comment.