diff --git a/src/caustics/lenses/utils.py b/src/caustics/lenses/utils.py index fab3e8aa..c33e84b8 100644 --- a/src/caustics/lenses/utils.py +++ b/src/caustics/lenses/utils.py @@ -3,7 +3,6 @@ import torch from torch import Tensor -from ..utils import vmap_n __all__ = ("get_pix_jacobian", "get_pix_magnification", "get_magnification") @@ -119,4 +118,9 @@ def get_magnification(raytrace, x, y, z_s) -> Tensor: *Unit: unitless* """ - return vmap_n(get_pix_magnification, 2, (None, 0, 0, None))(raytrace, x, y, z_s) + return torch.reshape( + torch.func.vmap(get_pix_magnification, in_dims=(None, 0, 0, None))( + raytrace, x.reshape(-1), y.reshape(-1), z_s + ), + x.shape, + ) diff --git a/tests/test_base.py b/tests/test_base.py index 02e8c776..94ad97d8 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -39,6 +39,45 @@ def test(device): assert torch.all((sp_y - by).abs() < 1e-3) +def test_magnification(device): + z_l = torch.tensor(0.5, dtype=torch.float32, device=device) + z_s = torch.tensor(1.5, dtype=torch.float32, device=device) + + # Model + cosmology = FlatLambdaCDM(name="cosmo") + lens = SIE( + name="sie", + cosmology=cosmology, + z_l=z_l, + x0=torch.tensor(0.0), + y0=torch.tensor(0.0), + q=torch.tensor(0.4), + phi=torch.tensor(np.pi / 5), + b=torch.tensor(1.0), + ) + # Send to device + lens = lens.to(device) + + # Point in image plane + x = torch.tensor(0.1, device=device) + y = torch.tensor(0.1, device=device) + + mag = lens.magnification(x, y, z_s) + + assert np.isfinite(mag.item()) + assert mag.item() > 0 + + # grid in image plane + x = torch.linspace(-0.1, 0.1, 10, device=device) + y = torch.linspace(-0.1, 0.1, 10, device=device) + x, y = torch.meshgrid(x, y, indexing="ij") + + mag = lens.magnification(x, y, z_s) + + assert np.all(np.isfinite(mag.detach().cpu().numpy())) + assert np.all(mag.detach().cpu().numpy() > 0) + + def test_quicktest(device): """ Quick test to check that the built-in `test` module is working