Skip to content

Revisit usage of overwrite_x parameter #185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* NumPy interface `mkl_fft.interfaces.numpy_fft` is aligned with numpy-2.x.x [gh-139](https://github.com/IntelPython/mkl_fft/pull/139), [gh-157](https://github.com/IntelPython/mkl_fft/pull/157)
* To set `mkl_fft` as the backend for SciPy is only possible through `mkl_fft.interfaces.scipy_fft` [gh-179](https://github.com/IntelPython/mkl_fft/pull/179)
* SciPy interface `mkl_fft.interfaces.scipy_fft` uses the same function from SciPy for handling `s` and `axes` for N-D FFTs [gh-181](https://github.com/IntelPython/mkl_fft/pull/181)
* Dropped support for `scipy.fftpack` interface [gh-185](https://github.com/IntelPython/mkl_fft/pull/185)
* Dropped support for `overwrite_x` parameter in `mkl_fft` [gh-185](https://github.com/IntelPython/mkl_fft/pull/185)

### Fixed
* Fixed an issue for calling `mkl_fft.interfaces.numpy.fftn` with an empty axes [gh-139](https://github.com/IntelPython/mkl_fft/pull/139)
* Fixed an issue for calling `mkl_fft.interfaces.numpy.fftn` with a zero-size array [gh-139](https://github.com/IntelPython/mkl_fft/pull/139)
* Fixed a bug in `mkl_fft.interfaces.numpy.fftn` when an empty tuple is passed for `axes` [gh-139](https://github.com/IntelPython/mkl_fft/pull/139)
* Fixed a bug for a case when a zero-size array is passed to `mkl_fft.interfaces.numpy.fftn` [gh-139](https://github.com/IntelPython/mkl_fft/pull/139)
* Fixed inconsistency of input and output arrays dtype for `irfft` function [gh-180](https://github.com/IntelPython/mkl_fft/pull/180)
* Fixed a bug for N-D FFTs when both `s` and `out` are given [gh-185](https://github.com/IntelPython/mkl_fft/pull/185)

## [1.3.14] (04/10/2025)

Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ While using these interfaces is the easiest way to leverage `mk_fft`, one can al

### complex-to-complex (c2c) transforms:

`fft(x, n=None, axis=-1, overwrite_x=False, fwd_scale=1.0, out=None)` - 1D FFT, similar to `scipy.fft.fft`
`fft(x, n=None, axis=-1, fwd_scale=1.0, out=None)` - 1D FFT, similar to `numpy.fft.fft`

`fft2(x, s=None, axes=(-2, -1), overwrite_x=False, fwd_scale=1.0, out=None)` - 2D FFT, similar to `scipy.fft.fft2`
`fft2(x, s=None, axes=(-2, -1), fwd_scale=1.0, out=None)` - 2D FFT, similar to `numpy.fft.fft2`

`fftn(x, s=None, axes=None, overwrite_x=False, fwd_scale=1.0, out=None)` - ND FFT, similar to `scipy.fft.fftn`
`fftn(x, s=None, axes=None, fwd_scale=1.0, out=None)` - ND FFT, similar to `numpy.fft.fftn`

and similar inverse FFT (`ifft*`) functions.

Expand Down
3 changes: 0 additions & 3 deletions mkl_fft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
rfft2,
rfftn,
)
from ._pydfti import irfftpack, rfftpack # pylint: disable=no-name-in-module
from ._version import __version__

import mkl_fft.interfaces # isort: skip
Expand All @@ -51,8 +50,6 @@
"ifft2",
"fftn",
"ifftn",
"rfftpack",
"irfftpack",
"rfft",
"irfft",
"rfft2",
Expand Down
149 changes: 94 additions & 55 deletions mkl_fft/_fft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,23 +262,43 @@ def _iter_fftnd(
axes=None,
out=None,
direction=+1,
overwrite_x=False,
scale_function=lambda n, ind: 1.0,
scale_function=lambda ind: 1.0,
):
a = np.asarray(a)
s, axes = _init_nd_shape_and_axes(a, s, axes)
ovwr = overwrite_x
for ii in reversed(range(len(axes))):

# Combine the two, but in reverse, to end with the first axis given.
axes_and_s = list(zip(axes, s))[::-1]
# We try to use in-place calculations where possible, which is
# everywhere except when the size changes after the first FFT.
size_changes = [axis for axis, n in axes_and_s[1:] if a.shape[axis] != n]

# If there are any size changes, we cannot use out
res = None if size_changes else out
for ind, (axis, n) in enumerate(axes_and_s):
if axis in size_changes:
if axis == size_changes[-1]:
# Last size change, so any output should now be OK
# (an error will be raised if not), and if no output is
# required, we want a freshly allocated array of the right size.
res = out
elif res is not None and n < res.shape[axis]:
# For an intermediate step where we return fewer elements, we
# can use a smaller view of the previous array.
res = res[(slice(None),) * axis + (slice(n),)]
else:
# If we need more elements, we cannot use res.
res = None
a = _c2c_fft1d_impl(
a,
n=s[ii],
axis=axes[ii],
overwrite_x=ovwr,
n=n,
axis=axis,
direction=direction,
fsc=scale_function(s[ii], ii),
out=out,
fsc=scale_function(ind),
out=res,
)
ovwr = True
# Default output for next iteration.
res = a
return a


Expand Down Expand Up @@ -360,7 +380,6 @@ def _c2c_fftnd_impl(
x,
s=None,
axes=None,
overwrite_x=False,
direction=+1,
fsc=1.0,
out=None,
Expand All @@ -385,7 +404,6 @@ def _c2c_fftnd_impl(
if _direct:
return _direct_fftnd(
x,
overwrite_x=overwrite_x,
direction=direction,
fsc=fsc,
out=out,
Expand All @@ -403,11 +421,7 @@ def _c2c_fftnd_impl(
x,
axes,
_direct_fftnd,
{
"overwrite_x": overwrite_x,
"direction": direction,
"fsc": fsc,
},
{"direction": direction, "fsc": fsc},
res,
)
else:
Expand All @@ -418,97 +432,122 @@ def _c2c_fftnd_impl(
axes=axes,
out=out,
direction=direction,
overwrite_x=overwrite_x,
scale_function=lambda n, i: fsc if i == 0 else 1.0,
scale_function=lambda i: fsc if i == 0 else 1.0,
)


def _r2c_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
a = np.asarray(x)
no_trim = (s is None) and (axes is None)
s, axes = _cook_nd_args(a, s, axes)
axes = [ax + a.ndim if ax < 0 else ax for ax in axes]
la = axes[-1]

# trim array, so that rfft avoids doing unnecessary computations
if not no_trim:
a = _trim_array(a, s, axes)

# last axis is not included since we calculate r2c FFT separately
# and not in the loop
axes_and_s = list(zip(axes, s))[-2::-1]
size_changes = [axis for axis, n in axes_and_s if a.shape[axis] != n]
res = None if size_changes else out

# r2c along last axis
a = _r2c_fft1d_impl(a, n=s[-1], axis=la, fsc=fsc, out=out)
a = _r2c_fft1d_impl(a, n=s[-1], axis=la, fsc=fsc, out=res)
res = a
if len(s) > 1:
if not no_trim:
ss = list(s)
ss[-1] = a.shape[la]
a = _pad_array(a, tuple(ss), axes)

len_axes = len(axes)
if len(set(axes)) == len_axes and len_axes == a.ndim and len_axes > 2:
if not no_trim:
ss = list(s)
ss[-1] = a.shape[la]
a = _pad_array(a, tuple(ss), axes)
# a series of ND c2c FFTs along last axis
ss, aa = _remove_axis(s, axes, -1)
ind = [
slice(None, None, 1),
] * len(s)
ind = [slice(None, None, 1)] * len(s)
for ii in range(a.shape[la]):
ind[la] = ii
tind = tuple(ind)
a_inp = a[tind]
res = out[tind] if out is not None else None
a_res = _c2c_fftnd_impl(
a_inp, s=ss, axes=aa, overwrite_x=True, direction=1, out=res
)
if a_res is not a_inp:
a[tind] = a_res # copy in place
res = out[tind] if out is not None else a_inp
_ = _c2c_fftnd_impl(a_inp, s=ss, axes=aa, direction=1, out=res)
if out is not None:
a = out
else:
# another size_changes check is needed if there are repeated axes
# of last axis, since since FFT changes the shape along last axis
size_changes = [
axis for axis, n in axes_and_s if a.shape[axis] != n
]

# a series of 1D c2c FFTs along all axes except last
for ii in range(len(axes) - 2, -1, -1):
a = _c2c_fft1d_impl(a, s[ii], axes[ii], overwrite_x=True)
for axis, n in axes_and_s:
if axis in size_changes:
if axis == size_changes[-1]:
res = out
elif res is not None and n < res.shape[axis]:
res = res[(slice(None),) * axis + (slice(n),)]
else:
res = None
a = _c2c_fft1d_impl(a, n, axis, out=res)
res = a
return a


def _c2r_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
a = np.asarray(x)
no_trim = (s is None) and (axes is None)
s, axes = _cook_nd_args(a, s, axes, invreal=True)
axes = [ax + a.ndim if ax < 0 else ax for ax in axes]
la = axes[-1]
if not no_trim:
a = _trim_array(a, s, axes)
if len(s) > 1:
if not no_trim:
a = _pad_array(a, s, axes)
ovr_x = True if _datacopied(a, x) else False
len_axes = len(axes)
if len(set(axes)) == len_axes and len_axes == a.ndim and len_axes > 2:
if not no_trim:
a = _pad_array(a, s, axes)
# a series of ND c2c FFTs along last axis
# due to need to write into a, we must copy
if not ovr_x:
a = a.copy()
ovr_x = True
a = a if _datacopied(a, x) else a.copy()
if not np.issubdtype(a.dtype, np.complexfloating):
# complex output will be copied to input, copy is needed
if a.dtype == np.float32:
a = a.astype(np.complex64)
else:
a = a.astype(np.complex128)
ovr_x = True
ss, aa = _remove_axis(s, axes, -1)
ind = [
slice(None, None, 1),
] * len(s)
ind = [slice(None, None, 1)] * len(s)
for ii in range(a.shape[la]):
ind[la] = ii
tind = tuple(ind)
a_inp = a[tind]
# out has real dtype and cannot be used in intermediate steps
a_res = _c2c_fftnd_impl(
a_inp, s=ss, axes=aa, overwrite_x=True, direction=-1
# ss and aa are reversed since np.irfftn uses forward order but
# np.ifftn uses reverse order see numpy-gh-28950
_ = _c2c_fftnd_impl(
a_inp, s=ss[::-1], axes=aa[::-1], out=a_inp, direction=-1
)
if a_res is not a_inp:
a[tind] = a_res # copy in place
else:
# a series of 1D c2c FFTs along all axes except last
for ii in range(len(axes) - 1):
# out has real dtype and cannot be used in intermediate steps
a = _c2c_fft1d_impl(
a, s[ii], axes[ii], overwrite_x=ovr_x, direction=-1
)
ovr_x = True
# forward order, see numpy-gh-28950
axes_and_s = list(zip(axes, s))[:-1]
size_changes = [
axis for axis, n in axes_and_s[1:] if a.shape[axis] != n
]
# out has real dtype cannot be used for intermediate steps
res = None
for axis, n in axes_and_s:
if axis in size_changes:
if res is not None and n < res.shape[axis]:
# pylint: disable=unsubscriptable-object
res = res[(slice(None),) * axis + (slice(n),)]
else:
res = None
a = _c2c_fft1d_impl(a, n, axis, out=res, direction=-1)
res = a
# c2r along last axis
a = _c2r_fft1d_impl(a, n=s[-1], axis=la, fsc=fsc, out=out)
return a
64 changes: 18 additions & 46 deletions mkl_fft/_mkl_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,85 +45,57 @@
]


def fft(x, n=None, axis=-1, out=None, overwrite_x=False, fwd_scale=1.0):
def fft(x, n=None, axis=-1, fwd_scale=1.0, out=None):
return _c2c_fft1d_impl(
x,
n=n,
axis=axis,
out=out,
overwrite_x=overwrite_x,
direction=+1,
fsc=fwd_scale,
x, n=n, axis=axis, out=out, direction=+1, fsc=fwd_scale
)


def ifft(x, n=None, axis=-1, out=None, overwrite_x=False, fwd_scale=1.0):
def ifft(x, n=None, axis=-1, fwd_scale=1.0, out=None):
return _c2c_fft1d_impl(
x,
n=n,
axis=axis,
out=out,
overwrite_x=overwrite_x,
direction=-1,
fsc=fwd_scale,
x, n=n, axis=axis, out=out, direction=-1, fsc=fwd_scale
)


def fft2(x, s=None, axes=(-2, -1), out=None, overwrite_x=False, fwd_scale=1.0):
return fftn(
x, s=s, axes=axes, out=out, overwrite_x=overwrite_x, fwd_scale=fwd_scale
)
def fft2(x, s=None, axes=(-2, -1), fwd_scale=1.0, out=None):
return fftn(x, s=s, axes=axes, out=out, fwd_scale=fwd_scale)


def ifft2(x, s=None, axes=(-2, -1), out=None, overwrite_x=False, fwd_scale=1.0):
return ifftn(
x, s=s, axes=axes, out=out, overwrite_x=overwrite_x, fwd_scale=fwd_scale
)
def ifft2(x, s=None, axes=(-2, -1), fwd_scale=1.0, out=None):
return ifftn(x, s=s, axes=axes, out=out, fwd_scale=fwd_scale)


def fftn(x, s=None, axes=None, out=None, overwrite_x=False, fwd_scale=1.0):
def fftn(x, s=None, axes=None, fwd_scale=1.0, out=None):
return _c2c_fftnd_impl(
x,
s=s,
axes=axes,
out=out,
overwrite_x=overwrite_x,
direction=+1,
fsc=fwd_scale,
x, s=s, axes=axes, out=out, direction=+1, fsc=fwd_scale
)


def ifftn(x, s=None, axes=None, out=None, overwrite_x=False, fwd_scale=1.0):
def ifftn(x, s=None, axes=None, fwd_scale=1.0, out=None):
return _c2c_fftnd_impl(
x,
s=s,
axes=axes,
out=out,
overwrite_x=overwrite_x,
direction=-1,
fsc=fwd_scale,
x, s=s, axes=axes, out=out, direction=-1, fsc=fwd_scale
)


def rfft(x, n=None, axis=-1, out=None, fwd_scale=1.0):
def rfft(x, n=None, axis=-1, fwd_scale=1.0, out=None):
return _r2c_fft1d_impl(x, n=n, axis=axis, out=out, fsc=fwd_scale)


def irfft(x, n=None, axis=-1, out=None, fwd_scale=1.0):
def irfft(x, n=None, axis=-1, fwd_scale=1.0, out=None):
return _c2r_fft1d_impl(x, n=n, axis=axis, out=out, fsc=fwd_scale)


def rfft2(x, s=None, axes=(-2, -1), out=None, fwd_scale=1.0):
def rfft2(x, s=None, axes=(-2, -1), fwd_scale=1.0, out=None):
return rfftn(x, s=s, axes=axes, out=out, fwd_scale=fwd_scale)


def irfft2(x, s=None, axes=(-2, -1), out=None, fwd_scale=1.0):
def irfft2(x, s=None, axes=(-2, -1), fwd_scale=1.0, out=None):
return irfftn(x, s=s, axes=axes, out=out, fwd_scale=fwd_scale)


def rfftn(x, s=None, axes=None, out=None, fwd_scale=1.0):
def rfftn(x, s=None, axes=None, fwd_scale=1.0, out=None):
return _r2c_fftnd_impl(x, s=s, axes=axes, out=out, fsc=fwd_scale)


def irfftn(x, s=None, axes=None, out=None, fwd_scale=1.0):
def irfftn(x, s=None, axes=None, fwd_scale=1.0, out=None):
return _c2r_fftnd_impl(x, s=s, axes=axes, out=out, fsc=fwd_scale)
Loading
Loading