Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: better use of autograd for base lens quantities #306

Merged
merged 13 commits into from
Jan 25, 2025
Merged
18 changes: 6 additions & 12 deletions src/caustics/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,9 @@
physical_deflection_angle_nfw,
potential_nfw,
convergence_nfw,
_f_batchable_nfw,
_f_differentiable_nfw,
_g_batchable_nfw,
_g_differentiable_nfw,
_h_batchable_nfw,
_h_differentiable_nfw,
_f_nfw,
_g_nfw,
_h_nfw,
reduced_deflection_angle_pixelated_convergence,
potential_pixelated_convergence,
_fft2_padded,
Expand Down Expand Up @@ -79,12 +76,9 @@
"physical_deflection_angle_nfw",
"potential_nfw",
"convergence_nfw",
"_f_batchable_nfw",
"_f_differentiable_nfw",
"_g_batchable_nfw",
"_g_differentiable_nfw",
"_h_batchable_nfw",
"_h_differentiable_nfw",
"_f_nfw",
"_g_nfw",
"_h_nfw",
"reduced_deflection_angle_pixelated_convergence",
"potential_pixelated_convergence",
"_fft2_padded",
Expand Down
156 changes: 31 additions & 125 deletions src/caustics/lenses/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@
self,
x: Tensor,
y: Tensor,
z_s: Annotated[Tensor, "Param"],
) -> Tensor:
"""
Compute the gravitational magnification at the given coordinates.
Expand Down Expand Up @@ -162,14 +161,6 @@

*Unit: arcsec*

z_s: Tensor
Tensor of source redshifts.

*Unit: unitless*

params: Packed, optional
Dynamic parameter container for the lens model. Defaults to None.

epsilon: Tensor
maximum distance between two images (arcsec) before they are
considered the same image.
Expand Down Expand Up @@ -289,14 +280,6 @@

*Unit: arcsec*

z_s: Tensor
Tensor of source redshifts.

*Unit: unitless*

params: Packed, optional
Dynamic parameter container for the lens model. Defaults to None.

"""
bx, by = self.raytrace(x, y, **kwargs)
return x - bx, y - by
Expand Down Expand Up @@ -325,14 +308,6 @@

*Unit: arcsec*

z_s: Tensor
Tensor of source redshifts.

*Unit: unitless*

params: Packed, optional
Dynamic parameter container for the lens model. Defaults to None.

Returns
-------
x_component: Tensor
Expand Down Expand Up @@ -492,40 +467,16 @@
Return the jacobian of the effective reduced deflection angle vector field.
This equates to a (2,2) matrix at each (x,y) point.
"""

# Build Jacobian
J = torch.zeros((*x.shape, 2, 2), device=x.device, dtype=x.dtype)

# Compute deflection angle gradients
dax_dx = torch.func.grad(
lambda *a: self.effective_reduced_deflection_angle(*a)[0], argnums=0
)
J[..., 0, 0] = torch.vmap(dax_dx, chunk_size=chunk_size)(
x.flatten(), y.flatten()
).reshape(x.shape)

dax_dy = torch.func.grad(
lambda *a: self.effective_reduced_deflection_angle(*a)[0], argnums=1
)
J[..., 0, 1] = torch.vmap(dax_dy, chunk_size=chunk_size)(
x.flatten(), y.flatten()
).reshape(x.shape)

day_dx = torch.func.grad(
lambda *a: self.effective_reduced_deflection_angle(*a)[1], argnums=0
)
J[..., 1, 0] = torch.vmap(day_dx, chunk_size=chunk_size)(
x.flatten(), y.flatten()
).reshape(x.shape)

day_dy = torch.func.grad(
lambda *a: self.effective_reduced_deflection_angle(*a)[1], argnums=1
)
J[..., 1, 1] = torch.vmap(day_dy, chunk_size=chunk_size)(
x.flatten(), y.flatten()
).reshape(x.shape)

return J.detach()
J = torch.vmap(
torch.func.jacfwd(
self.effective_reduced_deflection_angle,
argnums=(0, 1),
randomness="different",
),
chunk_size=chunk_size,
)(x.flatten(), y.flatten())
J = torch.stack([torch.stack(Jrow, dim=-1) for Jrow in J], dim=-2)
return J.reshape(*x.shape, 2, 2)

@forward
def jacobian_effective_deflection_angle(
Expand Down Expand Up @@ -670,8 +621,6 @@
self,
x: Tensor,
y: Tensor,
z_s: Annotated[Tensor, "Param"],
z_l: Annotated[Tensor, "Param"],
) -> tuple[Tensor, Tensor]:
"""
Computes the reduced deflection angle of the lens at given coordinates [arcsec].
Expand Down Expand Up @@ -701,12 +650,11 @@
*Unit: arcsec*

"""
d_s = self.cosmology.angular_diameter_distance(z_s)
d_ls = self.cosmology.angular_diameter_distance_z1z2(z_l, z_s)
deflection_angle_x, deflection_angle_y = self.physical_deflection_angle(x, y)
return func.reduced_from_physical_deflection_angle(
deflection_angle_x, deflection_angle_y, d_s, d_ls
)
ax, ay = torch.vmap(
torch.func.grad(self.potential, (0, 1)),
chunk_size=10000,
)(x.flatten(), y.flatten())
return ax.reshape(x.shape), ay.reshape(y.shape)

@forward
def physical_deflection_angle(
Expand Down Expand Up @@ -775,14 +723,6 @@

*Unit: arcsec*

z_s: Tensor
Tensor of source redshifts.

*Unit: unitless*

params: (Packed, optional)
Dynamic parameter container for the lens model. Defaults to None.

Returns
-------
Tensor
Expand All @@ -791,7 +731,13 @@
*Unit: unitless*

"""
...
Psi_H = torch.vmap(
torch.func.hessian(self.potential, (0, 1)),
chunk_size=10000,
)(x.flatten(), y.flatten())
Psi_H = torch.stack([torch.stack(Hrow, dim=-1) for Hrow in Psi_H], dim=-2)
Psi_H = Psi_H.reshape(*x.shape, 2, 2)
return 0.5 * (Psi_H[..., 0, 0] + Psi_H[..., 1, 1]).reshape(x.shape)

