Skip to content

Commit

Permalink
transform dialect for pattern combination
Browse files Browse the repository at this point in the history
  • Loading branch information
ftynse committed Mar 20, 2024
1 parent 2661d0b commit 5f74a83
Show file tree
Hide file tree
Showing 14 changed files with 706 additions and 22 deletions.
13 changes: 13 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ cc_binary(
visibility = ["//visibility:public"],
deps = [
"//src/enzyme_ad/jax:XLADerivatives",
"//src/enzyme_ad/jax:TransformOps",
"@enzyme//:EnzymeMLIR",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AllPassesAndDialects",
Expand All @@ -47,6 +48,18 @@ cc_binary(
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:TransformDialect",
],
)

cc_binary(
name = "enzymexlamlir-tblgen",
srcs = ["//src/enzyme_ad/tools:enzymexlamlir-tblgen.cpp"],
visibility = ["//visibility:public"],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//llvm:config",
],
)

Expand Down
88 changes: 88 additions & 0 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,68 @@ symlink_inputs(
}},
)

td_library(
name = "TransformOpsTdFiles",
srcs = [
"TransformOps/TransformOps.td",
],
deps = [
"@llvm-project//mlir:TransformDialectTdFiles",
]
)

gentbl_cc_library(
name = "TransformOpsIncGen",
tbl_outs = [(
["-gen-op-decls"],
"TransformOps/TransformOps.h.inc",
), (
["-gen-op-defs"],
"TransformOps/TransformOps.cpp.inc",
), (
["-gen-op-interface-decls"],
"TransformOps/OpInterfaces.h.inc",
), (
["-gen-op-interface-defs"],
"TransformOps/OpInterfaces.cpp.inc",
),
],
td_file = "TransformOps/TransformOps.td",
deps = [
":TransformOpsTdFiles",
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
)

gentbl_cc_library(
name = "TransformOpsImplIncGen",
tbl_outs = [(
["-gen-populate-patterns-interface-impl"],
"TransformOps/TransformOpsImpl.cpp.inc"
)],
td_file = "TransformOps/TransformOps.td",
deps = [
":TransformOpsTdFiles",
],
tblgen = "//:enzymexlamlir-tblgen",
)

cc_library(
name = "TransformOps",
srcs = glob(["TransformOps/*.cpp"]),
hdrs = glob(["TransformOps/*.h"]),
deps = [
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgTransformOps",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TransformDialect",
":TransformOpsIncGen",
":TransformOpsImplIncGen",
":XLADerivatives",
],
)

td_library(
name = "ImplementationsCommonTdFiles",
srcs = [
Expand Down Expand Up @@ -127,6 +189,23 @@ gentbl_cc_library(
deps = [":EnzymeXLAPassesTdFiles"],
)

gentbl_cc_library(
name = "EnzyeHLOPatternsIncGen",
tbl_outs = [
(
["-gen-populate-patterns-func-decls"],
"Passes/EnzymeHLOPatterns.h.inc",
), (
["-gen-populate-patterns-func-defs"],
"Passes/EnzymeHLOPatterns.cpp.inc",
)],
td_file = "TransformOps/TransformOps.td",
deps = [
":TransformOpsTdFiles",
],
tblgen = "//:enzymexlamlir-tblgen",
)

cc_library(
name = "XLADerivatives",
srcs = glob(
Expand All @@ -147,15 +226,19 @@ cc_library(
],
deps = [
":EnzymeXLAPassesIncGen",
":EnzyeHLOPatternsIncGen",
":mhlo-derivatives",
":stablehlo-derivatives",
"@enzyme//:EnzymeMLIR",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:CommonFolders",
"@llvm-project//mlir:ControlFlowInterfaces",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:LLVMCommonConversion",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
Expand All @@ -176,6 +259,7 @@ pybind_library(
]),
deps = [
":XLADerivatives",
":TransformOps",
# This is similar to xla_binary rule and is needed to make XLA client compile.
"@tsl//tsl/framework:allocator",
"@tsl//tsl/framework:allocator_registry_impl",
Expand All @@ -199,6 +283,7 @@ pybind_library(
"@xla//xla/client",
"@xla//xla/client:client_library",
"@xla//xla/client:executable_build_options",
"@xla//xla/client:local_client",
"@xla//xla/client:xla_computation",
"@xla//xla/service",
"@xla//xla/service:local_service",
Expand Down Expand Up @@ -240,9 +325,11 @@ pybind_library(
# MLIR dialects and parser.
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncExtensions",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:LinalgTransformOps",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@xla//xla/mlir_hlo:all_passes",
Expand All @@ -264,6 +351,7 @@ pybind_extension(
deps = [
":clang_compile",
":compile_with_xla",
":TransformOps",
"@com_google_absl//absl/status:statusor",
"@enzyme//:EnzymeMLIR",
"@enzyme//:EnzymeStatic",
Expand Down
26 changes: 16 additions & 10 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,20 @@
// ops.
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"

#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h"
#include "src/enzyme_ad/jax/Passes/PassDetails.h"
#include "src/enzyme_ad/jax/Passes/Passes.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"

#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/reference/Ops.h"
#include "stablehlo/transforms/Passes.h"

#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"

#define DEBUG_TYPE "enzyme"

Expand Down Expand Up @@ -3538,7 +3535,16 @@ struct ReshapeToSlice : public OpRewritePattern<stablehlo::SliceOp> {
return success();
}
};
} // namespace

#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.cpp.inc"

void mlir::transform::addPadDotGeneral(RewritePatternSet &patterns,
bool postPad, MLIRContext &context) {
patterns.insert<PadDotGeneral>(postPad, &context);
}

namespace {
struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {

void runOnOperation() override {
Expand Down
11 changes: 11 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
namespace mlir {
class RewritePatternSet;
class MLIRContext;
} // namespace mlir

#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h.inc"

namespace mlir::transform {
void addPadDotGeneral(RewritePatternSet &patterns, bool postPad,
MLIRContext &context);
}
145 changes: 145 additions & 0 deletions src/enzyme_ad/jax/TransformOps/GenerateApplyPatterns.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
//===- GenerateApplyPatterns.cpp - Generate transform scripts --------------- //
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Pass/Pass.h"
#include "src/enzyme_ad/jax/TransformOps/TransformOps.h"

using namespace mlir;

struct OpConfig {
OperationName name;
DictionaryAttr attrs;
};

void generatePatternGroup(OpBuilder &builder, Location loc, Value root,
ArrayRef<OpConfig> configurations,
llvm::APInt selectionBitmask) {
OpBuilder::InsertionGuard guard(builder);
auto apply = builder.create<transform::ApplyPatternsOp>(
loc, root, [](OpBuilder &builder, Location loc) {});
builder.setInsertionPointToStart(apply.getBody());
for (auto &&[i, opConfig] : llvm::enumerate(configurations)) {
if (selectionBitmask.extractBits(/*numBits=*/1, /*bitPosition=*/i).isZero())
continue;
OperationState state(loc, opConfig.name);
state.addAttributes(opConfig.attrs.getValue());
builder.create(state);
}
}

LogicalResult generateTransform(OpBuilder &builder, llvm::APInt version) {
auto loc = builder.getUnknownLoc();
auto namedSequence = builder.create<transform::NamedSequenceOp>(
loc, "__transform_main", builder.getType<transform::AnyOpType>(),
TypeRange(), [](OpBuilder &builder, Location loc, BlockArgument) {
builder.create<transform::YieldOp>(loc);
});

SmallVector<OpConfig> opConfigurations;
for (StringRef name : mlir::enzyme::getTransformOperationNames()) {
std::optional<RegisteredOperationName> opName =
RegisteredOperationName::lookup(name, builder.getContext());
if (!opName) {
return namedSequence->emitError() << "unregistered pattern op '" << name
<< "' listed for construction";
}
auto *concept =
opName->getInterface<SearchablePatternDescriptorOpInterface>();
for (DictionaryAttr attrs : concept->getPossibleAttrCombinations(builder)) {
opConfigurations.push_back(OpConfig{*opName, attrs});
}
}

builder.setInsertionPointToStart(&namedSequence.getBody().front());
auto match = builder.create<transform::MatchOp>(
loc, namedSequence.getBody().front().getArgument(0),
ArrayRef<StringRef>{func::FuncOp::getOperationName()});

auto configPow = llvm::APInt::getOneBitSet(opConfigurations.size() + 1,
opConfigurations.size());
do {
llvm::APInt configuration = version.srem(configPow);
generatePatternGroup(builder, loc, match, opConfigurations, configuration);
version = version.sdiv(configPow);
} while (!version.isZero());
return success();
}

namespace {
class GenerateApplyPatternsPass
: public PassWrapper<GenerateApplyPatternsPass, OperationPass<>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GenerateApplyPatternsPass)

GenerateApplyPatternsPass() = default;
GenerateApplyPatternsPass(const GenerateApplyPatternsPass &other)
: PassWrapper<GenerateApplyPatternsPass, OperationPass<>>(other) {}

StringRef getArgument() const override { return "enzyme-hlo-generate-td"; }

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<transform::TransformDialect>();
}

void runOnOperation() override {
Operation *op = getOperation();
if (op->getNumRegions() != 1 || !llvm::hasSingleElement(op->getRegion(0))) {
op->emitError()
<< "can only run on a single-region single-block operation";
return signalPassFailure();
}

llvm::APInt version(
llvm::APInt::getSufficientBitsNeeded(flags.getValue(), radix),
flags.getValue(), radix);

OpBuilder builder(&getContext());
op->setAttr(transform::TransformDialect::kWithNamedSequenceAttrName,
builder.getUnitAttr());

builder.setInsertionPointToStart(&op->getRegion(0).front());
if (failed(generateTransform(builder, version)))
return signalPassFailure();
}

Option<std::string> flags{*this, "flags", llvm::cl::init("")};
Option<int> radix{*this, "radix", llvm::cl::init(10)};
};

class RemoveTransform : public PassWrapper<RemoveTransform, OperationPass<>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RemoveTransform)

StringRef getArgument() const override {
return "enzyme-hlo-remove-transform";
}

void runOnOperation() override {
getOperation()->walk<WalkOrder::PreOrder>([&](Operation *op) {
if (isa<transform::TransformOpInterface>(op)) {
op->erase();
return WalkResult::skip();
}
return WalkResult::advance();
});
}
};
} // namespace

void mlir::enzyme::registerGenerateApplyPatternsPass() {
PassRegistration<GenerateApplyPatternsPass>();
}

void mlir::enzyme::registerRemoveTransformPass() {
PassRegistration<RemoveTransform>();
}
Loading

0 comments on commit 5f74a83

Please sign in to comment.