From 044febb5dfe06a4b08b8cbd02c77de59d6f5218e Mon Sep 17 00:00:00 2001 From: CKrawczyk Date: Fri, 28 Jun 2024 10:25:26 +0100 Subject: [PATCH] When using JAX don't import Numpy Switches out `numpy` for `jax.numpy`. --- autogalaxy/profiles/geometry_profiles.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/autogalaxy/profiles/geometry_profiles.py b/autogalaxy/profiles/geometry_profiles.py index c75b8cdb..d036b8ce 100644 --- a/autogalaxy/profiles/geometry_profiles.py +++ b/autogalaxy/profiles/geometry_profiles.py @@ -1,6 +1,11 @@ +import os + from typing import Optional, Tuple, Type -import numpy as np +if os.environ.get("USE_JAX", "0") == "1": + import jax.numpy as np +else: + import numpy as np import autoarray as aa