Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix fft rev diff rule #176

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

fix fft rev diff rule #176

wants to merge 6 commits into from

Conversation

mofeing
Copy link
Collaborator

@mofeing mofeing commented Nov 29, 2024

CC @avik-pal

should fix #170

@wsmoses
Copy link
Member

wsmoses commented Dec 16, 2024

@mofeing is this the right math here?

@wsmoses wsmoses marked this pull request as ready for review December 16, 2024 20:51
Comment on lines 706 to 721
auto RT = RankedTensorType::get({1}, resTy.getElementType());
auto zero_constant = builder.create<ConstantOp>(op.getLoc(), SplatElementsAttr::get(
RT, FloatAttr::get(resTy.getElementType(), 0)));
auto end_constant = builder.create<ConstantOp>(op.getLoc(), SplatElementsAttr::get(
RT, FloatAttr::get(resTy.getElementType(), lengths.back()-1)));

auto RT64 = RankedTensorType::get({1}, builder.getIntegerType(64));

Value start[] = {
builder.create<stablehlo::ConstantOp>(op.getLoc(), SplatElementsAttr::get(RT64, op.getI64IntegerAttr(0)))
};
Value end[] = {
builder.create<stablehlo::ConstantOp>(op.getLoc(), SplatElementsAttr::get(RT64, op.getI64IntegerAttr(lengths.size()-1)))
};
ret_constant = builder.create<stablehlo::DynamicUpdateSliceOp>(op.getLoc(), resTy, ret_constant, zero_constant, start);
ret_constant = builder.create<stablehlo::DynamicUpdateSliceOp>(op.getLoc(), resTy, ret_constant, end_constant, end);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this is not right, and i believe it's my fault for not explaining it properly

  • RFFT => mult = N/2, except for i_n = 0 and i_n = dim(i_n) - 1 whose value is then N
  • IRFFT => mult = 2/N, except for i_n = 0 and i_n = dim(i_n) - 1 whose value is then 1/N

what i meant by i_n = 0 is that the n-th index to be equal to 0 and the res to be "colons", so i really meant [:,:,:,...,:,0]

but now rechecking, this is wrong because it should be for slices where i_0 = 0 and i_0 = dim(i_0) - 1. it's hard because StableHLO has weird semantics: if doing a 3-dim FFT, it performs FFT on 1st dimension and last 2 ones

so... for the case of RFFT and IRFFT, this Julia code should be equivalent:

value = if FFT
    N
elseif IFFT
    1/N
elseif RFFT
    N/2
else # IRFFT
    2/N
end

multiplier = fill(value, size(input))

if RFFT || IRFFT
    value = RFFT ? N : 1/N
    selectdim(multiplier, 1, 1) .= value
    selectdim(multiplier, 1, size(input, 1)) .= value
end

also note that because it's just a multiplier... maybe we could just skip the dynamic_update_slice in here, let it multiply with the FFT result and call slice + multiply + dynamic_update_slice on the result to correct it. this has the advantage that no array instantiation will happen even after optimizations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm perhaps I'll let you take it from here then. That said it is much better to have it done here (since in batched mode the constant is generated once vs in the ops itself we'd have to do it for each)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

FFT Derivative is Broken for RFFT/IFFT
2 participants