Skip to content

Commit

Permalink
get the phase in
Browse files Browse the repository at this point in the history
  • Loading branch information
dkazanc committed May 1, 2024
1 parent d2ff482 commit 768cbc3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
2 changes: 1 addition & 1 deletion httomolibgpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from httomolibgpu.misc.rescale import rescale_to_int
from httomolibgpu.prep.alignment import distortion_correction_proj_discorpy
from httomolibgpu.prep.normalize import normalize
#from httomolibgpu.prep.phase import paganin_filter_savu, paganin_filter_tomopy
from httomolibgpu.prep.phase import paganin_filter_savu, paganin_filter_tomopy
# from httomolibgpu.prep.stripe import (
# remove_stripe_based_sorting,
# remove_stripe_ti,
Expand Down
19 changes: 10 additions & 9 deletions httomolibgpu/prep/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""Modules for phase retrieval and phase-contrast enhancement"""

cupy_run = False
import numpy as np
try:
import cupy as xp

Expand All @@ -35,11 +36,13 @@
import numpy as np

from numpy import float32
import numpy as np
import nvtx

if cupy_run:
from httomolibgpu.cuda_kernels import load_cuda_module
from cupyx.scipy.fft import fft2, ifft2, fftshift
else:
from scipy.fft import fft2, ifft2, fftshift

__all__ = [
"paganin_filter_savu",
Expand Down Expand Up @@ -104,8 +107,7 @@ def paganin_filter_savu(
-------
cp.ndarray
The stack of filtered projections.
"""
import cupyx
"""

# Check the input data is valid
if data.ndim != 3:
Expand Down Expand Up @@ -170,7 +172,7 @@ def paganin_filter_savu(

# avoid normalising in both directions - we include multiplier in the post_kernel
data = xp.asarray(data, dtype=xp.complex64)
data = cupyx.scipy.fft.fft2(data, axes=(-2, -1), overwrite_x=True, norm="backward")
data = fft2(data, axes=(-2, -1), overwrite_x=True, norm="backward")

# prepare filter here, while the GPU is busy with the FFT
filtercomplex = xp.empty((height1, width1), dtype=np.complex64)
Expand All @@ -193,7 +195,7 @@ def paganin_filter_savu(
)
data *= filtercomplex

data = cupyx.scipy.fft.ifft2(data, axes=(-2, -1), overwrite_x=True, norm="forward")
data = ifft2(data, axes=(-2, -1), overwrite_x=True, norm="forward")

post_kernel = xp.ElementwiseKernel(
"C pci1, raw float32 increment, raw float32 ratio, raw float32 fft_scale",
Expand Down Expand Up @@ -318,7 +320,6 @@ def paganin_filter_tomopy(
cp.ndarray
The 3D array of Paganin phase-filtered projection images.
"""
import cupyx

# Check the input data is valid
if tomo.ndim != 3:
Expand All @@ -337,20 +338,20 @@ def paganin_filter_tomopy(

# 3D FFT of tomo data
padded_tomo = xp.asarray(padded_tomo, dtype=xp.complex64)
fft_tomo = cupyx.scipy.fft.fft2(padded_tomo, axes=(-2, -1), overwrite_x=True)
fft_tomo = fft2(padded_tomo, axes=(-2, -1), overwrite_x=True)

# Compute the reciprocal grid.
w2 = _reciprocal_grid(pixel_size, (dy, dx))

# Build filter in the Fourier space.
phase_filter = cupyx.scipy.fft.fftshift(
phase_filter = fftshift(
_paganin_filter_factor2(energy, dist, alpha, w2)
)
phase_filter = phase_filter / phase_filter.max() # normalisation

# Apply filter and take inverse FFT
ifft_filtered_tomo = (
cupyx.scipy.fft.ifft2(phase_filter * fft_tomo, axes=(-2, -1), overwrite_x=True)
ifft2(phase_filter * fft_tomo, axes=(-2, -1), overwrite_x=True)
).real

# slicing indices for cropping
Expand Down

0 comments on commit 768cbc3

Please sign in to comment.