diff --git a/src/enzyme_ad/jax/Implementations/HLODerivatives.td b/src/enzyme_ad/jax/Implementations/HLODerivatives.td index d48ab47cf..da8ac71ed 100644 --- a/src/enzyme_ad/jax/Implementations/HLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/HLODerivatives.td @@ -896,9 +896,43 @@ def FftMultiplier : GlobalExpr(op.getLoc(), builder.getDenseI64ArrayAttr(ArrayRef({N}))); - auto ret_broadcast = builder.create(op.getLoc(), op_type.clone(op_type.getShape(), builder.getI64Type()), ret_constant, builder.getI64VectorAttr(op_type.getShape())); - builder.create(op.getLoc(), op->getResult(0).getType(), ret_broadcast); + double value = N; + switch (op.getFftType()) { + case FftType::FFT: + break; + case FftType::IFFT: + value = 1 / value; + break; + case FftType::RFFT: + value /= 2; + break; + case FftType::IRFFT: + value = 2 / value; + break; + } + auto resTy = op->getResult(0).getType().cast(); + mlir::Value ret_constant = builder.create(op.getLoc(), SplatElementsAttr::get( + resTy, FloatAttr::get(resTy.getElementType(), value))); + + if (op.getFftType() == FftType::RFFT || op.getFftType() == FftType::IRFFT) { + auto RT = RankedTensorType::get({1}, resTy.getElementType()); + auto zero_constant = builder.create(op.getLoc(), SplatElementsAttr::get( + RT, FloatAttr::get(resTy.getElementType(), 0))); + auto end_constant = builder.create(op.getLoc(), SplatElementsAttr::get( + RT, FloatAttr::get(resTy.getElementType(), lengths.back()-1))); + + auto RT64 = RankedTensorType::get({1}, builder.getIntegerType(64)); + + Value start[] = { + builder.create(op.getLoc(), SplatElementsAttr::get(RT64, op.getI64IntegerAttr(0))) + }; + Value end[] = { + builder.create(op.getLoc(), SplatElementsAttr::get(RT64, op.getI64IntegerAttr(lengths.size()-1))) + }; + ret_constant = builder.create(op.getLoc(), resTy, ret_constant, zero_constant, start); + ret_constant = builder.create(op.getLoc(), resTy, ret_constant, end_constant, end); + } + ret_constant; }]>; def FftIsIRFFT : GlobalExpr