Skip to content

Commit

Permalink
check langevin drift term and diffusion term have same args
Browse files Browse the repository at this point in the history
  • Loading branch information
andyElking committed Sep 1, 2024
1 parent d5c4e59 commit 8e8e454
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 10 deletions.
43 changes: 33 additions & 10 deletions diffrax/_solver/foster_langevin_srk.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@

def _get_args_from_terms(
terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]],
) -> tuple[PyTree, PyTree, Callable[[UnderdampedLangevinX], UnderdampedLangevinX]]:
) -> tuple[
PyTree,
PyTree,
PyTree,
PyTree,
Callable[[UnderdampedLangevinX], UnderdampedLangevinX],
]:
drift, diffusion = terms.terms
if isinstance(drift, WrapTerm):
assert isinstance(diffusion, WrapTerm)
Expand All @@ -53,10 +59,12 @@ def _get_args_from_terms(

assert isinstance(drift, UnderdampedLangevinDriftTerm)
assert isinstance(diffusion, UnderdampedLangevinDiffusionTerm)
gamma = drift.gamma
u = drift.u
gamma_drift = drift.gamma
u_drift = drift.u
f = drift.grad_f
return gamma, u, f
gamma_diff = diffusion.gamma
u_diff = diffusion.u
return gamma_drift, u_drift, gamma_diff, u_diff, f


# CONCERNING COEFFICIENTS:
Expand Down Expand Up @@ -248,23 +256,38 @@ def init(
evaluation of grad_f.
"""
drift, diffusion = terms.terms
gamma, u, grad_f = _get_args_from_terms(terms)
gamma_drift, u_drift, gamma_diff, u_diff, grad_f = _get_args_from_terms(terms)

h = drift.contr(t0, t1)
x0, v0 = y0

gamma = broadcast_underdamped_langevin_arg(gamma, x0, "gamma")
u = broadcast_underdamped_langevin_arg(u, x0, "u")
gamma = broadcast_underdamped_langevin_arg(gamma_drift, x0, "gamma")
u = broadcast_underdamped_langevin_arg(u_drift, x0, "u")

# Check that drift and diffusion have the same arguments
gamma_diff = broadcast_underdamped_langevin_arg(gamma_diff, x0, "gamma")
u_diff = broadcast_underdamped_langevin_arg(u_diff, x0, "u")

def compare_args_fun(arg1, arg2):
arg = eqx.error_if(
arg1,
jnp.any(arg1 != arg2),
"The arguments of the drift and diffusion terms must match.",
)
return arg

gamma = jtu.tree_map(compare_args_fun, gamma, gamma_diff)
u = jtu.tree_map(compare_args_fun, u, u_diff)

try:
grad_f_shape = jax.eval_shape(grad_f, x0)
except ValueError:
raise UnderdampedLangevinStructureError("grad_f")

def _shape_check_fun(_x, _g, _u, _fx):
def shape_check_fun(_x, _g, _u, _fx):
return _x.shape == _g.shape == _u.shape == _fx.shape

if not jtu.tree_all(jtu.tree_map(_shape_check_fun, x0, gamma, u, grad_f_shape)):
if not jtu.tree_all(jtu.tree_map(shape_check_fun, x0, gamma, u, grad_f_shape)):
raise UnderdampedLangevinStructureError(None)

tay_coeffs = jtu.tree_map(self._tay_coeffs_single, gamma)
Expand Down Expand Up @@ -342,7 +365,7 @@ def step(
old_coeffs: _Coeffs = st.coeffs

gamma, u, rho = st.gamma, st.u, st.rho
_, _, grad_f = _get_args_from_terms(terms)
_, _, _, _, grad_f = _get_args_from_terms(terms)

# If h changed, recompute coefficients
# Even when using constant step sizes, h can fluctuate by small amounts,
Expand Down
44 changes: 44 additions & 0 deletions test/test_underdamped_langevin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import jaxlib.xla_extension
import pytest
from diffrax import diffeqsolve, make_underdamped_langevin_term, SaveAt

Expand Down Expand Up @@ -238,3 +239,46 @@ def test_reverse_solve(solver_cls):

error = path_l2_dist(sol.ys, ref_sol.ys)
assert error < 0.1


# Here we check that if the drift and diffusion term have different arguments,
# an error is thrown.
def test_different_args():
x0 = (jnp.ones(2), jnp.zeros(2))
v0 = (jnp.zeros(2), jnp.zeros(2))
y0 = (x0, v0)
g1 = (jnp.array([1, 2]), jnp.array([1, 2]))
u1 = (jnp.array([1, 2]), 1)
g2 = (jnp.array([1, 2]), jnp.array([1, 3]))
u2 = (jnp.array([1, 2]), jnp.ones((2,)))
grad_f = lambda x: x

w_shape = (
jax.ShapeDtypeStruct((2,), jnp.float64),
jax.ShapeDtypeStruct((2,), jnp.float64),
)
bm = diffrax.VirtualBrownianTree(
0,
1,
tol=0.05,
shape=w_shape,
key=jr.key(0),
levy_area=diffrax.SpaceTimeTimeLevyArea,
)

drift_term = diffrax.UnderdampedLangevinDriftTerm(g1, u1, grad_f)

# This one should fail
diffusion_term_a = diffrax.UnderdampedLangevinDiffusionTerm(g2, u1, bm)
terms_a = diffrax.MultiTerm(drift_term, diffusion_term_a)

# This one should not fail
diffusion_term_b = diffrax.UnderdampedLangevinDiffusionTerm(g1, u2, bm)
terms_b = diffrax.MultiTerm(drift_term, diffusion_term_b)

solver = diffrax.ShOULD(0.01)
try:
diffeqsolve(terms_a, solver, 0, 1, 0.1, y0, args=None)
except jaxlib.xla_extension.XlaRuntimeError:
pass
diffeqsolve(terms_b, solver, 0, 1, 0.1, y0, args=None)

0 comments on commit 8e8e454

Please sign in to comment.