Skip to content

Commit

Permalink
feat: a - a => 0 if no_nan
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 31, 2024
1 parent da5e47a commit 945c13e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 13 deletions.
42 changes: 30 additions & 12 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2679,6 +2679,22 @@ struct SubSimplify : public OpRewritePattern<mlir::stablehlo::SubtractOp> {
}
};

struct NoNanSelfSubSimplify
: public OpRewritePattern<mlir::stablehlo::SubtractOp> {
using OpRewritePattern<mlir::stablehlo::SubtractOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::SubtractOp op,
PatternRewriter &rewriter) const final {

if (isa<FloatType>(op.getType().getElementType()) &&
op.getLhs() == op.getRhs()) {
rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(
op, rewriter.getZeroAttr(op.getType()));
return success();
}
}
};

struct NegateSimplify : public OpRewritePattern<mlir::stablehlo::NegOp> {
using OpRewritePattern<mlir::stablehlo::NegOp>::OpRewritePattern;

Expand Down Expand Up @@ -6806,16 +6822,16 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
auto context = getOperation()->getContext();
RewritePatternSet patterns(context);
patterns
.add<AddSimplify, ReplaceNegAddWithSubtract, SubSimplify, AndSimplify,
MaxSimplify, MinSimplify, OrSimplify, NegateSimplify, MulSimplify,
DivSimplify, RemSimplify, PowSimplify, SqrtSimplify, CosSimplify,
SinSimplify, NoopSlice, NoopReverse, SliceSlice, PadSimplify,
ShiftRightLogicalSimplify, NegativePadToSlice, TanhSimplify,
ExpSimplify, SliceSimplify, ConvertSimplify, TransposeSimplify,
DotGeneralSimplify, DynamicSliceToStatic, DynamicUpdateSliceElim,
ReduceToReshape, BroadcastToReshape, GatherSimplify,
ReshapeEmptyBroadcast, BroadcastReshape, ConstPropThroughBarrier>(
context, PatternBenefit(65000));
.add<AddSimplify, SubSimplify, AndSimplify, MaxSimplify, MinSimplify,
OrSimplify, NegateSimplify, MulSimplify, DivSimplify, RemSimplify,
PowSimplify, SqrtSimplify, CosSimplify, SinSimplify, NoopSlice,
NoopReverse, SliceSlice, PadSimplify, ShiftRightLogicalSimplify,
NegativePadToSlice, TanhSimplify, ExpSimplify, SliceSimplify,
ConvertSimplify, TransposeSimplify, DotGeneralSimplify,
DynamicSliceToStatic, DynamicUpdateSliceElim, ReduceToReshape,
BroadcastToReshape, GatherSimplify, ReshapeEmptyBroadcast,
BroadcastReshape, ConstPropThroughBarrier>(context,
PatternBenefit(65000));

patterns.add<IotaSimplify, BroadcastInDimSimplify>(
max_constant_expansion, context, PatternBenefit(65000));
Expand Down Expand Up @@ -6912,8 +6928,10 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {

if (all_finite)
patterns.add<AllFinite>(context);
if (no_nan || all_finite)
patterns.add<NoNan>(context);
if (no_nan || all_finite) {
patterns.add<NoNan, ReplaceNegAddWithSubtract, NoNanSelfSubSimplify>(
context);
}

patterns.add<CompareOpCanon, BroadcastInDimOpCanon, ConvertOpCanon,
DynamicBroadcastInDimOpNotActuallyDynamic,
Expand Down
4 changes: 4 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,10 @@ def ApplyNoNanPatterns : EnzymeHLOPatternOp<
"no_nan"> {
let patterns = ["NoNan"];
}
def ApplyNoNanSelfSubSimplify : EnzymeHLOPatternOp<
"no_nan_self_sub_simplify"> {
let patterns = ["NoNanSelfSubSimplify"];
}

def ApplyConcatPushBinopAddPatterns : EnzymeHLOPatternOp<
"concat_push_binop_add"> {
Expand Down
1 change: 0 additions & 1 deletion src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def hlo_opts():
concatenate_op_canon<16>(1024);
select_op_canon<16>(1024);
add_simplify<16>;
replace_neg_add_with_subtract<16>;
sub_simplify<16>;
and_simplify<16>;
max_simplify<16>;
Expand Down

0 comments on commit 945c13e

Please sign in to comment.