diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 00000000..13566b81
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/diffrax_STLA.iml b/.idea/diffrax_STLA.iml
new file mode 100644
index 00000000..76f8ed3e
--- /dev/null
+++ b/.idea/diffrax_STLA.iml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 00000000..105ce2da
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 00000000..928ba795
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 00000000..9a53c7d2
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 00000000..35eb1ddf
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/diffrax/brownian/base.py b/diffrax/brownian/base.py
index ee90ba60..4d887696 100644
--- a/diffrax/brownian/base.py
+++ b/diffrax/brownian/base.py
@@ -1,14 +1,21 @@
import abc
+from typing import Optional, Union
-from ..custom_types import Array, PyTree, Scalar
+from ..custom_types import Array, LevyVal, PyTree, Scalar
from ..path import AbstractPath
class AbstractBrownianPath(AbstractPath):
- "Abstract base class for all Brownian paths."
+ """Abstract base class for all Brownian paths."""
@abc.abstractmethod
- def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]:
+ def evaluate(
+ self,
+ t0: Scalar,
+ t1: Optional[Scalar] = None,
+ left: bool = True,
+ use_levy: bool = False,
+ ) -> Union[PyTree[Array], LevyVal]:
r"""Samples a Brownian increment $w(t_1) - w(t_0)$.
Each increment has distribution $\mathcal{N}(0, t_1 - t_0)$.
@@ -20,6 +27,8 @@ def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]:
- `left`: Ignored. (This determines whether to treat the path as
left-continuous or right-continuous at any jump points, but Brownian
motion has no jump points.)
+ - `use_levy`: If True, the return type will be a `LevyVal`, which contains
+ PyTrees of Brownian increments and their Levy areas.
**Returns:**
diff --git a/diffrax/brownian/path.py b/diffrax/brownian/path.py
index a844450a..215f18e3 100644
--- a/diffrax/brownian/path.py
+++ b/diffrax/brownian/path.py
@@ -1,4 +1,4 @@
-from typing import Tuple, Union
+from typing import Literal, Tuple, Union
import equinox as eqx
import equinox.internal as eqxi
@@ -7,7 +7,7 @@
import jax.random as jrandom
import jax.tree_util as jtu
-from ..custom_types import Array, PyTree, Scalar
+from ..custom_types import Array, levy_tree_transpose, LevyVal, PyTree, Scalar
from ..misc import force_bitcast_convert_type, is_tuple_of_ints, split_by_tree
from .base import AbstractBrownianPath
@@ -30,9 +30,14 @@ class UnsafeBrownianPath(AbstractBrownianPath):
interval, ignoring the correlation between samples exhibited in true Brownian
motion. Hence the restrictions above. (They describe the general case for which the
correlation structure isn't needed.)
+
+ Depending on the `levy_area` argument, this can also be used to generate Levy area.
+ `levy_area` can be "" or "space-time".
+
"""
shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
+ levy_area: Literal["", "space-time"] = eqx.field(static=True)
# Handled as a string because PRNGKey is actually a function, not a class, which
# makes it appearly badly in autogenerated documentation.
key: "jax.random.PRNGKey" # noqa: F821
@@ -41,6 +46,7 @@ def __init__(
self,
shape: Union[Tuple[int, ...], PyTree[jax.ShapeDtypeStruct]],
key: "jax.random.PRNGKey",
+ levy_area: Literal["", "space-time"] = "",
):
self.shape = (
jax.ShapeDtypeStruct(shape, jax.dtypes.canonicalize_dtype(None))
@@ -48,6 +54,12 @@ def __init__(
else shape
)
self.key = key
+ if levy_area not in ["", "space-time"]:
+ raise ValueError(
+ f"levy_area must be one of '', 'space-time', but got {levy_area}."
+ )
+ self.levy_area = levy_area
+
if any(
not jnp.issubdtype(x.dtype, jnp.inexact)
for x in jtu.tree_leaves(self.shape)
@@ -63,7 +75,13 @@ def t1(self):
return None
@eqx.filter_jit
- def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]:
+ def evaluate(
+ self,
+ t0: Scalar,
+ t1: Scalar,
+ left: bool = True,
+ use_levy: bool = False,
+ ) -> PyTree[Array]:
del left
t0 = eqxi.nondifferentiable(t0, name="t0")
t1 = eqxi.nondifferentiable(t1, name="t1")
@@ -72,14 +90,42 @@ def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]:
key = jrandom.fold_in(self.key, t0_)
key = jrandom.fold_in(key, t1_)
key = split_by_tree(key, self.shape)
- return jtu.tree_map(
- lambda key, shape: self._evaluate_leaf(t0, t1, key, shape), key, self.shape
+ out = jtu.tree_map(
+ lambda key, shape: self._evaluate_leaf(
+ t0, t1, key, shape, self.levy_area, use_levy
+ ),
+ key,
+ self.shape,
)
+ if use_levy:
+ out = levy_tree_transpose(self.shape, self.levy_area, out)
+ assert isinstance(out, LevyVal)
+ return out
+
+ @staticmethod
+ def _evaluate_leaf(
+ t0: Scalar,
+ t1: Scalar,
+ key,
+ shape: jax.ShapeDtypeStruct,
+ levy_area: str,
+ use_levy: bool,
+ ):
+ w_std = jnp.sqrt(t1 - t0).astype(shape.dtype)
- def _evaluate_leaf(self, t0: Scalar, t1: Scalar, key, shape: jax.ShapeDtypeStruct):
- return jrandom.normal(key, shape.shape, shape.dtype) * jnp.sqrt(t1 - t0).astype(
- shape.dtype
- )
+ if levy_area == "space-time":
+ key_w, key_hh = jrandom.split(key, 2)
+ w = jrandom.normal(key_w, shape.shape, shape.dtype) * w_std
+ hh_std = jnp.sqrt((t1 - t0) / 12).astype(shape.dtype)
+ hh = jrandom.normal(key_hh, shape.shape, shape.dtype) * hh_std
+ else:
+ hh = None
+ w = jrandom.normal(key, shape.shape, shape.dtype) * w_std
+
+ if use_levy:
+ return LevyVal(dt=t1 - t0, W=w, H=hh, bar_H=None, K=None, bar_K=None)
+ else:
+ return w
UnsafeBrownianPath.__init__.__doc__ = """
@@ -89,5 +135,6 @@ def _evaluate_leaf(self, t0: Scalar, t1: Scalar, key, shape: jax.ShapeDtypeStruc
dtype, and PyTree structure of the output. For simplicity, `shape` can also just
be a tuple of integers, describing the shape of a single JAX array. In that case
the dtype is chosen to be the default floating-point dtype.
+
- `key`: A random key.
"""
diff --git a/diffrax/brownian/tree.py b/diffrax/brownian/tree.py
index 577092ce..46b1fbce 100644
--- a/diffrax/brownian/tree.py
+++ b/diffrax/brownian/tree.py
@@ -1,5 +1,5 @@
from dataclasses import field
-from typing import Optional, Tuple, Union
+from typing import Literal, Optional, Tuple, Union
import equinox as eqx
import equinox.internal as eqxi
@@ -9,8 +9,8 @@
import jax.random as jrandom
import jax.tree_util as jtu
-from ..custom_types import Array, PyTree, Scalar
-from ..misc import is_tuple_of_ints, split_by_tree
+from ..custom_types import levy_tree_transpose, LevyVal, PyTree, Scalar
+from ..misc import is_tuple_of_ints, linear_rescale, split_by_tree
from .base import AbstractBrownianPath
@@ -25,21 +25,75 @@
# }
#
+# We define
+# H_{s,t} = 1/(t-s) ( \int_s^t ( W_u - (u-s)/(t-s) W_{s,t} ) du ).
+# bhh_t = t * H_{0,t}
+# For more details see Definition 4.2.1 and Theorem 6.1.4 of
+#
+# Foster, J. M. (2020). Numerical approximations for stochastic
+# differential equations [PhD thesis]. University of Oxford.
+
class _State(eqx.Module):
- s: Scalar
- t: Scalar
- u: Scalar
- w_s: Scalar
- w_t: Scalar
- w_u: Scalar
+ level: int
+ stu: tuple[Scalar, Scalar, Scalar] # s, t, u
+ w_stu: tuple[Scalar, Scalar, Scalar] # W at times s, t, u
+ w_st_tu: tuple[Scalar, Scalar] # W_{s,t} and W_{t,u}
key: "jax.random.PRNGKey"
+ bhh_stu: Optional[Tuple[Scalar, Scalar, Scalar]] # \bar{H} at times s, t, u
+ bhh_st_tu: Optional[Tuple[Scalar, Scalar]] # \bar{H}_{s,t} and \bar{H}_{t,u}
+ bkk_stu: Optional[Tuple[Scalar, Scalar, Scalar]] # \bar{K} at times s, t, u
+ bkk_st_tu: Optional[Tuple[Scalar, Scalar]] # \bar{K}_{s,t} and \bar{K}_{t,u}
+
+
+def _levy_diff(x0: LevyVal, x1: LevyVal) -> LevyVal:
+ r"""Computes $(W_{s,u}, H_{s,u})$ from $(W_s, \bar{H}_{s,u})$ and
+ $(W_u, \bar{H}_u)$, where $\bar{H}_u = u * H_u$.
+
+ **Arguments:**
+
+ - `x0`: `LevyVal` at time `s`
+
+ - `x1`: `LevyVal` at time `u`
+
+ **Returns:**
+
+ `LevyVal(W_su, H_su)`
+ """
+
+ su = (x1.dt - x0.dt).astype(x0.W.dtype)
+ w_su = x1.W - x0.W
+ if x0.H is None or x1.H is None: # BM only case
+ return LevyVal(dt=su, W=w_su, H=None, bar_H=None, K=None, bar_K=None)
+
+ # levy_area == "space-time"
+ _su = jnp.where(jnp.abs(su) < jnp.finfo(su).eps, jnp.inf, su)
+ inverse_su = 1 / _su
+ u_bb_s = x1.dt * x0.W - x0.dt * x1.W
+ bhh_su = x1.bar_H - x0.bar_H - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s)
+ hh_su = inverse_su * bhh_su
+
+ return LevyVal(dt=su, W=w_su, H=hh_su, bar_H=None, K=None, bar_K=None)
+
+
+def split_interval(_cond, x_stu, x_st_tu):
+ x_s, x_t, x_u = x_stu
+ x_st, x_tu = x_st_tu
+ x_s = jnp.where(_cond, x_t, x_s)
+ x_u = jnp.where(_cond, x_u, x_t)
+ x_su = jnp.where(_cond, x_tu, x_st)
+ return x_s, x_u, x_su
class VirtualBrownianTree(AbstractBrownianPath):
"""Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`,
and is piecewise quadratic at that discretisation.
+ Can be initialised with `levy_area` set to `""`, or `"space-time"`.
+ If `levy_area=="space_time"`, then it also computes space-time Lévy area `H`.
+ This will impact the Brownian path, so even with the same key, the trajectory will
+ be different depending on the value of `levy_area`.
+
??? cite "Reference"
```bibtex
@@ -52,15 +106,15 @@ class VirtualBrownianTree(AbstractBrownianPath):
}
```
- (The implementation here is a slight improvement on the reference
- implementation, by being piecwise quadratic rather than piecewise linear. This
- corrects a small bias in the generated samples.)
+ (The implementation here is a slight improvement on the reference implementation
+ by using an interpolation method which ensures all the 2nd moments are correct.)
"""
t0: Scalar = field(init=True)
t1: Scalar = field(init=True) # override init=False in AbstractPath
tol: Scalar
shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
+ levy_area: Literal["", "space-time"] = eqx.field(static=True)
key: "jax.random.PRNGKey" # noqa: F821
def __init__(
@@ -70,10 +124,20 @@ def __init__(
tol: Scalar,
shape: Union[Tuple[int, ...], PyTree[jax.ShapeDtypeStruct]],
key: "jax.random.PRNGKey",
+ levy_area: Literal["", "space-time"] = "",
):
+ (t0, t1) = eqx.error_if((t0, t1), t0 >= t1, "t0 must be strictly less than t1")
self.t0 = t0
self.t1 = t1
- self.tol = tol
+ # Since we rescale the interval to [0,1],
+ # we need to rescale the tolerance too.
+ self.tol = tol / (self.t1 - self.t0)
+
+ if levy_area not in ["", "space-time"]:
+ raise ValueError(
+ f"levy_area must be one of '', 'space-time', but got {levy_area}."
+ )
+ self.levy_area = levy_area
self.shape = (
jax.ShapeDtypeStruct(shape, jax.dtypes.canonicalize_dtype(None))
if is_tuple_of_ints(shape)
@@ -88,70 +152,224 @@ def __init__(
)
self.key = split_by_tree(key, self.shape)
+ def _denormalise_bm_inc(self, x: LevyVal) -> LevyVal:
+ # TODO: demonstrate rescaling actually helps
+
+ # Rescaling back from [0, 1] to the original interval [t0, t1].
+ interval_len = self.t1 - self.t0 # can be any dtype
+ sqrt_len = jnp.sqrt(interval_len)
+
+ def sqrt_mult(z):
+ # need to cast to dtype of each leaf in PyTree
+ dtype = jnp.dtype(z)
+ return (z * sqrt_len).astype(dtype)
+
+ def mult(z):
+ dtype = jnp.dtype(z)
+ return (interval_len * z).astype(dtype)
+
+ return LevyVal(
+ dt=jtu.tree_map(mult, x.dt),
+ W=jtu.tree_map(sqrt_mult, x.W),
+ H=jtu.tree_map(sqrt_mult, x.H),
+ bar_H=None,
+ K=jtu.tree_map(sqrt_mult, x.K),
+ bar_K=None,
+ )
+
@eqx.filter_jit
def evaluate(
- self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True
- ) -> PyTree[Array]:
- del left
+ self,
+ t0: Scalar,
+ t1: Optional[Scalar] = None,
+ left: bool = True,
+ use_levy: bool = False,
+ ) -> LevyVal:
+ def _is_levy_val(obj):
+ return isinstance(obj, LevyVal)
+
t0 = eqxi.nondifferentiable(t0, name="t0")
+ # map the interval [self.t0, self.t1] onto [0,1]
+ t0 = linear_rescale(self.t0, t0, self.t1)
+ levy_0 = self._evaluate(t0)
if t1 is None:
- return self._evaluate(t0)
+ levy_out = levy_0
+
else:
t1 = eqxi.nondifferentiable(t1, name="t1")
- return jtu.tree_map(
- lambda x, y: x - y,
- self._evaluate(t1),
- self._evaluate(t0),
- )
+ # map the interval [self.t0, self.t1] onto [0,1]
+ t1 = linear_rescale(self.t0, t1, self.t1)
+ levy_1 = self._evaluate(t1)
+
+ levy_out = jtu.tree_map(_levy_diff, levy_0, levy_1, is_leaf=_is_levy_val)
- def _evaluate(self, τ: Scalar) -> PyTree[Array]:
- map_func = lambda key, shape: self._evaluate_leaf(key, τ, shape)
+ levy_out = levy_tree_transpose(self.shape, self.levy_area, levy_out)
+ # now map [0,1] back onto [self.t0, self.t1]
+ levy_out = self._denormalise_bm_inc(levy_out)
+ assert isinstance(levy_out, LevyVal)
+ return levy_out if use_levy else levy_out.W
+
+ def _evaluate(self, r: Scalar) -> PyTree[LevyVal]:
+ """Maps the _evaluate_leaf function at time τ using self.key onto self.shape"""
+ r = eqxi.error_if(
+ r,
+ r < 0,
+ "Cannot evaluate VirtualBrownianTree outside of its range [t0, t1].",
+ )
+ r = eqxi.error_if(
+ r,
+ r > 1,
+ "Cannot evaluate VirtualBrownianTree outside of its range [t0, t1].",
+ )
+ # Clip because otherwise the while loop below won't terminate, and the above
+ # errors are only raised after everything has finished executing.
+ r = jnp.clip(r, 0, 1)
+ map_func = lambda key, shape: self._evaluate_leaf(key, r, shape)
return jtu.tree_map(map_func, self.key, self.shape)
- def _brownian_bridge(self, s, t, u, w_s, w_u, key, shape, dtype):
- mean = w_s + (w_u - w_s) * ((t - s) / (u - s))
- var = (u - t) * (t - s) / (u - s)
- std = jnp.sqrt(var)
- return mean + std * jrandom.normal(key, shape, dtype)
+ def _brownian_arch(
+ self,
+ level: int,
+ s: Scalar,
+ u: Scalar,
+ w: Tuple[Scalar, Scalar, Scalar],
+ key,
+ shape,
+ dtype,
+ bhh: Optional[Tuple[Scalar, Scalar, Scalar]],
+ bkk: Optional[Tuple[Scalar, Scalar, Scalar]],
+ ):
+ r"""For `t = (s+u)/2` evaluates `w_t` and (optionally) `bhh_t`
+ conditioned on `w_s`, `w_u`, `bhh_s`, `bhh_u`, where
+ `bhh_st` represents $\bar{H}_{s,t} \coloneqq (t-s) H_{s,t}$.
+ To avoid cancellation errors, requires an input of `w_su`, `bhh_su`
+ and also returns `w_st` and `w_tu` in addition to just `w_t`. Same for `bhh`
+ if it is not None.
+ Note that the inputs and outputs already contain `bkk`. These values are
+ there for the sake of a future extension with "space-time-time" Levy area
+ and should be None for now.
+
+ **Arguments:**
+ - `s`: start time
+ - `u`: end time
+ - `w_s`: value of BM at s
+ - `w_u`: value of BM at u
+ - `w_su`: $W_{s,u}$
+ - `key`:
+ - `shape`:
+ - `dtype`:
+ - `bhh`: (optional) $(\bar{H}_s, \bar{H}_u, \bar{H}_{s,u})$
+ - `bkk`: (optional) $(\bar{K}_s, \bar{K}_u, \bar{K}_{s,u})$
+
+ **Returns:**
+ - `t`: midpoint time
+ - `w_t`: value of BM at t
+ - `w_st_tu`: $(W_{s,t}, W_{t,u})$
+ - `bhh`: (optional) $(\bar{H}_s, \bar{H}_t, \bar{H}_u)$
+ - `bhh_st_tu`: (optional) $(\bar{H}_{s,t}, \bar{H}_{t,u})$
+ - `bkk_t`: (optional) $(\bar{K}_s, \bar{K}_t, \bar{K}_u)$
+ - `bkk_st_tu`: (optional) $(\bar{K}_{s,t}, \bar{K}_{t,u})$
+
+ """
+
+ su = jnp.power(jnp.asarray(2.0, dtype), -level)
+ st = su / 2
+ t = s + st
+ u_minus_s = u - s
+ su = eqxi.error_if(
+ su,
+ jnp.abs(u_minus_s - su) > 1e-17,
+ "VirtualBrownianTree: u-s is not 2^(-tree_level)",
+ )
+ root_su = jnp.sqrt(su)
+
+ w_s, w_u, w_su = w
+
+ if self.levy_area == "space-time":
+ assert bhh is not None
+ assert bkk is None
+ bhh_s, bhh_u, bhh_su = bhh
+
+ z1_key, z2_key = jrandom.split(key, 2)
+ z1 = jrandom.normal(z1_key, shape, dtype)
+ z2 = jrandom.normal(z2_key, shape, dtype)
+ z = z1 * (root_su / 4)
+ n = z2 * jnp.sqrt(su / 12)
+
+ w_term1 = w_su / 2
+ w_term2 = 3 / (2 * su) * bhh_su + z
+ w_st = w_term1 + w_term2
+ w_tu = w_term1 - w_term2
+ w_st_tu = (w_st, w_tu)
+
+ bhh_term1 = bhh_su / 8 - su / 4 * z
+ bhh_term2 = su / 4 * n
+ bhh_st = bhh_term1 + bhh_term2
+ bhh_tu = bhh_term1 - bhh_term2
+ bhh_st_tu = (bhh_st, bhh_tu)
+
+ w_t = w_s + w_st
+ bhh_t = bhh_s + bhh_st + 0.5 * (t * w_s - s * w_t)
+ bhh = (bhh_s, bhh_t, bhh_u)
+ bkk = None
+ bkk_st_tu = None
+
+ else:
+ assert bhh is None
+ assert bkk is None
+ mean = 0.5 * w_su
+ w_term2 = root_su / 2 * jrandom.normal(key, shape, dtype)
+ w_st = mean + w_term2
+ w_tu = mean - w_term2
+ w_st_tu = (w_st, w_tu)
+ w_t = w_s + w_st
+ bhh, bhh_st_tu, bkk, bkk_st_tu = None, None, None, None
+ return t, w_t, w_st_tu, bhh, bhh_st_tu, bkk, bkk_st_tu
def _evaluate_leaf(
self,
key,
- τ: Scalar,
+ r: Scalar,
shape: jax.ShapeDtypeStruct,
- ) -> Array:
+ ) -> LevyVal:
shape, dtype = shape.shape, shape.dtype
- cond = self.t0 < self.t1
- t0 = jnp.where(cond, self.t0, self.t1).astype(dtype)
- t1 = jnp.where(cond, self.t1, self.t0).astype(dtype)
+ t0 = jnp.zeros((), dtype)
+ t1 = jnp.ones((), dtype)
+ r = jnp.asarray(r, dtype)
- t0 = eqxi.error_if(
- t0,
- τ < t0,
- "Cannot evaluate VirtualBrownianTree outside of its range [t0, t1].",
- )
- t1 = eqxi.error_if(
- t1,
- τ > t1,
- "Cannot evaluate VirtualBrownianTree outside of its range [t0, t1].",
- )
- # Clip because otherwise the while loop below won't terminate, and the above
- # errors are only raised after everything has finished executing.
- τ = jnp.clip(τ, t0, t1).astype(dtype)
+ w_0 = jnp.zeros(shape, dtype)
+
+ if self.levy_area == "space-time":
+ key, init_key_w, init_key_la, midpoint_key = jrandom.split(key, 4)
+ w_1 = jrandom.normal(init_key_w, shape, dtype)
- key, init_key = jrandom.split(key, 2)
- thalf = t0 + 0.5 * (t1 - t0)
- w_t1 = jrandom.normal(init_key, shape, dtype) * jnp.sqrt(t1 - t0)
- w_thalf = self._brownian_bridge(t0, thalf, t1, 0, w_t1, key, shape, dtype)
+ bhh_1 = jnp.sqrt(1 / 12) * jrandom.normal(init_key_la, shape, dtype)
+ bhh_0 = jnp.zeros_like(bhh_1)
+ bhh = (bhh_0, bhh_1, bhh_1)
+ bkk = None
+
+ else:
+ key, init_key_w, midpoint_key = jrandom.split(key, 3)
+ w_1 = jrandom.normal(init_key_w, shape, dtype)
+ bhh = None
+ bkk = None
+
+ w = (w_0, w_1, w_1)
+
+ half, w_half, w_inc, bhh, bhh_inc, bkk, bkk_inc = self._brownian_arch(
+ 0, t0, t1, w, key, shape, dtype, bhh, bkk
+ )
init_state = _State(
- s=t0,
- t=thalf,
- u=t1,
- w_s=jnp.zeros_like(w_t1),
- w_t=w_thalf,
- w_u=w_t1,
+ level=0,
+ stu=(t0, half, t1),
+ w_stu=(w_0, w_half, w_1),
+ w_st_tu=w_inc,
key=key,
+ bhh_stu=bhh,
+ bhh_st_tu=bhh_inc,
+ bkk_stu=bkk,
+ bkk_st_tu=bkk_inc,
)
def _cond_fun(_state):
@@ -161,67 +379,105 @@ def _cond_fun(_state):
# jnp.abs(τ - state.s) > self.tol
# Here, because we use quadratic splines to get better samples, we always
# iterate down to the level of the spline.
- return (_state.u - _state.s) > self.tol
+ _s, _t, _u = _state.stu
+ return (_u - _s) > 2 * self.tol
- def _body_fun(_state):
+ def _body_fun(_state: _State):
+ """Single-step of binary search for τ."""
+ _level = _state.level + 1
_key1, _key2 = jrandom.split(_state.key, 2)
- _cond = τ > _state.t
- _s = jnp.where(_cond, _state.t, _state.s)
- _u = jnp.where(_cond, _state.u, _state.t)
- _w_s = jnp.where(_cond, _state.w_t, _state.w_s)
- _w_u = jnp.where(_cond, _state.w_u, _state.w_t)
+ _s, _t, _u = _state.stu
+ _cond = r > _t
+ _s = jnp.where(_cond, _t, _s)
+ _u = jnp.where(_cond, _u, _t)
+
+ _w = split_interval(_cond, _state.w_stu, _state.w_st_tu)
+ if self.levy_area in ["space-time", "space-time-time"]:
+ _bhh = split_interval(_cond, _state.bhh_stu, _state.bhh_st_tu)
+ _bkk = None
+ else:
+ _bhh = None
+ _bkk = None
+
_key = jnp.where(_cond, _key1, _key2)
- _t = _s + 0.5 * (_u - _s)
- _w_t = self._brownian_bridge(_s, _t, _u, _w_s, _w_u, _key, shape, dtype)
- return _State(s=_s, t=_t, u=_u, w_s=_w_s, w_t=_w_t, w_u=_w_u, key=_key)
+ _key, _midpoint_key = jrandom.split(_key, 2)
+
+ _t, _w_t, _w_inc, _bhh, _bhh_inc, _bkk, _bkk_inc = self._brownian_arch(
+ _level, _s, _u, _w, _midpoint_key, shape, dtype, _bhh, _bkk
+ )
+
+ return _State(
+ level=_level,
+ stu=(_s, _t, _u),
+ w_stu=(_w[0], _w_t, _w[2]),
+ w_st_tu=_w_inc,
+ key=_key,
+ bhh_stu=_bhh,
+ bhh_st_tu=_bhh_inc,
+ bkk_stu=_bkk,
+ bkk_st_tu=_bkk_inc,
+ )
final_state = lax.while_loop(_cond_fun, _body_fun, init_state)
- # Quadratic interpolation.
- # We have w_s, w_t, w_u available to us, and interpolate with the unique
- # parabola passing through them.
- # Why quadratic and not just "linear from w_s to w_t to w_u"? Because linear
- # only gets the conditional mean correct at every point, but not the
- # conditional variance. This means that the Virtual Brownian Tree will pass
- # statistical tests comparing w(t)|(w(s),w(u)) against the true Brownian
- # bridge. (Provided s, t, u are greater than the discretisation level `tol`.)
- # (If you just do linear then you find that the variance is *ever so slightly*
- # too small.)
- s = final_state.s
- u = final_state.u
- w_s = final_state.w_s
- w_t = final_state.w_t
- w_u = final_state.w_u
- rescaled_τ = (τ - s) / (u - s)
- # Fit polynomial as usual.
- # The polynomial ax^2 + bx + c is found by solving
- # [s^2 s 1][a] [w_s]
- # [t^2 t 1][b] = [w_t]
- # [u^2 u 1][c] [w_u]
- #
- # `A` is the inverse of the above matrix, with s=0, t=0.5, u=1.
- A = jnp.array([[2, -4, 2], [-3, 4, -1], [1, 0, 0]])
- coeffs = jnp.tensordot(A, jnp.stack([w_s, w_t, w_u]), axes=1)
- return jnp.polyval(coeffs, rescaled_τ)
-
-
-VirtualBrownianTree.__init__.__doc__ = """
-**Arguments:**
-
-- `t0`: The start of the interval the Brownian motion is defined over.
-- `t1`: The start of the interval the Brownian motion is defined over.
-- `tol`: The discretisation that `[t0, t1]` is discretised to.
-- `shape`: Should be a PyTree of `jax.ShapeDtypeStruct`s, representing the shape,
- dtype, and PyTree structure of the output. For simplicity, `shape` can also just
- be a tuple of integers, describing the shape of a single JAX array. In that case
- the dtype is chosen to be the default floating-point dtype.
-- `key`: A random key.
-
-!!! info
-
- If using this as part of an SDE solver, and you know (or have an estimate of) the
- step sizes made in the solver, then you can optimise the computational efficiency
- of the Virtual Brownian Tree by setting `tol` to be just slightly smaller than the
- step size of the solver.
-
-The Brownian motion is defined to equal 0 at `t0`.
-"""
+
+ level = final_state.level + 1
+ s, t, u = final_state.stu
+
+ # Split the interval in half one last time depending on whether r < t or r > t
+ # but this time complete the step with the general interpolation, rather
+ # than the midpoint rule (as given by _brownian_arch).
+
+ cond = r > t
+ s = jnp.where(cond, t, s)
+ u = jnp.where(cond, u, t)
+ su = jnp.power(jnp.asarray(2.0, dtype), -level)
+ su = eqxi.error_if(
+ su,
+ jnp.abs(u - s - su) > 1e-17,
+ "VirtualBrownianTree: u-s is not 2^(-tree_level) in final step.",
+ )
+
+ sr = r - s
+ ru = u - r # make sure su = sr + ru regardless of cancellation error
+
+ w_s, w_u, w_su = split_interval(cond, final_state.w_stu, final_state.w_st_tu)
+ key1, key2 = jrandom.split(final_state.key, 2)
+ key = jnp.where(cond, key1, key2)
+
+ # BM only case
+ if self.levy_area == "":
+ z = jrandom.normal(key, shape, dtype)
+ w_sr = sr / su * w_su + jnp.sqrt(sr * ru / su) * z
+ w_r = w_s + w_sr
+ return LevyVal(dt=r, W=w_r, H=None, bar_H=None, K=None, bar_K=None)
+
+ elif self.levy_area == "space-time":
+ bhh_s, bhh_u, bhh_su = split_interval(
+ cond, final_state.bhh_stu, final_state.bhh_st_tu
+ )
+ sr3 = jnp.power(sr, 3)
+ ru3 = jnp.power(ru, 3)
+ su3 = jnp.power(su, 3)
+ key1, key2 = jrandom.split(key, 2)
+ x1 = jrandom.normal(key1, shape, dtype)
+ x2 = jrandom.normal(key2, shape, dtype)
+
+ sr_ru_half = jnp.sqrt(sr * ru)
+ d = jnp.sqrt(sr3 + ru3)
+ d_prime = 1 / (2 * su * d)
+ a = d_prime * sr3 * sr_ru_half
+ b = d_prime * ru3 * sr_ru_half
+
+ w_sr = sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b) / su * x1
+ w_r = w_s + w_sr
+ c = jnp.sqrt(3 * sr3 * ru3) / (6 * d)
+ bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2
+ bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r)
+
+ inverse_r = 1 / jnp.where(jnp.abs(r) < jnp.finfo(r).eps, jnp.inf, r)
+ hh_r = inverse_r * bhh_r
+
+ else:
+ assert False
+
+ return LevyVal(dt=r, W=w_r, H=hh_r, bar_H=bhh_r, K=None, bar_K=None)
diff --git a/diffrax/custom_types.py b/diffrax/custom_types.py
index 93e818b5..ff1a1fe0 100644
--- a/diffrax/custom_types.py
+++ b/diffrax/custom_types.py
@@ -1,8 +1,10 @@
import inspect
import typing
-from typing import Any, Dict, Generic, Tuple, TypeVar, Union
+from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Union
+import equinox as eqx
import equinox.internal as eqxi
+import jax
import jax.tree_util as jtu
@@ -131,3 +133,52 @@ def __class_getitem__(cls, item):
DenseInfo = Dict[str, PyTree[Array]]
DenseInfos = Dict[str, PyTree[Array["times", ...]]] # noqa: F821
sentinel: Any = eqxi.doc_repr(object(), "sentinel")
+
+
+class LevyVal(eqx.Module):
+ dt: Scalar
+ W: PyTree[Array]
+ H: Optional[PyTree[Array]]
+ bar_H: Optional[PyTree[Array]]
+ K: Optional[PyTree[Array]]
+ bar_K: Optional[PyTree[Array]]
+
+
+def levy_tree_transpose(tree_shape, levy_area, tree):
+ """Helper that takes a PyTree of LevyVals and transposes
+ into a LevyVal of PyTrees.
+
+ **Arguments:**
+ - `tree_shape`: Corresponds to `outer_treedef` in `jax.tree_transpose`.
+
+ - `levy_area`: can be "", "space-time" or "space-time-time", which indicates
+ which fields of the LevyVal will have values.
+
+ - `tree`: the PyTree of LevyVals to transpose.
+
+ **Returns:**
+ A `LevyVal` of PyTrees.
+ """
+ if levy_area in ["space-time", "space-time-time"]:
+ hh_default_val = 0.0
+ if levy_area == "space-time-time":
+ kk_default_val = 0.0
+ else:
+ kk_default_val = None
+ else:
+ hh_default_val = None
+ kk_default_val = None
+ return jtu.tree_transpose(
+ outer_treedef=jax.tree_structure(tree_shape),
+ inner_treedef=jax.tree_structure(
+ LevyVal(
+ dt=0.0,
+ W=0.0,
+ H=hh_default_val,
+ bar_H=None,
+ K=kk_default_val,
+ bar_K=None,
+ )
+ ),
+ pytree_to_transpose=tree,
+ )
diff --git a/test/test_brownian.py b/test/test_brownian.py
index 4e6b8389..2af20239 100644
--- a/test/test_brownian.py
+++ b/test/test_brownian.py
@@ -7,8 +7,11 @@
import jax.tree_util as jtu
import pytest
import scipy.stats as stats
+from jax import config
+config.update("jax_enable_x64", True)
+
_vals = {
int: [0, 2],
float: [0.0, 2.0],
@@ -120,7 +123,7 @@ def _eval(key):
def test_conditional_statistics():
- key = jrandom.PRNGKey(5678)
+ key = jrandom.PRNGKey(5679)
bm_key, sample_key, permute_key = jrandom.split(key, 3)
# Get >80 randomly selected points; not too close to avoid discretisation error.
@@ -128,8 +131,8 @@ def test_conditional_statistics():
t1 = 8.7
ts = jrandom.uniform(sample_key, shape=(100,), minval=t0, maxval=t1)
sorted_ts = jnp.sort(ts)
- ts = []
prev_ti = sorted_ts[0]
+ ts = [prev_ti]
for ti in sorted_ts[1:]:
if ti < prev_ti + 2**-10:
continue
@@ -143,7 +146,7 @@ def test_conditional_statistics():
bm_keys = jrandom.split(bm_key, 100000)
path = jax.vmap(
lambda k: diffrax.VirtualBrownianTree(
- t0=t0, t1=t1, shape=(), tol=2**-12, key=k
+ t0=t0, t1=t1, shape=_make_struct((), jnp.float64), tol=2**-12, key=k
)
)(bm_keys)
@@ -155,7 +158,7 @@ def test_conditional_statistics():
out = sorted(out, key=lambda x: x[0])
# Test their conditional statistics
- for i in range(1, 98):
+ for i in range(1, len(ts) - 2):
prev_t, prev_vals = out[i - 1]
this_t, this_vals = out[i]
next_t, next_vals = out[i + 1]
diff --git a/test/test_stla.py b/test/test_stla.py
new file mode 100644
index 00000000..25ca960b
--- /dev/null
+++ b/test/test_stla.py
@@ -0,0 +1,231 @@
+import math
+
+import diffrax
+import jax
+import jax.numpy as jnp
+import jax.random as jrandom
+import jax.tree_util as jtu
+import pytest
+import scipy.stats as stats
+
+
+_vals = {
+ int: [0, 2],
+ float: [0.0, 2.0],
+ jnp.int32: [jnp.array(0, dtype=jnp.int32), jnp.array(2, dtype=jnp.int32)],
+ jnp.float32: [jnp.array(0.0, dtype=jnp.float32), jnp.array(2.0, dtype=jnp.float32)],
+}
+
+
+def _make_struct(shape, dtype):
+ dtype = jax.dtypes.canonicalize_dtype(dtype)
+ return jax.ShapeDtypeStruct(shape, dtype)
+
+
+@pytest.mark.parametrize(
+ "ctr", [diffrax.UnsafeBrownianPath, diffrax.VirtualBrownianTree]
+)
+def test_shape_and_dtype(ctr, getkey):
+ t0 = 0
+ t1 = 2
+
+ shapes = (
+ (),
+ (0,),
+ (1, 0),
+ (2,),
+ (3, 4),
+ (1, 2, 3, 4),
+ {
+ "a": (1,),
+ "b": (2, 3),
+ },
+ (
+ (1, 2),
+ (
+ (3, 4),
+ (5, 6),
+ ),
+ ),
+ )
+
+ dtypes = (
+ None,
+ None,
+ None,
+ jnp.float16,
+ jnp.float32,
+ jnp.float64,
+ {"a": None, "b": jnp.float64},
+ (jnp.float16, (jnp.float32, jnp.float64)),
+ )
+
+ def is_tuple_of_ints(obj):
+ return isinstance(obj, tuple) and all(isinstance(x, int) for x in obj)
+
+ for shape, dtype in zip(shapes, dtypes):
+ # Shape to pass as input
+ if dtype is not None:
+ shape = jtu.tree_map(_make_struct, shape, dtype, is_leaf=is_tuple_of_ints)
+
+ if ctr is diffrax.UnsafeBrownianPath:
+ path = ctr(shape, getkey(), levy_area="space-time")
+ assert path.t0 is None
+ assert path.t1 is None
+ elif ctr is diffrax.VirtualBrownianTree:
+ tol = 2**-5
+ path = ctr(t0, t1, tol, shape, getkey(), levy_area="space-time")
+ assert path.t0 == 0
+ assert path.t1 == 2
+ else:
+ assert False
+
+ # Expected output shape
+ if dtype is None:
+ shape = jtu.tree_map(_make_struct, shape, dtype, is_leaf=is_tuple_of_ints)
+
+ for _t0 in _vals.values():
+ for _t1 in _vals.values():
+ t0, _ = _t0
+ _, t1 = _t1
+ bm = path.evaluate(t0, t1, use_levy=True)
+ out_w = bm.W
+ out_hh = bm.H
+ out_w_shape = jtu.tree_map(
+ lambda leaf: jax.ShapeDtypeStruct(leaf.shape, leaf.dtype), out_w
+ )
+ out_hh_shape = jtu.tree_map(
+ lambda leaf: jax.ShapeDtypeStruct(leaf.shape, leaf.dtype), out_hh
+ )
+ assert out_hh_shape == shape
+ assert out_w_shape == shape
+
+
+@pytest.mark.parametrize(
+ "ctr", [diffrax.UnsafeBrownianPath, diffrax.VirtualBrownianTree]
+)
+def test_statistics(ctr):
+ # Deterministic key for this test; not using getkey()
+ key = jrandom.PRNGKey(5678)
+ keys = jrandom.split(key, 10000)
+
+ def _eval(key):
+ if ctr is diffrax.UnsafeBrownianPath:
+ path = ctr(shape=(), key=key, levy_area="space-time")
+ elif ctr is diffrax.VirtualBrownianTree:
+ path = ctr(
+ t0=0, t1=5, tol=2**-5, shape=(), key=key, levy_area="space-time"
+ )
+ else:
+ assert False
+ return path.evaluate(0, 5, use_levy=True)
+
+ bm_inc = jax.vmap(_eval)(keys)
+ values_w = bm_inc.W
+ values_h = bm_inc.H
+ assert values_w.shape == (10000,) and values_h.shape == (10000,)
+ ref_dist_w = stats.norm(loc=0, scale=math.sqrt(5))
+ _, pval_w = stats.kstest(values_w, ref_dist_w.cdf)
+ ref_dist_h = stats.norm(loc=0, scale=math.sqrt(5 / 12))
+ _, pval_h = stats.kstest(values_h, ref_dist_h.cdf)
+ assert pval_w > 0.1
+ assert pval_h > 0.1
+
+
+def test_conditional_statistics():
+ key = jrandom.PRNGKey(5678)
+ bm_key, sample_key, permute_key = jrandom.split(key, 3)
+
+ # Get >80 randomly selected points; not too close to avoid discretisation error.
+ t0 = 0.3
+ t1 = 8.7
+ ts = jrandom.uniform(sample_key, shape=(100,), minval=t0, maxval=t1)
+ sorted_ts = jnp.sort(ts)
+ ts = []
+ prev_ti = sorted_ts[0]
+ for ti in sorted_ts[1:]:
+ if ti < prev_ti + 2**-9:
+ continue
+ prev_ti = ti
+ ts.append(ti)
+ ts = jnp.stack(ts)
+ assert len(ts) > 80
+ ts = jrandom.permutation(permute_key, ts)
+
+ # Get some random paths
+ bm_keys = jrandom.split(bm_key, 100000)
+ path = jax.vmap(
+ lambda k: diffrax.VirtualBrownianTree(
+ t0=t0, t1=t1, shape=(), tol=2**-12, key=k, levy_area="space-time"
+ )
+ )(bm_keys)
+
+ # Sample some points
+ out = []
+ for ti in ts:
+ vals = jax.vmap(lambda p: p.evaluate(t0, ti, use_levy=True))(path)
+ out.append((ti, vals))
+ out = sorted(out, key=lambda x: x[0])
+
+ # Test their conditional statistics
+ for i in range(1, len(ts) - 1):
+ s, bm_s = out[i - 1]
+ r, bm_r = out[i]
+ u, bm_u = out[i + 1]
+
+ w_s, hh_s = bm_s.W, bm_s.H
+ w_r, hh_r = bm_r.W, bm_r.H
+ w_u, hh_u = bm_u.W, bm_u.H
+
+ s = s - t0
+ r = r - t0
+ u = u - t0
+ su = u - s
+ sr = r - s
+ ru = u - r
+ d = jnp.sqrt(jnp.power(sr, 3) + jnp.power(ru, 3))
+ a = (1 / (2 * su * d)) * jnp.power(sr, 7 / 2) * jnp.sqrt(ru)
+ b = (1 / (2 * su * d)) * jnp.power(ru, 7 / 2) * jnp.sqrt(sr)
+ c = (1.0 / (jnp.sqrt(12) * d)) * jnp.power(sr, 3 / 2) * jnp.power(ru, 3 / 2)
+
+ hh_su = (1.0 / su) * (u * hh_u - s * hh_s - u / 2 * w_s + s / 2 * w_u)
+
+ w_mean = w_s + (sr / su) * (w_u - w_s) + (6 * sr * ru / jnp.square(su)) * hh_su
+ w_std = 2 * (a + b) / su
+ normalised_w = (w_r - w_mean) / w_std
+ hh_mean = (
+ (s / r) * hh_s
+ + (jnp.power(sr, 3) / (r * jnp.square(su))) * hh_su
+ + 0.5 * w_s
+ - s / (2 * r) * w_mean
+ )
+ hh_var = jnp.square(c / r) + jnp.square((a * u + s * b) / (r * su))
+ hh_std = jnp.sqrt(hh_var)
+ normalised_hh = (hh_r - hh_mean) / hh_std
+
+ _, pval_w = stats.kstest(normalised_w, stats.norm.cdf)
+ _, pval_hh = stats.kstest(normalised_hh, stats.norm.cdf)
+
+ # Raise if the failure is statistically significant at 10%, subject to
+ # multiple-testing correction.
+ assert pval_w > 0.001
+ assert pval_hh > 0.001
+
+
+def test_reverse_time():
+ key = jrandom.PRNGKey(5678)
+ bm_key, sample_key = jrandom.split(key, 2)
+ bm = diffrax.VirtualBrownianTree(
+ t0=0, t1=5, tol=2**-5, shape=(), key=bm_key, levy_area="space-time"
+ )
+
+ ts = jrandom.uniform(sample_key, shape=(100,), minval=0, maxval=5)
+
+ vec_eval = jax.vmap(lambda t_prev, t: bm.evaluate(t_prev, t))
+
+ fwd_increments = vec_eval(ts[:-1], ts[1:])
+ back_increments = vec_eval(ts[1:], ts[:-1])
+
+ assert jtu.tree_map(
+ lambda fwd, bck: jnp.allclose(fwd, -bck), fwd_increments, back_increments
+ )