Skip to content

Commit

Permalink
Merge pull request #202 from Jammy2211/feature/jax_tracer
Browse files Browse the repository at this point in the history
Update functions for Tracer to be jax safe
  • Loading branch information
CKrawczyk authored Oct 25, 2024
2 parents 8e4d556 + b6d157c commit e591a26
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
24 changes: 19 additions & 5 deletions autogalaxy/convert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from astropy import units
from autofit.jax_wrapper import numpy as np
from autofit.jax_wrapper import numpy as np, use_jax
from typing import Tuple

if use_jax:
import jax


def ell_comps_from(axis_ratio: float, angle: float) -> Tuple[float, float]:
"""
Expand Down Expand Up @@ -62,12 +65,23 @@ def axis_ratio_and_angle_from(ell_comps: Tuple[float, float]) -> Tuple[float, fl
angle = np.arctan2(ell_comps[0], ell_comps[1]) / 2
angle *= 180.0 / np.pi

if abs(angle) > 45 and angle < 0:
angle += 180
if use_jax:
angle = jax.lax.select(
angle < -45,
angle + 180,
angle
)
else:
if abs(angle) > 45 and angle < 0:
angle += 180

fac = np.sqrt(ell_comps[1] ** 2 + ell_comps[0] ** 2)
if fac > 0.999:
fac = 0.999 # avoid unphysical solution
if use_jax:
fac = jax.lax.min(fac, 0.999)
else:
fac = min(fac, 0.999)
# if fac > 0.999:
# fac = 0.999 # avoid unphysical solution
# if fac > 1: print('unphysical e1,e2')
axis_ratio = (1 - fac) / (1 + fac)
return axis_ratio, angle
Expand Down
3 changes: 3 additions & 0 deletions autogalaxy/profiles/mass/total/isothermal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
if os.environ.get("USE_JAX", "0") == "1":
USING_JAX = True
import jax.numpy as np
from jax.lax import min
else:
USING_JAX = False
import numpy as np
Expand Down Expand Up @@ -110,6 +111,8 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, **kwargs):
* self.axis_ratio
/ np.sqrt(1 - self.axis_ratio**2)
)
if USING_JAX:
grid = grid.array

psi = psi_from(grid=grid, axis_ratio=self.axis_ratio, core_radius=0.0)

Expand Down

0 comments on commit e591a26

Please sign in to comment.