Skip to content

Commit

Permalink
add foster
Browse files Browse the repository at this point in the history
  • Loading branch information
lockwo committed Aug 23, 2024
1 parent 68b750b commit 8f4b4cc
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 16 deletions.
3 changes: 2 additions & 1 deletion diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
AbstractSpaceTimeTimeLevyArea as AbstractSpaceTimeTimeLevyArea,
AbstractWeakSpaceSpaceLevyArea as AbstractWeakSpaceSpaceLevyArea,
BrownianIncrement as BrownianIncrement,
DavieFosterWeakSpaceSpaceLevyArea as DavieFosterWeakSpaceSpaceLevyArea,
DavieWeakSpaceSpaceLevyArea as DavieWeakSpaceSpaceLevyArea,
SpaceTimeLevyArea as SpaceTimeLevyArea,
SpaceTimeTimeLevyArea as SpaceTimeTimeLevyArea,
WeakSpaceSpaceLevyArea as WeakSpaceSpaceLevyArea,
)
from ._event import (
# Deliberately not provided with `X as X` as these are now deprecated, so we'd like
Expand Down
64 changes: 52 additions & 12 deletions diffrax/_brownian/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from .._custom_types import (
AbstractBrownianIncrement,
BrownianIncrement,
DavieFosterWeakSpaceSpaceLevyArea,
DavieWeakSpaceSpaceLevyArea,
levy_tree_transpose,
RealScalarLike,
SpaceTimeLevyArea,
SpaceTimeTimeLevyArea,
WeakSpaceSpaceLevyArea,
)
from .._misc import (
force_bitcast_convert_type,
Expand All @@ -29,7 +30,11 @@


_Levy_Areas = Union[
BrownianIncrement, SpaceTimeLevyArea, SpaceTimeTimeLevyArea, WeakSpaceSpaceLevyArea
BrownianIncrement,
SpaceTimeLevyArea,
SpaceTimeTimeLevyArea,
DavieWeakSpaceSpaceLevyArea,
DavieFosterWeakSpaceSpaceLevyArea,
]


Expand Down Expand Up @@ -158,21 +163,56 @@ def _evaluate_leaf(
kk = jr.normal(key_kk, shape.shape, shape.dtype) * kk_std
levy_val = SpaceTimeTimeLevyArea(dt=dt, W=w, H=hh, K=kk)

elif levy_area is WeakSpaceSpaceLevyArea:
elif levy_area is DavieWeakSpaceSpaceLevyArea:
key_w, key_hh, key_b = jr.split(key, 3)
w = jr.normal(key_w, shape.shape, shape.dtype) * w_std
hh_std = w_std / math.sqrt(12)
hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std
if w.ndim == 0 or w.ndim == 1:
a = jnp.zeros_like(w, dtype=shape.dtype)
levy_val = DavieWeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a)
else:
b_std = (dt / jnp.sqrt(12)).astype(shape.dtype)
b = (
jr.normal(key_b, shape.shape + shape.shape[-1:], shape.dtype)
* b_std
)
b = b - b.transpose(*range(b.ndim - 2), -1, -2)
a = jnp.expand_dims(hh, -1) * jnp.expand_dims(w, -2) - jnp.expand_dims(
w, -1
) * jnp.expand_dims(hh, -2)
a += b
levy_val = DavieWeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a)

elif levy_area is DavieFosterWeakSpaceSpaceLevyArea:
key_w, key_hh, key_b = jr.split(key, 3)
w = jr.normal(key_w, shape.shape, shape.dtype) * w_std
hh_std = w_std / math.sqrt(12)
hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std
b_std = dt / jnp.sqrt(12)
b = jr.normal(key_b, shape.shape + shape.shape, shape.dtype) * b_std
if b.ndim == 0 or b.size == 1:
b = jnp.zeros(shape=shape.shape + shape.shape, dtype=shape.dtype)
if w.ndim == 0 or w.ndim == 1:
a = jnp.zeros_like(w, dtype=shape.dtype)
levy_val = DavieFosterWeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a)
else:
# TODO: generalize to tensors?
assert b.ndim == 2
b = jnp.tril(b) - jnp.tril(b).T
a = jnp.tensordot(hh, w, axes=0) - jnp.tensordot(w, hh, axes=0) + b
levy_val = WeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a)
tenth_dt = (0.1 * dt).astype(shape.dtype)
hh_squared = hh**2
b_std = jnp.sqrt(
tenth_dt
* (
tenth_dt
+ jnp.expand_dims(hh_squared, -1)
+ jnp.expand_dims(hh_squared, -2)
)
).astype(shape.dtype)
b = (
jr.normal(key_b, shape.shape + shape.shape[-1:], shape.dtype)
* b_std
)
b = b - b.transpose(*range(b.ndim - 2), -1, -2)
a = jnp.expand_dims(hh, -1) * jnp.expand_dims(w, -2) - jnp.expand_dims(
w, -1
) * jnp.expand_dims(hh, -2)
a += b
levy_val = DavieFosterWeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a)

