Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add Scipy and Cupy as fft interfaces #8

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ['3.10']
python-version: ['3.10', '3.11']
os: [macos-latest, ubuntu-latest, windows-latest]

steps:
Expand Down
7 changes: 3 additions & 4 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
sphinx==4.3.0
sphinxcontrib.bibtex>=2.0
sphinx_rtd_theme==1.0

sphinx
sphinxcontrib.bibtex
sphinx_rtd_theme
17 changes: 15 additions & 2 deletions docs/sec_code_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,34 @@ Fourier transform methods
=========================

.. _sec_code_fourier_numpy:

Numpy
-----
.. automodule:: qpretrieve.fourier.ff_numpy
:members:
:inherited-members:

.. _sec_code_fourier_pyfftw:

PyFFTW
------
.. automodule:: qpretrieve.fourier.ff_pyfftw
:members:
:inherited-members:

.. _sec_code_fourier_scipy:
Scipy
------
.. automodule:: qpretrieve.fourier.ff_scipy
:members:
:inherited-members:

.. _sec_code_fourier_cupy:
Cupy
----
.. automodule:: qpretrieve.fourier.ff_cupy
:members:
:inherited-members:


.. _sec_code_ifer:

Interference image analysis
Expand Down
2 changes: 2 additions & 0 deletions docs/sec_examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ Examples
.. fancy_include:: filter_visualization.py

.. fancy_include:: fourier_scale.py

.. fancy_include:: fft_options.py
Binary file added examples/fft_options.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
73 changes: 73 additions & 0 deletions examples/fft_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Fourier Transform interfaces available

This example visualizes the different backends and packages available to the
user for performing Fourier transforms.

- PyFFTW is initially slow, but over many FFTs is very quick.
- CuPy using CUDA can be very fast, but is currently limited because we are
transferring one image at a time to the GPU.

"""
import time
import matplotlib.pylab as plt
import numpy as np
import qpretrieve

# load the experimental data
edata = np.load("./data/hologram_cell.npz")

# get the available fft interfaces
interfaces_available = qpretrieve.fourier.get_available_interfaces()

n_transforms = 100

# one transform
results_1 = {}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to clean this script up a bit to make it simpler

for fft_interface in interfaces_available:
t0 = time.time()
holo = qpretrieve.OffAxisHologram(data=edata["data"],
fft_interface=fft_interface)
holo.run_pipeline(filter_name="disk", filter_size=1 / 2)
bg = qpretrieve.OffAxisHologram(data=edata["bg_data"])
bg.process_like(holo)
t1 = time.time()
results_1[fft_interface.__name__] = t1 - t0
num_interfaces = len(results_1)

# multiple transforms (should see speed increase for PyFFTW)
results = {}
for fft_interface in interfaces_available:
t0 = time.time()
for _ in range(n_transforms):
holo = qpretrieve.OffAxisHologram(data=edata["data"],
fft_interface=fft_interface)
holo.run_pipeline(filter_name="disk", filter_size=1 / 2)
bg = qpretrieve.OffAxisHologram(data=edata["bg_data"])
bg.process_like(holo)
t1 = time.time()
results[fft_interface.__name__] = t1 - t0
num_interfaces = len(results)

fft_interfaces = list(results.keys())
speed_1 = list(results_1.values())
speed = list(results.values())

fig, axes = plt.subplots(1, 2, figsize=(8, 5))
ax1, ax2 = axes
labels = [fftstr[9:] for fftstr in fft_interfaces]

ax1.bar(range(num_interfaces), height=speed_1, color='lightseagreen')
ax1.set_xticks(range(num_interfaces), labels=labels,
rotation=45)
ax1.set_ylabel("Speed (s)")
ax1.set_title("1 Transform")

ax2.bar(range(num_interfaces), height=speed, color='lightseagreen')
ax2.set_xticks(range(num_interfaces), labels=labels,
rotation=45)
ax2.set_ylabel("Speed (s)")
ax2.set_title(f"{n_transforms} Transforms")

plt.suptitle("Speed of FFT Interfaces")
plt.tight_layout()
plt.show()
1 change: 1 addition & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
matplotlib
8 changes: 5 additions & 3 deletions qpretrieve/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_filter_array(filter_name, filter_size, freq_pos, fft_shape):
and must be between 0 and `max(fft_shape)/2`
freq_pos: tuple of floats
The position of the filter in frequency coordinates as
returned by :func:`nunpy.fft.fftfreq`.
returned by :func:`numpy.fft.fftfreq`.
fft_shape: tuple of int
The shape of the Fourier transformed image for which the
filter will be applied. The shape must be squared (two
Expand Down Expand Up @@ -104,8 +104,10 @@ def get_filter_array(filter_name, filter_size, freq_pos, fft_shape):
# TODO: avoid the np.roll, instead use the indices directly
alpha = 0.1
rsize = int(min(fx.size, fy.size) * filter_size) * 2
tukey_window_x = signal.tukey(rsize, alpha=alpha).reshape(-1, 1)
tukey_window_y = signal.tukey(rsize, alpha=alpha).reshape(1, -1)
tukey_window_x = signal.windows.tukey(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scipy's new versions imports tukey from signal.window, not signal

rsize, alpha=alpha).reshape(-1, 1)
tukey_window_y = signal.windows.tukey(
rsize, alpha=alpha).reshape(1, -1)
tukey = tukey_window_x * tukey_window_y
base = np.zeros(fft_shape)
s1 = (np.array(fft_shape) - rsize) // 2
Expand Down
23 changes: 23 additions & 0 deletions qpretrieve/fourier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,38 @@
import warnings

from .ff_numpy import FFTFilterNumpy
from .ff_scipy import FFTFilterScipy

try:
from .ff_pyfftw import FFTFilterPyFFTW
except ImportError:
FFTFilterPyFFTW = None

try:
from .ff_cupy import FFTFilterCupy
from .ff_cupy3D import FFTFilterCupy3D
except ImportError:
FFTFilterCupy = None

PREFERRED_INTERFACE = None


def get_available_interfaces():
"""Return a list of available FFT algorithms"""
interfaces = [
FFTFilterPyFFTW,
FFTFilterNumpy,
FFTFilterScipy,
FFTFilterCupy,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't necessarily the "perferred order" we want, but I'd like to keep it as is due to old default pipelines.

FFTFilterCupy3D,
]
interfaces_available = []
for interface in interfaces:
if interface is not None and interface.is_available:
interfaces_available.append(interface)
return interfaces_available


def get_best_interface():
"""Return the fastest refocusing interface available

