Skip to content

Commit

Permalink
Merge pull request #248 from Ciela-Institute/fixmultiplane
Browse files Browse the repository at this point in the history
bug fix: ensure pixelated convergence kernel FFT convolution is valid
  • Loading branch information
RonanLegin authored Nov 5, 2024
2 parents 2f333c5 + 7338960 commit e67d4cc
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 30 deletions.
4 changes: 4 additions & 0 deletions src/caustics/lenses/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
reduced_deflection_angle_pixelated_convergence,
potential_pixelated_convergence,
_fft2_padded,
_fft_size,
build_kernels_pixelated_convergence,
build_window_pixelated_convergence,
)
from .pseudo_jaffe import (
convergence_0_pseudo_jaffe,
Expand Down Expand Up @@ -111,7 +113,9 @@
"reduced_deflection_angle_pixelated_convergence",
"potential_pixelated_convergence",
"_fft2_padded",
"_fft_size",
"build_kernels_pixelated_convergence",
"build_window_pixelated_convergence",
"convergence_0_pseudo_jaffe",
"potential_pseudo_jaffe",
"reduced_deflection_angle_pseudo_jaffe",
Expand Down
37 changes: 30 additions & 7 deletions src/caustics/lenses/func/pixelated_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,44 @@ def build_kernels_pixelated_convergence(pixelscale, n_pix):
"""
x_mg, y_mg = meshgrid(pixelscale, 2 * n_pix)
# Shift to center kernels within pixel at index n_pix
x_mg = x_mg - pixelscale / 2
y_mg = y_mg - pixelscale / 2

d2 = x_mg**2 + y_mg**2
potential_kernel = safe_log(d2.sqrt())
ax_kernel = safe_divide(x_mg, d2)
ay_kernel = safe_divide(y_mg, d2)
# Set centers of kernels to zero
potential_kernel[..., n_pix, n_pix] = 0
ax_kernel[..., n_pix, n_pix] = 0
ay_kernel[..., n_pix, n_pix] = 0

return ax_kernel, ay_kernel, potential_kernel


def build_window_pixelated_convergence(window, kernel_shape):
"""
Window the kernel for stable FFT.
Parameters
----------
window: float
The window to apply as a fraction of the image width. For example a
window of 1/4 will set the kernel to start decreasing at 1/4 of the
image width, and then linearly go to zero.
kernel_shape: tuple
The shape of the kernel to be windowed.
Returns
-------
Tensor
The window to multiply with the kernel.
"""
x, y = torch.meshgrid(
torch.linspace(-1, 1, kernel_shape[-1]),
torch.linspace(-1, 1, kernel_shape[-2]),
indexing="xy",
)
r = (x**2 + y**2).sqrt()
return torch.clip((1 - r) / window, 0, 1)


def _fft_size(n_pix):
pad = 2 * n_pix
pad = next_fast_len(pad)
Expand Down
58 changes: 35 additions & 23 deletions src/caustics/lenses/pixelated_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,16 @@ def __init__(
Literal["zero", "circular", "reflect", "tile"],
"Specifies the type of padding",
] = "zero",
window_kernel: Annotated[float, "Amount of kernel to be windowed"] = 1.0 / 8.0,
name: NameType = None,
):
"""Strong lensing with user provided kappa map
PixelatedConvergence is a class for strong gravitational lensing with a
user-provided kappa map. It inherits from the ThinLens class.
This class enables the computation of deflection angles and
lensing potential by applying the user-provided kappa map to a
grid using either Fast Fourier Transform (FFT) or a 2D
convolution.
user-provided kappa map. It inherits from the ThinLens class. This class
enables the computation of deflection angles and lensing potential by
applying the user-provided kappa map to a grid using either Fast Fourier
Transform (FFT) or a 2D convolution.
Attributes
----------
Expand Down Expand Up @@ -104,27 +104,32 @@ def __init__(
The shape of the convergence map.
convolution_mode: str, optional
The convolution mode for calculating deflection angles and lensing potential.
It can be either "fft" (Fast Fourier Transform) or "conv2d" (2D convolution).
Default is "fft".
The convolution mode for calculating deflection angles and lensing
potential. It can be either "fft" (Fast Fourier Transform) or
"conv2d" (2D convolution). Default is "fft".
use_next_fast_len: bool, optional
If True, adds additional padding to speed up the FFT by calling
`scipy.fft.next_fast_len`.
The speed boost can be substantial when `n_pix` is a multiple of a
small prime number. Default is True.
`scipy.fft.next_fast_len`. The speed boost can be substantial when
`n_pix` is a multiple of a small prime number. Default is True.
padding: { "zero", "circular", "reflect", "tile" }
Specifies the type of padding to use:
"zero" will do zero padding,
"circular" will do cyclic boundaries.
"reflect" will do reflection padding.
"tile" will tile the image at 2x2 which
basically identical to circular padding, but is easier.
Specifies the type of padding to use: "zero" will do zero padding,
"circular" will do cyclic boundaries. "reflect" will do reflection
padding. "tile" will tile the image at 2x2 which basically identical
to circular padding, but is easier. Use zero padding to represent an
overdensity, the other padding schemes represent a mass distribution
embedded in a field of similar mass distributions.
Generally you should use either "zero" or "tile".
window_kernel: float, optional
Amount of kernel to be windowed, specify the fraction of the kernel
size from which a linear window scaling will ensure the edges go to
zero for the purpose of FFT stability. Set to 0 for no windowing.
Default is 1/8.
"""

super().__init__(cosmology, z_l, name=name)
Expand Down Expand Up @@ -153,6 +158,13 @@ def __init__(
self.ax_kernel, self.ay_kernel, self.potential_kernel = (
func.build_kernels_pixelated_convergence(self.pixelscale, self.n_pix)
)
# Window the kernels if needed
if padding != "zero" and convolution_mode == "fft" and window_kernel > 0:
window = func.build_window_pixelated_convergence(
window_kernel, self.ax_kernel.shape
)
self.ax_kernel = self.ax_kernel * window
self.ay_kernel = self.ay_kernel * window

self.potential_kernel_tilde = None
self.ax_kernel_tilde = None
Expand Down Expand Up @@ -214,14 +226,14 @@ def convolution_mode(self, convolution_mode: str):
"""
if convolution_mode == "fft":
# Create FFTs of kernels
self.potential_kernel_tilde = func._fft2_padded(
self.potential_kernel, self.n_pix, self.padding
self.potential_kernel_tilde = torch.fft.rfft2(
self.potential_kernel, func._fft_size(self.n_pix)
)
self.ax_kernel_tilde = func._fft2_padded(
self.ax_kernel, self.n_pix, self.padding
self.ax_kernel_tilde = torch.fft.rfft2(
self.ax_kernel, func._fft_size(self.n_pix)
)
self.ay_kernel_tilde = func._fft2_padded(
self.ay_kernel, self.n_pix, self.padding
self.ay_kernel_tilde = torch.fft.rfft2(
self.ay_kernel, func._fft_size(self.n_pix)
)
elif convolution_mode == "conv2d":
# Drop FFTs of kernels
Expand Down

0 comments on commit e67d4cc

Please sign in to comment.