From b6d157c6e28398cfdf1431be853e49b0924970d0 Mon Sep 17 00:00:00 2001 From: CKrawczyk Date: Fri, 25 Oct 2024 13:15:37 +0100 Subject: [PATCH] Update functions for Tracer to be jax safe Needed to make the `autolens.Tracer` example work. --- autogalaxy/convert.py | 24 ++++++++++++++++---- autogalaxy/profiles/mass/total/isothermal.py | 3 +++ 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/autogalaxy/convert.py b/autogalaxy/convert.py index 8ff432d7..4c1ae929 100644 --- a/autogalaxy/convert.py +++ b/autogalaxy/convert.py @@ -1,7 +1,10 @@ from astropy import units -from autofit.jax_wrapper import numpy as np +from autofit.jax_wrapper import numpy as np, use_jax from typing import Tuple +if use_jax: + import jax + def ell_comps_from(axis_ratio: float, angle: float) -> Tuple[float, float]: """ @@ -62,12 +65,23 @@ def axis_ratio_and_angle_from(ell_comps: Tuple[float, float]) -> Tuple[float, fl angle = np.arctan2(ell_comps[0], ell_comps[1]) / 2 angle *= 180.0 / np.pi - if abs(angle) > 45 and angle < 0: - angle += 180 + if use_jax: + angle = jax.lax.select( + angle < -45, + angle + 180, + angle + ) + else: + if abs(angle) > 45 and angle < 0: + angle += 180 fac = np.sqrt(ell_comps[1] ** 2 + ell_comps[0] ** 2) - if fac > 0.999: - fac = 0.999 # avoid unphysical solution + if use_jax: + fac = jax.lax.min(fac, 0.999) + else: + fac = min(fac, 0.999) + # if fac > 0.999: + # fac = 0.999 # avoid unphysical solution # if fac > 1: print('unphysical e1,e2') axis_ratio = (1 - fac) / (1 + fac) return axis_ratio, angle diff --git a/autogalaxy/profiles/mass/total/isothermal.py b/autogalaxy/profiles/mass/total/isothermal.py index 79f37893..f5073bd2 100644 --- a/autogalaxy/profiles/mass/total/isothermal.py +++ b/autogalaxy/profiles/mass/total/isothermal.py @@ -3,6 +3,7 @@ if os.environ.get("USE_JAX", "0") == "1": USING_JAX = True import jax.numpy as np + from jax.lax import min else: USING_JAX = False import numpy as np @@ -110,6 +111,8 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, **kwargs): * self.axis_ratio / np.sqrt(1 - self.axis_ratio**2) ) + if USING_JAX: + grid = grid.array psi = psi_from(grid=grid, axis_ratio=self.axis_ratio, core_radius=0.0)