Skip to content

Commit

Permalink
transform dialect for pattern combination
Browse files Browse the repository at this point in the history
  • Loading branch information
ftynse committed Mar 27, 2024
1 parent cdcab89 commit ff65836
Show file tree
Hide file tree
Showing 14 changed files with 709 additions and 23 deletions.
17 changes: 15 additions & 2 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ cc_binary(
srcs = ["//src/enzyme_ad/jax:enzymexlamlir-opt.cpp"],
visibility = ["//visibility:public"],
deps = [
"//src/enzyme_ad/jax:XLADerivatives",
"@enzyme//:EnzymeMLIR",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AllPassesAndDialects",
Expand All @@ -37,20 +36,34 @@ cc_binary(
"@llvm-project//mlir:DLTIDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:OpenMPDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:Transforms",
"//src/enzyme_ad/jax:TransformOps",
"//src/enzyme_ad/jax:XLADerivatives",
"@stablehlo//:chlo_ops",
],
)

cc_binary(
name = "enzymexlamlir-tblgen",
srcs = ["//src/enzyme_ad/tools:enzymexlamlir-tblgen.cpp"],
visibility = ["//visibility:public"],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//llvm:config",
],
)

py_wheel(
name = "enzyme_ad",
author = "Enzyme Authors",
Expand Down
88 changes: 88 additions & 0 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,68 @@ symlink_inputs(
}},
)

td_library(
name = "TransformOpsTdFiles",
srcs = [
"TransformOps/TransformOps.td",
],
deps = [
"@llvm-project//mlir:TransformDialectTdFiles",
]
)

gentbl_cc_library(
name = "TransformOpsIncGen",
tbl_outs = [(
["-gen-op-decls"],
"TransformOps/TransformOps.h.inc",
), (
["-gen-op-defs"],
"TransformOps/TransformOps.cpp.inc",
), (
["-gen-op-interface-decls"],
"TransformOps/OpInterfaces.h.inc",
), (
["-gen-op-interface-defs"],
"TransformOps/OpInterfaces.cpp.inc",
),
],
td_file = "TransformOps/TransformOps.td",
deps = [
":TransformOpsTdFiles",
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
)

gentbl_cc_library(
name = "TransformOpsImplIncGen",
tbl_outs = [(
["-gen-populate-patterns-interface-impl"],
"TransformOps/TransformOpsImpl.cpp.inc"
)],
td_file = "TransformOps/TransformOps.td",
deps = [
":TransformOpsTdFiles",
],
tblgen = "//:enzymexlamlir-tblgen",
)

cc_library(
name = "TransformOps",
srcs = glob(["TransformOps/*.cpp"]),
hdrs = glob(["TransformOps/*.h"]),
deps = [
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgTransformOps",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TransformDialect",
":TransformOpsIncGen",
":TransformOpsImplIncGen",
":XLADerivatives",
],
)

td_library(
name = "ImplementationsCommonTdFiles",
srcs = [
Expand Down Expand Up @@ -127,6 +189,23 @@ gentbl_cc_library(
deps = [":EnzymeXLAPassesTdFiles"],
)

gentbl_cc_library(
name = "EnzyeHLOPatternsIncGen",
tbl_outs = [
(
["-gen-populate-patterns-func-decls"],
"Passes/EnzymeHLOPatterns.h.inc",
), (
["-gen-populate-patterns-func-defs"],
"Passes/EnzymeHLOPatterns.cpp.inc",
)],
td_file = "TransformOps/TransformOps.td",
deps = [
":TransformOpsTdFiles",
],
tblgen = "//:enzymexlamlir-tblgen",
)

cc_library(
name = "XLADerivatives",
srcs = glob(
Expand All @@ -147,15 +226,19 @@ cc_library(
],
deps = [
":EnzymeXLAPassesIncGen",
":EnzyeHLOPatternsIncGen",
":mhlo-derivatives",
":stablehlo-derivatives",
"@enzyme//:EnzymeMLIR",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:CommonFolders",
"@llvm-project//mlir:ControlFlowInterfaces",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:LLVMCommonConversion",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
Expand All @@ -177,6 +260,7 @@ pybind_library(
]),
deps = [
":XLADerivatives",
":TransformOps",
# This is similar to xla_binary rule and is needed to make XLA client compile.
"@tsl//tsl/framework:allocator",
"@tsl//tsl/framework:allocator_registry_impl",
Expand All @@ -200,6 +284,7 @@ pybind_library(
"@xla//xla/client",
"@xla//xla/client:client_library",
"@xla//xla/client:executable_build_options",
"@xla//xla/client:local_client",
"@xla//xla/client:xla_computation",
"@xla//xla/service",
"@xla//xla/service:local_service",
Expand Down Expand Up @@ -242,9 +327,11 @@ pybind_library(
# MLIR dialects and parser.
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncExtensions",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:LinalgTransformOps",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@xla//xla/mlir_hlo:all_passes",
Expand All @@ -266,6 +353,7 @@ pybind_extension(
deps = [
":clang_compile",
":compile_with_xla",
":TransformOps",
"@com_google_absl//absl/status:statusor",
"@enzyme//:EnzymeMLIR",
"@enzyme//:EnzymeStatic",
Expand Down
26 changes: 16 additions & 10 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,20 @@
// ops.
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"

#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h"
#include "src/enzyme_ad/jax/Passes/PassDetails.h"
#include "src/enzyme_ad/jax/Passes/Passes.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"

#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/reference/Ops.h"
#include "stablehlo/transforms/Passes.h"

#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"

#include "stablehlo/dialect/TypeInference.h"

Expand Down Expand Up @@ -3733,6 +3730,7 @@ struct SliceReshape : public OpRewritePattern<stablehlo::SliceOp> {
return success();
}
};
} // namespace

// Rewritten from
// https://github.com/openxla/stablehlo/blob/4f180d3c2236a15f82f29aad1b47f6ea2c14fc52/stablehlo/reference/Ops.cpp#L1381
Expand Down Expand Up @@ -4009,6 +4007,14 @@ template <typename T> struct CSE final : OpRewritePattern<T> {
}
};

#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.cpp.inc"

void mlir::transform::addPadDotGeneral(RewritePatternSet &patterns,
bool postPad, MLIRContext &context) {
patterns.insert<PadDotGeneral>(postPad, &context);
}

namespace {
struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {

void runOnOperation() override {
Expand Down
11 changes: 11 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
namespace mlir {
class RewritePatternSet;
class MLIRContext;
} // namespace mlir

#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h.inc"

namespace mlir::transform {
void addPadDotGeneral(RewritePatternSet &patterns, bool postPad,
MLIRContext &context);
}
Loading

0 comments on commit ff65836

Please sign in to comment.