Skip to content

Commit

Permalink
feat: simplify binary transpose with one constant operand
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 1, 2025
1 parent af99cdf commit 721156d
Showing 1 changed file with 36 additions and 3 deletions.
39 changes: 36 additions & 3 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2550,12 +2550,45 @@ LogicalResult simplifyBinaryOpWithTranspose(OpType op,
auto rhsOp = op.getRhs().template getDefiningOp<stablehlo::TransposeOp>();
if ((lhsOp && rhsOp) && (lhsOp.getPermutation() == rhsOp.getPermutation()) &&
lhsOp->hasOneUse() && rhsOp->hasOneUse()) {
auto newOp = rewriter.create<OpType>(
op.getLoc(), op.getType(), lhsOp.getOperand(), rhsOp.getOperand());
rewriter.replaceOpWithNewOp<stablehlo::TransposeOp>(op, op.getType(), newOp,
auto newOp = rewriter.create<OpType>(op.getLoc(), lhsOp.getOperand(),
rhsOp.getOperand());
rewriter.replaceOpWithNewOp<stablehlo::TransposeOp>(op, newOp,
lhsOp.getPermutation());
return success();
}

if (lhsOp && lhsOp->hasOneUse()) {
auto rhsConstOp =
op.getRhs().template getDefiningOp<stablehlo::ConstantOp>();
if (rhsConstOp && rhsConstOp->hasOneUse()) {
// This will be eliminated by a transpose(constant) -> constant
// optimization
auto transposedConstOp = rewriter.create<stablehlo::TransposeOp>(
rhsConstOp.getLoc(), rhsConstOp, lhsOp.getPermutation());
auto newOp = rewriter.create<OpType>(op.getLoc(), lhsOp.getOperand(),
transposedConstOp);
rewriter.replaceOpWithNewOp<stablehlo::TransposeOp>(
op, newOp, lhsOp.getPermutation());
return success();
}
}

if (rhsOp && rhsOp->hasOneUse()) {
auto lhsConstOp =
op.getLhs().template getDefiningOp<stablehlo::ConstantOp>();
if (lhsConstOp && lhsConstOp->hasOneUse()) {
// This will be eliminated by a transpose(constant) -> constant
// optimization
auto transposedConstOp = rewriter.create<stablehlo::TransposeOp>(
lhsConstOp.getLoc(), lhsConstOp, rhsOp.getPermutation());
auto newOp = rewriter.create<OpType>(op.getLoc(), transposedConstOp,
rhsOp.getOperand());
rewriter.replaceOpWithNewOp<stablehlo::TransposeOp>(
op, newOp, rhsOp.getPermutation());
return success();
}
}

return failure();
}

Expand Down

0 comments on commit 721156d

Please sign in to comment.