Skip to content

Commit

Permalink
Transform interpreter runnable from within a function (#211)
Browse files Browse the repository at this point in the history
- Modify the TD generation pass to emit the transform module if requested by a
  flag.
- Introduce a separate interpreter pass that moves the transform script into a
  separate module to allow it to be applied to the pass anchor operation, and
  then remove it entirely.
  • Loading branch information
ftynse authored Jan 3, 2025
1 parent 288bbfa commit 7244109
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ cc_library(
"@llvm-project//mlir:AffineToStandard",
"@llvm-project//mlir:ArithToLLVM",
"@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:TransformDialectTransforms",
"@llvm-project//mlir:GPUToGPURuntimeTransforms",
"@llvm-project//mlir:GPUCommonTransforms",
"@llvm-project//mlir:GPUToNVVMTransforms",
Expand Down
59 changes: 59 additions & 0 deletions src/enzyme_ad/jax/Passes/ConsumingInterpreterPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
//===- ConsumingInterpreterPass.cpp - Interpret and remove transforms -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
#include "mlir/Pass/Pass.h"
#include "src/enzyme_ad/jax/Passes/PassDetails.h"
#include "src/enzyme_ad/jax/Passes/Passes.h"

using namespace mlir;
using namespace mlir::enzyme;

namespace {
class ConsumingInterpreterPass
: public ConsumingInterpreterPassBase<ConsumingInterpreterPass> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConsumingInterpreterPass)

StringRef getArgument() const override {
return "enzyme-consuming-transform-interpreter";
}

void runOnOperation() override {
Operation *op = getOperation();
Operation *entryPoint =
transform::detail::findTransformEntryPoint(op, nullptr);
if (!entryPoint)
return signalPassFailure();

auto transformModule = dyn_cast<ModuleOp>(entryPoint->getParentOp());
if (!transformModule) {
emitError(entryPoint->getLoc())
<< "expected the transform entry point to be located in a module";
return signalPassFailure();
}

transformModule->remove();
OwningOpRef<ModuleOp> owningTransformModule(transformModule);

RaggedArray<transform::MappedValue> bindings;
bindings.push_back(ArrayRef<Operation *>{op});
if (failed(transform::applyTransformNamedSequence(
bindings, cast<transform::TransformOpInterface>(entryPoint),
*owningTransformModule,
transform::TransformOptions().enableExpensiveChecks(true))))
return signalPassFailure();
}
};
} // namespace

std::unique_ptr<Pass> mlir::enzyme::createConsumingInterpreterPass() {
return std::make_unique<ConsumingInterpreterPass>();
}
2 changes: 2 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class RewritePatternSet;
class DominanceInfo;
namespace enzyme {
std::unique_ptr<Pass> createArithRaisingPass();
std::unique_ptr<Pass> createConsumingInterpreterPass();
std::unique_ptr<Pass> createEnzymeHLOOptPass();
std::unique_ptr<Pass> createEnzymeHLOUnrollPass();
std::unique_ptr<Pass> createPrintPass();
Expand Down Expand Up @@ -111,5 +112,6 @@ static void regsiterenzymeXLAPasses() {
registerEnzymeHLOOptPass();
registerEnzymeHLOUnrollPass();
registerLowerKernelPass();
registerConsumingInterpreterPass();
}
#endif // ENZYMEXLA_PASSES_H
9 changes: 9 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ def ArithRaisingPass : Pass<"arith-raise"> {
];
}

def ConsumingInterpreterPass : Pass<"enzyme-consuming-transform-interpreter"> {
let summary = "Run the transform interpreter and remove the script";
let constructor = "mlir::enzyme::createConsumingInterpreterPass()";
let description = [{
This pass isolates the transform script in a separate module, making it
possible to apply the script to the anchor operation of the pass.
}];
}

def EnzymeHLOOptPass : Pass<"enzyme-hlo-opt"> {
let summary = "Optimize stablehlo";
let dependentDialects = [
Expand Down
25 changes: 21 additions & 4 deletions src/enzyme_ad/jax/TransformOps/GenerateApplyPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,20 @@ LogicalResult parseTransform(OpBuilder &builder, Location loc,
}
}

OperationState state(loc,
"transform.apply_patterns.enzyme_hlo." + opName.str());
std::string potentialOpName =
"transform.apply_patterns.enzyme_hlo." + opName.str();
if (!RegisteredOperationName::lookup(potentialOpName,
builder.getContext())) {
potentialOpName = "transform.apply_patterns." + opName.str();
if (!RegisteredOperationName::lookup(potentialOpName,
builder.getContext())) {
return ::emitError(loc)
<< "couldn't find a pattern operation corresponding to "
<< opName;
}
}

OperationState state(loc, potentialOpName);
if (benefit != 1)
state.addAttribute("benefit", builder.getI64IntegerAttr(benefit));
if (parameter != -1)
Expand Down Expand Up @@ -166,11 +178,15 @@ class GenerateApplyPatternsPass
}

OpBuilder builder(&getContext());
builder.setInsertionPointToStart(&op->getRegion(0).front());
if (createModule) {
auto transformModule = builder.create<ModuleOp>(op->getLoc());
op = transformModule;
builder.setInsertionPointToStart(&op->getRegion(0).front());
}
op->setAttr(transform::TransformDialect::kWithNamedSequenceAttrName,
builder.getUnitAttr());

builder.setInsertionPointToStart(&op->getRegion(0).front());

if (!flags.empty()) {
llvm::APInt version(
llvm::APInt::getSufficientBitsNeeded(flags.getValue(), radix) + 1,
Expand All @@ -186,6 +202,7 @@ class GenerateApplyPatternsPass
Option<std::string> flags{*this, "flags", llvm::cl::init("")};
Option<int> radix{*this, "radix", llvm::cl::init(10)};
Option<std::string> patterns{*this, "patterns", llvm::cl::init("")};
Option<bool> createModule{*this, "create-module", llvm::cl::init(false)};
};

class RemoveTransform : public PassWrapper<RemoveTransform, OperationPass<>> {
Expand Down
26 changes: 26 additions & 0 deletions test/lit_tests/transform_in_function.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: enzymexlamlir-opt --enzyme='postpasses=enzyme-hlo-generate-td{create-module=true patterns=canonicalization<1>}' %s | FileCheck %s --check-prefixes=CHECK,GENERATE
// RUN: enzymexlamlir-opt --enzyme='postpasses=enzyme-hlo-generate-td{create-module=true patterns=canonicalization<1>},enzyme-consuming-transform-interpreter' %s | FileCheck %s --check-prefixes=CHECK,INTERP

// CHECK-LABEL: @square
func.func @square(%x: complex<f64>) -> complex<f64> {
%next = complex.mul %x, %x : complex<f64>
return %next : complex<f64>
}

// CHECK-LABEL: @dsquare
func.func @dsquare(%x: complex<f64>, %dx: complex<f64>) -> complex<f64> {
// CHECK: call @fwddiffesquare
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (complex<f64>, complex<f64>) -> complex<f64>
return %r : complex<f64>
}

// CHECK: func private @fwddiffesquare
// GENERATE: module attributes {transform.with_named_sequence}
// GENERATE: transform.named_sequence @__transform_main
// GENERATE: transform.apply_patterns
// GENERATE: transform.apply_patterns.canonicalization
// CHECK: complex.mul
// CHECK: complex.mul
// CHECK: complex.add
// GENERATE: complex.mul
// INTERP-NOT: complex.mul

0 comments on commit 7244109

Please sign in to comment.