diff --git a/autogalaxy/profiles/geometry_profiles.py b/autogalaxy/profiles/geometry_profiles.py index c75b8cdb..d036b8ce 100644 --- a/autogalaxy/profiles/geometry_profiles.py +++ b/autogalaxy/profiles/geometry_profiles.py @@ -1,6 +1,11 @@ +import os + from typing import Optional, Tuple, Type -import numpy as np +if os.environ.get("USE_JAX", "0") == "1": + import jax.numpy as np +else: + import numpy as np import autoarray as aa