Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix fft rev diff rule #176

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions src/enzyme_ad/jax/Implementations/HLODerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -681,12 +681,51 @@ def FftLength : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{

def FftMultiplier : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
auto op_type = op->getResult(0).getType().cast<RankedTensorType>();
auto lengths = op.getFftLength();
auto lengths = op.getFftLengthAttr().getValues<int64_t>();
auto N = std::accumulate(lengths.begin(), lengths.end(), llvm::APInt(64, 1, true), std::multiplies{}).getSExtValue();

auto ret_constant = builder.create<ConstantOp>(op.getLoc(), builder.getDenseI64ArrayAttr(ArrayRef<int64_t>({N})));
auto ret_broadcast = builder.create<BroadcastInDimOp>(op.getLoc(), op_type.clone(op_type.getShape(), builder.getI64Type()), ret_constant, builder.getI64VectorAttr(op_type.getShape()));
builder.create<ConvertOp>(op.getLoc(), op->getResult(0).getType(), ret_broadcast);
double value = N;
switch (op.getFftType()) {
case FftType::FFT:
break;
case FftType::IFFT:
value = 1 / value;
break;
case FftType::RFFT:
value /= 2;
break;
case FftType::IRFFT:
value = 2 / value;
break;
}
auto resTy = op->getResult(0).getType().cast<RankedTensorType>();
mlir::Value ret_constant = builder.create<ConstantOp>(op.getLoc(), SplatElementsAttr::get(
resTy, FloatAttr::get(resTy.getElementType(), value)));

if (op.getFftType() == FftType::RFFT || op.getFftType() == FftType::IRFFT) {
auto RT = RankedTensorType::get({1}, resTy.getElementType());
auto zero_constant = builder.create<ConstantOp>(op.getLoc(), SplatElementsAttr::get(
RT, FloatAttr::get(resTy.getElementType(), 0)));
auto end_constant = builder.create<ConstantOp>(op.getLoc(), SplatElementsAttr::get(
RT, FloatAttr::get(resTy.getElementType(), lengths[lengths.size()-1]-1)));

auto RT64 = RankedTensorType::get({1}, builder.getIntegerType(64));

Value start[] = {
builder.create<stablehlo::ConstantOp>(op.getLoc(), SplatElementsAttr::get(RT64, rewriter.getI64IntegerAttr(0)))
};
Value end[] = {
builder.create<stablehlo::ConstantOp>(op.getLoc(), SplatElementsAttr::get(RT64, rewriter.getI64IntegerAttr(lengths.size()-1)))
};
ret_constant = builder.create<stablehlo::DynamicUpdateSliceOp>(op.getLoc(), resTy, ret_constant, zero_constant, start);
ret_constant = builder.create<stablehlo::DynamicUpdateSliceOp>(op.getLoc(), resTy, ret_constant, end_constant, end);
}
ret_constant;
}]>;

def FftIsIRFFT : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
auto cond = op.getFftType() == FftType::IRFFT;
builder.create<ConstantOp>(op.getLoc(), builder.getDenseBoolArrayAttr(ArrayRef<bool>({cond})));
}]>;

// Derivative rules
Expand Down Expand Up @@ -785,12 +824,17 @@ def : HLODerivative<"ExpOp", (Op $x), [(CheckedMul (DiffeRet), (Exp $x))]>;

def : HLODerivative<"Expm1Op", (Op $x), [(CheckedMul (DiffeRet), (Exp $x))]>;

// TODO fix `rfft` and `irfft` derivatives:
// - `rfft` => divide `DiffeRet` elems by 2 except 1st elem, and last elem if `FftLength` is even
// - `irfft` =>
def : HLODerivative<"FftOp", (Op $x),
[
(Fft (DiffeRet), (FftType), (FftLength)) // TODO maybe we need to conjugate? or inverse fft + multiply by N?
(Mul
(FftMultiplier), // TODO fix this
(Fft
(Select
(FftIsIRFFT), // if IRFFT
(Real (DiffeRet)), // call real(diff)
(DiffeRet),
(FftTypeInverse),
(FftLength))))
],
(Fft (Shadow $x), (FftType), (FftLength))
>;
Expand Down
Loading