diff --git a/swirl_dynamics/lib/solvers/ode.py b/swirl_dynamics/lib/solvers/ode.py index 4a0f561..4f19a03 100644 --- a/swirl_dynamics/lib/solvers/ode.py +++ b/swirl_dynamics/lib/solvers/ode.py @@ -14,6 +14,7 @@ """Solvers for ordinary differential equations (ODEs).""" +import dataclasses from typing import Any, Protocol import flax @@ -45,8 +46,16 @@ def __call__( ... +@dataclasses.dataclass class ScanOdeSolver: - """ODE solver based on `jax.lax.scan`.""" + """ODE solver based on `jax.lax.scan`. + + Attributes: + time_axis_pos: move the time axis to the specified position in the output + tensor (by default it is at the 0th position). + """ + + time_axis_pos: int = 0 def step( self, func: OdeDynamics, x0: Array, t0: Array, dt: Array, params: PyTree @@ -68,7 +77,10 @@ def scan_fun( return (x_next, t_next), x_next _, out = jax.lax.scan(scan_fun, (x0, tspan[0]), tspan[1:]) - return jnp.concatenate([x0[None], out], axis=0) + out = jnp.concatenate([x0[None], out], axis=0) + if self.time_axis_pos: + out = jnp.moveaxis(out, 0, self.time_axis_pos) + return out class ExplicitEuler(ScanOdeSolver): diff --git a/swirl_dynamics/lib/solvers/ode_test.py b/swirl_dynamics/lib/solvers/ode_test.py index e07dbcf..5701985 100644 --- a/swirl_dynamics/lib/solvers/ode_test.py +++ b/swirl_dynamics/lib/solvers/ode_test.py @@ -46,6 +46,20 @@ def test_output_shape_and_value(self, solver, backward): self.assertEqual(out.shape, (num_steps, x_dim)) np.testing.assert_allclose(out[-1], np.ones((x_dim,)) * tspan[-1]) + def test_move_time_axis_pos(self): + dt = 0.1 + num_steps = 10 + x_dim = 5 + batch_sz = 6 + tspan = jnp.arange(num_steps) * dt + out = ode.ExplicitEuler(time_axis_pos=1)( + dummy_ode_dynamics, jnp.zeros((batch_sz, x_dim)), tspan, {} + ) + self.assertEqual(out.shape, (batch_sz, num_steps, x_dim)) + np.testing.assert_allclose( + out[:, -1], np.ones((batch_sz, x_dim)) * tspan[-1] + ) + @parameterized.parameters((np.arange(10) * -1,), (np.zeros(10),)) def test_dopri45_backward_error(self, tspan): tspan = jnp.asarray(tspan)