Skip to content

Commit

Permalink
refactor: Fix magnification input shape to allow any shape (#220)
Browse files Browse the repository at this point in the history
* fix magnification to take any shape

* test for magnification

* reshape magnification works for all cases now
  • Loading branch information
ConnorStoneAstro authored Jun 28, 2024
1 parent 73c5527 commit 2b97f57
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/caustics/lenses/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
from torch import Tensor

from ..utils import vmap_n

__all__ = ("get_pix_jacobian", "get_pix_magnification", "get_magnification")

Expand Down Expand Up @@ -119,4 +118,9 @@ def get_magnification(raytrace, x, y, z_s) -> Tensor:
*Unit: unitless*
"""
return vmap_n(get_pix_magnification, 2, (None, 0, 0, None))(raytrace, x, y, z_s)
return torch.reshape(
torch.func.vmap(get_pix_magnification, in_dims=(None, 0, 0, None))(
raytrace, x.reshape(-1), y.reshape(-1), z_s
),
x.shape,
)
39 changes: 39 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,45 @@ def test(device):
assert torch.all((sp_y - by).abs() < 1e-3)


def test_magnification(device):
z_l = torch.tensor(0.5, dtype=torch.float32, device=device)
z_s = torch.tensor(1.5, dtype=torch.float32, device=device)

# Model
cosmology = FlatLambdaCDM(name="cosmo")
lens = SIE(
name="sie",
cosmology=cosmology,
z_l=z_l,
x0=torch.tensor(0.0),
y0=torch.tensor(0.0),
q=torch.tensor(0.4),
phi=torch.tensor(np.pi / 5),
b=torch.tensor(1.0),
)
# Send to device
lens = lens.to(device)

# Point in image plane
x = torch.tensor(0.1, device=device)
y = torch.tensor(0.1, device=device)

mag = lens.magnification(x, y, z_s)

assert np.isfinite(mag.item())
assert mag.item() > 0

# grid in image plane
x = torch.linspace(-0.1, 0.1, 10, device=device)
y = torch.linspace(-0.1, 0.1, 10, device=device)
x, y = torch.meshgrid(x, y, indexing="ij")

mag = lens.magnification(x, y, z_s)

assert np.all(np.isfinite(mag.detach().cpu().numpy()))
assert np.all(mag.detach().cpu().numpy() > 0)


def test_quicktest(device):
"""
Quick test to check that the built-in `test` module is working
Expand Down

0 comments on commit 2b97f57

Please sign in to comment.