From 30861b8b5e46a2338fe42c5ceecf7dd6884b93ac Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 16 Dec 2024 09:32:45 -0600 Subject: [PATCH] Fix transpose conert on complex types (#188) --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index d49394118..5d55ea1cc 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -3656,6 +3656,13 @@ struct TransposeTranspose } }; +size_t getBitWidth(mlir::Type ty) { + if (auto CT = dyn_cast(ty)) { + return 2 * getBitWidth(CT.getElementType()); + } + return ty.getIntOrFloatBitWidth(); +} + struct TransposeConvert : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -3665,8 +3672,10 @@ struct TransposeConvert : public OpRewritePattern { auto operandType = op.getOperand().getType().cast(); if (!resultType.hasStaticShape() || !operandType.hasStaticShape()) return failure(); - if (resultType.getNumElements() * resultType.getElementTypeBitWidth() >= - operandType.getNumElements() * operandType.getElementTypeBitWidth()) + if (resultType.getNumElements() * + getBitWidth(resultType.getElementType()) >= + operandType.getNumElements() * + getBitWidth(operandType.getElementType())) return failure(); auto transpose =