From 96b8efc8125ffc83244233dd54b7582a9b1b5d85 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 7 Jan 2025 04:51:15 +0100 Subject: [PATCH] Batched reverse mode (#2216) * fix for reversemode * fix test * fixup * fixup --------- Co-authored-by: William S. Moses --- .github/workflows/enzyme-mlir.yml | 2 +- enzyme/Enzyme/MLIR/Passes/CMakeLists.txt | 1 + enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 6 ++--- enzyme/Enzyme/MLIR/Passes/Passes.td | 1 + enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp | 1 - enzyme/Enzyme/MLIR/Passes/RemovalUtils.h | 2 +- .../MLIR/Passes/RemoveUnusedEnzymeOps.cpp | 8 +++++- .../test/MLIR/ForwardMode/batched_scalar.mlir | 2 +- .../test/MLIR/ReverseMode/batched_square.mlir | 27 +++++++++++++++++++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 21 ++++++++------- 10 files changed, 53 insertions(+), 18 deletions(-) create mode 100644 enzyme/test/MLIR/ReverseMode/batched_square.mlir diff --git a/.github/workflows/enzyme-mlir.yml b/.github/workflows/enzyme-mlir.yml index 89ae72957c7..16b3fe6e11e 100644 --- a/.github/workflows/enzyme-mlir.yml +++ b/.github/workflows/enzyme-mlir.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/checkout@v4 with: repository: 'llvm/llvm-project' - ref: 'eaa7b385368fa7e3dad9b95411d04be55e71494e' + ref: 'ff24e9a19e3db330dd6412aac9d1d6c0b416697f' path: 'llvm-project' - name: Get MLIR commit hash diff --git a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt index 99db4d80034..00b2cae1e38 100644 --- a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIREnzymeTransforms SimplifyMath.cpp AddToOpToIndexAndLoad.cpp AddToOpToSplit.cpp + RemovalUtils.cpp RemoveUnusedEnzymeOps.cpp SimplifyMemrefCache.cpp Utils.cpp diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index c91f5400fef..972222f87ba 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -39,9 +39,9 @@ struct DifferentiatePass : public DifferentiatePassBase { pm.getDependentDialects(registry); } - registry - .insert(); + registry.insert(); } static std::vector mode_from_fn(FunctionOpInterface fn, diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index d3494956a12..ebe00135b9f 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -18,6 +18,7 @@ def DifferentiatePass : Pass<"enzyme"> { "complex::ComplexDialect", "cf::ControlFlowDialect", "tensor::TensorDialect", + "enzyme::EnzymeDialect", ]; let options = [ Option< diff --git a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp index 002b11d6bc9..572fddd1cae 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp @@ -26,7 +26,6 @@ mlir::enzyme::CacheInfo::merge(mlir::enzyme::CacheInfo other) { other.initOp->erase(); } - enzyme::PushOp newPushOp = pushOp; other.pushOp->erase(); enzyme::PopOp newPopOp; diff --git a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.h b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.h index 32308ed1d6b..d56ce6018da 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.h +++ b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.h @@ -41,7 +41,7 @@ struct CacheInfo { Value pushedValue() { return pushOp.getValue(); } Type cachedType() { - return initOp.getResult().getType().cast().getType(); + return cast(initOp.getResult().getType()).getType(); } // Pushed values must be the same diff --git a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp index 8ee77113e9a..cb25fa6fa8b 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp @@ -306,12 +306,18 @@ struct RemoveUnusedEnzymeOpsPass applyPatterns(op); + bool failed = false; op->walk([&](FunctionOpInterface func) { func->walk([&](enzyme::EnzymeOpsRemoverOpInterface iface) { - iface.removeEnzymeOps(); + auto result = iface.removeEnzymeOps(); + if (!result.succeeded()) + failed = true; }); }); + if (failed) + return signalPassFailure(); + applyPatterns(op); } }; diff --git a/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir index d384bdd0933..8acd131c169 100644 --- a/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir +++ b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir @@ -21,6 +21,6 @@ module { // CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (f64) -> tensor<2xf64> // CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64> // CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64> -// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64> +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] // CHECK-NEXT: return %[[i2]] : tensor<2xf64> // CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ReverseMode/batched_square.mlir b/enzyme/test/MLIR/ReverseMode/batched_square.mlir new file mode 100644 index 00000000000..86c28670312 --- /dev/null +++ b/enzyme/test/MLIR/ReverseMode/batched_square.mlir @@ -0,0 +1,27 @@ +// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math %s | FileCheck %s + +module { + func.func @square(%x: f64) -> f64 { + %next = arith.mulf %x, %x : f64 + return %next : f64 + } + + func.func @dsquare(%x: f64, %dr: tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.autodiff @square(%x, %dr) { activity=[#enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>) -> tensor<2xf64> + return %r : tensor<2xf64> + } +} + +// CHECK: func.func @dsquare(%arg0: f64, %arg1: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %0 = call @diffe2square(%arg0, %arg1) : (f64, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: return %0 : tensor<2xf64> +// CHECK-NEXT: } + +// CHECK: func.func private @diffe2square(%arg0: f64, %arg1: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %0 = "enzyme.broadcast"(%arg0) <{shape = array}> : (f64) -> tensor<2xf64> +// CHECK-NEXT: %1 = arith.mulf %arg1, %0 : tensor<2xf64> +// CHECK-NEXT: %2 = "enzyme.broadcast"(%arg0) <{shape = array}> : (f64) -> tensor<2xf64> +// CHECK-NEXT: %3 = arith.mulf %arg1, %2 : tensor<2xf64> +// CHECK-NEXT: %4 = arith.addf %1, %3 : tensor<2xf64> +// CHECK-NEXT: return %4 : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index dccbc7b7923..3f85a07548e 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -277,16 +277,17 @@ SmallVector prepareArgs(const Twine &curIndent, raw_ostream &os, if (!vecValue && !startsWith(ord, "local")) { if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) { os << ")"; - if (intrinsic == MLIRDerivatives) { - os << ";\n"; - os << "if (gutils->width != 1) {\n" - << " " << argName << "_" << (idx - 1) - << " = builder.create(\n" - << " op.getLoc(),\n" - << " " << argName << "_" << (idx - 1) << ",\n" - << " llvm::SmallVector({gutils->width}));\n" - << "}"; - } + } + if (intrinsic == MLIRDerivatives) { + os << ";\n"; + os << curIndent << "if (gutils->width != 1) {\n" + << curIndent << " " << argName << "_" << (idx - 1) + << " = builder.create(\n" + << curIndent << " op.getLoc(),\n" + << curIndent << " " << argName << "_" << (idx - 1) << ",\n" + << curIndent + << " llvm::SmallVector({gutils->width}));\n" + << curIndent << "}"; } if (lookup && intrinsic != MLIRDerivatives)