Skip to content

Commit

Permalink
Merge pull request #282 from Ciela-Institute/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
ConnorStoneAstro authored Nov 7, 2024
2 parents c776167 + 3e6c45f commit 5e2fbb2
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 32 deletions.
54 changes: 26 additions & 28 deletions src/caustics/lenses/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,31 @@ def jacobian_lens_equation(
else:
raise ValueError("method should be one of: autograd, finitediff")

@unpack
def shear(
self,
x: Tensor,
y: Tensor,
z_s: Tensor,
*args,
params: Optional["Packed"] = None,
method="autograd",
pixelscale: Optional[Tensor] = None,
**kwargs,
):
"""
General shear calculation for a lens model using the jacobian of the
lens equation. Individual lenses may implement more efficient methods.
"""
A = self.jacobian_lens_equation(
x, y, z_s, params=params, method=method, pixelscale=pixelscale
)
I = torch.eye(2, device=A.device, dtype=A.dtype).reshape( # noqa E741
*[1] * len(A.shape[:-2]), 2, 2
)
negPsi = 0.5 * (A[..., 0, 0] + A[..., 1, 1]).unsqueeze(-1).unsqueeze(-1) * I - A
return 0.5 * (negPsi[..., 0, 0] - negPsi[..., 1, 1]), negPsi[..., 0, 1]

@unpack
def magnification(
self,
Expand Down Expand Up @@ -193,34 +218,7 @@ def forward_raytrace(
x0 = torch.zeros((), device=bx.device, dtype=bx.dtype)
if y0 is None:
y0 = torch.zeros((), device=by.device, dtype=by.dtype)
# X = torch.stack((x0, y0)).repeat(4, 1)
# X[0] -= fov / 2
# X[1][0] -= fov / 2
# X[1][1] += fov / 2
# X[2][0] += fov / 2
# X[2][1] -= fov / 2
# X[3] += fov / 2

# Sx, Sy = raytrace(X[..., 0], X[..., 1])
# S = torch.stack((Sx, Sy)).T
# res1, ap1 = func.triangle_search(
# torch.stack((bx, by)),
# X[:3],
# S[:3],
# raytrace,
# epsilon,
# torch.zeros((0, 2)),
# )
# res2, ap2 = func.triangle_search(
# torch.stack((bx, by)),
# X[1:],
# S[1:],
# raytrace,
# epsilon,
# torch.zeros((0, 2)),
# )
# res = torch.cat((res1, res2), dim=0)
# return res[:, 0], res[:, 1], torch.cat((ap1, ap2), dim=0)

return func.forward_raytrace(
torch.stack((bx, by)), raytrace, x0, y0, fov, divisions, epsilon
)
Expand Down
4 changes: 3 additions & 1 deletion src/caustics/sims/lens_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,9 @@ def forward(
elif self.psf_mode == "conv2d":
mu = (
conv2d(
mu[None, None], (psf.T / psf.sum())[None, None], padding="same"
mu[None, None],
(torch.flip(psf, (0, 1)) / psf.sum())[None, None],
padding="same",
)
.squeeze(0)
.squeeze(0)
Expand Down
15 changes: 13 additions & 2 deletions tests/test_nfw.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ def test_nfw(sim_source, device, lens_models, m, c):
z_l: {float(z_l)}
init_kwargs:
cosmology: *cosmology
use_case: differentiable
"""
yaml_dict = yaml.safe_load(yaml_str.encode("utf-8"))
mod = lens_models.get("NFW")
lens = mod(**yaml_dict["lens"]).model_obj()
else:
# Models
cosmology = CausticFlatLambdaCDM(name="cosmo")
lens = NFW(name="nfw", cosmology=cosmology, z_l=z_l)
lens = NFW(name="nfw", cosmology=cosmology, z_l=z_l, use_case="differentiable")
lens_model_list = ["NFW"]
lens_ls = LensModel(lens_model_list=lens_model_list)

Expand All @@ -72,7 +73,17 @@ def test_nfw(sim_source, device, lens_models, m, c):
{"Rs": Rs_angle, "alpha_Rs": alpha_Rs, "center_x": thx0, "center_y": thy0}
]

lens_test_helper(lens, lens_ls, z_s, x, kwargs_ls, atol, rtol, device=device)
lens_test_helper(
lens,
lens_ls,
z_s,
x,
kwargs_ls,
atol,
rtol,
shear_egregious=True, # not why match is so bad
device=device,
)


def test_runs(sim_source, device, lens_models):
Expand Down
55 changes: 55 additions & 0 deletions tests/test_simulator_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,61 @@ def test_simulator_runs(sim_source, device, mocker):
assert torch.allclose(sim(), sim_q3(), rtol=1e-1)


def test_fft_vs_conv2d():
# Model
cosmology = FlatLambdaCDM(name="cosmo")
lensmass = SIE(
name="lens",
cosmology=cosmology,
z_l=1.0,
x0=0.0,
y0=0.01,
q=0.5,
phi=pi / 3.0,
b=1.0,
)

source = Sersic(
name="source", x0=0.01, y0=-0.03, q=0.6, phi=-pi / 4, n=2.0, Re=0.5, Ie=1.0
)
lenslight = Sersic(
name="lenslight", x0=0.0, y0=0.01, q=0.7, phi=pi / 4, n=3.0, Re=0.7, Ie=1.0
)

psf = gaussian(0.05, 11, 11, 0.2, upsample=2)
psf[3, 4] = 0.1 # make PSF asymmetric
psf /= psf.sum()

sim_fft = LensSource(
name="simulatorfft",
lens=lensmass,
source=source,
pixelscale=0.05,
pixels_x=50,
lens_light=lenslight,
psf=psf,
psf_mode="fft",
z_s=2.0,
quad_level=3,
)

sim_conv2d = LensSource(
name="simulatorconv2d",
lens=lensmass,
source=source,
pixelscale=0.05,
pixels_x=50,
lens_light=lenslight,
psf=psf,
psf_mode="conv2d",
z_s=2.0,
quad_level=3,
)

print(torch.max(torch.abs((sim_fft() - sim_conv2d()) / sim_fft())))
assert torch.allclose(sim_fft(), sim_conv2d(), rtol=1e-1)


def test_microlens_simulator_runs():
cosmology = FlatLambdaCDM()
sie = SIE(cosmology=cosmology, name="lens")
Expand Down
2 changes: 2 additions & 0 deletions tests/test_tnfw.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def test(sim_source, device, lens_models, m, c, t):
test_alpha=True,
test_Psi=False,
test_kappa=True,
test_shear=True,
shear_egregious=True, # not sure why match is so bad
device=device,
)

Expand Down
31 changes: 30 additions & 1 deletion tests/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ def alpha_test_helper(lens, lens_ls, z_s, x, kwargs_ls, atol, rtol, device=None)
thx, thy, thx_ls, thy_ls = setup_grids(device=device)
alpha_x, alpha_y = lens.reduced_deflection_angle(thx, thy, z_s, x)
alpha_x_ls, alpha_y_ls = lens_ls.alpha(thx_ls, thy_ls, kwargs_ls)
print(np.sum(np.abs(1 - alpha_x.cpu().numpy() / alpha_x_ls) > 1e-3))
assert np.allclose(alpha_x.cpu().numpy(), alpha_x_ls, rtol, atol)
assert np.allclose(alpha_y.cpu().numpy(), alpha_y_ls, rtol, atol)

Expand All @@ -234,6 +233,21 @@ def kappa_test_helper(lens, lens_ls, z_s, x, kwargs_ls, atol, rtol, device=None)
assert np.allclose(kappa.cpu().numpy(), kappa_ls, rtol, atol)


def shear_test_helper(
lens, lens_ls, z_s, x, kwargs_ls, atol, rtol, just_egregious=False, device=None
):
thx, thy, thx_ls, thy_ls = setup_grids(device=device)
gamma1, gamma2 = lens.shear(thx, thy, z_s, x)
gamma1_ls, gamma2_ls = lens_ls.gamma(thx_ls, thy_ls, kwargs_ls)
if just_egregious:
print(np.sum(np.abs(np.log10(np.abs(1 - gamma1.cpu().numpy() / gamma1_ls))) < 1))
assert np.sum(np.abs(np.log10(np.abs(1 - gamma1.cpu().numpy() / gamma1_ls))) < 1) < 1000
assert np.sum(np.abs(np.log10(np.abs(1 - gamma2.cpu().numpy() / gamma2_ls))) < 1) < 1000
else:
assert np.allclose(gamma1.cpu().numpy(), gamma1_ls, rtol, atol)
assert np.allclose(gamma2.cpu().numpy(), gamma2_ls, rtol, atol)


def lens_test_helper(
lens: Union[ThinLens, ThickLens],
lens_ls: LensModel,
Expand All @@ -245,6 +259,8 @@ def lens_test_helper(
test_alpha=True,
test_Psi=True,
test_kappa=True,
test_shear=True,
shear_egregious=False,
device=None,
):
if device is not None:
Expand All @@ -260,3 +276,16 @@ def lens_test_helper(

if test_kappa:
kappa_test_helper(lens, lens_ls, z_s, x, kwargs_ls, atol, rtol, device=device)

if test_shear:
shear_test_helper(
lens,
lens_ls,
z_s,
x,
kwargs_ls,
atol,
rtol * 10,
just_egregious=shear_egregious,
device=device,
) # shear seems less precise than other measurements

0 comments on commit 5e2fbb2

Please sign in to comment.