Skip to content

Commit

Permalink
And/or simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 10, 2024
1 parent 14724d5 commit f170798
Showing 1 changed file with 54 additions and 2 deletions.
56 changes: 54 additions & 2 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,58 @@ struct NegateSimplify : public OpRewritePattern<mlir::stablehlo::NegOp> {
}
};

struct AndSimplify : public OpRewritePattern<mlir::stablehlo::AndOp> {
using OpRewritePattern<mlir::stablehlo::AndOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::AndOp op,
PatternRewriter &rewriter) const final {

// false & x -> x
for (auto v : op.getOperands()) {
if (matchPattern(v, m_Zero())) {
rewriter.replaceOp(op, v);
return success();
}
}

// true & x -> x
for (int i = 0; i < 2; i++) {
if (matchPattern(op.getOperand(i), m_One())) {
rewriter.replaceOp(op, op.getOperand(1 - i));
return success();
}
}

return failure();
}
};

struct OrSimplify : public OpRewritePattern<mlir::stablehlo::OrOp> {
using OpRewritePattern<mlir::stablehlo::OrOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::OrOp op,
PatternRewriter &rewriter) const final {

// true | x -> x
for (auto v : op.getOperands()) {
if (matchPattern(v, m_One())) {
rewriter.replaceOp(op, v);
return success();
}
}

// false | x -> x
for (int i = 0; i < 2; i++) {
if (matchPattern(op.getOperand(i), m_Zero())) {
rewriter.replaceOp(op, op.getOperand(1 - i));
return success();
}
}

return failure();
}
};

struct MulSimplify : public OpRewritePattern<mlir::stablehlo::MulOp> {
using OpRewritePattern<mlir::stablehlo::MulOp>::OpRewritePattern;

Expand Down Expand Up @@ -1194,8 +1246,8 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
/*ScatterToPad, */ BroadcastToReshape, ReduceToReshape,
ReduceConcat, SliceConcat, SliceSimplification, CosSimplify,
SinSimplify, SqrtSimplify, AddSimplify, SubSimplify,
NegateSimplify, MulSimplify, DivSimplify, PowSimplify,
BinBroadcastSplat<stablehlo::AddOp>,
AndSimplify, OrSimplify, NegateSimplify, MulSimplify,
DivSimplify, PowSimplify, BinBroadcastSplat<stablehlo::AddOp>,
BinBroadcastSplat<stablehlo::SubtractOp>,
BinBroadcastSplat<stablehlo::DivOp>,
BinBroadcastSplat<stablehlo::MulOp>>(context);
Expand Down

0 comments on commit f170798

Please sign in to comment.