diff --git a/diffrax/_solver/foster_langevin_srk.py b/diffrax/_solver/foster_langevin_srk.py index 2a08f34e..7f3cc0cc 100644 --- a/diffrax/_solver/foster_langevin_srk.py +++ b/diffrax/_solver/foster_langevin_srk.py @@ -44,7 +44,13 @@ def _get_args_from_terms( terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]], -) -> tuple[PyTree, PyTree, Callable[[UnderdampedLangevinX], UnderdampedLangevinX]]: +) -> tuple[ + PyTree, + PyTree, + PyTree, + PyTree, + Callable[[UnderdampedLangevinX], UnderdampedLangevinX], +]: drift, diffusion = terms.terms if isinstance(drift, WrapTerm): assert isinstance(diffusion, WrapTerm) @@ -53,10 +59,12 @@ def _get_args_from_terms( assert isinstance(drift, UnderdampedLangevinDriftTerm) assert isinstance(diffusion, UnderdampedLangevinDiffusionTerm) - gamma = drift.gamma - u = drift.u + gamma_drift = drift.gamma + u_drift = drift.u f = drift.grad_f - return gamma, u, f + gamma_diff = diffusion.gamma + u_diff = diffusion.u + return gamma_drift, u_drift, gamma_diff, u_diff, f # CONCERNING COEFFICIENTS: @@ -248,23 +256,38 @@ def init( evaluation of grad_f. """ drift, diffusion = terms.terms - gamma, u, grad_f = _get_args_from_terms(terms) + gamma_drift, u_drift, gamma_diff, u_diff, grad_f = _get_args_from_terms(terms) h = drift.contr(t0, t1) x0, v0 = y0 - gamma = broadcast_underdamped_langevin_arg(gamma, x0, "gamma") - u = broadcast_underdamped_langevin_arg(u, x0, "u") + gamma = broadcast_underdamped_langevin_arg(gamma_drift, x0, "gamma") + u = broadcast_underdamped_langevin_arg(u_drift, x0, "u") + + # Check that drift and diffusion have the same arguments + gamma_diff = broadcast_underdamped_langevin_arg(gamma_diff, x0, "gamma") + u_diff = broadcast_underdamped_langevin_arg(u_diff, x0, "u") + + def compare_args_fun(arg1, arg2): + arg = eqx.error_if( + arg1, + jnp.any(arg1 != arg2), + "The arguments of the drift and diffusion terms must match.", + ) + return arg + + gamma = jtu.tree_map(compare_args_fun, gamma, gamma_diff) + u = jtu.tree_map(compare_args_fun, u, u_diff) try: grad_f_shape = jax.eval_shape(grad_f, x0) except ValueError: raise UnderdampedLangevinStructureError("grad_f") - def _shape_check_fun(_x, _g, _u, _fx): + def shape_check_fun(_x, _g, _u, _fx): return _x.shape == _g.shape == _u.shape == _fx.shape - if not jtu.tree_all(jtu.tree_map(_shape_check_fun, x0, gamma, u, grad_f_shape)): + if not jtu.tree_all(jtu.tree_map(shape_check_fun, x0, gamma, u, grad_f_shape)): raise UnderdampedLangevinStructureError(None) tay_coeffs = jtu.tree_map(self._tay_coeffs_single, gamma) @@ -342,7 +365,7 @@ def step( old_coeffs: _Coeffs = st.coeffs gamma, u, rho = st.gamma, st.u, st.rho - _, _, grad_f = _get_args_from_terms(terms) + _, _, _, _, grad_f = _get_args_from_terms(terms) # If h changed, recompute coefficients # Even when using constant step sizes, h can fluctuate by small amounts, diff --git a/test/test_underdamped_langevin.py b/test/test_underdamped_langevin.py index 68d3fed9..9c2f8cb3 100644 --- a/test/test_underdamped_langevin.py +++ b/test/test_underdamped_langevin.py @@ -3,6 +3,7 @@ import jax.numpy as jnp import jax.random as jr import jax.tree_util as jtu +import jaxlib.xla_extension import pytest from diffrax import diffeqsolve, make_underdamped_langevin_term, SaveAt @@ -238,3 +239,46 @@ def test_reverse_solve(solver_cls): error = path_l2_dist(sol.ys, ref_sol.ys) assert error < 0.1 + + +# Here we check that if the drift and diffusion term have different arguments, +# an error is thrown. +def test_different_args(): + x0 = (jnp.ones(2), jnp.zeros(2)) + v0 = (jnp.zeros(2), jnp.zeros(2)) + y0 = (x0, v0) + g1 = (jnp.array([1, 2]), jnp.array([1, 2])) + u1 = (jnp.array([1, 2]), 1) + g2 = (jnp.array([1, 2]), jnp.array([1, 3])) + u2 = (jnp.array([1, 2]), jnp.ones((2,))) + grad_f = lambda x: x + + w_shape = ( + jax.ShapeDtypeStruct((2,), jnp.float64), + jax.ShapeDtypeStruct((2,), jnp.float64), + ) + bm = diffrax.VirtualBrownianTree( + 0, + 1, + tol=0.05, + shape=w_shape, + key=jr.key(0), + levy_area=diffrax.SpaceTimeTimeLevyArea, + ) + + drift_term = diffrax.UnderdampedLangevinDriftTerm(g1, u1, grad_f) + + # This one should fail + diffusion_term_a = diffrax.UnderdampedLangevinDiffusionTerm(g2, u1, bm) + terms_a = diffrax.MultiTerm(drift_term, diffusion_term_a) + + # This one should not fail + diffusion_term_b = diffrax.UnderdampedLangevinDiffusionTerm(g1, u2, bm) + terms_b = diffrax.MultiTerm(drift_term, diffusion_term_b) + + solver = diffrax.ShOULD(0.01) + try: + diffeqsolve(terms_a, solver, 0, 1, 0.1, y0, args=None) + except jaxlib.xla_extension.XlaRuntimeError: + pass + diffeqsolve(terms_b, solver, 0, 1, 0.1, y0, args=None)