Expand Down
8 changes: 6 additions & 2 deletions qpretrieve/fourier/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def __init__(self, data, subtract_mean=True, padding=2, copy=True):
else:
# convert integer-arrays to floating point arrays
dtype = float
if not copy:
# numpy v2.x behaviour requires asarray with copy=False
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy has made copying a bit awkward. When an array can't be copied (with either np.array or np.asarray), and the user has copy=False, then it throws an error!

copy = None
data_ed = np.array(data, dtype=dtype, copy=copy)
#: original data (with subtracted mean)
self.origin = data_ed
Expand Down Expand Up @@ -175,7 +178,7 @@ def filter(self, filter_name: str, filter_size: float,
and must be between 0 and `max(fft_shape)/2`
freq_pos: tuple of floats
The position of the filter in frequency coordinates as
returned by :func:`nunpy.fft.fftfreq`.
returned by :func:`numpy.fft.fftfreq`.
scale_to_filter: bool or float
Crop the image in Fourier space after applying the filter,
effectively removing surplus (zero-padding) data and
Expand Down Expand Up @@ -220,7 +223,8 @@ def filter(self, filter_name: str, filter_size: float,
filter_name=filter_name,
filter_size=filter_size,
freq_pos=freq_pos,
fft_shape=self.fft_origin.shape)
# only take shape of a single fft
fft_shape=self.fft_origin.shape[-2:])
fft_filtered = self.fft_origin * filt_array
px = int(freq_pos[0] * self.shape[0])
py = int(freq_pos[1] * self.shape[1])
Expand Down
40 changes: 40 additions & 0 deletions qpretrieve/fourier/ff_cupy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import scipy as sp
import cupy as cp
import cupyx.scipy.fft as cufft

from .base import FFTFilter


class FFTFilterCupy(FFTFilter):
"""Wraps the cupy Fourier transform and uses it via the scipy backend
"""
is_available = True
# sp.fft.set_backend(cufft)

def _init_fft(self, data):
"""Perform initial Fourier transform of the input data

Parameters
----------
data: 2d real-valued np.ndarray
Input field to be refocused

Returns
-------
fft_fdata: 2d complex-valued ndarray
Fourier transform `data`
"""
data_gpu = cp.asarray(data)
# likely an inefficiency here, could use `set_global_backend`
with sp.fft.set_backend(cufft):
fft_gpu = sp.fft.fft2(data_gpu)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not ideal, this is something to work on.

fft_cpu = fft_gpu.get()
return fft_cpu

def _ifft(self, data):
"""Perform inverse Fourier transform"""
data_gpu = cp.asarray(data)
with sp.fft.set_backend(cufft):
ifft_gpu = sp.fft.ifft2(data_gpu)
ifft_cpu = ifft_gpu.get()
return ifft_cpu
40 changes: 40 additions & 0 deletions qpretrieve/fourier/ff_cupy3D.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import scipy as sp
import cupy as cp
import cupyx.scipy.fft as cufft

from .base import FFTFilter


class FFTFilterCupy3D(FFTFilter):
"""Wraps the cupy Fourier transform and uses it via the scipy backend
"""
is_available = True
# sp.fft.set_backend(cufft)

def _init_fft(self, data):
"""Perform initial Fourier transform of the input data

Parameters
----------
data: 2d real-valued np.ndarray
Input field to be refocused

Returns
-------
fft_fdata: 2d complex-valued ndarray
Fourier transform `data`
"""
data_gpu = cp.asarray(data)
# likely an inefficiency here, could use `set_global_backend`
with sp.fft.set_backend(cufft):
fft_gpu = sp.fft.fft2(data_gpu)
fft_cpu = fft_gpu.get()
return fft_cpu

def _ifft(self, data):
"""Perform inverse Fourier transform"""
data_gpu = cp.asarray(data)
with sp.fft.set_backend(cufft):
ifft_gpu = sp.fft.ifft2(data_gpu)
ifft_cpu = ifft_gpu.get()
return ifft_cpu
30 changes: 30 additions & 0 deletions qpretrieve/fourier/ff_scipy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import scipy as sp


from .base import FFTFilter


class FFTFilterScipy(FFTFilter):
"""Wraps the scipy Fourier transform
"""
# always available, because scipy is a dependency
is_available = True

def _init_fft(self, data):
"""Perform initial Fourier transform of the input data

Parameters
----------
data: 2d real-valued np.ndarray
Input field to be refocused

Returns
-------
fft_fdata: 2d complex-valued ndarray
Fourier transform `data`
"""
return sp.fft.fft2(data)

def _ifft(self, data):
"""Perform inverse Fourier transform"""
return sp.fft.ifft2(data)
Loading
Loading