Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 8, 2024
1 parent 1f7c32e commit ef65b9c
Show file tree
Hide file tree
Showing 6 changed files with 383 additions and 296 deletions.
17 changes: 6 additions & 11 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,16 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen

pip_install_dependencies()

ENZYME_COMMIT = "97066352a40b3c66f9a1f41ec1802af255216c0c"
ENZYME_SHA256 = ""
ENZYME_COMMIT = "0a129ae7e45114a08f281e50632b9f967fae8396"
ENZYME_SHA256 = "715982efd0a0ef8038e8ad35047e9c1941eb3f9cb038883342969b0bcc8915ad"

local_repository(
http_archive(
name = "enzyme",
path = "../Enzyme/enzyme"
sha256 = ENZYME_SHA256,
strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme",
urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)],
)

# http_archive(
# name = "enzyme",
# sha256 = ENZYME_SHA256,
# strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme",
# urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)],
# )

JAX_COMMIT = "9a098e922aff62a3b49bd673b9518d97ee599248"
JAX_SHA256 = ""

Expand Down
2 changes: 2 additions & 0 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ cc_library(
"@stablehlo//:reference_ops",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:ControlFlowInterfaces",
Expand Down Expand Up @@ -225,6 +226,7 @@ pybind_library(
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncExtensions",
"@llvm-project//mlir:TensorDialect",

"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
Expand Down
158 changes: 86 additions & 72 deletions src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ class AutoDiffBroadcastInDimRev
AutoDiffBroadcastInDimRev, BroadcastInDimOp> {
public:
LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto op = cast<BroadcastInDimOp>(orig);
auto inTy = op.getOperand().getType();
auto outTy = op.getType();
Expand All @@ -205,16 +205,16 @@ class AutoDiffBroadcastInDimRev

SmallVector<int64_t> newDims;
for (auto en : llvm::enumerate(outTy.getShape())) {
if (llvm::is_contained(bcastDims, en.index())) continue;
if (llvm::is_contained(bcastDims, en.index()))
continue;
newDims.push_back(en.index());
}

Value zero = gutils->getShadowType(inTy)
.cast<AutoDiffTypeInterface>()
.createNullValue(builder, op.getLoc());
.cast<AutoDiffTypeInterface>()
.createNullValue(builder, op.getLoc());

auto red = builder.create<ReduceOp>(op.getLoc(),
TypeRange(zero.getType()),
auto red = builder.create<ReduceOp>(op.getLoc(), TypeRange(zero.getType()),
inDiffe, zero, newDims);
red.getBody().push_back(new Block());
Block &body = red.getBody().front();
Expand All @@ -228,9 +228,9 @@ class AutoDiffBroadcastInDimRev
bodyBuilder.create<ReturnOp>(op.getLoc(), ValueRange(add));

Value res = red->getResult(0);
Type resTy = gutils->getShadowType(op.getOperand().getType());
Type resTy = gutils->getShadowType(op.getOperand().getType());
if (res.getType() != resTy)
res = builder.create<ReshapeOp>(op.getLoc(), resTy, res);
res = builder.create<ReshapeOp>(op.getLoc(), resTy, res);

gutils->addToDiffe(op.getOperand(), res, builder);
return success();
Expand All @@ -250,8 +250,8 @@ class AutoDiffSliceRev
SliceOp> {
public:
LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto op = cast<SliceOp>(orig);
auto inTy = op.getOperand().getType();
auto outTy = op.getType();
Expand All @@ -263,21 +263,25 @@ class AutoDiffSliceRev
SmallVector<int64_t> starts;
SmallVector<int64_t> edge_padding_high;
SmallVector<int64_t> interior_padding;
for (auto &&[start, limit, stride, dim] : llvm::zip(
op.getStartIndices(), op.getLimitIndices(), op.getStrides(), inTy.getShape())) {
for (auto &&[start, limit, stride, dim] :
llvm::zip(op.getStartIndices(), op.getLimitIndices(), op.getStrides(),
inTy.getShape())) {
starts.push_back(start);
edge_padding_high.push_back(dim - limit);
interior_padding.push_back(stride - 1);
}


auto zeroPad = RankedTensorType::get({}, inTy.getElementType()).cast<AutoDiffTypeInterface>().createNullValue(builder,
op.getLoc());
auto red = builder.create<stablehlo::PadOp>(op.getLoc(), inDiffe, zeroPad, builder.getDenseI64ArrayAttr(starts), builder.getDenseI64ArrayAttr(edge_padding_high), builder.getDenseI64ArrayAttr(interior_padding));
auto zeroPad = RankedTensorType::get({}, inTy.getElementType())
.cast<AutoDiffTypeInterface>()
.createNullValue(builder, op.getLoc());
auto red = builder.create<stablehlo::PadOp>(
op.getLoc(), inDiffe, zeroPad, builder.getDenseI64ArrayAttr(starts),
builder.getDenseI64ArrayAttr(edge_padding_high),
builder.getDenseI64ArrayAttr(interior_padding));

gutils->addToDiffe(op.getOperand(), red->getResult(0), builder);
return success();
#if 0
#if 0

Value idxs;
{
Expand Down Expand Up @@ -351,7 +355,7 @@ class AutoDiffSliceRev
// gutils->setDiffe(op.getOperand(), red->getResult(0), builder);

return success();
#endif
#endif
}

SmallVector<Value> cacheValues(Operation *orig,
Expand All @@ -368,26 +372,27 @@ class AutoDiffReduceRev
ReduceOp> {
public:
LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto op = cast<ReduceOp>(orig);
if (!isEligibleForCompactPrint(op)) {
orig->emitError() << "Unsupported operation in reduction rev autodiff(1): "
<< *orig << "\n";
orig->emitError()
<< "Unsupported operation in reduction rev autodiff(1): " << *orig
<< "\n";
return failure();
}

Operation &innerOp = op.getBody().front().front();

auto inTy = op->getOperand(0).getType().cast<RankedTensorType>();
auto zero = inTy.cast<AutoDiffTypeInterface>().createNullValue(builder,
op.getLoc());
auto inDiffe = gutils->diffe(op->getResult(0), builder);
gutils->zeroDiffe(op->getResult(0), builder);
SmallVector<int64_t> toBroadcast;
{
size_t idx=0;

SmallVector<int64_t> toBroadcast;
{
size_t idx = 0;
for (auto en : llvm::enumerate(inTy.getShape())) {
if (llvm::is_contained(op.getDimensions(), en.index())) {
// reduced op
Expand All @@ -396,56 +401,62 @@ class AutoDiffReduceRev
toBroadcast.push_back(idx);
idx++;
}
}
}

if (isa<AddOp>(innerOp)) {
if (!gutils->isConstantValue(op.getInputs()[0])) {
if (!gutils->isConstantValue(op.getInputs()[0])) {
Value bcast;


bcast = builder.create<BroadcastInDimOp>(op.getLoc(), gutils->getShadowType(inTy), inDiffe, builder.getDenseI64ArrayAttr(toBroadcast));
bcast = builder.create<BroadcastInDimOp>(
op.getLoc(), gutils->getShadowType(inTy), inDiffe,
builder.getDenseI64ArrayAttr(toBroadcast));

gutils->addToDiffe(op.getInputs()[0], bcast, builder);
}
if (!gutils->isConstantValue(op.getInitValues()[0])) {
}
if (!gutils->isConstantValue(op.getInitValues()[0])) {
gutils->addToDiffe(op.getInitValues()[0], inDiffe, builder);
}
return success();
}
return success();
}

if (isa<MaxOp>(innerOp) || isa<MinOp>(innerOp)) {
// TODO: technically we should invert the order here to pick the last value (or divide by count) if multiple are the same as the
// result
// TODO: technically we should invert the order here to pick the last
// value (or divide by count) if multiple are the same as the result
auto ores = gutils->getNewFromOriginal(op->getResult(0));

if (!gutils->isConstantValue(op.getInputs()[0])) {
auto oprev = gutils->getNewFromOriginal(op.getInputs()[0]);
auto attr = builder.getDenseI64ArrayAttr(toBroadcast);
auto bc = builder.create<BroadcastInDimOp>(op.getLoc(), oprev.getType(), ores, attr);
auto bc = builder.create<BroadcastInDimOp>(op.getLoc(), oprev.getType(),
ores, attr);

auto cmp = builder.create<CompareOp>(op.getLoc(), bc, oprev, ComparisonDirection::EQ);
auto cmp = builder.create<CompareOp>(op.getLoc(), bc, oprev,
ComparisonDirection::EQ);

auto bc2 = builder.create<BroadcastInDimOp>(op.getLoc(), oprev.getType(), inDiffe, attr);
auto bc2 = builder.create<BroadcastInDimOp>(
op.getLoc(), oprev.getType(), inDiffe, attr);

auto res = builder.create<SelectOp>(op.getLoc(), cmp, bc2, zero);
gutils->addToDiffe(op.getInputs()[0], res, builder);
}
if (!gutils->isConstantValue(op.getInitValues()[0])) {
auto oprev = gutils->getNewFromOriginal(op.getInitValues()[0]);

auto zeroI = inDiffe.getType().cast<AutoDiffTypeInterface>().createNullValue(builder,
op.getLoc());
auto zeroI =
inDiffe.getType().cast<AutoDiffTypeInterface>().createNullValue(
builder, op.getLoc());

auto cmp = builder.create<CompareOp>(op.getLoc(), ores, oprev, ComparisonDirection::EQ);
auto cmp = builder.create<CompareOp>(op.getLoc(), ores, oprev,
ComparisonDirection::EQ);

auto res = builder.create<SelectOp>(op.getLoc(), cmp, inDiffe, zeroI);
gutils->addToDiffe(op.getInitValues()[0], res, builder);
}
return success();
}

orig->emitError() << "Unsupported operation in reduction rev autodiff(1): "
<< *orig << "\n";
<< *orig << "\n";
return failure();
}

Expand All @@ -463,40 +474,43 @@ class AutoDiffConcatenateRev
ConcatenateOp> {
public:
LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto op = cast<ConcatenateOp>(orig);

auto inDiffe = gutils->diffe(op->getResult(0), builder);
gutils->zeroDiffe(op->getResult(0), builder);

auto dim = op.getDimension();
size_t startDim = 0;
for (auto &ope : op->getOpOperands()) {
auto op = ope.get();
auto inTy = gutils->getShadowType(op.getType());
SmallVector<int64_t> start;
SmallVector<int64_t> limit;
SmallVector<int64_t> strides;
SmallVector<int64_t> tys;
auto RT = inTy.cast<RankedTensorType>();
for (auto i=0; i<RT.getShape().size(); i++) {
tys.push_back(RT.getShape()[i]);
if (i == dim) {
start.push_back(startDim);
limit.push_back(startDim + RT.getShape()[i]);
startDim += RT.getShape()[i];
strides.push_back(1);
continue;
}
start.push_back(0);
limit.push_back(RT.getShape()[i]);
strides.push_back(1);
for (auto &ope : op->getOpOperands()) {
auto op = ope.get();
auto inTy = gutils->getShadowType(op.getType());
SmallVector<int64_t> start;
SmallVector<int64_t> limit;
SmallVector<int64_t> strides;
SmallVector<int64_t> tys;
auto RT = inTy.cast<RankedTensorType>();
for (auto i = 0; i < RT.getShape().size(); i++) {
tys.push_back(RT.getShape()[i]);
if (i == dim) {
start.push_back(startDim);
limit.push_back(startDim + RT.getShape()[i]);
startDim += RT.getShape()[i];
strides.push_back(1);
continue;
}
if (gutils->isConstantValue(op)) continue;
auto res = builder.create<SliceOp>(op.getLoc(), RankedTensorType::get(tys, RT.getElementType()), inDiffe, start, limit, strides);
auto res2 = builder.create<ReshapeOp>(op.getLoc(), inTy, res);
gutils->addToDiffe(op, res2, builder);
start.push_back(0);
limit.push_back(RT.getShape()[i]);
strides.push_back(1);
}
if (gutils->isConstantValue(op))
continue;
auto res = builder.create<SliceOp>(
op.getLoc(), RankedTensorType::get(tys, RT.getElementType()), inDiffe,
start, limit, strides);
auto res2 = builder.create<ReshapeOp>(op.getLoc(), inTy, res);
gutils->addToDiffe(op, res2, builder);
}
return success();
}
Expand Down
Loading

0 comments on commit ef65b9c

Please sign in to comment.