Skip to content

Commit

Permalink
Jax version of MGE Sersic profile
Browse files Browse the repository at this point in the history
This adds the MGE code needed for the Sersic profile.
  • Loading branch information
CKrawczyk committed Sep 18, 2024
1 parent 044febb commit 5e30644
Show file tree
Hide file tree
Showing 6 changed files with 576 additions and 285 deletions.
8 changes: 4 additions & 4 deletions autogalaxy/profiles/geometry_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(self, centre: Tuple[float, float] = (0.0, 0.0)):
def radial_grid_from(self, grid: aa.type.Grid2DLike, **kwargs) -> np.ndarray:
"""
Convert a grid of (y, x) coordinates, to their radial distances from the profile
centre (e.g. :math: r = x**2 + y**2).
centre (e.g. :math: r = sqrt(x**2 + y**2)).
Parameters
----------
Expand Down Expand Up @@ -311,7 +311,7 @@ def elliptical_radii_grid_from(
"""
return np.sqrt(
np.add(
np.square(grid[:, 1]), np.square(np.divide(grid[:, 0], self.axis_ratio))
np.square(grid.array[:, 1]), np.square(np.divide(grid.array[:, 0], self.axis_ratio))
)
)

Expand All @@ -334,9 +334,9 @@ def eccentric_radii_grid_from(
The (y, x) coordinates in the reference frame of the elliptical profile.
"""

grid_radii = self.elliptical_radii_grid_from(grid=grid, **kwargs)
grid_radii = self.elliptical_radii_grid_from(grid=grid, **kwargs).array

return np.multiply(np.sqrt(self.axis_ratio), grid_radii).view(np.ndarray)
return np.multiply(np.sqrt(self.axis_ratio), grid_radii)#.view(np.ndarray)

@aa.grid_dec.to_grid
def transformed_to_reference_frame_grid_from(
Expand Down
119 changes: 119 additions & 0 deletions autogalaxy/profiles/mass/abstract/jax_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import jax
import jax.numpy as jnp
import numpy as np

from jax import custom_jvp
from jax.scipy.special import gammaln


def reg1(z, _, i_sqrt_pi):
return i_sqrt_pi / z


def reg2(z, _, i_sqrt_pi):
z2 = z**2
return i_sqrt_pi * z / (z2 - 0.5)


def reg3(z, _, i_sqrt_pi):
z2 = z**2
return (i_sqrt_pi / z) * (1 + 0.5 / (z2 - 1.5))


def reg4(z, _, i_sqrt_pi):
z2 = z**2
return (i_sqrt_pi * z) * (z2 - 2.5) / (z2 * (z2 - 3.0) + 0.75)


def reg5(z, sqrt_pi, _):
mz2 = -z**2
f1 = sqrt_pi
f2 = 1.0
s1 = [1.320522, 35.7668, 219.031, 1540.787, 3321.99, 36183.31]
s2 = [1.841439, 61.57037, 364.2191, 2186.181, 9022.228, 24322.84, 32066.6]

for s in s1:
f1 = s - f1 * mz2
for s in s2:
f2 = s - f2 * mz2

return jnp.exp(mz2) + 1j * z * f1 / f2


def reg6(z, sqrt_pi, _):
miz = -1j * z
f1 = sqrt_pi
f2 = 1
s1 = [5.9126262, 30.180142, 93.15558, 181.92853, 214.38239, 122.60793]
s2 = [10.479857, 53.992907, 170.35400, 348.70392, 457.33448, 352.73063, 122.60793]

for s in s1:
f1 = s + f1 * miz
for s in s2:
f2 = s + f2 * miz

return f1 / f2


@custom_jvp
def w_f_approx(z):
"""Compute the Faddeeva function :math:`w_{\\mathrm F}(z)` using the
approximation given in Zaghloul (2017).
:param z: complex number
:type z: ``complex`` or ``numpy.array(dtype=complex)``
:return: :math:`w_\\mathrm{F}(z)`
:rtype: ``complex``
# This function is a JAX conversion of
# "https://github.com/sibirrer/lenstronomy/tree/master/lenstronomy/LensModel/Profiles"
# original function written by Anowar J. Shajib (see 1906.08263)
# JAX conversion written by Coleman M. Krawczyk
"""
sqrt_pi = 1 / jnp.sqrt(jnp.pi)
i_sqrt_pi = 1j * sqrt_pi

z_imag2 = z.imag**2
abs_z2 = z.real**2 + z_imag2

r1 = abs_z2 >= 38000.0
r2 = (abs_z2 >= 256.0) & (abs_z2 < 38000.0)
r3 = (abs_z2 >= 62.0) & (abs_z2 < 256.0)
r4 = (abs_z2 >= 30.0) & (abs_z2 < 62.0) & (z_imag2 >= 1e-13)
# region bounds for 5 taken directly from Zaghloul (2017)
# https://dl.acm.org/doi/pdf/10.1145/3119904
r5_1 = (abs_z2 >= 30.0) & (abs_z2 < 62.0) & (z_imag2 < 1e-13)
r5_2 = (abs_z2 >= 2.5) & (abs_z2 < 30.0) & (z_imag2 < 0.072)
r5 = r5_1 | r5_2
r6 = (abs_z2 < 30.0) & jnp.logical_not(r5)

args = (z, sqrt_pi, i_sqrt_pi)
wz = jnp.empty_like(z)
wz = jnp.where(r1, reg1(*args), wz)
wz = jnp.where(r2, reg2(*args), wz)
wz = jnp.where(r3, reg3(*args), wz)
wz = jnp.where(r4, reg4(*args), wz)
wz = jnp.where(r5, reg5(*args), wz)
wz = jnp.where(r6, reg6(*args), wz)
return wz


@w_f_approx.defjvp
def w_f_approx_jvp(primals, tangents):
# define a custom jvp to avoid the issue using `jnp.where` with `jax.grad`
# also the derivative is defined analytically for this function so bypass
# auto diffing over the complex functions above.
z, = primals
z_dot, = tangents
primal_out = w_f_approx(z)
i_sqrt_pi = 1j / jnp.sqrt(jnp.pi)
tangent_out = z_dot * 2 * (i_sqrt_pi - z * primal_out)
return primal_out, tangent_out


def comb(x: int, y: int) -> int:
# use the gamma function definition as that is the only
# JAX friendly way to do this (internally the factorial function
# uses this method as well). Round to closest int at the end of the
# calculation as we only use this for int inputs anyways.
return jnp.exp(gammaln(x + 1) - gammaln(y + 1) - gammaln(x - y + 1)).round(1)
Loading

0 comments on commit 5e30644

Please sign in to comment.