From 721156defae49656aa07cf683627d2e1f4c0e086 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 1 Jan 2025 13:30:36 -0500 Subject: [PATCH] feat: simplify binary transpose with one constant operand --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 39 +++++++++++++++++++++-- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 5804d09d..97eb2b7c 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -2550,12 +2550,45 @@ LogicalResult simplifyBinaryOpWithTranspose(OpType op, auto rhsOp = op.getRhs().template getDefiningOp(); if ((lhsOp && rhsOp) && (lhsOp.getPermutation() == rhsOp.getPermutation()) && lhsOp->hasOneUse() && rhsOp->hasOneUse()) { - auto newOp = rewriter.create( - op.getLoc(), op.getType(), lhsOp.getOperand(), rhsOp.getOperand()); - rewriter.replaceOpWithNewOp(op, op.getType(), newOp, + auto newOp = rewriter.create(op.getLoc(), lhsOp.getOperand(), + rhsOp.getOperand()); + rewriter.replaceOpWithNewOp(op, newOp, lhsOp.getPermutation()); return success(); } + + if (lhsOp && lhsOp->hasOneUse()) { + auto rhsConstOp = + op.getRhs().template getDefiningOp(); + if (rhsConstOp && rhsConstOp->hasOneUse()) { + // This will be eliminated by a transpose(constant) -> constant + // optimization + auto transposedConstOp = rewriter.create( + rhsConstOp.getLoc(), rhsConstOp, lhsOp.getPermutation()); + auto newOp = rewriter.create(op.getLoc(), lhsOp.getOperand(), + transposedConstOp); + rewriter.replaceOpWithNewOp( + op, newOp, lhsOp.getPermutation()); + return success(); + } + } + + if (rhsOp && rhsOp->hasOneUse()) { + auto lhsConstOp = + op.getLhs().template getDefiningOp(); + if (lhsConstOp && lhsConstOp->hasOneUse()) { + // This will be eliminated by a transpose(constant) -> constant + // optimization + auto transposedConstOp = rewriter.create( + lhsConstOp.getLoc(), lhsConstOp, rhsOp.getPermutation()); + auto newOp = rewriter.create(op.getLoc(), transposedConstOp, + rhsOp.getOperand()); + rewriter.replaceOpWithNewOp( + op, newOp, rhsOp.getPermutation()); + return success(); + } + } + return failure(); }