elif levy_area is SpaceTimeLevyArea:
key_w, key_hh = jr.split(key, 2)
Expand Down
14 changes: 13 additions & 1 deletion diffrax/_custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class AbstractWeakSpaceSpaceLevyArea(AbstractBrownianIncrement):
A: eqx.AbstractVar[BM]


class WeakSpaceSpaceLevyArea(AbstractWeakSpaceSpaceLevyArea):
class DavieWeakSpaceSpaceLevyArea(AbstractWeakSpaceSpaceLevyArea):
"""
Davie's approximation to weak Space Space Levy Areas.
See (7.4.1) of Foster's thesis.
Expand All @@ -94,6 +94,18 @@ class WeakSpaceSpaceLevyArea(AbstractWeakSpaceSpaceLevyArea):
A: Area


class DavieFosterWeakSpaceSpaceLevyArea(AbstractWeakSpaceSpaceLevyArea):
"""
Davie's approximation to weak Space Space Levy Areas.
See (7.4.2) of Foster's thesis.
"""

dt: PyTree[FloatScalarLike, "BM"]
W: BM
H: BM
A: Area


class AbstractSpaceTimeTimeLevyArea(AbstractSpaceTimeLevyArea):
"""
Abstract base class for all Space Time Time Levy Areas.
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,7 @@ def _promote(yi):
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
# Specific check to not work even if using HalfSolver(Euler())
if isinstance(solver, Euler):
raise ValueError(
warnings.warn(
"An SDE should not be solved with adaptive step sizes with Euler's "
"method, as it may not converge to the correct solution."
)
Expand Down
16 changes: 15 additions & 1 deletion test/test_brownian.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import jax


jax.config.update("jax_enable_x64", True)
import contextlib
import math
from typing import Literal
Expand Down Expand Up @@ -36,12 +40,22 @@ def _make_struct(shape, dtype):
@pytest.mark.parametrize(
"ctr", [diffrax.UnsafeBrownianPath, diffrax.VirtualBrownianTree]
)
@pytest.mark.parametrize("levy_area", _levy_areas)
@pytest.mark.parametrize(
"levy_area",
_levy_areas
+ (diffrax.DavieWeakSpaceSpaceLevyArea, diffrax.DavieFosterWeakSpaceSpaceLevyArea),
)
@pytest.mark.parametrize("use_levy", (False, True))
def test_shape_and_dtype(ctr, levy_area, use_levy, getkey):
t0 = 0.0
t1 = 2.0

if (
issubclass(levy_area, diffrax.AbstractWeakSpaceSpaceLevyArea)
and ctr is diffrax.VirtualBrownianTree
):
return

shapes_dtypes1 = (
((), None),
((0,), None),
Expand Down

0 comments on commit 8f4b4cc

Please sign in to comment.