Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 16, 2024
1 parent 5a246c4 commit 8c74fa0
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 13 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen

pip_install_dependencies()

ENZYME_COMMIT = "2c753c97fcb41623e9aca972edfc08202b23e04f"
ENZYME_SHA256 = "0af3843503e25b973ae82dfa958a843e372708e24b34d195f38b28ced17fcb84"
ENZYME_COMMIT = "4ccab29dc691cb43d250a7c5ca612c3ff9cd23e3"
ENZYME_SHA256 = ""

http_archive(
name = "enzyme",
Expand Down
7 changes: 7 additions & 0 deletions src/enzyme_ad/jax/Implementations/Common.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,12 @@ class Inst<string mnemonic, string dialect_> : Operation</*primal*/1, /*shadow*
string dialect = dialect_;
}


class ConstantFP<string val, string dialect_, string op_> : Operation</*primal*/0, /*shadow*/0> {
string value = val;
string dialect = dialect_;
string opName = op_;
}

def Op {
}
42 changes: 41 additions & 1 deletion src/enzyme_ad/jax/Implementations/HLODerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ def Neg : HLOInst<"NegOp">;
def Mul : HLOInst<"MulOp">;
def Div : HLOInst<"DivOp">;
def Rem : HLOInst<"RemainderOp">;
def Pow : HLOInst<"PowOp">;
def Log : HLOInst<"LogOp">;
def Cos : HLOInst<"CosineOp">;
def Sin : HLOInst<"SineOp">;
def Sqrt : HLOInst<"SqrtOp">;
def Exp : HLOInst<"ExpOp">;


def CheckedMul : HLOInst<"MulOp">;
Expand Down Expand Up @@ -42,7 +48,41 @@ def : HLODerivative<"DivOp", (Op $x, $y),
// (CheckedDiv (FSub (SelectIfActive $x, (FMul (Shadow $x), $y), (Zero $x)), (SelectIfActive $y, (FMul (Shadow $y), $x), (Zero $y))), (FMul $y, $y))
>;

def : HLODerivative<"PowOp", (Op $x, $y),
[
(CheckedMul (DiffeRet), (Mul $y, (Pow $x, (Sub $y, (HLOConstantFP<"1"> $y))))),
(CheckedMul (DiffeRet), (Mul (Pow $x, $y), (Log $x)
))
]
>;

def : HLODerivative<"CosineOp", (Op $x),
[
(CheckedMul (DiffeRet), (Neg (Sin $x)))
]
>;
def : HLODerivative<"ExpOp", (Op $x),
[
(CheckedMul (DiffeRet), (Exp $x))
]
>;
def : HLODerivative<"SineOp", (Op $x),
[
(CheckedMul (DiffeRet), (Cos $x))
]
>;
def : HLODerivative<"SqrtOp", (Op $x),
[
// (Select (FCmpUEQ $x, (ConstantFP<"0"> $x)), (ConstantFP<"0"> $x), (FDiv (DiffeRet), (FMul (ConstantFP<"2"> $x), (Call<(SameFunc), [ReadNone,NoUnwind]> $x))))
(Div (DiffeRet), (Mul (HLOConstantFP<"2"> $x), (Sqrt $x)))
]
>;

def : HLOReadOnlyIdentityOp<"ReshapeOp">;
def : HLOReadOnlyIdentityOp<"SliceOp">;
def : HLOReadOnlyIdentityOp<"BroadcastInDimOp">;

def : HLOReadOnlyIdentityOp<"ConcatenateOp">;
// convert
// cos
// sin
// sqrt
2 changes: 2 additions & 0 deletions src/enzyme_ad/jax/Implementations/MHLODerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ class HLOReadOnlyIdentityOp<string opName_, list<int> ptrargs_ = [0]> : ReadOnly

class HLOControlFlowOp<string opName_, string impl_> : ControlFlowOp<"mhlo", opName_, impl_>;

class HLOConstantFP<string m> : ConstantFP<m, "mhlo", "ConstantOp">;

include "HLODerivatives.td"
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/Implementations/StableHLODerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ class HLOReadOnlyIdentityOp<string opName_, list<int> ptrargs_ = [0]> : ReadOnly

class HLOControlFlowOp<string opName_, string impl_> : ControlFlowOp<"stablehlo", opName_, impl_>;

class HLOConstantFP<string m> : ConstantFP<m, "stablehlo", "ConstantOp">;

include "HLODerivatives.td"
16 changes: 16 additions & 0 deletions src/enzyme_ad/jax/Implementations/XLADerivatives.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectRegistry.h"
#include "llvm/ADT/APFloat.h"

namespace mlir {
namespace enzyme {
Expand All @@ -12,3 +15,16 @@ registerXLAAutoDiffInterfaces(mlir::DialectRegistry &registry) {
}
} // namespace enzyme
} // namespace mlir

static inline mlir::DenseFPElementsAttr getTensorAttr(mlir::Type type,
llvm::StringRef value) {
using namespace mlir;
auto T = cast<TensorType>(type);
size_t num = 1;
for (auto sz : T.getShape())
num *= sz;
APFloat apvalue(T.getElementType().cast<FloatType>().getFloatSemantics(),
value);
SmallVector<APFloat> supportedValues(num, apvalue);
return DenseFPElementsAttr::get(type.cast<ShapedType>(), supportedValues);
}
3 changes: 3 additions & 0 deletions src/enzyme_ad/jax/compile_with_xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@
#include "mlir/CAPI/IR.h"
#include "mlir/lib/Bindings/Python/IRModule.h"

#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"

void prepareRegistry(mlir::DialectRegistry &registry) {
mlir::enzyme::registerCoreDialectAutodiffInterfaces(registry);
mlir::enzyme::registerXLAAutoDiffInterfaces(registry);
mlir::func::registerInlinerExtension(registry);
}

/// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
Expand Down
11 changes: 1 addition & 10 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,20 +495,12 @@ def _enzyme_primal_lowering(
continue
seen[in_idx_map[i]] = i
orig_shapes.append(shape)
print("orig_shapes", orig_shapes)
print("seen", seen)
avals = [ctx.avals_in[seen[i]] for i in seen]
avals_in = jax.tree_util.tree_unflatten(in_tree, avals)
print("avals_in", avals_in)
print("in_idx_map", in_idx_map)
print("in_shapes", in_shapes)


lowered_func = jax.jit(mfunc).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect="stablehlo")
source = str(mhlo)
kept = lowered_func.compile()._executable._kept_var_idx
print("kept", kept)
in_args = tuple(arg for (i, arg) in enumerate(in_args) if in_idx_map[i] in kept)
if len(kept) != len(orig_shapes):
post = ",".join(["enzyme_dup"] * len(kept))
Expand All @@ -518,8 +510,6 @@ def _enzyme_primal_lowering(
in_shapes = [
shape for (i, shape) in enumerate(in_shapes) if in_idx_map[i] in kept
]
print("post in_shapes", in_shapes)

if pipeline_options.stablehlo_inject():
fn = enzyme_call.run_pass_pipeline(source, pass_pipeline)
print(fn)
Expand Down Expand Up @@ -845,6 +835,7 @@ def make_zero(tan, prim):
if pipeline_options.mlir_ad() and kwargs["lang"] == LANG_MHLO:
act_tup = ",".join(["enzyme_dup" for a in arg_primals])
newpasses = (
"inline{default-pipeline=canonicalize max-iterations=4}," +
"func.func(stablehlo-aggressive-simplification),cse,print,enzyme-wrap{infn=main outfn= retTy=enzyme_dup argTys="
+ act_tup
+ " mode=ForwardMode},"
Expand Down

0 comments on commit 8c74fa0

Please sign in to comment.