From dd478d290003e701632b3896d03644115731d101 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Mon, 12 Aug 2024 21:43:57 -0400 Subject: [PATCH 01/10] getting pixelated convergence to remove artifacts --- src/caustics/lenses/func/__init__.py | 2 + .../lenses/func/pixelated_convergence.py | 47 +++++++++++++++---- src/caustics/lenses/pixelated_convergence.py | 12 ++--- 3 files changed, 47 insertions(+), 14 deletions(-) diff --git a/src/caustics/lenses/func/__init__.py b/src/caustics/lenses/func/__init__.py index f9bfc03c..553373d5 100644 --- a/src/caustics/lenses/func/__init__.py +++ b/src/caustics/lenses/func/__init__.py @@ -33,6 +33,7 @@ reduced_deflection_angle_pixelated_convergence, potential_pixelated_convergence, _fft2_padded, + _fft_size, build_kernels_pixelated_convergence, ) from .pseudo_jaffe import ( @@ -91,6 +92,7 @@ "reduced_deflection_angle_pixelated_convergence", "potential_pixelated_convergence", "_fft2_padded", + "_fft_size", "build_kernels_pixelated_convergence", "convergence_0_pseudo_jaffe", "potential_pseudo_jaffe", diff --git a/src/caustics/lenses/func/pixelated_convergence.py b/src/caustics/lenses/func/pixelated_convergence.py index 86e3f5fe..5d94e923 100644 --- a/src/caustics/lenses/func/pixelated_convergence.py +++ b/src/caustics/lenses/func/pixelated_convergence.py @@ -34,24 +34,47 @@ def build_kernels_pixelated_convergence(pixelscale, n_pix): *Unit: unitless* """ - x_mg, y_mg = meshgrid(pixelscale, 2 * n_pix) + x_mg, y_mg = meshgrid(pixelscale, 3 * n_pix) # Shift to center kernels within pixel at index n_pix - x_mg = x_mg - pixelscale / 2 - y_mg = y_mg - pixelscale / 2 + # x_mg = x_mg - pixelscale / 2 + # y_mg = y_mg - pixelscale / 2 d2 = x_mg**2 + y_mg**2 + print(torch.any(d2 == 0), n_pix, d2.shape) potential_kernel = safe_log(d2.sqrt()) ax_kernel = safe_divide(x_mg, d2) ay_kernel = safe_divide(y_mg, d2) + import matplotlib.pyplot as plt + import numpy as np + + print(np.unravel_index(torch.argmax(torch.abs(ax_kernel)), ax_kernel.shape)) + print(ax_kernel[63, 0], ax_kernel[63, -1], ax_kernel[0, 63], ax_kernel[-1, 63]) # Set centers of kernels to zero - potential_kernel[..., n_pix, n_pix] = 0 - ax_kernel[..., n_pix, n_pix] = 0 - ay_kernel[..., n_pix, n_pix] = 0 + plt.imshow( + torch.clip(3 * (x_mg.max() - x_mg.abs()) / x_mg.max(), 0, 1).detach().numpy() + ) + plt.colorbar() + plt.show() + ax_kernel = ax_kernel * torch.clip(3 * (x_mg.max() - x_mg.abs()) / x_mg.max(), 0, 1) + ay_kernel = ay_kernel * torch.clip(3 * (y_mg.max() - y_mg.abs()) / y_mg.max(), 0, 1) + # ax_kernel[: n_pix // 2] = 0 + # ax_kernel[-n_pix // 2 :] = 0 + # ax_kernel[:, : n_pix // 2] = 0 + # ax_kernel[:, -n_pix // 2 :] = 0 + # ay_kernel[:, : n_pix // 2] = 0 + # ay_kernel[:, -n_pix // 2 :] = 0 + # ay_kernel[: n_pix // 2] = 0 + # ay_kernel[-n_pix // 2 :] = 0 + plt.imshow(torch.log(torch.abs(ax_kernel)).detach().numpy()) + plt.show() + # potential_kernel[..., n_pix, n_pix] = 0 + # ax_kernel[..., n_pix, n_pix] = 0 + # ay_kernel[..., n_pix, n_pix] = 0 return ax_kernel, ay_kernel, potential_kernel def _fft_size(n_pix): - pad = 2 * n_pix + pad = 3 * n_pix pad = next_fast_len(pad) return pad, pad @@ -77,13 +100,21 @@ def _fft2_padded(x, n_pix, padding: str): if padding == "zero": pass - elif padding in ["reflect", "circular"]: + elif padding in ["reflect"]: + print(x.shape, n_pix) + x = F.pad( + x[None, None], (n_pix - 1, n_pix - 1, n_pix - 1, n_pix - 1), mode=padding + ).squeeze() + elif padding in ["circular"]: x = F.pad(x[None, None], (0, n_pix - 1, 0, n_pix - 1), mode=padding).squeeze() elif padding == "tile": x = torch.tile(x, (2, 2)) else: raise ValueError(f"Invalid padding type: {padding}") + import matplotlib.pyplot as plt + plt.imshow(x.detach().numpy()) + plt.show() return torch.fft.rfft2(x, _fft_size(n_pix)) diff --git a/src/caustics/lenses/pixelated_convergence.py b/src/caustics/lenses/pixelated_convergence.py index 1003f758..a3c1d06f 100644 --- a/src/caustics/lenses/pixelated_convergence.py +++ b/src/caustics/lenses/pixelated_convergence.py @@ -214,14 +214,14 @@ def convolution_mode(self, convolution_mode: str): """ if convolution_mode == "fft": # Create FFTs of kernels - self.potential_kernel_tilde = func._fft2_padded( - self.potential_kernel, self.n_pix, self.padding + self.potential_kernel_tilde = torch.fft.rfft2( + self.potential_kernel, func._fft_size(self.n_pix) ) - self.ax_kernel_tilde = func._fft2_padded( - self.ax_kernel, self.n_pix, self.padding + self.ax_kernel_tilde = torch.fft.rfft2( + self.ax_kernel, func._fft_size(self.n_pix) ) - self.ay_kernel_tilde = func._fft2_padded( - self.ay_kernel, self.n_pix, self.padding + self.ay_kernel_tilde = torch.fft.rfft2( + self.ay_kernel, func._fft_size(self.n_pix) ) elif convolution_mode == "conv2d": # Drop FFTs of kernels From d41e10ff76674fd3319c6d82801aa69d38ac9915 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Mon, 12 Aug 2024 22:35:03 -0400 Subject: [PATCH 02/10] no more artifacts in pixelated convergence --- .../lenses/func/pixelated_convergence.py | 50 +++++-------------- 1 file changed, 12 insertions(+), 38 deletions(-) diff --git a/src/caustics/lenses/func/pixelated_convergence.py b/src/caustics/lenses/func/pixelated_convergence.py index 5d94e923..5554ecc1 100644 --- a/src/caustics/lenses/func/pixelated_convergence.py +++ b/src/caustics/lenses/func/pixelated_convergence.py @@ -34,47 +34,29 @@ def build_kernels_pixelated_convergence(pixelscale, n_pix): *Unit: unitless* """ - x_mg, y_mg = meshgrid(pixelscale, 3 * n_pix) + x_mg, y_mg = meshgrid(pixelscale, 2 * n_pix) # Shift to center kernels within pixel at index n_pix - # x_mg = x_mg - pixelscale / 2 - # y_mg = y_mg - pixelscale / 2 + x_mg = x_mg - pixelscale / 2 + y_mg = y_mg - pixelscale / 2 d2 = x_mg**2 + y_mg**2 - print(torch.any(d2 == 0), n_pix, d2.shape) potential_kernel = safe_log(d2.sqrt()) ax_kernel = safe_divide(x_mg, d2) ay_kernel = safe_divide(y_mg, d2) - import matplotlib.pyplot as plt - import numpy as np - print(np.unravel_index(torch.argmax(torch.abs(ax_kernel)), ax_kernel.shape)) - print(ax_kernel[63, 0], ax_kernel[63, -1], ax_kernel[0, 63], ax_kernel[-1, 63]) # Set centers of kernels to zero - plt.imshow( - torch.clip(3 * (x_mg.max() - x_mg.abs()) / x_mg.max(), 0, 1).detach().numpy() - ) - plt.colorbar() - plt.show() - ax_kernel = ax_kernel * torch.clip(3 * (x_mg.max() - x_mg.abs()) / x_mg.max(), 0, 1) - ay_kernel = ay_kernel * torch.clip(3 * (y_mg.max() - y_mg.abs()) / y_mg.max(), 0, 1) - # ax_kernel[: n_pix // 2] = 0 - # ax_kernel[-n_pix // 2 :] = 0 - # ax_kernel[:, : n_pix // 2] = 0 - # ax_kernel[:, -n_pix // 2 :] = 0 - # ay_kernel[:, : n_pix // 2] = 0 - # ay_kernel[:, -n_pix // 2 :] = 0 - # ay_kernel[: n_pix // 2] = 0 - # ay_kernel[-n_pix // 2 :] = 0 - plt.imshow(torch.log(torch.abs(ax_kernel)).detach().numpy()) - plt.show() - # potential_kernel[..., n_pix, n_pix] = 0 - # ax_kernel[..., n_pix, n_pix] = 0 - # ay_kernel[..., n_pix, n_pix] = 0 + potential_kernel[..., n_pix, n_pix] = 0 + ax_kernel[..., n_pix, n_pix] = 0 + ay_kernel[..., n_pix, n_pix] = 0 + + # Window the deflection angle kernels for stable FFT + ax_kernel = ax_kernel * torch.clip(8 * (x_mg.max() - x_mg.abs()) / x_mg.max(), 0, 1) + ay_kernel = ay_kernel * torch.clip(8 * (y_mg.max() - y_mg.abs()) / y_mg.max(), 0, 1) return ax_kernel, ay_kernel, potential_kernel def _fft_size(n_pix): - pad = 3 * n_pix + pad = 2 * n_pix pad = next_fast_len(pad) return pad, pad @@ -100,21 +82,13 @@ def _fft2_padded(x, n_pix, padding: str): if padding == "zero": pass - elif padding in ["reflect"]: - print(x.shape, n_pix) - x = F.pad( - x[None, None], (n_pix - 1, n_pix - 1, n_pix - 1, n_pix - 1), mode=padding - ).squeeze() - elif padding in ["circular"]: + elif padding in ["circular", "reflect"]: x = F.pad(x[None, None], (0, n_pix - 1, 0, n_pix - 1), mode=padding).squeeze() elif padding == "tile": x = torch.tile(x, (2, 2)) else: raise ValueError(f"Invalid padding type: {padding}") - import matplotlib.pyplot as plt - plt.imshow(x.detach().numpy()) - plt.show() return torch.fft.rfft2(x, _fft_size(n_pix)) From 9cf12a96f9ea19755fdb7b7b7b043a38c07f5e37 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Wed, 14 Aug 2024 20:01:23 -0400 Subject: [PATCH 03/10] remove half pixel shift in kernel --- src/caustics/lenses/func/pixelated_convergence.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/caustics/lenses/func/pixelated_convergence.py b/src/caustics/lenses/func/pixelated_convergence.py index 5554ecc1..e9eb5a52 100644 --- a/src/caustics/lenses/func/pixelated_convergence.py +++ b/src/caustics/lenses/func/pixelated_convergence.py @@ -36,17 +36,18 @@ def build_kernels_pixelated_convergence(pixelscale, n_pix): """ x_mg, y_mg = meshgrid(pixelscale, 2 * n_pix) # Shift to center kernels within pixel at index n_pix - x_mg = x_mg - pixelscale / 2 - y_mg = y_mg - pixelscale / 2 + # x_mg = x_mg - pixelscale / 2 + # y_mg = y_mg - pixelscale / 2 + d2 = x_mg**2 + y_mg**2 potential_kernel = safe_log(d2.sqrt()) ax_kernel = safe_divide(x_mg, d2) ay_kernel = safe_divide(y_mg, d2) # Set centers of kernels to zero - potential_kernel[..., n_pix, n_pix] = 0 - ax_kernel[..., n_pix, n_pix] = 0 - ay_kernel[..., n_pix, n_pix] = 0 + # potential_kernel[..., n_pix, n_pix] = 0 + # ax_kernel[..., n_pix, n_pix] = 0 + # ay_kernel[..., n_pix, n_pix] = 0 # Window the deflection angle kernels for stable FFT ax_kernel = ax_kernel * torch.clip(8 * (x_mg.max() - x_mg.abs()) / x_mg.max(), 0, 1) From e0e8070927376087d0f203cbc2f445e6171545a8 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Sat, 17 Aug 2024 11:53:21 -0400 Subject: [PATCH 04/10] Change kernel window to be round --- src/caustics/lenses/func/pixelated_convergence.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/caustics/lenses/func/pixelated_convergence.py b/src/caustics/lenses/func/pixelated_convergence.py index e9eb5a52..cf9a3697 100644 --- a/src/caustics/lenses/func/pixelated_convergence.py +++ b/src/caustics/lenses/func/pixelated_convergence.py @@ -50,8 +50,8 @@ def build_kernels_pixelated_convergence(pixelscale, n_pix): # ay_kernel[..., n_pix, n_pix] = 0 # Window the deflection angle kernels for stable FFT - ax_kernel = ax_kernel * torch.clip(8 * (x_mg.max() - x_mg.abs()) / x_mg.max(), 0, 1) - ay_kernel = ay_kernel * torch.clip(8 * (y_mg.max() - y_mg.abs()) / y_mg.max(), 0, 1) + ax_kernel = ax_kernel * torch.clip(8 * (x_mg.max() - d2.sqrt()) / x_mg.max(), 0, 1) + ay_kernel = ay_kernel * torch.clip(8 * (y_mg.max() - d2.sqrt()) / y_mg.max(), 0, 1) return ax_kernel, ay_kernel, potential_kernel From 35dd8acc38ac7d25e39e957b4f43b163b37372e6 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Fri, 23 Aug 2024 15:47:14 -0400 Subject: [PATCH 05/10] Set zero padding to mean overdensity instead of periodic boundary --- .../lenses/func/pixelated_convergence.py | 21 ++++++----- src/caustics/lenses/pixelated_convergence.py | 36 +++++++++---------- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/src/caustics/lenses/func/pixelated_convergence.py b/src/caustics/lenses/func/pixelated_convergence.py index cf9a3697..0346b787 100644 --- a/src/caustics/lenses/func/pixelated_convergence.py +++ b/src/caustics/lenses/func/pixelated_convergence.py @@ -5,7 +5,7 @@ from ...utils import safe_divide, safe_log, meshgrid, interp2d -def build_kernels_pixelated_convergence(pixelscale, n_pix): +def build_kernels_pixelated_convergence(pixelscale, n_pix, window=0): """ Build the kernels for the pixelated convergence. @@ -36,8 +36,8 @@ def build_kernels_pixelated_convergence(pixelscale, n_pix): """ x_mg, y_mg = meshgrid(pixelscale, 2 * n_pix) # Shift to center kernels within pixel at index n_pix - # x_mg = x_mg - pixelscale / 2 - # y_mg = y_mg - pixelscale / 2 + x_mg = x_mg - pixelscale / 2 + y_mg = y_mg - pixelscale / 2 d2 = x_mg**2 + y_mg**2 potential_kernel = safe_log(d2.sqrt()) @@ -45,13 +45,18 @@ def build_kernels_pixelated_convergence(pixelscale, n_pix): ay_kernel = safe_divide(y_mg, d2) # Set centers of kernels to zero - # potential_kernel[..., n_pix, n_pix] = 0 - # ax_kernel[..., n_pix, n_pix] = 0 - # ay_kernel[..., n_pix, n_pix] = 0 + potential_kernel[..., n_pix, n_pix] = 0 + ax_kernel[..., n_pix, n_pix] = 0 + ay_kernel[..., n_pix, n_pix] = 0 # Window the deflection angle kernels for stable FFT - ax_kernel = ax_kernel * torch.clip(8 * (x_mg.max() - d2.sqrt()) / x_mg.max(), 0, 1) - ay_kernel = ay_kernel * torch.clip(8 * (y_mg.max() - d2.sqrt()) / y_mg.max(), 0, 1) + if window > 0: + ax_kernel = ax_kernel * torch.clip( + window * (x_mg.max() - d2.sqrt()) / x_mg.max(), 0, 1 + ) + ay_kernel = ay_kernel * torch.clip( + window * (y_mg.max() - d2.sqrt()) / y_mg.max(), 0, 1 + ) return ax_kernel, ay_kernel, potential_kernel diff --git a/src/caustics/lenses/pixelated_convergence.py b/src/caustics/lenses/pixelated_convergence.py index a3c1d06f..49aea749 100644 --- a/src/caustics/lenses/pixelated_convergence.py +++ b/src/caustics/lenses/pixelated_convergence.py @@ -61,11 +61,10 @@ def __init__( """Strong lensing with user provided kappa map PixelatedConvergence is a class for strong gravitational lensing with a - user-provided kappa map. It inherits from the ThinLens class. - This class enables the computation of deflection angles and - lensing potential by applying the user-provided kappa map to a - grid using either Fast Fourier Transform (FFT) or a 2D - convolution. + user-provided kappa map. It inherits from the ThinLens class. This class + enables the computation of deflection angles and lensing potential by + applying the user-provided kappa map to a grid using either Fast Fourier + Transform (FFT) or a 2D convolution. Attributes ---------- @@ -104,24 +103,23 @@ def __init__( The shape of the convergence map. convolution_mode: str, optional - The convolution mode for calculating deflection angles and lensing potential. - It can be either "fft" (Fast Fourier Transform) or "conv2d" (2D convolution). - Default is "fft". + The convolution mode for calculating deflection angles and lensing + potential. It can be either "fft" (Fast Fourier Transform) or + "conv2d" (2D convolution). Default is "fft". use_next_fast_len: bool, optional If True, adds additional padding to speed up the FFT by calling - `scipy.fft.next_fast_len`. - The speed boost can be substantial when `n_pix` is a multiple of a - small prime number. Default is True. + `scipy.fft.next_fast_len`. The speed boost can be substantial when + `n_pix` is a multiple of a small prime number. Default is True. padding: { "zero", "circular", "reflect", "tile" } - Specifies the type of padding to use: - "zero" will do zero padding, - "circular" will do cyclic boundaries. - "reflect" will do reflection padding. - "tile" will tile the image at 2x2 which - basically identical to circular padding, but is easier. + Specifies the type of padding to use: "zero" will do zero padding, + "circular" will do cyclic boundaries. "reflect" will do reflection + padding. "tile" will tile the image at 2x2 which basically identical + to circular padding, but is easier. Use zero padding to represent an + overdensity, the other padding schemes represent a mass distribution + embedded in a field of similar mass distributions. Generally you should use either "zero" or "tile". @@ -151,7 +149,9 @@ def __init__( # Construct kernels self.ax_kernel, self.ay_kernel, self.potential_kernel = ( - func.build_kernels_pixelated_convergence(self.pixelscale, self.n_pix) + func.build_kernels_pixelated_convergence( + self.pixelscale, self.n_pix, 0 if padding == "zero" else 8 + ) ) self.potential_kernel_tilde = None From dcd15f80dbe4091eb0b96de668995200f42ab2c5 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Wed, 28 Aug 2024 16:51:41 -0400 Subject: [PATCH 06/10] return fixes to pixelated convergence --- src/caustics/lenses/func/pixelated_convergence.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/caustics/lenses/func/pixelated_convergence.py b/src/caustics/lenses/func/pixelated_convergence.py index 0346b787..e33a4241 100644 --- a/src/caustics/lenses/func/pixelated_convergence.py +++ b/src/caustics/lenses/func/pixelated_convergence.py @@ -36,8 +36,8 @@ def build_kernels_pixelated_convergence(pixelscale, n_pix, window=0): """ x_mg, y_mg = meshgrid(pixelscale, 2 * n_pix) # Shift to center kernels within pixel at index n_pix - x_mg = x_mg - pixelscale / 2 - y_mg = y_mg - pixelscale / 2 + # x_mg = x_mg - pixelscale / 2 + # y_mg = y_mg - pixelscale / 2 d2 = x_mg**2 + y_mg**2 potential_kernel = safe_log(d2.sqrt()) @@ -45,9 +45,9 @@ def build_kernels_pixelated_convergence(pixelscale, n_pix, window=0): ay_kernel = safe_divide(y_mg, d2) # Set centers of kernels to zero - potential_kernel[..., n_pix, n_pix] = 0 - ax_kernel[..., n_pix, n_pix] = 0 - ay_kernel[..., n_pix, n_pix] = 0 + # potential_kernel[..., n_pix, n_pix] = 0 + # ax_kernel[..., n_pix, n_pix] = 0 + # ay_kernel[..., n_pix, n_pix] = 0 # Window the deflection angle kernels for stable FFT if window > 0: From eb81bba669deef8cbc2172d11b57a0f466bd5027 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 10 Oct 2024 14:33:10 -0400 Subject: [PATCH 07/10] separate pixelated kernel and windowing into different functions --- src/caustics/lenses/func/__init__.py | 2 + .../lenses/func/pixelated_convergence.py | 46 ++++++++++++------- src/caustics/lenses/pixelated_convergence.py | 17 +++++-- 3 files changed, 45 insertions(+), 20 deletions(-) diff --git a/src/caustics/lenses/func/__init__.py b/src/caustics/lenses/func/__init__.py index 0d6bd63d..d89cc12c 100644 --- a/src/caustics/lenses/func/__init__.py +++ b/src/caustics/lenses/func/__init__.py @@ -42,6 +42,7 @@ _fft2_padded, _fft_size, build_kernels_pixelated_convergence, + build_window_pixelated_convergence, ) from .pseudo_jaffe import ( convergence_0_pseudo_jaffe, @@ -108,6 +109,7 @@ "_fft2_padded", "_fft_size", "build_kernels_pixelated_convergence", + "build_window_pixelated_convergence", "convergence_0_pseudo_jaffe", "potential_pseudo_jaffe", "reduced_deflection_angle_pseudo_jaffe", diff --git a/src/caustics/lenses/func/pixelated_convergence.py b/src/caustics/lenses/func/pixelated_convergence.py index e33a4241..f072e357 100644 --- a/src/caustics/lenses/func/pixelated_convergence.py +++ b/src/caustics/lenses/func/pixelated_convergence.py @@ -5,7 +5,7 @@ from ...utils import safe_divide, safe_log, meshgrid, interp2d -def build_kernels_pixelated_convergence(pixelscale, n_pix, window=0): +def build_kernels_pixelated_convergence(pixelscale, n_pix): """ Build the kernels for the pixelated convergence. @@ -35,30 +35,42 @@ def build_kernels_pixelated_convergence(pixelscale, n_pix, window=0): """ x_mg, y_mg = meshgrid(pixelscale, 2 * n_pix) - # Shift to center kernels within pixel at index n_pix - # x_mg = x_mg - pixelscale / 2 - # y_mg = y_mg - pixelscale / 2 d2 = x_mg**2 + y_mg**2 potential_kernel = safe_log(d2.sqrt()) ax_kernel = safe_divide(x_mg, d2) ay_kernel = safe_divide(y_mg, d2) - # Set centers of kernels to zero - # potential_kernel[..., n_pix, n_pix] = 0 - # ax_kernel[..., n_pix, n_pix] = 0 - # ay_kernel[..., n_pix, n_pix] = 0 + return ax_kernel, ay_kernel, potential_kernel - # Window the deflection angle kernels for stable FFT - if window > 0: - ax_kernel = ax_kernel * torch.clip( - window * (x_mg.max() - d2.sqrt()) / x_mg.max(), 0, 1 - ) - ay_kernel = ay_kernel * torch.clip( - window * (y_mg.max() - d2.sqrt()) / y_mg.max(), 0, 1 - ) - return ax_kernel, ay_kernel, potential_kernel +def build_window_pixelated_convergence(window, kernel_shape): + """ + Window the kernel for stable FFT. + + Parameters + ---------- + window: float + The window to apply as a fraction of the image width. For example a + window of 1/4 will set the kernel to start decreasing at 1/4 of the + image width, and then linearly go to zero. + + kernel_shape: tuple + The shape of the kernel to be windowed. + + Returns + ------- + Tensor + The window to multiply with the kernel. + + """ + x, y = torch.meshgrid( + torch.linspace(-1, 1, kernel_shape[-1]), + torch.linspace(-1, 1, kernel_shape[-2]), + indexing="xy", + ) + r = (x**2 + y**2).sqrt() + return torch.clip((1 - r) / window, 0, 1) def _fft_size(n_pix): diff --git a/src/caustics/lenses/pixelated_convergence.py b/src/caustics/lenses/pixelated_convergence.py index 49aea749..b49e2ac6 100644 --- a/src/caustics/lenses/pixelated_convergence.py +++ b/src/caustics/lenses/pixelated_convergence.py @@ -56,6 +56,7 @@ def __init__( Literal["zero", "circular", "reflect", "tile"], "Specifies the type of padding", ] = "zero", + window_kernel: Annotated[float, "Amount of kernel to be windowed"] = 1.0 / 8.0, name: NameType = None, ): """Strong lensing with user provided kappa map @@ -123,6 +124,11 @@ def __init__( Generally you should use either "zero" or "tile". + window_kernel: float, optional + Amount of kernel to be windowed, specify the fraction of the kernel + size from which a linear window scaling will ensure the edges go to + zero for the purpose of FFT stability. Default is 1/8. + """ super().__init__(cosmology, z_l, name=name) @@ -149,10 +155,15 @@ def __init__( # Construct kernels self.ax_kernel, self.ay_kernel, self.potential_kernel = ( - func.build_kernels_pixelated_convergence( - self.pixelscale, self.n_pix, 0 if padding == "zero" else 8 - ) + func.build_kernels_pixelated_convergence(self.pixelscale, self.n_pix) ) + # Window the kernels if needed + if padding != "zero" and convolution_mode == "fft": + window = func.build_window_pixelated_convergence( + window_kernel, self.ax_kernel.shape + ) + self.ax_kernel = self.ax_kernel * window + self.ay_kernel = self.ay_kernel * window self.potential_kernel_tilde = None self.ax_kernel_tilde = None From cf82ad6ca1314ff8bb4b1f5b4b4e7ddbdc48d767 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 10 Oct 2024 14:48:32 -0400 Subject: [PATCH 08/10] tidying --- src/caustics/lenses/func/pixelated_convergence.py | 2 +- src/caustics/lenses/pixelated_convergence.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/caustics/lenses/func/pixelated_convergence.py b/src/caustics/lenses/func/pixelated_convergence.py index f072e357..6fdf07dc 100644 --- a/src/caustics/lenses/func/pixelated_convergence.py +++ b/src/caustics/lenses/func/pixelated_convergence.py @@ -100,7 +100,7 @@ def _fft2_padded(x, n_pix, padding: str): if padding == "zero": pass - elif padding in ["circular", "reflect"]: + elif padding in ["reflect", "circular"]: x = F.pad(x[None, None], (0, n_pix - 1, 0, n_pix - 1), mode=padding).squeeze() elif padding == "tile": x = torch.tile(x, (2, 2)) diff --git a/src/caustics/lenses/pixelated_convergence.py b/src/caustics/lenses/pixelated_convergence.py index b49e2ac6..c11529a8 100644 --- a/src/caustics/lenses/pixelated_convergence.py +++ b/src/caustics/lenses/pixelated_convergence.py @@ -127,7 +127,8 @@ def __init__( window_kernel: float, optional Amount of kernel to be windowed, specify the fraction of the kernel size from which a linear window scaling will ensure the edges go to - zero for the purpose of FFT stability. Default is 1/8. + zero for the purpose of FFT stability. Set to 0 for no windowing. + Default is 1/8. """ @@ -158,7 +159,7 @@ def __init__( func.build_kernels_pixelated_convergence(self.pixelscale, self.n_pix) ) # Window the kernels if needed - if padding != "zero" and convolution_mode == "fft": + if padding != "zero" and convolution_mode == "fft" and window_kernel > 0: window = func.build_window_pixelated_convergence( window_kernel, self.ax_kernel.shape ) From 7d0fb0647b185f608260caa008fe4dd5acdfef49 Mon Sep 17 00:00:00 2001 From: "Connor Stone, PhD" Date: Thu, 7 Nov 2024 06:31:14 -0800 Subject: [PATCH 09/10] feat: add shear calculation for all lenses (#280) * add shear function to base lens * add shear tests vs lenstronomy --- src/caustics/lenses/base.py | 54 ++++++++++++++++++------------------- tests/test_nfw.py | 15 +++++++++-- tests/test_tnfw.py | 2 ++ tests/utils/__init__.py | 31 ++++++++++++++++++++- 4 files changed, 71 insertions(+), 31 deletions(-) diff --git a/src/caustics/lenses/base.py b/src/caustics/lenses/base.py index ee819f2b..3eb9ebe4 100644 --- a/src/caustics/lenses/base.py +++ b/src/caustics/lenses/base.py @@ -79,6 +79,31 @@ def jacobian_lens_equation( else: raise ValueError("method should be one of: autograd, finitediff") + @unpack + def shear( + self, + x: Tensor, + y: Tensor, + z_s: Tensor, + *args, + params: Optional["Packed"] = None, + method="autograd", + pixelscale: Optional[Tensor] = None, + **kwargs, + ): + """ + General shear calculation for a lens model using the jacobian of the + lens equation. Individual lenses may implement more efficient methods. + """ + A = self.jacobian_lens_equation( + x, y, z_s, params=params, method=method, pixelscale=pixelscale + ) + I = torch.eye(2, device=A.device, dtype=A.dtype).reshape( # noqa E741 + *[1] * len(A.shape[:-2]), 2, 2 + ) + negPsi = 0.5 * (A[..., 0, 0] + A[..., 1, 1]).unsqueeze(-1).unsqueeze(-1) * I - A + return 0.5 * (negPsi[..., 0, 0] - negPsi[..., 1, 1]), negPsi[..., 0, 1] + @unpack def magnification( self, @@ -193,34 +218,7 @@ def forward_raytrace( x0 = torch.zeros((), device=bx.device, dtype=bx.dtype) if y0 is None: y0 = torch.zeros((), device=by.device, dtype=by.dtype) - # X = torch.stack((x0, y0)).repeat(4, 1) - # X[0] -= fov / 2 - # X[1][0] -= fov / 2 - # X[1][1] += fov / 2 - # X[2][0] += fov / 2 - # X[2][1] -= fov / 2 - # X[3] += fov / 2 - - # Sx, Sy = raytrace(X[..., 0], X[..., 1]) - # S = torch.stack((Sx, Sy)).T - # res1, ap1 = func.triangle_search( - # torch.stack((bx, by)), - # X[:3], - # S[:3], - # raytrace, - # epsilon, - # torch.zeros((0, 2)), - # ) - # res2, ap2 = func.triangle_search( - # torch.stack((bx, by)), - # X[1:], - # S[1:], - # raytrace, - # epsilon, - # torch.zeros((0, 2)), - # ) - # res = torch.cat((res1, res2), dim=0) - # return res[:, 0], res[:, 1], torch.cat((ap1, ap2), dim=0) + return func.forward_raytrace( torch.stack((bx, by)), raytrace, x0, y0, fov, divisions, epsilon ) diff --git a/tests/test_nfw.py b/tests/test_nfw.py index cd6c2984..a2a63d9d 100644 --- a/tests/test_nfw.py +++ b/tests/test_nfw.py @@ -40,6 +40,7 @@ def test_nfw(sim_source, device, lens_models, m, c): z_l: {float(z_l)} init_kwargs: cosmology: *cosmology + use_case: differentiable """ yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) mod = lens_models.get("NFW") @@ -47,7 +48,7 @@ def test_nfw(sim_source, device, lens_models, m, c): else: # Models cosmology = CausticFlatLambdaCDM(name="cosmo") - lens = NFW(name="nfw", cosmology=cosmology, z_l=z_l) + lens = NFW(name="nfw", cosmology=cosmology, z_l=z_l, use_case="differentiable") lens_model_list = ["NFW"] lens_ls = LensModel(lens_model_list=lens_model_list) @@ -72,7 +73,17 @@ def test_nfw(sim_source, device, lens_models, m, c): {"Rs": Rs_angle, "alpha_Rs": alpha_Rs, "center_x": thx0, "center_y": thy0} ] - lens_test_helper(lens, lens_ls, z_s, x, kwargs_ls, atol, rtol, device=device) + lens_test_helper( + lens, + lens_ls, + z_s, + x, + kwargs_ls, + atol, + rtol, + shear_egregious=True, # not why match is so bad + device=device, + ) def test_runs(sim_source, device, lens_models): diff --git a/tests/test_tnfw.py b/tests/test_tnfw.py index 46e1f6e1..2751c3d6 100644 --- a/tests/test_tnfw.py +++ b/tests/test_tnfw.py @@ -93,6 +93,8 @@ def test(sim_source, device, lens_models, m, c, t): test_alpha=True, test_Psi=False, test_kappa=True, + test_shear=True, + shear_egregious=True, # not sure why match is so bad device=device, ) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index b18b7605..b9da902f 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -212,7 +212,6 @@ def alpha_test_helper(lens, lens_ls, z_s, x, kwargs_ls, atol, rtol, device=None) thx, thy, thx_ls, thy_ls = setup_grids(device=device) alpha_x, alpha_y = lens.reduced_deflection_angle(thx, thy, z_s, x) alpha_x_ls, alpha_y_ls = lens_ls.alpha(thx_ls, thy_ls, kwargs_ls) - print(np.sum(np.abs(1 - alpha_x.cpu().numpy() / alpha_x_ls) > 1e-3)) assert np.allclose(alpha_x.cpu().numpy(), alpha_x_ls, rtol, atol) assert np.allclose(alpha_y.cpu().numpy(), alpha_y_ls, rtol, atol) @@ -234,6 +233,21 @@ def kappa_test_helper(lens, lens_ls, z_s, x, kwargs_ls, atol, rtol, device=None) assert np.allclose(kappa.cpu().numpy(), kappa_ls, rtol, atol) +def shear_test_helper( + lens, lens_ls, z_s, x, kwargs_ls, atol, rtol, just_egregious=False, device=None +): + thx, thy, thx_ls, thy_ls = setup_grids(device=device) + gamma1, gamma2 = lens.shear(thx, thy, z_s, x) + gamma1_ls, gamma2_ls = lens_ls.gamma(thx_ls, thy_ls, kwargs_ls) + if just_egregious: + print(np.sum(np.abs(np.log10(np.abs(1 - gamma1.cpu().numpy() / gamma1_ls))) < 1)) + assert np.sum(np.abs(np.log10(np.abs(1 - gamma1.cpu().numpy() / gamma1_ls))) < 1) < 1000 + assert np.sum(np.abs(np.log10(np.abs(1 - gamma2.cpu().numpy() / gamma2_ls))) < 1) < 1000 + else: + assert np.allclose(gamma1.cpu().numpy(), gamma1_ls, rtol, atol) + assert np.allclose(gamma2.cpu().numpy(), gamma2_ls, rtol, atol) + + def lens_test_helper( lens: Union[ThinLens, ThickLens], lens_ls: LensModel, @@ -245,6 +259,8 @@ def lens_test_helper( test_alpha=True, test_Psi=True, test_kappa=True, + test_shear=True, + shear_egregious=False, device=None, ): if device is not None: @@ -260,3 +276,16 @@ def lens_test_helper( if test_kappa: kappa_test_helper(lens, lens_ls, z_s, x, kwargs_ls, atol, rtol, device=device) + + if test_shear: + shear_test_helper( + lens, + lens_ls, + z_s, + x, + kwargs_ls, + atol, + rtol * 10, + just_egregious=shear_egregious, + device=device, + ) # shear seems less precise than other measurements From 55b7fba7615313814d8386f8fb17d4f553a43d41 Mon Sep 17 00:00:00 2001 From: "Connor Stone, PhD" Date: Thu, 7 Nov 2024 06:31:47 -0800 Subject: [PATCH 10/10] bug fix: conv2d in LensSource now properly reorients the kernel (#281) * add consistency test between fft and conv2d for LensSource * fix conv2d for LensSource * more difficult unit test --- src/caustics/sims/lens_source.py | 4 ++- tests/test_simulator_runs.py | 55 ++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/src/caustics/sims/lens_source.py b/src/caustics/sims/lens_source.py index 008859b2..8235ee2a 100644 --- a/src/caustics/sims/lens_source.py +++ b/src/caustics/sims/lens_source.py @@ -371,7 +371,9 @@ def forward( elif self.psf_mode == "conv2d": mu = ( conv2d( - mu[None, None], (psf.T / psf.sum())[None, None], padding="same" + mu[None, None], + (torch.flip(psf, (0, 1)) / psf.sum())[None, None], + padding="same", ) .squeeze(0) .squeeze(0) diff --git a/tests/test_simulator_runs.py b/tests/test_simulator_runs.py index 23816889..faf470e3 100644 --- a/tests/test_simulator_runs.py +++ b/tests/test_simulator_runs.py @@ -193,6 +193,61 @@ def test_simulator_runs(sim_source, device, mocker): assert torch.allclose(sim(), sim_q3(), rtol=1e-1) +def test_fft_vs_conv2d(): + # Model + cosmology = FlatLambdaCDM(name="cosmo") + lensmass = SIE( + name="lens", + cosmology=cosmology, + z_l=1.0, + x0=0.0, + y0=0.01, + q=0.5, + phi=pi / 3.0, + b=1.0, + ) + + source = Sersic( + name="source", x0=0.01, y0=-0.03, q=0.6, phi=-pi / 4, n=2.0, Re=0.5, Ie=1.0 + ) + lenslight = Sersic( + name="lenslight", x0=0.0, y0=0.01, q=0.7, phi=pi / 4, n=3.0, Re=0.7, Ie=1.0 + ) + + psf = gaussian(0.05, 11, 11, 0.2, upsample=2) + psf[3, 4] = 0.1 # make PSF asymmetric + psf /= psf.sum() + + sim_fft = LensSource( + name="simulatorfft", + lens=lensmass, + source=source, + pixelscale=0.05, + pixels_x=50, + lens_light=lenslight, + psf=psf, + psf_mode="fft", + z_s=2.0, + quad_level=3, + ) + + sim_conv2d = LensSource( + name="simulatorconv2d", + lens=lensmass, + source=source, + pixelscale=0.05, + pixels_x=50, + lens_light=lenslight, + psf=psf, + psf_mode="conv2d", + z_s=2.0, + quad_level=3, + ) + + print(torch.max(torch.abs((sim_fft() - sim_conv2d()) / sim_fft()))) + assert torch.allclose(sim_fft(), sim_conv2d(), rtol=1e-1) + + def test_microlens_simulator_runs(): cosmology = FlatLambdaCDM() sie = SIE(cosmology=cosmology, name="lens")