Skip to content

Commit 10c278e

Browse files
committed
add an interface for dask.fft
1 parent 384fcb6 commit 10c278e

File tree

14 files changed

+295
-15
lines changed

14 files changed

+295
-15
lines changed

.github/workflows/build-with-clang.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,5 +73,5 @@ jobs:
7373
- name: Run mkl_fft tests
7474
run: |
7575
source ${{ env.ONEAPI_ROOT }}/setvars.sh
76-
pip install scipy mkl-service pytest
76+
pip install pytest mkl-service scipy dask
7777
pytest -s -v --pyargs mkl_fft

.github/workflows/conda-package-cf.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ jobs:
132132
- name: Install mkl_fft
133133
run: |
134134
CHANNELS="-c $GITHUB_WORKSPACE/channel ${{ env.CHANNELS }}"
135-
conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python_ver }} ${{ matrix.numpy }} $PACKAGE_NAME pytest scipy $CHANNELS
135+
conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python_ver }} ${{ matrix.numpy }} $PACKAGE_NAME pytest $CHANNELS
136136
# Test installed packages
137137
conda list -n ${{ env.TEST_ENV_NAME }}
138138
@@ -295,7 +295,7 @@ jobs:
295295
FOR /F "tokens=* USEBACKQ" %%F IN (`python -c "%SCRIPT%"`) DO (
296296
SET PACKAGE_VERSION=%%F
297297
)
298-
SET "TEST_DEPENDENCIES=pytest scipy"
298+
SET "TEST_DEPENDENCIES=pytest"
299299
conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} ${{ matrix.numpy }} -c ${{ env.workdir }}/channel ${{ env.CHANNELS }}
300300
301301
- name: Report content of test environment

