Skip to content

Commit

Permalink
correct impl
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored and mofeing committed Jan 10, 2025
1 parent 61ea441 commit 08a05ee
Showing 1 changed file with 37 additions and 3 deletions.
40 changes: 37 additions & 3 deletions src/enzyme_ad/jax/Implementations/HLODerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -896,9 +896,43 @@ def FftMultiplier : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
auto lengths = op.getFftLength();
auto N = std::accumulate(lengths.begin(), lengths.end(), llvm::APInt(64, 1, true), std::multiplies{}).getSExtValue();

auto ret_constant = builder.create<ConstantOp>(op.getLoc(), builder.getDenseI64ArrayAttr(ArrayRef<int64_t>({N})));
auto ret_broadcast = builder.create<BroadcastInDimOp>(op.getLoc(), op_type.clone(op_type.getShape(), builder.getI64Type()), ret_constant, builder.getI64VectorAttr(op_type.getShape()));
builder.create<ConvertOp>(op.getLoc(), op->getResult(0).getType(), ret_broadcast);
double value = N;
switch (op.getFftType()) {
case FftType::FFT:
break;
case FftType::IFFT:
value = 1 / value;
break;
case FftType::RFFT:
value /= 2;
break;
case FftType::IRFFT:
value = 2 / value;
break;
}
auto resTy = op->getResult(0).getType().cast<RankedTensorType>();
mlir::Value ret_constant = builder.create<ConstantOp>(op.getLoc(), SplatElementsAttr::get(
resTy, FloatAttr::get(resTy.getElementType(), value)));

if (op.getFftType() == FftType::RFFT || op.getFftType() == FftType::IRFFT) {
auto RT = RankedTensorType::get({1}, resTy.getElementType());
auto zero_constant = builder.create<ConstantOp>(op.getLoc(), SplatElementsAttr::get(
RT, FloatAttr::get(resTy.getElementType(), 0)));
auto end_constant = builder.create<ConstantOp>(op.getLoc(), SplatElementsAttr::get(
RT, FloatAttr::get(resTy.getElementType(), lengths.back()-1)));

auto RT64 = RankedTensorType::get({1}, builder.getIntegerType(64));

Value start[] = {
builder.create<stablehlo::ConstantOp>(op.getLoc(), SplatElementsAttr::get(RT64, op.getI64IntegerAttr(0)))
};
Value end[] = {
builder.create<stablehlo::ConstantOp>(op.getLoc(), SplatElementsAttr::get(RT64, op.getI64IntegerAttr(lengths.size()-1)))
};
ret_constant = builder.create<stablehlo::DynamicUpdateSliceOp>(op.getLoc(), resTy, ret_constant, zero_constant, start);
ret_constant = builder.create<stablehlo::DynamicUpdateSliceOp>(op.getLoc(), resTy, ret_constant, end_constant, end);
}
ret_constant;
}]>;

def FftIsIRFFT : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
Expand Down

0 comments on commit 08a05ee

Please sign in to comment.