-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Transform interpreter runnable from within a function
- Modify the TD generation pass to emit the transform module if requested by a flag. - Introduce a separate interpreter pass that moves the transform script into a separate module to allow it to be applied to the pass anchor operation, and then remove it entirely.
- Loading branch information
Showing
6 changed files
with
119 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
//===- ConsumingInterpreterPass.cpp - Interpret and remove transforms -----===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Transform/IR/TransformDialect.h" | ||
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" | ||
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "src/enzyme_ad/jax/Passes/PassDetails.h" | ||
#include "src/enzyme_ad/jax/Passes/Passes.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::enzyme; | ||
|
||
namespace { | ||
class ConsumingInterpreterPass | ||
: public ConsumingInterpreterPassBase<ConsumingInterpreterPass> { | ||
public: | ||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConsumingInterpreterPass) | ||
|
||
StringRef getArgument() const override { | ||
return "enzyme-consuming-transform-interpreter"; | ||
} | ||
|
||
void runOnOperation() override { | ||
Operation *op = getOperation(); | ||
Operation *entryPoint = | ||
transform::detail::findTransformEntryPoint(op, nullptr); | ||
if (!entryPoint) | ||
return signalPassFailure(); | ||
|
||
auto transformModule = dyn_cast<ModuleOp>(entryPoint->getParentOp()); | ||
if (!transformModule) { | ||
emitError(entryPoint->getLoc()) | ||
<< "expected the transform entry point to be located in a module"; | ||
return signalPassFailure(); | ||
} | ||
|
||
transformModule->remove(); | ||
OwningOpRef<ModuleOp> owningTransformModule(transformModule); | ||
|
||
RaggedArray<transform::MappedValue> bindings; | ||
bindings.push_back(ArrayRef<Operation *>{op}); | ||
if (failed(transform::applyTransformNamedSequence( | ||
bindings, cast<transform::TransformOpInterface>(entryPoint), | ||
*owningTransformModule, | ||
transform::TransformOptions().enableExpensiveChecks(true)))) | ||
return signalPassFailure(); | ||
} | ||
}; | ||
} // namespace | ||
|
||
std::unique_ptr<Pass> mlir::enzyme::createConsumingInterpreterPass() { | ||
return std::make_unique<ConsumingInterpreterPass>(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// RUN: enzymexlamlir-opt --enzyme='postpasses=enzyme-hlo-generate-td{create-module=true patterns=canonicalization<1>}' %s | FileCheck %s --check-prefixes=CHECK,GENERATE | ||
// RUN: enzymexlamlir-opt --enzyme='postpasses=enzyme-hlo-generate-td{create-module=true patterns=canonicalization<1>},enzyme-consuming-transform-interpreter' %s | FileCheck %s --check-prefixes=CHECK,INTERP | ||
|
||
// CHECK-LABEL: @square | ||
func.func @square(%x: complex<f64>) -> complex<f64> { | ||
%next = complex.mul %x, %x : complex<f64> | ||
return %next : complex<f64> | ||
} | ||
|
||
// CHECK-LABEL: @dsquare | ||
func.func @dsquare(%x: complex<f64>, %dx: complex<f64>) -> complex<f64> { | ||
// CHECK: call @fwddiffesquare | ||
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (complex<f64>, complex<f64>) -> complex<f64> | ||
return %r : complex<f64> | ||
} | ||
|
||
// CHECK: func private @fwddiffesquare | ||
// GENERATE: module attributes {transform.with_named_sequence} | ||
// GENERATE: transform.named_sequence @__transform_main | ||
// GENERATE: transform.apply_patterns | ||
// GENERATE: transform.apply_patterns.canonicalization | ||
// CHECK: complex.mul | ||
// CHECK: complex.mul | ||
// CHECK: complex.add | ||
// GENERATE: complex.mul | ||
// INTERP-NOT: complex.mul |