Skip to content

Commit

Permalink
Fix constprop on integers
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 10, 2024
1 parent 514c9e9 commit 15ae277
Showing 1 changed file with 103 additions and 48 deletions.
151 changes: 103 additions & 48 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -857,18 +857,32 @@ struct AddSimplify : public OpRewritePattern<mlir::stablehlo::AddOp> {
for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
matchPattern(op->getOperand(i), m_Constant(&constants[i]));

if (auto res =
constFoldBinaryOpConditional<FloatAttr, FloatAttr::ValueType, void>(
constants,
[](const APFloat &a,
const APFloat &b) -> std::optional<APFloat> {
APFloat res2(a);
res2.add(b, llvm::RoundingMode::NearestTiesToEven);
return res2;
})) {
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op.getType(), res.cast<ElementsAttr>());
return success();
if (op.getType().getElementType().isa<FloatType>()) {
if (auto res = constFoldBinaryOpConditional<FloatAttr,
FloatAttr::ValueType, void>(
constants,
[](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
APFloat res2(a);
res2.add(b, llvm::RoundingMode::NearestTiesToEven);
return res2;
})) {
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op.getType(), res.cast<ElementsAttr>());
return success();
}
} else {
if (auto res = constFoldBinaryOpConditional<IntegerAttr,
IntegerAttr::ValueType, void>(
constants,
[](const APInt &a, const APInt &b) -> std::optional<APInt> {
APInt res2(a);
res2 += b;
return res2;
})) {
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op.getType(), res.cast<ElementsAttr>());
return success();
}
}

return failure();
Expand Down Expand Up @@ -896,18 +910,32 @@ struct SubSimplify : public OpRewritePattern<mlir::stablehlo::SubtractOp> {
for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
matchPattern(op->getOperand(i), m_Constant(&constants[i]));

if (auto res =
constFoldBinaryOpConditional<FloatAttr, FloatAttr::ValueType, void>(
constants,
[](const APFloat &a,
const APFloat &b) -> std::optional<APFloat> {
APFloat res2(a);
res2.subtract(b, llvm::RoundingMode::NearestTiesToEven);
return res2;
})) {
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op.getType(), res.cast<ElementsAttr>());
return success();
if (op.getType().getElementType().isa<FloatType>()) {
if (auto res = constFoldBinaryOpConditional<FloatAttr,
FloatAttr::ValueType, void>(
constants,
[](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
APFloat res2(a);
res2.subtract(b, llvm::RoundingMode::NearestTiesToEven);
return res2;
})) {
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op.getType(), res.cast<ElementsAttr>());
return success();
}
} else {
if (auto res = constFoldBinaryOpConditional<IntegerAttr,
IntegerAttr::ValueType, void>(
constants,
[](const APInt &a, const APInt &b) -> std::optional<APInt> {
APInt res2(a);
res2 -= b;
return res2;
})) {
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op.getType(), res.cast<ElementsAttr>());
return success();
}
}

return failure();
Expand Down Expand Up @@ -1026,18 +1054,32 @@ struct MulSimplify : public OpRewritePattern<mlir::stablehlo::MulOp> {
for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
matchPattern(op->getOperand(i), m_Constant(&constants[i]));

if (auto res =
constFoldBinaryOpConditional<FloatAttr, FloatAttr::ValueType, void>(
constants,
[](const APFloat &a,
const APFloat &b) -> std::optional<APFloat> {
APFloat res2(a);
res2.multiply(b, llvm::RoundingMode::NearestTiesToEven);
return res2;
})) {
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op.getType(), res.cast<ElementsAttr>());
return success();
if (op.getType().getElementType().isa<FloatType>()) {
if (auto res = constFoldBinaryOpConditional<FloatAttr,
FloatAttr::ValueType, void>(
constants,
[](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
APFloat res2(a);
res2.multiply(b, llvm::RoundingMode::NearestTiesToEven);
return res2;
})) {
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op.getType(), res.cast<ElementsAttr>());
return success();
}
} else {
if (auto res = constFoldBinaryOpConditional<IntegerAttr,
IntegerAttr::ValueType, void>(
constants,
[](const APInt &a, const APInt &b) -> std::optional<APInt> {
APInt res2(a);
res2 *= b;
return res2;
})) {
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op.getType(), res.cast<ElementsAttr>());
return success();
}
}

return failure();
Expand Down Expand Up @@ -1067,18 +1109,31 @@ struct DivSimplify : public OpRewritePattern<mlir::stablehlo::DivOp> {
for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
matchPattern(op->getOperand(i), m_Constant(&constants[i]));

if (auto res =
constFoldBinaryOpConditional<FloatAttr, FloatAttr::ValueType, void>(
constants,
[](const APFloat &a,
const APFloat &b) -> std::optional<APFloat> {
APFloat res2(a);
res2.divide(b, llvm::RoundingMode::NearestTiesToEven);
return res2;
})) {
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op.getType(), res.cast<ElementsAttr>());
return success();
if (op.getType().getElementType().isa<FloatType>()) {
if (auto res = constFoldBinaryOpConditional<FloatAttr,
FloatAttr::ValueType, void>(
constants,
[](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
APFloat res2(a);
res2.divide(b, llvm::RoundingMode::NearestTiesToEven);
return res2;
})) {
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op.getType(), res.cast<ElementsAttr>());
return success();
}
} else {
if (auto res = constFoldBinaryOpConditional<IntegerAttr,
IntegerAttr::ValueType, void>(
constants,
[](const APInt &a, const APInt &b) -> std::optional<APInt> {
APInt res2(a);
return res2.sdiv(b);
})) {
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op.getType(), res.cast<ElementsAttr>());
return success();
}
}

return failure();
Expand Down

0 comments on commit 15ae277

Please sign in to comment.