Skip to content

Commit

Permalink
Batched reverse mode (#2216)
Browse files Browse the repository at this point in the history
* fix for reversemode

* fix test

* fixup

* fixup

---------

Co-authored-by: William S. Moses <[email protected]>
  • Loading branch information
jumerckx and wsmoses authored Jan 7, 2025
1 parent 5b330a9 commit 96b8efc
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/enzyme-mlir.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
- uses: actions/checkout@v4
with:
repository: 'llvm/llvm-project'
ref: 'eaa7b385368fa7e3dad9b95411d04be55e71494e'
ref: 'ff24e9a19e3db330dd6412aac9d1d6c0b416697f'
path: 'llvm-project'

- name: Get MLIR commit hash
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Passes/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIREnzymeTransforms
SimplifyMath.cpp
AddToOpToIndexAndLoad.cpp
AddToOpToSplit.cpp
RemovalUtils.cpp
RemoveUnusedEnzymeOps.cpp
SimplifyMemrefCache.cpp
Utils.cpp
Expand Down
6 changes: 3 additions & 3 deletions enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
pm.getDependentDialects(registry);
}

registry
.insert<mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
mlir::cf::ControlFlowDialect, mlir::tensor::TensorDialect>();
registry.insert<mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
mlir::cf::ControlFlowDialect, mlir::tensor::TensorDialect,
mlir::enzyme::EnzymeDialect>();
}

static std::vector<DIFFE_TYPE> mode_from_fn(FunctionOpInterface fn,
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def DifferentiatePass : Pass<"enzyme"> {
"complex::ComplexDialect",
"cf::ControlFlowDialect",
"tensor::TensorDialect",
"enzyme::EnzymeDialect",
];
let options = [
Option<
Expand Down
1 change: 0 additions & 1 deletion enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ mlir::enzyme::CacheInfo::merge(mlir::enzyme::CacheInfo other) {
other.initOp->erase();
}

enzyme::PushOp newPushOp = pushOp;
other.pushOp->erase();

enzyme::PopOp newPopOp;
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Passes/RemovalUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct CacheInfo {

Value pushedValue() { return pushOp.getValue(); }
Type cachedType() {
return initOp.getResult().getType().cast<enzyme::CacheType>().getType();
return cast<enzyme::CacheType>(initOp.getResult().getType()).getType();
}

// Pushed values must be the same
Expand Down
8 changes: 7 additions & 1 deletion enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,18 @@ struct RemoveUnusedEnzymeOpsPass

applyPatterns(op);

bool failed = false;
op->walk([&](FunctionOpInterface func) {
func->walk([&](enzyme::EnzymeOpsRemoverOpInterface iface) {
iface.removeEnzymeOps();
auto result = iface.removeEnzymeOps();
if (!result.succeeded())
failed = true;
});
});

if (failed)
return signalPassFailure();

applyPatterns(op);
}
};
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/MLIR/ForwardMode/batched_scalar.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ module {
// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array<i64: 2>}> : (f64) -> tensor<2xf64>
// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64>
// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64>
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64>
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]]
// CHECK-NEXT: return %[[i2]] : tensor<2xf64>
// CHECK-NEXT: }
27 changes: 27 additions & 0 deletions enzyme/test/MLIR/ReverseMode/batched_square.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math %s | FileCheck %s

module {
func.func @square(%x: f64) -> f64 {
%next = arith.mulf %x, %x : f64
return %next : f64
}

func.func @dsquare(%x: f64, %dr: tensor<2xf64>) -> tensor<2xf64> {
%r = enzyme.autodiff @square(%x, %dr) { activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_activenoneed>], width=2 } : (f64, tensor<2xf64>) -> tensor<2xf64>
return %r : tensor<2xf64>
}
}

// CHECK: func.func @dsquare(%arg0: f64, %arg1: tensor<2xf64>) -> tensor<2xf64> {
// CHECK-NEXT: %0 = call @diffe2square(%arg0, %arg1) : (f64, tensor<2xf64>) -> tensor<2xf64>
// CHECK-NEXT: return %0 : tensor<2xf64>
// CHECK-NEXT: }

// CHECK: func.func private @diffe2square(%arg0: f64, %arg1: tensor<2xf64>) -> tensor<2xf64> {
// CHECK-NEXT: %0 = "enzyme.broadcast"(%arg0) <{shape = array<i64: 2>}> : (f64) -> tensor<2xf64>
// CHECK-NEXT: %1 = arith.mulf %arg1, %0 : tensor<2xf64>
// CHECK-NEXT: %2 = "enzyme.broadcast"(%arg0) <{shape = array<i64: 2>}> : (f64) -> tensor<2xf64>
// CHECK-NEXT: %3 = arith.mulf %arg1, %2 : tensor<2xf64>
// CHECK-NEXT: %4 = arith.addf %1, %3 : tensor<2xf64>
// CHECK-NEXT: return %4 : tensor<2xf64>
// CHECK-NEXT: }
21 changes: 11 additions & 10 deletions enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,16 +277,17 @@ SmallVector<bool, 1> prepareArgs(const Twine &curIndent, raw_ostream &os,
if (!vecValue && !startsWith(ord, "local")) {
if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) {
os << ")";
if (intrinsic == MLIRDerivatives) {
os << ";\n";
os << "if (gutils->width != 1) {\n"
<< " " << argName << "_" << (idx - 1)
<< " = builder.create<enzyme::BroadcastOp>(\n"
<< " op.getLoc(),\n"
<< " " << argName << "_" << (idx - 1) << ",\n"
<< " llvm::SmallVector<int64_t>({gutils->width}));\n"
<< "}";
}
}
if (intrinsic == MLIRDerivatives) {
os << ";\n";
os << curIndent << "if (gutils->width != 1) {\n"
<< curIndent << " " << argName << "_" << (idx - 1)
<< " = builder.create<enzyme::BroadcastOp>(\n"
<< curIndent << " op.getLoc(),\n"
<< curIndent << " " << argName << "_" << (idx - 1) << ",\n"
<< curIndent
<< " llvm::SmallVector<int64_t>({gutils->width}));\n"
<< curIndent << "}";
}

if (lookup && intrinsic != MLIRDerivatives)
Expand Down

0 comments on commit 96b8efc

Please sign in to comment.