.github/workflows/conda-package.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,7 @@ jobs:
131131
- name: Install mkl_fft
132132
run: |
133133
CHANNELS="-c $GITHUB_WORKSPACE/channel ${{ env.CHANNELS }}"
134-
conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python }} "scipy>=1.10" $CHANNELS
135-
conda install -n ${{ env.TEST_ENV_NAME }} $PACKAGE_NAME pytest $CHANNELS
134+
conda create -n ${{ env.TEST_ENV_NAME }} $PACKAGE_NAME python=${{ matrix.python }} pytest $CHANNELS
136135
# Test installed packages
137136
conda list -n ${{ env.TEST_ENV_NAME }}
138137
@@ -296,7 +295,7 @@ jobs:
296295
FOR /F "tokens=* USEBACKQ" %%F IN (`python -c "%SCRIPT%"`) DO (
297296
SET PACKAGE_VERSION=%%F
298297
)
299-
SET "TEST_DEPENDENCIES=pytest scipy"
298+
SET "TEST_DEPENDENCIES=pytest"
300299
conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} -c ${{ env.workdir }}/channel ${{ env.CHANNELS }}
301300
302301
- name: Report content of test environment

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
* Added Hermitian FFT functions to SciPy interface `mkl_fft.interfaces.scipy_fft`: `hfft`, `ihfft`, `hfftn`, `ihfftn`, `hfft2`, and `ihfft2` [gh-161](https://github.com/IntelPython/mkl_fft/pull/161)
1111
* Added support for `out` kwarg to all FFT functions in `mkl_fft` and `mkl_fft.interfaces.numpy_fft` [gh-157](https://github.com/IntelPython/mkl_fft/pull/157)
1212
* Added `fftfreq`, `fftshift`, `ifftshift`, and `rfftfreq` to both NumPy and SciPy interfaces [gh-179](https://github.com/IntelPython/mkl_fft/pull/179)
13+
* Added a new interface for FFT module of Dask accessible through `mkl_fft.interfaces.dask_fft` [gh-184](https://github.com/IntelPython/mkl_fft/pull/184)
1314

1415
### Changed
1516
* NumPy interface `mkl_fft.interfaces.numpy_fft` is aligned with numpy-2.x.x [gh-139](https://github.com/IntelPython/mkl_fft/pull/139), [gh-157](https://github.com/IntelPython/mkl_fft/pull/157)

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ More details can be found in [SciPy 2017 conference proceedings](https://github.
4747

4848
---
4949

50-
The `mkl_fft` package offers interfaces that act as drop-in replacements for equivalent functions in NumPy and SciPy. Learn more about these interfaces [here](https://github.com/IntelPython/mkl_fft/blob/master/mkl_fft/interfaces/README.md).
50+
The `mkl_fft` package offers interfaces that act as drop-in replacements for equivalent functions in NumPy, SciPy, and Dask. Learn more about these interfaces [here](https://github.com/IntelPython/mkl_fft/blob/master/mkl_fft/interfaces/README.md).
5151

5252
While using these interfaces is the easiest way to leverage `mk_fft`, one can also use `mkl_fft` directly with the following FFT functions:
5353

conda-recipe-cf/meta.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,20 @@ requirements:
2626
- python
2727
- mkl-service
2828
- numpy
29+
- scipy >=1.10.1
30+
- dask
2931

3032
test:
3133
commands:
3234
- pytest -v --pyargs mkl_fft
3335
requires:
3436
- pytest
35-
- scipy
3637
imports:
3738
- mkl_fft
3839
- mkl_fft.interfaces
3940
- mkl_fft.interfaces.numpy_fft
4041
- mkl_fft.interfaces.scipy_fft
42+
- mkl_fft.interfaces.dask_fft
4143

4244
about:
4345
home: http://github.com/IntelPython/mkl_fft

conda-recipe/meta.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,20 @@ requirements:
2626
- python
2727
- mkl-service
2828
- {{ pin_compatible('numpy') }}
29+
- scipy >=1.10.1
30+
- dask
2931

3032
test:
3133
commands:
3234
- pytest -v --pyargs mkl_fft
3335
requires:
3436
- pytest
35-
- scipy
3637
imports:
3738
- mkl_fft
3839
- mkl_fft.interfaces
3940
- mkl_fft.interfaces.numpy_fft
4041
- mkl_fft.interfaces.scipy_fft
42+
- mkl_fft.interfaces.dask_fft
4143

4244
about:
4345
home: http://github.com/IntelPython/mkl_fft

mkl_fft/interfaces/README.md

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Interfaces
2-
The `mkl_fft` package provides interfaces that serve as drop-in replacements for equivalent functions in NumPy and SciPy.
2+
The `mkl_fft` package provides interfaces that serve as drop-in replacements for equivalent functions in NumPy, SciPy, and Dask.
33

44
---
55

@@ -125,3 +125,43 @@ with mkl_fft.set_workers(4):
125125
y = scipy.signal.fftconvolve(a, a) # Note that Nthr:4
126126
# MKL_VERBOSE FFT(dcbo256x128,input_strides:{0,128,1},output_strides:{0,128,1},bScale:3.05176e-05,tLim:4,unaligned_output,desc:0x563aefe86180) 187.37us CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:4
127127
```
128+
129+
---
130+
131+
## Dask interface - `mkl_fft.interfaces.dask_fft`
132+
133+
This interface is a drop-in replacement for the [`dask.fft`](https://dask.pydata.org/en/latest/array-api.html#fast-fourier-transforms) module and includes **all** the functions available there:
134+
135+
* complex-to-complex FFTs: `fft`, `ifft`, `fft2`, `ifft2`, `fftn`, `ifftn`.
136+
137+
* real-to-complex and complex-to-real FFTs: `rfft`, `irfft`, `rfft2`, `irfft2`, `rfftn`, `irfftn`.
138+
139+
* Hermitian FFTs: `hfft`, `ihfft`.
140+
141+
* Helper routines: `fft_wrap`, `fftfreq`, `rfftfreq`, `fftshift`, `ifftshift`. These routines serve as a fallback to the Dask implementation and are included for completeness.
142+
143+
The following example shows how to use this interface for calculating a 2D FFT.
144+
145+
```python
146+
import numpy, dask
147+
import mkl_fft.interfaces.dask_fft as dask_fft
148+
149+
a = numpy.random.randn(128, 64) + 1j*numpy.random.randn(128, 64)
150+
x = dask.array.from_array(a, chunks=(64, 64))
151+
lazy_res = dask_fft.fft(x)
152+
mkl_res = lazy_res.compute()
153+
np_res = numpy.fft.fft(a)
154+
numpy.allclose(mkl_res, np_res)
155+
# True
156+
157+
# There are two chunks in this example based on the size of input array (128, 64) and chunk size (64, 64)
158+
# to confirm that MKL FFT is called twice, turn on verbosity
159+
import mkl
160+
mkl.verbose(1)
161+
# True
162+
163+
mkl_res = lazy_res.compute() # MKL_VERBOSE FFT is shown twice below which means MKL FFT is called twice
164+
# MKL_VERBOSE oneMKL 2024.0 Update 2 Patch 2 Product build 20240823 for Intel(R) 64 architecture Intel(R) Advanced Vector Extensions 512 (Intel(R) AVX-512) with support for INT8, BF16, FP16 (limited) instructions, and Intel(R) Advanced Matrix Extensions (Intel(R) AMX) with INT8 and BF16, Lnx 3.80GHz intel_thread
165+
# MKL_VERBOSE FFT(dcfo64*64,input_strides:{0,1},output_strides:{0,1},input_distance:64,output_distance:64,bScale:0.015625,tLim:32,unaligned_input,desc:0x7fd000010e40) 432.84us CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:112
166+
# MKL_VERBOSE FFT(dcfo64*64,input_strides:{0,1},output_strides:{0,1},input_distance:64,output_distance:64,bScale:0.015625,tLim:32,unaligned_input,desc:0x7fd480011300) 499.00us CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:112
167+
```

mkl_fft/interfaces/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,5 @@
2424
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2525

2626
from . import numpy_fft, scipy_fft
27+
28+
from . import dask_fft # isort: skip

mkl_fft/interfaces/dask_fft.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2025, Intel Corporation
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of Intel Corporation nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
20+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
from dask.array.fft import fft_wrap, fftfreq, fftshift, ifftshift, rfftfreq
28+
29+
from . import numpy_fft as _numpy_fft
30+
31+
__all__ = [
32+
"fft",
33+
"ifft",
34+
"fft2",
35+
"ifft2",
36+
"fftn",
37+
"ifftn",
38+
"rfft",
39+
"irfft",
40+
"rfft2",
41+
"irfft2",
42+
"rfftn",
43+
"irfftn",
44+
"hfft",
45+
"ihfft",
46+
"fftshift",
47+
"ifftshift",
48+
"fftfreq",
49+
"rfftfreq",
50+
"fft_wrap",
51+
]
52+
53+
54+
fft = fft_wrap(_numpy_fft.fft)
55+
ifft = fft_wrap(_numpy_fft.ifft)
56+
fft2 = fft_wrap(_numpy_fft.fft2)
57+
ifft2 = fft_wrap(_numpy_fft.ifft2)
58+
fftn = fft_wrap(_numpy_fft.fftn)
59+
ifftn = fft_wrap(_numpy_fft.ifftn)
60+
rfft = fft_wrap(_numpy_fft.rfft)
61+
irfft = fft_wrap(_numpy_fft.irfft)
62+
rfft2 = fft_wrap(_numpy_fft.rfft2)
63+
irfft2 = fft_wrap(_numpy_fft.irfft2)
64+
rfftn = fft_wrap(_numpy_fft.rfftn)
65+
irfftn = fft_wrap(_numpy_fft.irfftn)
66+
hfft = fft_wrap(_numpy_fft.hfft)
67+
ihfft = fft_wrap(_numpy_fft.ihfft)

mkl_fft/tests/test_interfaces.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,16 @@ def test_axes(func):
167167

168168

169169
@pytest.mark.parametrize(
170-
"interface", [mfi.scipy_fft, mfi.numpy_fft], ids=["scipy", "numpy"]
170+
"interface",
171+
[mfi.scipy_fft, mfi.numpy_fft, mfi.dask_fft],
172+
ids=["scipy", "numpy", "dask"],
171173
)
172174
@pytest.mark.parametrize(
173175
"func", ["fftshift", "ifftshift", "fftfreq", "rfftfreq"]
174176
)
175177
def test_interface_helper_functions(interface, func):
176178
assert hasattr(interface, func)
179+
180+
181+
def test_dask_fftwrap():
182+
assert hasattr(mfi.dask_fft, "fft_wrap")

0 commit comments

Comments
 (0)