Skip to content

Commit

Permalink
Fix transpose conert on complex types (#188)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Dec 16, 2024
1 parent 561732e commit 30861b8
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3656,6 +3656,13 @@ struct TransposeTranspose
}
};

size_t getBitWidth(mlir::Type ty) {
if (auto CT = dyn_cast<ComplexType>(ty)) {
return 2 * getBitWidth(CT.getElementType());
}
return ty.getIntOrFloatBitWidth();
}

struct TransposeConvert : public OpRewritePattern<mlir::stablehlo::ConvertOp> {
using OpRewritePattern<mlir::stablehlo::ConvertOp>::OpRewritePattern;

Expand All @@ -3665,8 +3672,10 @@ struct TransposeConvert : public OpRewritePattern<mlir::stablehlo::ConvertOp> {
auto operandType = op.getOperand().getType().cast<TensorType>();
if (!resultType.hasStaticShape() || !operandType.hasStaticShape())
return failure();
if (resultType.getNumElements() * resultType.getElementTypeBitWidth() >=
operandType.getNumElements() * operandType.getElementTypeBitWidth())
if (resultType.getNumElements() *
getBitWidth(resultType.getElementType()) >=
operandType.getNumElements() *
getBitWidth(operandType.getElementType()))
return failure();

auto transpose =
Expand Down

0 comments on commit 30861b8

Please sign in to comment.