From 72441096536b14812432995fbd0c6e3d88842e8d Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Fri, 3 Jan 2025 10:14:08 -0800 Subject: [PATCH] Transform interpreter runnable from within a function (#211) - Modify the TD generation pass to emit the transform module if requested by a flag. - Introduce a separate interpreter pass that moves the transform script into a separate module to allow it to be applied to the pass anchor operation, and then remove it entirely. --- src/enzyme_ad/jax/BUILD | 2 + .../jax/Passes/ConsumingInterpreterPass.cpp | 59 +++++++++++++++++++ src/enzyme_ad/jax/Passes/Passes.h | 2 + src/enzyme_ad/jax/Passes/Passes.td | 9 +++ .../TransformOps/GenerateApplyPatterns.cpp | 25 ++++++-- test/lit_tests/transform_in_function.mlir | 26 ++++++++ 6 files changed, 119 insertions(+), 4 deletions(-) create mode 100644 src/enzyme_ad/jax/Passes/ConsumingInterpreterPass.cpp create mode 100644 test/lit_tests/transform_in_function.mlir diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 4179f88e..435ac13e 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -326,6 +326,8 @@ cc_library( "@llvm-project//mlir:AffineToStandard", "@llvm-project//mlir:ArithToLLVM", "@llvm-project//mlir:MemRefTransforms", + "@llvm-project//mlir:TransformDialect", + "@llvm-project//mlir:TransformDialectTransforms", "@llvm-project//mlir:GPUToGPURuntimeTransforms", "@llvm-project//mlir:GPUCommonTransforms", "@llvm-project//mlir:GPUToNVVMTransforms", diff --git a/src/enzyme_ad/jax/Passes/ConsumingInterpreterPass.cpp b/src/enzyme_ad/jax/Passes/ConsumingInterpreterPass.cpp new file mode 100644 index 00000000..735ea847 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/ConsumingInterpreterPass.cpp @@ -0,0 +1,59 @@ +//===- ConsumingInterpreterPass.cpp - Interpret and remove transforms -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" +#include "mlir/Pass/Pass.h" +#include "src/enzyme_ad/jax/Passes/PassDetails.h" +#include "src/enzyme_ad/jax/Passes/Passes.h" + +using namespace mlir; +using namespace mlir::enzyme; + +namespace { +class ConsumingInterpreterPass + : public ConsumingInterpreterPassBase { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConsumingInterpreterPass) + + StringRef getArgument() const override { + return "enzyme-consuming-transform-interpreter"; + } + + void runOnOperation() override { + Operation *op = getOperation(); + Operation *entryPoint = + transform::detail::findTransformEntryPoint(op, nullptr); + if (!entryPoint) + return signalPassFailure(); + + auto transformModule = dyn_cast(entryPoint->getParentOp()); + if (!transformModule) { + emitError(entryPoint->getLoc()) + << "expected the transform entry point to be located in a module"; + return signalPassFailure(); + } + + transformModule->remove(); + OwningOpRef owningTransformModule(transformModule); + + RaggedArray bindings; + bindings.push_back(ArrayRef{op}); + if (failed(transform::applyTransformNamedSequence( + bindings, cast(entryPoint), + *owningTransformModule, + transform::TransformOptions().enableExpensiveChecks(true)))) + return signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr mlir::enzyme::createConsumingInterpreterPass() { + return std::make_unique(); +} diff --git a/src/enzyme_ad/jax/Passes/Passes.h b/src/enzyme_ad/jax/Passes/Passes.h index 8693981c..9bef7dda 100644 --- a/src/enzyme_ad/jax/Passes/Passes.h +++ b/src/enzyme_ad/jax/Passes/Passes.h @@ -18,6 +18,7 @@ class RewritePatternSet; class DominanceInfo; namespace enzyme { std::unique_ptr createArithRaisingPass(); +std::unique_ptr createConsumingInterpreterPass(); std::unique_ptr createEnzymeHLOOptPass(); std::unique_ptr createEnzymeHLOUnrollPass(); std::unique_ptr createPrintPass(); @@ -111,5 +112,6 @@ static void regsiterenzymeXLAPasses() { registerEnzymeHLOOptPass(); registerEnzymeHLOUnrollPass(); registerLowerKernelPass(); + registerConsumingInterpreterPass(); } #endif // ENZYMEXLA_PASSES_H diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index d1ce5c9a..d82ba89f 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -32,6 +32,15 @@ def ArithRaisingPass : Pass<"arith-raise"> { ]; } +def ConsumingInterpreterPass : Pass<"enzyme-consuming-transform-interpreter"> { + let summary = "Run the transform interpreter and remove the script"; + let constructor = "mlir::enzyme::createConsumingInterpreterPass()"; + let description = [{ + This pass isolates the transform script in a separate module, making it + possible to apply the script to the anchor operation of the pass. + }]; +} + def EnzymeHLOOptPass : Pass<"enzyme-hlo-opt"> { let summary = "Optimize stablehlo"; let dependentDialects = [ diff --git a/src/enzyme_ad/jax/TransformOps/GenerateApplyPatterns.cpp b/src/enzyme_ad/jax/TransformOps/GenerateApplyPatterns.cpp index 5cc7d1f7..e2a0138a 100644 --- a/src/enzyme_ad/jax/TransformOps/GenerateApplyPatterns.cpp +++ b/src/enzyme_ad/jax/TransformOps/GenerateApplyPatterns.cpp @@ -126,8 +126,20 @@ LogicalResult parseTransform(OpBuilder &builder, Location loc, } } - OperationState state(loc, - "transform.apply_patterns.enzyme_hlo." + opName.str()); + std::string potentialOpName = + "transform.apply_patterns.enzyme_hlo." + opName.str(); + if (!RegisteredOperationName::lookup(potentialOpName, + builder.getContext())) { + potentialOpName = "transform.apply_patterns." + opName.str(); + if (!RegisteredOperationName::lookup(potentialOpName, + builder.getContext())) { + return ::emitError(loc) + << "couldn't find a pattern operation corresponding to " + << opName; + } + } + + OperationState state(loc, potentialOpName); if (benefit != 1) state.addAttribute("benefit", builder.getI64IntegerAttr(benefit)); if (parameter != -1) @@ -166,11 +178,15 @@ class GenerateApplyPatternsPass } OpBuilder builder(&getContext()); + builder.setInsertionPointToStart(&op->getRegion(0).front()); + if (createModule) { + auto transformModule = builder.create(op->getLoc()); + op = transformModule; + builder.setInsertionPointToStart(&op->getRegion(0).front()); + } op->setAttr(transform::TransformDialect::kWithNamedSequenceAttrName, builder.getUnitAttr()); - builder.setInsertionPointToStart(&op->getRegion(0).front()); - if (!flags.empty()) { llvm::APInt version( llvm::APInt::getSufficientBitsNeeded(flags.getValue(), radix) + 1, @@ -186,6 +202,7 @@ class GenerateApplyPatternsPass Option flags{*this, "flags", llvm::cl::init("")}; Option radix{*this, "radix", llvm::cl::init(10)}; Option patterns{*this, "patterns", llvm::cl::init("")}; + Option createModule{*this, "create-module", llvm::cl::init(false)}; }; class RemoveTransform : public PassWrapper> { diff --git a/test/lit_tests/transform_in_function.mlir b/test/lit_tests/transform_in_function.mlir new file mode 100644 index 00000000..7406823d --- /dev/null +++ b/test/lit_tests/transform_in_function.mlir @@ -0,0 +1,26 @@ +// RUN: enzymexlamlir-opt --enzyme='postpasses=enzyme-hlo-generate-td{create-module=true patterns=canonicalization<1>}' %s | FileCheck %s --check-prefixes=CHECK,GENERATE +// RUN: enzymexlamlir-opt --enzyme='postpasses=enzyme-hlo-generate-td{create-module=true patterns=canonicalization<1>},enzyme-consuming-transform-interpreter' %s | FileCheck %s --check-prefixes=CHECK,INTERP + +// CHECK-LABEL: @square +func.func @square(%x: complex) -> complex { + %next = complex.mul %x, %x : complex + return %next : complex +} + +// CHECK-LABEL: @dsquare +func.func @dsquare(%x: complex, %dx: complex) -> complex { + // CHECK: call @fwddiffesquare + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme] } : (complex, complex) -> complex + return %r : complex +} + +// CHECK: func private @fwddiffesquare +// GENERATE: module attributes {transform.with_named_sequence} +// GENERATE: transform.named_sequence @__transform_main +// GENERATE: transform.apply_patterns +// GENERATE: transform.apply_patterns.canonicalization +// CHECK: complex.mul +// CHECK: complex.mul +// CHECK: complex.add +// GENERATE: complex.mul +// INTERP-NOT: complex.mul