diff --git a/src/caustics/lenses/base.py b/src/caustics/lenses/base.py index ee819f2b..3eb9ebe4 100644 --- a/src/caustics/lenses/base.py +++ b/src/caustics/lenses/base.py @@ -79,6 +79,31 @@ def jacobian_lens_equation( else: raise ValueError("method should be one of: autograd, finitediff") + @unpack + def shear( + self, + x: Tensor, + y: Tensor, + z_s: Tensor, + *args, + params: Optional["Packed"] = None, + method="autograd", + pixelscale: Optional[Tensor] = None, + **kwargs, + ): + """ + General shear calculation for a lens model using the jacobian of the + lens equation. Individual lenses may implement more efficient methods. + """ + A = self.jacobian_lens_equation( + x, y, z_s, params=params, method=method, pixelscale=pixelscale + ) + I = torch.eye(2, device=A.device, dtype=A.dtype).reshape( # noqa E741 + *[1] * len(A.shape[:-2]), 2, 2 + ) + negPsi = 0.5 * (A[..., 0, 0] + A[..., 1, 1]).unsqueeze(-1).unsqueeze(-1) * I - A + return 0.5 * (negPsi[..., 0, 0] - negPsi[..., 1, 1]), negPsi[..., 0, 1] + @unpack def magnification( self, @@ -193,34 +218,7 @@ def forward_raytrace( x0 = torch.zeros((), device=bx.device, dtype=bx.dtype) if y0 is None: y0 = torch.zeros((), device=by.device, dtype=by.dtype) - # X = torch.stack((x0, y0)).repeat(4, 1) - # X[0] -= fov / 2 - # X[1][0] -= fov / 2 - # X[1][1] += fov / 2 - # X[2][0] += fov / 2 - # X[2][1] -= fov / 2 - # X[3] += fov / 2 - - # Sx, Sy = raytrace(X[..., 0], X[..., 1]) - # S = torch.stack((Sx, Sy)).T - # res1, ap1 = func.triangle_search( - # torch.stack((bx, by)), - # X[:3], - # S[:3], - # raytrace, - # epsilon, - # torch.zeros((0, 2)), - # ) - # res2, ap2 = func.triangle_search( - # torch.stack((bx, by)), - # X[1:], - # S[1:], - # raytrace, - # epsilon, - # torch.zeros((0, 2)), - # ) - # res = torch.cat((res1, res2), dim=0) - # return res[:, 0], res[:, 1], torch.cat((ap1, ap2), dim=0) + return func.forward_raytrace( torch.stack((bx, by)), raytrace, x0, y0, fov, divisions, epsilon ) diff --git a/src/caustics/sims/lens_source.py b/src/caustics/sims/lens_source.py index 008859b2..8235ee2a 100644 --- a/src/caustics/sims/lens_source.py +++ b/src/caustics/sims/lens_source.py @@ -371,7 +371,9 @@ def forward( elif self.psf_mode == "conv2d": mu = ( conv2d( - mu[None, None], (psf.T / psf.sum())[None, None], padding="same" + mu[None, None], + (torch.flip(psf, (0, 1)) / psf.sum())[None, None], + padding="same", ) .squeeze(0) .squeeze(0) diff --git a/tests/test_nfw.py b/tests/test_nfw.py index cd6c2984..a2a63d9d 100644 --- a/tests/test_nfw.py +++ b/tests/test_nfw.py @@ -40,6 +40,7 @@ def test_nfw(sim_source, device, lens_models, m, c): z_l: {float(z_l)} init_kwargs: cosmology: *cosmology + use_case: differentiable """ yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) mod = lens_models.get("NFW") @@ -47,7 +48,7 @@ def test_nfw(sim_source, device, lens_models, m, c): else: # Models cosmology = CausticFlatLambdaCDM(name="cosmo") - lens = NFW(name="nfw", cosmology=cosmology, z_l=z_l) + lens = NFW(name="nfw", cosmology=cosmology, z_l=z_l, use_case="differentiable") lens_model_list = ["NFW"] lens_ls = LensModel(lens_model_list=lens_model_list) @@ -72,7 +73,17 @@ def test_nfw(sim_source, device, lens_models, m, c): {"Rs": Rs_angle, "alpha_Rs": alpha_Rs, "center_x": thx0, "center_y": thy0} ] - lens_test_helper(lens, lens_ls, z_s, x, kwargs_ls, atol, rtol, device=device) + lens_test_helper( + lens, + lens_ls, + z_s, + x, + kwargs_ls, + atol, + rtol, + shear_egregious=True, # not why match is so bad + device=device, + ) def test_runs(sim_source, device, lens_models): diff --git a/tests/test_simulator_runs.py b/tests/test_simulator_runs.py index 23816889..faf470e3 100644 --- a/tests/test_simulator_runs.py +++ b/tests/test_simulator_runs.py @@ -193,6 +193,61 @@ def test_simulator_runs(sim_source, device, mocker): assert torch.allclose(sim(), sim_q3(), rtol=1e-1) +def test_fft_vs_conv2d(): + # Model + cosmology = FlatLambdaCDM(name="cosmo") + lensmass = SIE( + name="lens", + cosmology=cosmology, + z_l=1.0, + x0=0.0, + y0=0.01, + q=0.5, + phi=pi / 3.0, + b=1.0, + ) + + source = Sersic( + name="source", x0=0.01, y0=-0.03, q=0.6, phi=-pi / 4, n=2.0, Re=0.5, Ie=1.0 + ) + lenslight = Sersic( + name="lenslight", x0=0.0, y0=0.01, q=0.7, phi=pi / 4, n=3.0, Re=0.7, Ie=1.0 + ) + + psf = gaussian(0.05, 11, 11, 0.2, upsample=2) + psf[3, 4] = 0.1 # make PSF asymmetric + psf /= psf.sum() + + sim_fft = LensSource( + name="simulatorfft", + lens=lensmass, + source=source, + pixelscale=0.05, + pixels_x=50, + lens_light=lenslight, + psf=psf, + psf_mode="fft", + z_s=2.0, + quad_level=3, + ) + + sim_conv2d = LensSource( + name="simulatorconv2d", + lens=lensmass, + source=source, + pixelscale=0.05, + pixels_x=50, + lens_light=lenslight, + psf=psf, + psf_mode="conv2d", + z_s=2.0, + quad_level=3, + ) + + print(torch.max(torch.abs((sim_fft() - sim_conv2d()) / sim_fft()))) + assert torch.allclose(sim_fft(), sim_conv2d(), rtol=1e-1) + + def test_microlens_simulator_runs(): cosmology = FlatLambdaCDM() sie = SIE(cosmology=cosmology, name="lens") diff --git a/tests/test_tnfw.py b/tests/test_tnfw.py index 46e1f6e1..2751c3d6 100644 --- a/tests/test_tnfw.py +++ b/tests/test_tnfw.py @@ -93,6 +93,8 @@ def test(sim_source, device, lens_models, m, c, t): test_alpha=True, test_Psi=False, test_kappa=True, + test_shear=True, + shear_egregious=True, # not sure why match is so bad device=device, ) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index b18b7605..b9da902f 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -212,7 +212,6 @@ def alpha_test_helper(lens, lens_ls, z_s, x, kwargs_ls, atol, rtol, device=None) thx, thy, thx_ls, thy_ls = setup_grids(device=device) alpha_x, alpha_y = lens.reduced_deflection_angle(thx, thy, z_s, x) alpha_x_ls, alpha_y_ls = lens_ls.alpha(thx_ls, thy_ls, kwargs_ls) - print(np.sum(np.abs(1 - alpha_x.cpu().numpy() / alpha_x_ls) > 1e-3)) assert np.allclose(alpha_x.cpu().numpy(), alpha_x_ls, rtol, atol) assert np.allclose(alpha_y.cpu().numpy(), alpha_y_ls, rtol, atol) @@ -234,6 +233,21 @@ def kappa_test_helper(lens, lens_ls, z_s, x, kwargs_ls, atol, rtol, device=None) assert np.allclose(kappa.cpu().numpy(), kappa_ls, rtol, atol) +def shear_test_helper( + lens, lens_ls, z_s, x, kwargs_ls, atol, rtol, just_egregious=False, device=None +): + thx, thy, thx_ls, thy_ls = setup_grids(device=device) + gamma1, gamma2 = lens.shear(thx, thy, z_s, x) + gamma1_ls, gamma2_ls = lens_ls.gamma(thx_ls, thy_ls, kwargs_ls) + if just_egregious: + print(np.sum(np.abs(np.log10(np.abs(1 - gamma1.cpu().numpy() / gamma1_ls))) < 1)) + assert np.sum(np.abs(np.log10(np.abs(1 - gamma1.cpu().numpy() / gamma1_ls))) < 1) < 1000 + assert np.sum(np.abs(np.log10(np.abs(1 - gamma2.cpu().numpy() / gamma2_ls))) < 1) < 1000 + else: + assert np.allclose(gamma1.cpu().numpy(), gamma1_ls, rtol, atol) + assert np.allclose(gamma2.cpu().numpy(), gamma2_ls, rtol, atol) + + def lens_test_helper( lens: Union[ThinLens, ThickLens], lens_ls: LensModel, @@ -245,6 +259,8 @@ def lens_test_helper( test_alpha=True, test_Psi=True, test_kappa=True, + test_shear=True, + shear_egregious=False, device=None, ): if device is not None: @@ -260,3 +276,16 @@ def lens_test_helper( if test_kappa: kappa_test_helper(lens, lens_ls, z_s, x, kwargs_ls, atol, rtol, device=device) + + if test_shear: + shear_test_helper( + lens, + lens_ls, + z_s, + x, + kwargs_ls, + atol, + rtol * 10, + just_egregious=shear_egregious, + device=device, + ) # shear seems less precise than other measurements