@abstractmethod
@forward
Expand All @@ -817,14 +763,6 @@

*Unit: arcsec*

z_s: Tensor
Tensor of source redshifts.

*Unit: unitless*

params: (Packed, optional)
Dynamic parameter container for the lens model. Defaults to None.

Returns
-------
Tensor
Expand Down Expand Up @@ -858,14 +796,6 @@

*Unit: arcsec*

z_s: Tensor
Tensor of source redshifts.

*Unit: unitless*

params: (Packed, optional)
Dynamic parameter container for the lens model. Defaults to None.

Returns
-------
Tensor
Expand All @@ -875,7 +805,7 @@

"""
critical_surface_density = self.cosmology.critical_surface_density(z_l, z_s)
return self.convergence(x, y, z_s) * critical_surface_density # fmt: skip
return self.convergence(x, y) * critical_surface_density # fmt: skip

Check warning on line 808 in src/caustics/lenses/base.py

View check run for this annotation

Codecov / codecov/patch

src/caustics/lenses/base.py#L808

Added line #L808 was not covered by tests

@forward
def raytrace(
Expand Down Expand Up @@ -1028,39 +958,15 @@
Return the jacobian of the deflection angle vector.
This equates to a (2,2) matrix at each (x,y) point.
"""
# Build Jacobian
J = torch.zeros((*x.shape, 2, 2), device=x.device, dtype=x.dtype)

# Compute deflection angle gradients
dax_dx = torch.func.grad(
lambda *a: self.reduced_deflection_angle(*a)[0], argnums=0
)
J[..., 0, 0] = torch.vmap(dax_dx, chunk_size=chunk_size)(
x.flatten(), y.flatten()
).reshape(x.shape)

dax_dy = torch.func.grad(
lambda *a: self.reduced_deflection_angle(*a)[0], argnums=1
)
J[..., 0, 1] = torch.vmap(dax_dy, chunk_size=chunk_size)(
x.flatten(), y.flatten()
).reshape(x.shape)

day_dx = torch.func.grad(
lambda *a: self.reduced_deflection_angle(*a)[1], argnums=0
)
J[..., 1, 0] = torch.vmap(day_dx, chunk_size=chunk_size)(
x.flatten(), y.flatten()
).reshape(x.shape)

day_dy = torch.func.grad(
lambda *a: self.reduced_deflection_angle(*a)[1], argnums=1
)
J[..., 1, 1] = torch.vmap(day_dy, chunk_size=chunk_size)(
x.flatten(), y.flatten()
).reshape(x.shape)

return J.detach()
J = torch.vmap(
torch.func.jacfwd(
self.reduced_deflection_angle, argnums=(0, 1), randomness="different"
),
chunk_size=chunk_size,
)(x.flatten(), y.flatten())
J = torch.stack([torch.stack(Jrow, dim=-1) for Jrow in J], dim=-2)
return J.reshape(*x.shape, 2, 2)

@forward
def jacobian_deflection_angle(
Expand Down
15 changes: 14 additions & 1 deletion src/caustics/lenses/enclosed_mass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from caskade import forward, Param

from .base import ThinLens, CosmologyType, NameType, ZType
from .func import physical_deflection_angle_enclosed_mass, convergence_enclosed_mass
from .func import (
physical_deflection_angle_enclosed_mass,
convergence_enclosed_mass,
reduced_from_physical_deflection_angle,
)

__all__ = ("EnclosedMass",)

Expand Down Expand Up @@ -158,6 +162,15 @@ def physical_deflection_angle(
x0, y0, q, phi, lambda r: self.enclosed_mass(r, p), x, y, self.s
)

@forward
def reduced_deflection_angle(self, x, y, z_s, z_l):
d_s = self.cosmology.angular_diameter_distance(z_s)
d_ls = self.cosmology.angular_diameter_distance_z1z2(z_l, z_s)
deflection_angle_x, deflection_angle_y = self.physical_deflection_angle(x, y)
return reduced_from_physical_deflection_angle(
deflection_angle_x, deflection_angle_y, d_s, d_ls
)

@forward
def potential(
self,
Expand Down
18 changes: 6 additions & 12 deletions src/caustics/lenses/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,9 @@
convergence_nfw,
scale_radius_nfw,
scale_density_nfw,
_f_batchable_nfw,
_f_differentiable_nfw,
_g_batchable_nfw,
_g_differentiable_nfw,
_h_batchable_nfw,
_h_differentiable_nfw,
_f_nfw,
_g_nfw,
_h_nfw,
)
from .pixelated_convergence import (
reduced_deflection_angle_pixelated_convergence,
Expand Down Expand Up @@ -112,12 +109,9 @@
"convergence_nfw",
"scale_radius_nfw",
"scale_density_nfw",
"_f_batchable_nfw",
"_f_differentiable_nfw",
"_g_batchable_nfw",
"_g_differentiable_nfw",
"_h_batchable_nfw",
"_h_differentiable_nfw",
"_f_nfw",
"_g_nfw",
"_h_nfw",
"reduced_deflection_angle_pixelated_convergence",
"potential_pixelated_convergence",
"_fft2_padded",
Expand Down
Loading
Loading