From f1707987c8091c370ee78540724612c5412461f1 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 10 Mar 2024 03:58:33 -0400 Subject: [PATCH] And/or simplification --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 56 ++++++++++++++++++++++- 1 file changed, 54 insertions(+), 2 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index e6db9bc26..c1c4c3748 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -940,6 +940,58 @@ struct NegateSimplify : public OpRewritePattern { } }; +struct AndSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::AndOp op, + PatternRewriter &rewriter) const final { + + // false & x -> x + for (auto v : op.getOperands()) { + if (matchPattern(v, m_Zero())) { + rewriter.replaceOp(op, v); + return success(); + } + } + + // true & x -> x + for (int i = 0; i < 2; i++) { + if (matchPattern(op.getOperand(i), m_One())) { + rewriter.replaceOp(op, op.getOperand(1 - i)); + return success(); + } + } + + return failure(); + } +}; + +struct OrSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::OrOp op, + PatternRewriter &rewriter) const final { + + // true | x -> x + for (auto v : op.getOperands()) { + if (matchPattern(v, m_One())) { + rewriter.replaceOp(op, v); + return success(); + } + } + + // false | x -> x + for (int i = 0; i < 2; i++) { + if (matchPattern(op.getOperand(i), m_Zero())) { + rewriter.replaceOp(op, op.getOperand(1 - i)); + return success(); + } + } + + return failure(); + } +}; + struct MulSimplify : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1194,8 +1246,8 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { /*ScatterToPad, */ BroadcastToReshape, ReduceToReshape, ReduceConcat, SliceConcat, SliceSimplification, CosSimplify, SinSimplify, SqrtSimplify, AddSimplify, SubSimplify, - NegateSimplify, MulSimplify, DivSimplify, PowSimplify, - BinBroadcastSplat, + AndSimplify, OrSimplify, NegateSimplify, MulSimplify, + DivSimplify, PowSimplify, BinBroadcastSplat, BinBroadcastSplat, BinBroadcastSplat, BinBroadcastSplat>(context);