diff --git a/jax_cosmo/background.py b/jax_cosmo/background.py index ba7566c..c76f0c4 100644 --- a/jax_cosmo/background.py +++ b/jax_cosmo/background.py @@ -1,10 +1,10 @@ # This module implements various functions for the background COSMOLOGY import jax.numpy as np from jax import lax +from jax.experimental.ode import odeint import jax_cosmo.constants as const from jax_cosmo.scipy.interpolate import interp -from jax_cosmo.scipy.ode import odeint __all__ = [ "w", diff --git a/jax_cosmo/core.py b/jax_cosmo/core.py index badf7a3..d09c88d 100644 --- a/jax_cosmo/core.py +++ b/jax_cosmo/core.py @@ -1,5 +1,4 @@ import jax.numpy as np -from jax.experimental.ode import odeint from jax.tree_util import register_pytree_node_class import jax_cosmo.constants as const diff --git a/jax_cosmo/scipy/ode.py b/jax_cosmo/scipy/ode.py deleted file mode 100644 index c7a66d8..0000000 --- a/jax_cosmo/scipy/ode.py +++ /dev/null @@ -1,22 +0,0 @@ -# this module stores custom ode code -import jax -import jax.numpy as np - - -def odeint(fn, y0, t): - """ - My dead-simple rk4 ODE solver. with no custom gradients - """ - - def rk4(carry, t): - y, t_prev = carry - h = t - t_prev - k1 = fn(y, t_prev) - k2 = fn(y + h * k1 / 2, t_prev + h / 2) - k3 = fn(y + h * k2 / 2, t_prev + h / 2) - k4 = fn(y + h * k3, t) - y = y + 1.0 / 6.0 * h * (k1 + 2 * k2 + 2 * k3 + k4) - return (y, t), y - - (yf, _), y = jax.lax.scan(rk4, (y0, np.array(t[0])), t) - return y