Skip to content

Commit

Permalink
Update isothermal and power law profiles to play nice at (0, 0)
Browse files Browse the repository at this point in the history
Add small radial offset to these two profiles to remove the `nan` values at the center.

For the power law profile use the expansion of the `hyp2f1` function given by Tessore and Metcalf 2015 (eqn 29).
  • Loading branch information
CKrawczyk committed Sep 18, 2024
1 parent 5e30644 commit 3210f4f
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 14 deletions.
29 changes: 21 additions & 8 deletions autogalaxy/profiles/mass/total/isothermal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
import numpy as np
import os

if os.environ.get("USE_JAX", "0") == "1":
USING_JAX = True
import jax.numpy as np
else:
USING_JAX = False
import numpy as np

from typing import Tuple

import autoarray as aa
Expand Down Expand Up @@ -30,14 +38,19 @@ def psi_from(grid, axis_ratio, core_radius):
The value of the Psi term.
"""
return np.sqrt(
np.add(
np.multiply(
axis_ratio**2.0, np.add(np.square(grid[:, 1]), core_radius**2.0)
),
np.square(grid[:, 0]),
if USING_JAX:
return np.sqrt(
(axis_ratio**2.0 * (grid[:, 1]**2.0 + core_radius**2.0)) + grid[:, 0]**2.0 + 1e-16
)
else:
return np.sqrt(
np.add(
np.multiply(
axis_ratio**2.0, np.add(np.square(grid[:, 1]), core_radius**2.0)
),
np.square(grid[:, 0]),
)
)
)


class Isothermal(PowerLaw):
Expand Down
47 changes: 47 additions & 0 deletions autogalaxy/profiles/mass/total/jax_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import jax
import jax.numpy as jnp
from jax.tree_util import Partial as partial


# A version of scan that will *not* re-compile partial functions when variables change
# taken from https://github.com/google/jax/issues/14743#issuecomment-1456900634
scan = jax.jit(jax.lax.scan, static_argnames=('length', 'reverse', 'unroll'))


def body_fun(carry, n, factor, ei2phi, slope):
omega_nm1, partial_sum = carry
two_n = 2 * n
two_minus_slope = 2 - slope
ratio = (two_n - two_minus_slope) / (two_n + two_minus_slope)
omega_n = -factor * ratio * ei2phi * omega_nm1
partial_sum = partial_sum + omega_n
return (omega_n, partial_sum), None


def omega(eiphi, slope, factor, n_terms=20):
'''JAX implementation of the numerical evaluation of the angular component of
the complex deflection angle for the elliptical power law profile as given as
given by Tessore and Metcalf 2015. Based on equation 29, and gives
omega (e.g. can be used as a drop in replacement for the exp(i * phi) * special.hyp2f1
term in equation 13).
Parameters
----------
eiphi:
`exp(i * phi)` where `phi` is the elliptical angle of the profile
slope:
The density slope of the power-law (lower value -> shallower profile, higher value
-> steeper profile).
factor:
The second flattening of and ellipse with axis ration q give by `f = (1 - q) / (1 + q)`
n_terms:
The number of terms to calculate for the series expansion, defaults to 20 (this should
be sufficient most of the time)
'''
# use modified scan with a partial'ed function to avoid re-compile
(_, partial_sum), _ = scan(
partial(body_fun, factor=factor, ei2phi=eiphi**2, slope=slope),
(eiphi, eiphi),
jnp.arange(1, n_terms)
)
return partial_sum
29 changes: 23 additions & 6 deletions autogalaxy/profiles/mass/total/power_law.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
import numpy as np
import os

if os.environ.get("USE_JAX", "0") == "1":
USING_JAX = True
import jax.numpy as np
from .jax_utils import omega
else:
USING_JAX = False
import numpy as np

from scipy import special
from typing import Tuple

Expand Down Expand Up @@ -79,20 +88,28 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, **kwargs):
angle = np.arctan2(
grid[:, 0], np.multiply(self.axis_ratio, grid[:, 1])
) # Note, this angle is not the position angle
R = np.sqrt(
np.add(np.multiply(self.axis_ratio**2, grid[:, 1] ** 2), grid[:, 0] ** 2)
)
z = np.add(
np.multiply(np.cos(angle), 1 + 0j), np.multiply(np.sin(angle), 0 + 1j)
)

if USING_JAX:
# offset radius so calculation is finite at (0, 0)
R = np.sqrt(
(self.axis_ratio * grid[:, 1])**2 + grid[:, 0]**2 + 1e-16
)
zh = omega(z, slope, factor, n_terms=20)
else:
R = np.sqrt(
np.add(np.multiply(self.axis_ratio**2, grid[:, 1] ** 2), grid[:, 0] ** 2)
)
zh = z * special.hyp2f1(1.0, 0.5 * slope, 2.0 - 0.5 * slope, -factor * z**2)

complex_angle = (
2.0
* b
/ (1.0 + self.axis_ratio)
* (b / R) ** (slope - 1.0)
* z
* special.hyp2f1(1.0, 0.5 * slope, 2.0 - 0.5 * slope, -factor * z**2)
* zh
)

deflection_y = complex_angle.imag
Expand Down

0 comments on commit 3210f4f

Please sign in to comment.