diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index d49394118..5d55ea1cc 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -3656,6 +3656,13 @@ struct TransposeTranspose } }; +size_t getBitWidth(mlir::Type ty) { + if (auto CT = dyn_cast(ty)) { + return 2 * getBitWidth(CT.getElementType()); + } + return ty.getIntOrFloatBitWidth(); +} + struct TransposeConvert : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -3665,8 +3672,10 @@ struct TransposeConvert : public OpRewritePattern { auto operandType = op.getOperand().getType().cast(); if (!resultType.hasStaticShape() || !operandType.hasStaticShape()) return failure(); - if (resultType.getNumElements() * resultType.getElementTypeBitWidth() >= - operandType.getNumElements() * operandType.getElementTypeBitWidth()) + if (resultType.getNumElements() * + getBitWidth(resultType.getElementType()) >= + operandType.getNumElements() * + getBitWidth(operandType.getElementType())) return failure(); auto transpose =