Skip to content

Commit

Permalink
sliceslice
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 8, 2024
1 parent c2d3c1f commit 8d7987e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
30 changes: 29 additions & 1 deletion src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,34 @@ struct SliceSimplification final : OpRewritePattern<mlir::stablehlo::SliceOp> {
}
};

struct SliceSlice final : OpRewritePattern<mlir::stablehlo::SliceOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::SliceOp op,
PatternRewriter &rewriter) const override {
auto type = dyn_cast<RankedTensorType>(op.getType());
if (!type)
return failure();

auto prev = op.getOperand().getDefiningOp<stablehlo::SliceOp>();
if (!prev) return failure();

SmallVector<int64_t> start;
SmallVector<int64_t> end;
SmallVector<int64_t> step;

for (auto && [pstart, pend, pstep, nstart, nend, nstep] : llvm::zip(prev.getStartIndices(), prev.getLimitIndices(), prev.getStrides(),
op.getStartIndices(), op.getLimitIndices(), op.getStrides()
)) {
start.push_back(pstart + pstep * nstart);
step.push_back(pstep * nstep);
end.push_back(pstart + pstep * nstep * (nend - nstart));
}
rewriter.replaceOpWithNewOp<stablehlo::SliceOp>(op, prev.getOperand(), start, end, step);
return failure();
}
};

struct SliceConcat final : OpRewritePattern<mlir::stablehlo::SliceOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -766,7 +794,7 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
void runOnOperation() override {
auto context = getOperation()->getContext();
RewritePatternSet patterns(context);
patterns.add<AddPad, DotReshapeDot, ConcatConstProp, /*ScatterToPad, */BroadcastToReshape, SliceConcat, SliceSimplification, CosSimplify, SinSimplify, SqrtSimplify, AddSimplify, SubSimplify, NegateSimplify, MulSimplify, DivSimplify, PowSimplify>(context);
patterns.add<SliceSlice, AddPad, DotReshapeDot, ConcatConstProp, /*ScatterToPad, */BroadcastToReshape, SliceConcat, SliceSimplification, CosSimplify, SinSimplify, SqrtSimplify, AddSimplify, SubSimplify, NegateSimplify, MulSimplify, DivSimplify, PowSimplify>(context);
mlir::stablehlo::populateStablehloCanonicalizationPatterns(context,
&patterns);

Expand Down
2 changes: 1 addition & 1 deletion test/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def erev(x, weights, kc, vc, dx, dkc, dvc):
print("Jax rev", jres)

jrev2 = enzyme_jax.enzyme_jax_ir(argv=argv, pipeline_options=enzyme_jax.JaXPipeline("inline{default-pipeline=canonicalize max-iterations=4},"
+ "canonicalize,cse,print,enzyme-hlo-opt,cse,print"))(jrev)
+ "canonicalize,cse,enzyme-hlo-opt,cse"))(jrev)

jres2 = jrev2(x, weights, key_cache, value_cache, dx, dkc, dvc)
print("Jax2 rev", jres2)
Expand Down

0 comments on commit 8d7987e

Please sign in to comment.