Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 8, 2024
1 parent 3aacc35 commit 38e6039
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 2 deletions.
46 changes: 45 additions & 1 deletion src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,50 @@ struct SliceSimplification final : OpRewritePattern<mlir::stablehlo::SliceOp> {
}
};

struct SliceConcat final : OpRewritePattern<mlir::stablehlo::SliceOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::SliceOp op,
PatternRewriter &rewriter) const override {
auto type = dyn_cast<RankedTensorType>(op.getType());
if (!type)
return failure();

auto concat = op.getOperand().getDefiningOp<stablehlo::ConcatenateOp>();
if (!concat) return failure();


auto dim = concat.getDimension();

if (op.getStrides()[dim] != 1) return failure();

SmallVector<Value> postConcat;
size_t curdim = 0;
for (auto v : concat.getInputs()) {
auto ty = v.getType().cast<RankedTensorType>();
auto nextdim = ty.getShape()[dim];
if (op.getStartIndices()[dim] < curdim) {
curdim += nextdim;
continue;
}
if (op.getLimitIndices()[dim] >= curdim) {
curdim += nextdim;
continue;
}
SmallVector<int64_t> nstart(op.getStartIndices().begin(), op.getStartIndices().end());
SmallVector<int64_t> nend(op.getStartIndices().begin(), op.getStartIndices().end());
nstart[dim] -= curdim;
if (nstart[dim] < 0) nstart[dim] = 0;
nend[dim] -= curdim;
if (nend[dim] > nextdim) nend[dim] = nextdim;
auto subslice = rewriter.create<stablehlo::SliceOp>(op.getLoc(), v, nstart, nend, op.getStrides());
postConcat.push_back(subslice);
}
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(op, postConcat, dim);
return success();
}
};


DenseElementsAttr fromTensor(stablehlo::Tensor inp) {
auto type = inp.getType();
Expand Down Expand Up @@ -698,7 +742,7 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
void runOnOperation() override {
auto context = getOperation()->getContext();
RewritePatternSet patterns(context);
patterns.add<AddPad, DotReshapeDot, ConcatConstProp, /*ScatterToPad, */BroadcastToReshape, SliceSimplification, CosSimplify, SinSimplify, SqrtSimplify, AddSimplify, SubSimplify, NegateSimplify, MulSimplify, DivSimplify, PowSimplify>(context);
patterns.add<AddPad, DotReshapeDot, ConcatConstProp, /*ScatterToPad, */BroadcastToReshape, SliceConcat, SliceSimplification, CosSimplify, SinSimplify, SqrtSimplify, AddSimplify, SubSimplify, NegateSimplify, MulSimplify, DivSimplify, PowSimplify>(context);
mlir::stablehlo::populateStablehloCanonicalizationPatterns(context,
&patterns);

Expand Down
2 changes: 1 addition & 1 deletion src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,7 @@ def enzyme_vjp(shadow_rets, *prim_args, **kwargs):
newpasses = (
prev_passes
+ "print," + ad_pass
+ ",canonicalize, remove-unnecessary-enzyme-ops, enzyme-simplify-math, enzyme-hlo-opt, canonicalize, cse, print"
+ ",arith-raise{stablehlo=true},canonicalize, remove-unnecessary-enzyme-ops, enzyme-simplify-math, enzyme-hlo-opt, canonicalize, cse, print"
+ post_passes
)

Expand Down

0 comments on commit 38e6039

Please sign in to comment.