From 5f74a835a5c895cf93b7a9626d22564fdfbb1592 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Tue, 19 Mar 2024 22:25:51 +0000 Subject: [PATCH] transform dialect for pattern combination --- BUILD | 13 ++ src/enzyme_ad/jax/BUILD | 88 ++++++++++ src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 26 +-- src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h | 11 ++ .../TransformOps/GenerateApplyPatterns.cpp | 145 ++++++++++++++++ .../jax/TransformOps/TransformOps.cpp | 63 +++++++ src/enzyme_ad/jax/TransformOps/TransformOps.h | 26 +++ .../jax/TransformOps/TransformOps.td | 155 ++++++++++++++++++ src/enzyme_ad/jax/compile_with_xla.cc | 4 + src/enzyme_ad/jax/enzyme_call.cc | 6 +- src/enzyme_ad/jax/enzymexlamlir-opt.cpp | 34 ++-- src/enzyme_ad/tools/BUILD | 8 + src/enzyme_ad/tools/enzymexlamlir-tblgen.cpp | 110 +++++++++++++ test/lit_tests/generate_td.mlir | 39 +++++ 14 files changed, 706 insertions(+), 22 deletions(-) create mode 100644 src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h create mode 100644 src/enzyme_ad/jax/TransformOps/GenerateApplyPatterns.cpp create mode 100644 src/enzyme_ad/jax/TransformOps/TransformOps.cpp create mode 100644 src/enzyme_ad/jax/TransformOps/TransformOps.h create mode 100644 src/enzyme_ad/jax/TransformOps/TransformOps.td create mode 100644 src/enzyme_ad/tools/BUILD create mode 100644 src/enzyme_ad/tools/enzymexlamlir-tblgen.cpp create mode 100644 test/lit_tests/generate_td.mlir diff --git a/BUILD b/BUILD index b6c12783f..789e4394c 100644 --- a/BUILD +++ b/BUILD @@ -27,6 +27,7 @@ cc_binary( visibility = ["//visibility:public"], deps = [ "//src/enzyme_ad/jax:XLADerivatives", + "//src/enzyme_ad/jax:TransformOps", "@enzyme//:EnzymeMLIR", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AllPassesAndDialects", @@ -47,6 +48,18 @@ cc_binary( "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:TransformDialect", + ], +) + +cc_binary( + name = "enzymexlamlir-tblgen", + srcs = ["//src/enzyme_ad/tools:enzymexlamlir-tblgen.cpp"], + visibility = ["//visibility:public"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TableGen", + "@llvm-project//llvm:config", ], ) diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index a403b9f29..7653004c9 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -56,6 +56,68 @@ symlink_inputs( }}, ) +td_library( + name = "TransformOpsTdFiles", + srcs = [ + "TransformOps/TransformOps.td", + ], + deps = [ + "@llvm-project//mlir:TransformDialectTdFiles", + ] +) + +gentbl_cc_library( + name = "TransformOpsIncGen", + tbl_outs = [( + ["-gen-op-decls"], + "TransformOps/TransformOps.h.inc", + ), ( + ["-gen-op-defs"], + "TransformOps/TransformOps.cpp.inc", + ), ( + ["-gen-op-interface-decls"], + "TransformOps/OpInterfaces.h.inc", + ), ( + ["-gen-op-interface-defs"], + "TransformOps/OpInterfaces.cpp.inc", + ), + ], + td_file = "TransformOps/TransformOps.td", + deps = [ + ":TransformOpsTdFiles", + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", +) + +gentbl_cc_library( + name = "TransformOpsImplIncGen", + tbl_outs = [( + ["-gen-populate-patterns-interface-impl"], + "TransformOps/TransformOpsImpl.cpp.inc" + )], + td_file = "TransformOps/TransformOps.td", + deps = [ + ":TransformOpsTdFiles", + ], + tblgen = "//:enzymexlamlir-tblgen", +) + +cc_library( + name = "TransformOps", + srcs = glob(["TransformOps/*.cpp"]), + hdrs = glob(["TransformOps/*.h"]), + deps = [ + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgTransformOps", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformDialect", + ":TransformOpsIncGen", + ":TransformOpsImplIncGen", + ":XLADerivatives", + ], +) + td_library( name = "ImplementationsCommonTdFiles", srcs = [ @@ -127,6 +189,23 @@ gentbl_cc_library( deps = [":EnzymeXLAPassesTdFiles"], ) +gentbl_cc_library( + name = "EnzyeHLOPatternsIncGen", + tbl_outs = [ + ( + ["-gen-populate-patterns-func-decls"], + "Passes/EnzymeHLOPatterns.h.inc", + ), ( + ["-gen-populate-patterns-func-defs"], + "Passes/EnzymeHLOPatterns.cpp.inc", + )], + td_file = "TransformOps/TransformOps.td", + deps = [ + ":TransformOpsTdFiles", + ], + tblgen = "//:enzymexlamlir-tblgen", +) + cc_library( name = "XLADerivatives", srcs = glob( @@ -147,15 +226,19 @@ cc_library( ], deps = [ ":EnzymeXLAPassesIncGen", + ":EnzyeHLOPatternsIncGen", ":mhlo-derivatives", ":stablehlo-derivatives", "@enzyme//:EnzymeMLIR", + "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:CommonFolders", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", @@ -176,6 +259,7 @@ pybind_library( ]), deps = [ ":XLADerivatives", + ":TransformOps", # This is similar to xla_binary rule and is needed to make XLA client compile. "@tsl//tsl/framework:allocator", "@tsl//tsl/framework:allocator_registry_impl", @@ -199,6 +283,7 @@ pybind_library( "@xla//xla/client", "@xla//xla/client:client_library", "@xla//xla/client:executable_build_options", + "@xla//xla/client:local_client", "@xla//xla/client:xla_computation", "@xla//xla/service", "@xla//xla/service:local_service", @@ -240,9 +325,11 @@ pybind_library( # MLIR dialects and parser. "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:LinalgTransformOps", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@xla//xla/mlir_hlo:all_passes", @@ -264,6 +351,7 @@ pybind_extension( deps = [ ":clang_compile", ":compile_with_xla", + ":TransformOps", "@com_google_absl//absl/status:statusor", "@enzyme//:EnzymeMLIR", "@enzyme//:EnzymeStatic", diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 486eb6087..1ae64fd95 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -10,23 +10,20 @@ // ops. //===----------------------------------------------------------------------===// +#include "mlir/Dialect/CommonFolders.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/IRMapping.h" - +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h" #include "src/enzyme_ad/jax/Passes/PassDetails.h" #include "src/enzyme_ad/jax/Passes/Passes.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" - #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/reference/Ops.h" #include "stablehlo/transforms/Passes.h" - -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "mlir/Dialect/CommonFolders.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #define DEBUG_TYPE "enzyme" @@ -3538,7 +3535,16 @@ struct ReshapeToSlice : public OpRewritePattern { return success(); } }; +} // namespace + +#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.cpp.inc" +void mlir::transform::addPadDotGeneral(RewritePatternSet &patterns, + bool postPad, MLIRContext &context) { + patterns.insert(postPad, &context); +} + +namespace { struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { void runOnOperation() override { diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h b/src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h new file mode 100644 index 000000000..1426aa34b --- /dev/null +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h @@ -0,0 +1,11 @@ +namespace mlir { +class RewritePatternSet; +class MLIRContext; +} // namespace mlir + +#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h.inc" + +namespace mlir::transform { +void addPadDotGeneral(RewritePatternSet &patterns, bool postPad, + MLIRContext &context); +} diff --git a/src/enzyme_ad/jax/TransformOps/GenerateApplyPatterns.cpp b/src/enzyme_ad/jax/TransformOps/GenerateApplyPatterns.cpp new file mode 100644 index 000000000..993b7c239 --- /dev/null +++ b/src/enzyme_ad/jax/TransformOps/GenerateApplyPatterns.cpp @@ -0,0 +1,145 @@ +//===- GenerateApplyPatterns.cpp - Generate transform scripts --------------- // +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/Pass/Pass.h" +#include "src/enzyme_ad/jax/TransformOps/TransformOps.h" + +using namespace mlir; + +struct OpConfig { + OperationName name; + DictionaryAttr attrs; +}; + +void generatePatternGroup(OpBuilder &builder, Location loc, Value root, + ArrayRef configurations, + llvm::APInt selectionBitmask) { + OpBuilder::InsertionGuard guard(builder); + auto apply = builder.create( + loc, root, [](OpBuilder &builder, Location loc) {}); + builder.setInsertionPointToStart(apply.getBody()); + for (auto &&[i, opConfig] : llvm::enumerate(configurations)) { + if (selectionBitmask.extractBits(/*numBits=*/1, /*bitPosition=*/i).isZero()) + continue; + OperationState state(loc, opConfig.name); + state.addAttributes(opConfig.attrs.getValue()); + builder.create(state); + } +} + +LogicalResult generateTransform(OpBuilder &builder, llvm::APInt version) { + auto loc = builder.getUnknownLoc(); + auto namedSequence = builder.create( + loc, "__transform_main", builder.getType(), + TypeRange(), [](OpBuilder &builder, Location loc, BlockArgument) { + builder.create(loc); + }); + + SmallVector opConfigurations; + for (StringRef name : mlir::enzyme::getTransformOperationNames()) { + std::optional opName = + RegisteredOperationName::lookup(name, builder.getContext()); + if (!opName) { + return namedSequence->emitError() << "unregistered pattern op '" << name + << "' listed for construction"; + } + auto *concept = + opName->getInterface(); + for (DictionaryAttr attrs : concept->getPossibleAttrCombinations(builder)) { + opConfigurations.push_back(OpConfig{*opName, attrs}); + } + } + + builder.setInsertionPointToStart(&namedSequence.getBody().front()); + auto match = builder.create( + loc, namedSequence.getBody().front().getArgument(0), + ArrayRef{func::FuncOp::getOperationName()}); + + auto configPow = llvm::APInt::getOneBitSet(opConfigurations.size() + 1, + opConfigurations.size()); + do { + llvm::APInt configuration = version.srem(configPow); + generatePatternGroup(builder, loc, match, opConfigurations, configuration); + version = version.sdiv(configPow); + } while (!version.isZero()); + return success(); +} + +namespace { +class GenerateApplyPatternsPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GenerateApplyPatternsPass) + + GenerateApplyPatternsPass() = default; + GenerateApplyPatternsPass(const GenerateApplyPatternsPass &other) + : PassWrapper>(other) {} + + StringRef getArgument() const override { return "enzyme-hlo-generate-td"; } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + Operation *op = getOperation(); + if (op->getNumRegions() != 1 || !llvm::hasSingleElement(op->getRegion(0))) { + op->emitError() + << "can only run on a single-region single-block operation"; + return signalPassFailure(); + } + + llvm::APInt version( + llvm::APInt::getSufficientBitsNeeded(flags.getValue(), radix), + flags.getValue(), radix); + + OpBuilder builder(&getContext()); + op->setAttr(transform::TransformDialect::kWithNamedSequenceAttrName, + builder.getUnitAttr()); + + builder.setInsertionPointToStart(&op->getRegion(0).front()); + if (failed(generateTransform(builder, version))) + return signalPassFailure(); + } + + Option flags{*this, "flags", llvm::cl::init("")}; + Option radix{*this, "radix", llvm::cl::init(10)}; +}; + +class RemoveTransform : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RemoveTransform) + + StringRef getArgument() const override { + return "enzyme-hlo-remove-transform"; + } + + void runOnOperation() override { + getOperation()->walk([&](Operation *op) { + if (isa(op)) { + op->erase(); + return WalkResult::skip(); + } + return WalkResult::advance(); + }); + } +}; +} // namespace + +void mlir::enzyme::registerGenerateApplyPatternsPass() { + PassRegistration(); +} + +void mlir::enzyme::registerRemoveTransformPass() { + PassRegistration(); +} diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.cpp b/src/enzyme_ad/jax/TransformOps/TransformOps.cpp new file mode 100644 index 000000000..a891f04a8 --- /dev/null +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.cpp @@ -0,0 +1,63 @@ +//===- TransformOps.cpp - Definition of transform extension ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/enzyme_ad/jax/TransformOps/TransformOps.h" + +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h" +#include "src/enzyme_ad/jax/TransformOps/OpInterfaces.cpp.inc" + +#define GET_OP_CLASSES +#include "src/enzyme_ad/jax/TransformOps/TransformOps.cpp.inc" +#include "src/enzyme_ad/jax/TransformOps/TransformOpsImpl.cpp.inc" + +using namespace mlir; +using namespace mlir::enzyme; + +namespace mlir { +namespace transform { + +void ApplyPadDotGeneralPatterns::populatePatterns(RewritePatternSet &patterns) { + addPadDotGeneral(patterns, getPostPad(), *getContext()); +} + +} // namespace transform +} // namespace mlir + +namespace { +class EnzymeJaxTransformExtension + : public transform::TransformDialectExtension { +public: + using Base::Base; + + void init() { + registerTransformOps< +#define GET_OP_LIST +#include "src/enzyme_ad/jax/TransformOps/TransformOps.cpp.inc" + >(); + } +}; +} // namespace + +void mlir::enzyme::registerEnzymeJaxTransformExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +} + +template static SmallVector extractNames() { + return {OpType::getOperationName()...}; +} + +SmallVector mlir::enzyme::getTransformOperationNames() { + return extractNames< +#define GET_OP_LIST +#include "src/enzyme_ad/jax/TransformOps/TransformOps.cpp.inc" + >(); +} diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.h b/src/enzyme_ad/jax/TransformOps/TransformOps.h new file mode 100644 index 000000000..f134b1d94 --- /dev/null +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.h @@ -0,0 +1,26 @@ +//===- TransformOps.h - Declarations of Transform extension -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "src/enzyme_ad/jax/TransformOps/OpInterfaces.h.inc" + +#define GET_OP_CLASSES +#include "src/enzyme_ad/jax/TransformOps/TransformOps.h.inc" + +namespace mlir { +namespace enzyme { +void registerEnzymeJaxTransformExtension(mlir::DialectRegistry ®istry); + +SmallVector getTransformOperationNames(); + +void registerGenerateApplyPatternsPass(); +void registerRemoveTransformPass(); +} // namespace enzyme +} // namespace mlir diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td new file mode 100644 index 000000000..606be8bf7 --- /dev/null +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -0,0 +1,155 @@ +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" + +def SearchablePatternDescriptorOpInterface : + OpInterface<"SearchablePatternDescriptorOpInterface", + [PatternDescriptorOpInterface]> { + let methods = [ + StaticInterfaceMethod< + [{Generates possible attribute combinations for opaque op construction.}], + "::llvm::SmallVector<::mlir::DictionaryAttr>", + "getPossibleAttrCombinations", + (ins "::mlir::Builder &":$builder), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return {builder.getDictionaryAttr({})}; + }]>, + ]; +} + +class EnzymeHLOPatternOp traits = []> + : Op, + DeclareOpInterfaceMethods], + traits)> { + list patterns = []; + let assemblyFormat = "attr-dict"; +} + +class EnzymeHLOParameterizedPatternOp traits = []> + : Op], + traits)> { +} + +def ApplyTransposeConcatReshapePatterns : EnzymeHLOPatternOp< + "transpose_concat_reshape"> { + let patterns = [ + "TransposeDotReorder", + "DotTranspose", + "ConvertConvertFloat", + "ConcatToPad", + "ConcatAppendingReshape", + "ReshapeIota", + "ReshapePad", + ]; +} + +def ApplyFullReduceReshapeOrTransposePatterns : EnzymeHLOPatternOp< + "full_reduce_reshape_or_transpose"> { + let patterns = [ + "FullReduceReshapeOrTranspose", + ]; +} + +def ApplySliceTransposeOrBroadcastPatterns : EnzymeHLOPatternOp< + "slice_transpose_or_broadcast"> { + let patterns = [ + "SliceTranspose", + "SliceBroadcast", + ]; +} + +def ApplyReducePadPatterns : EnzymeHLOPatternOp< + "reduce_pad"> { + let patterns = [ + "ReducePad", + ]; +} + +def ApplyZeroPadPatterns : EnzymeHLOPatternOp< + "zero_pad"> { + let patterns = [ + "MulZeroPad", + "DivZeroPad", + ]; +} + +def ApplyBinopConstPadPatterns : EnzymeHLOPatternOp< + "binop_const_pad"> { + let patterns = [ + "BinopConstPad", + "BinopConstPad", + "BinopConstPad", + "BinopConstPad", + ]; +} + +def ApplyBinopBinopPadPatterns : EnzymeHLOPatternOp< + "binop_binop_pad"> { + let patterns = [ + "BinopBinopPadPad", + "BinopBinopPadPad", + ]; +} + +def ApplyUnaryPushPadPatterns : EnzymeHLOPatternOp< + "unary_push_pad"> { + let patterns = [ + "UnaryPadPush", + "UnaryPadPush", + "UnaryPadPush", + ]; +} + +def ApplyTransposePadPatterns : EnzymeHLOPatternOp< + "transpose_pad"> { + let patterns = [ + "TransposePad", + ]; +} + +def ApplyTransposeConvertPatterns : EnzymeHLOPatternOp< + "transpose_convert"> { + let patterns = [ + "TransposeConvert", + ]; +} + +def ApplyTransposeTransposePatterns : EnzymeHLOPatternOp< + "transpose_transpose"> { + let patterns = [ + "TransposeTranspose", + ]; +} + +def ApplyBroadcastReducePatterns : EnzymeHLOPatternOp< + "broadcast_reduce"> { + let patterns = [ + "BroadcastReduce", + ]; +} + +def ApplyPadDotGeneralPatterns : EnzymeHLOParameterizedPatternOp< + "pad_dot_general"> { + let arguments = (ins BoolAttr:$postPad); + let assemblyFormat = "`postPad` `=` $postPad attr-dict"; + // TODO: the following can be automated by tablegen or some sort of + // lighter-weight introspection of searchable attributes. + let extraClassDeclaration = [{ + ::llvm::SmallVector<::mlir::DictionaryAttr> + static getPossibleAttrCombinations(::mlir::Builder &builder) { + return {builder.getDictionaryAttr( + builder.getNamedAttr("postPad", builder.getBoolAttr(true))), + builder.getDictionaryAttr( + builder.getNamedAttr("postPad", builder.getBoolAttr(false)))}; + } + }]; +} diff --git a/src/enzyme_ad/jax/compile_with_xla.cc b/src/enzyme_ad/jax/compile_with_xla.cc index ca7220149..7404a9321 100644 --- a/src/enzyme_ad/jax/compile_with_xla.cc +++ b/src/enzyme_ad/jax/compile_with_xla.cc @@ -18,6 +18,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/PassManager.h" @@ -43,6 +44,7 @@ #include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h" #include "Implementations/XLADerivatives.h" +#include "TransformOps/TransformOps.h" #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" @@ -51,6 +53,8 @@ void prepareRegistry(mlir::DialectRegistry ®istry) { mlir::enzyme::registerCoreDialectAutodiffInterfaces(registry); mlir::enzyme::registerXLAAutoDiffInterfaces(registry); + mlir::linalg::registerTransformDialectExtension(registry); + mlir::enzyme::registerEnzymeJaxTransformExtension(registry); mlir::func::registerInlinerExtension(registry); } diff --git a/src/enzyme_ad/jax/enzyme_call.cc b/src/enzyme_ad/jax/enzyme_call.cc index 190c104a7..8afaf985b 100644 --- a/src/enzyme_ad/jax/enzyme_call.cc +++ b/src/enzyme_ad/jax/enzyme_call.cc @@ -6,6 +6,9 @@ // //===----------------------------------------------------------------------===// +#include "src/enzyme_ad/jax/Passes/Passes.h" +#include "src/enzyme_ad/jax/TransformOps/TransformOps.h" + #include #include #include @@ -56,7 +59,6 @@ #include "Enzyme/FunctionUtils.h" #include "Enzyme/MLIR/Passes/Passes.h" -#include "src/enzyme_ad/jax/Passes/Passes.h" #include "stablehlo/transforms/Passes.h" enum class ABI { Primal, Forward, Augmented, Reverse, Tape }; @@ -1027,6 +1029,8 @@ PYBIND11_MODULE(enzyme_call, m) { mlir::memref::registerMemRefPasses(); mlir::registerenzymePasses(); regsiterenzymeXLAPasses(); + mlir::enzyme::registerGenerateApplyPatternsPass(); + mlir::enzyme::registerRemoveTransformPass(); mlir::stablehlo::registerPasses(); pybind11::enum_(m, "Language") diff --git a/src/enzyme_ad/jax/enzymexlamlir-opt.cpp b/src/enzyme_ad/jax/enzymexlamlir-opt.cpp index 7d748d5b1..57f11e11f 100644 --- a/src/enzyme_ad/jax/enzymexlamlir-opt.cpp +++ b/src/enzyme_ad/jax/enzymexlamlir-opt.cpp @@ -11,6 +11,12 @@ // //===----------------------------------------------------------------------===// +#include "Enzyme/MLIR/Dialect/Dialect.h" +#include "Enzyme/MLIR/Dialect/Ops.h" +#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Enzyme/MLIR/Passes/Passes.h" +#include "Implementations/XLADerivatives.h" +#include "Passes/Passes.h" #include "mlir/Conversion/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -23,30 +29,29 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Transform/Transforms/Passes.h" #include "mlir/InitAllPasses.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "mlir/Transforms/Passes.h" - -#include "Enzyme/MLIR/Dialect/Dialect.h" -#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h" -#include "Enzyme/MLIR/Passes/Passes.h" - -#include "Enzyme/MLIR/Dialect/Ops.h" - -#include "Implementations/XLADerivatives.h" -#include "Passes/Passes.h" - -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "stablehlo/dialect/StablehloOps.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" using namespace mlir; +namespace mlir { +namespace enzyme { +void registerEnzymeJaxTransformExtension(mlir::DialectRegistry ®istry); +void registerGenerateApplyPatternsPass(); +void registerRemoveTransformPass(); +} // namespace enzyme +} // namespace mlir + class MemRefInsider : public mlir::MemRefElementTypeInterface::FallbackModel {}; @@ -114,6 +119,13 @@ int main(int argc, char **argv) { // Register the autodiff interface implementations for upstream dialects. enzyme::registerCoreDialectAutodiffInterfaces(registry); + // Transform dialect and extensions. + mlir::transform::registerInterpreterPass(); + mlir::linalg::registerTransformDialectExtension(registry); + mlir::enzyme::registerGenerateApplyPatternsPass(); + mlir::enzyme::registerRemoveTransformPass(); + mlir::enzyme::registerEnzymeJaxTransformExtension(registry); + return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "Enzyme modular optimizer driver", registry)); } diff --git a/src/enzyme_ad/tools/BUILD b/src/enzyme_ad/tools/BUILD new file mode 100644 index 000000000..4ac5e4b26 --- /dev/null +++ b/src/enzyme_ad/tools/BUILD @@ -0,0 +1,8 @@ +exports_files(["enzymexlamlir-tblgen.cpp"]) + +licenses(["notice"]) + +package( + default_visibility = ["//:__subpackages__"], + features = ["layering_check"], +) diff --git a/src/enzyme_ad/tools/enzymexlamlir-tblgen.cpp b/src/enzyme_ad/tools/enzymexlamlir-tblgen.cpp new file mode 100644 index 000000000..53bda3e2a --- /dev/null +++ b/src/enzyme_ad/tools/enzymexlamlir-tblgen.cpp @@ -0,0 +1,110 @@ +//===- enzymexlamlir-tblgen.cpp - Tablegen backend for EnzymeJAX ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/PrettyStackTrace.h" +#include "llvm/Support/Signals.h" +#include "llvm/TableGen/Main.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +enum ActionType { + GenPopulatePatternsFuncDecl, + GenPopulatePatternsFuncDef, + GenPopulatePatternsInterfaceImpl, +}; + +static llvm::cl::opt action( + llvm::cl::desc("action to perform"), + llvm::cl::values(clEnumValN(GenPopulatePatternsFuncDecl, + "gen-populate-patterns-func-decls", "")), + llvm::cl::values(clEnumValN(GenPopulatePatternsFuncDef, + "gen-populate-patterns-func-defs", "")), + llvm::cl::values(clEnumValN(GenPopulatePatternsInterfaceImpl, + "gen-populate-patterns-interface-impl", ""))); + +llvm::StringRef getPopulateFunctionNameSuffix(llvm::Record *rec) { + return rec->getName().ends_with("Op") ? rec->getName().drop_back(2) + : rec->getName(); +} + +static bool emitPopulatePatterns(llvm::raw_ostream &os, + llvm::RecordKeeper &records) { + for (llvm::Record *rec : + records.getAllDerivedDefinitions("EnzymeHLOPatternOp")) { + os << "void "; + llvm::StringRef ns = rec->getValueAsString("cppNamespace"); + if (!ns.empty()) + os << ns << "::"; + os << rec->getName() + << "::populatePatterns(::mlir::RewritePatternSet &patterns) {\n"; + os << " " << ns << "::populate" << getPopulateFunctionNameSuffix(rec) + << "(patterns, *getContext());\n"; + os << "}\n\n"; + } + return false; +} + +static bool emitPopulatePatternsFuncDecls(llvm::raw_ostream &os, + llvm::RecordKeeper &records) { + for (llvm::Record *rec : + records.getAllDerivedDefinitions("EnzymeHLOPatternOp")) { + llvm::StringRef ns = rec->getValueAsString("cppNamespace"); + if (ns.starts_with("::")) + ns = ns.drop_front(2); + os << "namespace " << ns << " {\n"; + os << "void populate" << getPopulateFunctionNameSuffix(rec) + << "(::mlir::RewritePatternSet &patterns, ::mlir::MLIRContext " + "&context);\n"; + os << "} // namespace " << ns << "\n\n"; + } + return false; +} + +static bool emitPopulatePatternsFuncDefs(llvm::raw_ostream &os, + llvm::RecordKeeper &records) { + for (llvm::Record *rec : + records.getAllDerivedDefinitions("EnzymeHLOPatternOp")) { + os << "void "; + llvm::StringRef ns = rec->getValueAsString("cppNamespace"); + if (!ns.empty()) + os << ns; + os << "::populate" << getPopulateFunctionNameSuffix(rec) + << "(::mlir::RewritePatternSet &patterns,\n" + << " ::mlir::MLIRContext &context) {\n"; + + for (llvm::StringRef pattern : rec->getValueAsListOfStrings("patterns")) { + os << " patterns.add<" << pattern << ">(&context);\n"; + } + os << "}\n\n"; + } + return false; +} + +static bool tablegenMain(llvm::raw_ostream &os, llvm::RecordKeeper &records) { + switch (action) { + case GenPopulatePatternsFuncDecl: + return emitPopulatePatternsFuncDecls(os, records); + case GenPopulatePatternsFuncDef: + return emitPopulatePatternsFuncDefs(os, records); + case GenPopulatePatternsInterfaceImpl: + return emitPopulatePatterns(os, records); + default: + llvm::report_fatal_error("unknown action"); + return true; + } +} + +int main(int argc, char **argv) { + llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); + llvm::PrettyStackTraceProgram X(argc, argv); + llvm::cl::ParseCommandLineOptions(argc, argv); + + llvm::llvm_shutdown_obj Y; + return TableGenMain(argv[0], &tablegenMain); +} diff --git a/test/lit_tests/generate_td.mlir b/test/lit_tests/generate_td.mlir new file mode 100644 index 000000000..c947768ad --- /dev/null +++ b/test/lit_tests/generate_td.mlir @@ -0,0 +1,39 @@ +// RUN: enzymexlamlir-opt %s --enzyme-hlo-generate-td='flags=4' | FileCheck %s --check-prefixes=TD,FL4 +// RUN: enzymexlamlir-opt %s --enzyme-hlo-generate-td='flags=64' | FileCheck %s --check-prefixes=TD,FL64 +// RUN: enzymexlamlir-opt %s --enzyme-hlo-generate-td='flags=68' | FileCheck %s --check-prefixes=TD,FL4,FL64 +// RUN: enzymexlamlir-opt %s --enzyme-hlo-generate-td='flags=4' --transform-interpreter | FileCheck %s --check-prefixes=INTERPCOMMON,INTERP4 +// RUN: enzymexlamlir-opt %s --enzyme-hlo-generate-td='flags=64' --transform-interpreter | FileCheck %s --check-prefixes=INTERPCOMMON,INTERP64 +// RUN: enzymexlamlir-opt %s --enzyme-hlo-generate-td='flags=68' --transform-interpreter | FileCheck %s --check-prefixes=INTERPCOMMON,INTERP4,INTERP64 +// RUN: enzymexlamlir-opt %s --enzyme-hlo-generate-td='flags=68' --transform-interpreter --enzyme-hlo-remove-transform | FileCheck %s --check-prefixes=CLEAN + +// TD: module attributes {transform.with_named_sequence} { +// TD: transform.named_sequence @__transform_main(%[[ROOT:.+]]: !transform.any_op) { +// CLEAN-NOT: transform.named_sequence +// TD: %[[FUNC:.+]] = transform.structured.match ops{["func.func"]} in %[[ROOT]] : (!transform.any_op) -> !transform.any_op +// TD: transform.apply_patterns to %[[FUNC]] { +// FL4: transform.apply_patterns.enzyme_hlo.broadcast_reduce +// FL64: transform.apply_patterns.enzyme_hlo.reduce_pad +// TD: } : !transform.any_op +// TD: transform.yield +// TD: } + +// INTERPCOMMON-LABEL: @broadcastreduce +// INTERP4: reduce +// INTERP4: convert +// INTERP4: multiply +func.func @broadcastreduce(%154: tensor<1x3072xf32>, %151: tensor) -> tensor { + %211 = stablehlo.broadcast_in_dim %154, dims = [0, 1] : (tensor<1x3072xf32>) -> tensor<1x3072x32xf32> + %212 = stablehlo.reduce(%211 init: %151) applies stablehlo.add across dimensions = [0, 1, 2] : (tensor<1x3072x32xf32>, tensor) -> tensor + return %212 : tensor +} + +// INTERPCOMMON-LABEL: @reducepad +// INTERP64: reduce +// INTERP64: pad +func.func @reducepad(%a : tensor<2x3x1xf32>, %b : tensor) -> tensor<6x1xf32> { + %pv = stablehlo.constant dense<0.000000e+00> : tensor + %pad = stablehlo.pad %a, %pv, low = [1, 2, 0], high = [3, 4, 0], interior = [0, 1, 0] : (tensor<2x3x1xf32>, tensor) -> tensor<6x11x1xf32> + %conv = stablehlo.reduce(%pad init: %b) applies stablehlo.add across dimensions = [1] : (tensor<6x11x1xf32>, tensor) -> tensor<6x1xf32> + return %conv : tensor<6x1xf32> +} +