From 15ae277e711fb97340a17395e8b4c807130a9c8b Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 10 Mar 2024 04:44:09 -0400 Subject: [PATCH] Fix constprop on integers --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 151 +++++++++++++++------- 1 file changed, 103 insertions(+), 48 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index c1c4c3748..9e11fc5f2 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -857,18 +857,32 @@ struct AddSimplify : public OpRewritePattern { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) matchPattern(op->getOperand(i), m_Constant(&constants[i])); - if (auto res = - constFoldBinaryOpConditional( - constants, - [](const APFloat &a, - const APFloat &b) -> std::optional { - APFloat res2(a); - res2.add(b, llvm::RoundingMode::NearestTiesToEven); - return res2; - })) { - rewriter.replaceOpWithNewOp( - op, op.getType(), res.cast()); - return success(); + if (op.getType().getElementType().isa()) { + if (auto res = constFoldBinaryOpConditional( + constants, + [](const APFloat &a, const APFloat &b) -> std::optional { + APFloat res2(a); + res2.add(b, llvm::RoundingMode::NearestTiesToEven); + return res2; + })) { + rewriter.replaceOpWithNewOp( + op, op.getType(), res.cast()); + return success(); + } + } else { + if (auto res = constFoldBinaryOpConditional( + constants, + [](const APInt &a, const APInt &b) -> std::optional { + APInt res2(a); + res2 += b; + return res2; + })) { + rewriter.replaceOpWithNewOp( + op, op.getType(), res.cast()); + return success(); + } } return failure(); @@ -896,18 +910,32 @@ struct SubSimplify : public OpRewritePattern { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) matchPattern(op->getOperand(i), m_Constant(&constants[i])); - if (auto res = - constFoldBinaryOpConditional( - constants, - [](const APFloat &a, - const APFloat &b) -> std::optional { - APFloat res2(a); - res2.subtract(b, llvm::RoundingMode::NearestTiesToEven); - return res2; - })) { - rewriter.replaceOpWithNewOp( - op, op.getType(), res.cast()); - return success(); + if (op.getType().getElementType().isa()) { + if (auto res = constFoldBinaryOpConditional( + constants, + [](const APFloat &a, const APFloat &b) -> std::optional { + APFloat res2(a); + res2.subtract(b, llvm::RoundingMode::NearestTiesToEven); + return res2; + })) { + rewriter.replaceOpWithNewOp( + op, op.getType(), res.cast()); + return success(); + } + } else { + if (auto res = constFoldBinaryOpConditional( + constants, + [](const APInt &a, const APInt &b) -> std::optional { + APInt res2(a); + res2 -= b; + return res2; + })) { + rewriter.replaceOpWithNewOp( + op, op.getType(), res.cast()); + return success(); + } } return failure(); @@ -1026,18 +1054,32 @@ struct MulSimplify : public OpRewritePattern { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) matchPattern(op->getOperand(i), m_Constant(&constants[i])); - if (auto res = - constFoldBinaryOpConditional( - constants, - [](const APFloat &a, - const APFloat &b) -> std::optional { - APFloat res2(a); - res2.multiply(b, llvm::RoundingMode::NearestTiesToEven); - return res2; - })) { - rewriter.replaceOpWithNewOp( - op, op.getType(), res.cast()); - return success(); + if (op.getType().getElementType().isa()) { + if (auto res = constFoldBinaryOpConditional( + constants, + [](const APFloat &a, const APFloat &b) -> std::optional { + APFloat res2(a); + res2.multiply(b, llvm::RoundingMode::NearestTiesToEven); + return res2; + })) { + rewriter.replaceOpWithNewOp( + op, op.getType(), res.cast()); + return success(); + } + } else { + if (auto res = constFoldBinaryOpConditional( + constants, + [](const APInt &a, const APInt &b) -> std::optional { + APInt res2(a); + res2 *= b; + return res2; + })) { + rewriter.replaceOpWithNewOp( + op, op.getType(), res.cast()); + return success(); + } } return failure(); @@ -1067,18 +1109,31 @@ struct DivSimplify : public OpRewritePattern { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) matchPattern(op->getOperand(i), m_Constant(&constants[i])); - if (auto res = - constFoldBinaryOpConditional( - constants, - [](const APFloat &a, - const APFloat &b) -> std::optional { - APFloat res2(a); - res2.divide(b, llvm::RoundingMode::NearestTiesToEven); - return res2; - })) { - rewriter.replaceOpWithNewOp( - op, op.getType(), res.cast()); - return success(); + if (op.getType().getElementType().isa()) { + if (auto res = constFoldBinaryOpConditional( + constants, + [](const APFloat &a, const APFloat &b) -> std::optional { + APFloat res2(a); + res2.divide(b, llvm::RoundingMode::NearestTiesToEven); + return res2; + })) { + rewriter.replaceOpWithNewOp( + op, op.getType(), res.cast()); + return success(); + } + } else { + if (auto res = constFoldBinaryOpConditional( + constants, + [](const APInt &a, const APInt &b) -> std::optional { + APInt res2(a); + return res2.sdiv(b); + })) { + rewriter.replaceOpWithNewOp( + op, op.getType(), res.cast()); + return success(); + } } return failure();