diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 00000000..13566b81 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/diffrax_STLA.iml b/.idea/diffrax_STLA.iml new file mode 100644 index 00000000..76f8ed3e --- /dev/null +++ b/.idea/diffrax_STLA.iml @@ -0,0 +1,15 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 00000000..105ce2da --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 00000000..928ba795 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 00000000..9a53c7d2 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 00000000..35eb1ddf --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/diffrax/brownian/base.py b/diffrax/brownian/base.py index ee90ba60..4d887696 100644 --- a/diffrax/brownian/base.py +++ b/diffrax/brownian/base.py @@ -1,14 +1,21 @@ import abc +from typing import Optional, Union -from ..custom_types import Array, PyTree, Scalar +from ..custom_types import Array, LevyVal, PyTree, Scalar from ..path import AbstractPath class AbstractBrownianPath(AbstractPath): - "Abstract base class for all Brownian paths." + """Abstract base class for all Brownian paths.""" @abc.abstractmethod - def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]: + def evaluate( + self, + t0: Scalar, + t1: Optional[Scalar] = None, + left: bool = True, + use_levy: bool = False, + ) -> Union[PyTree[Array], LevyVal]: r"""Samples a Brownian increment $w(t_1) - w(t_0)$. Each increment has distribution $\mathcal{N}(0, t_1 - t_0)$. @@ -20,6 +27,8 @@ def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]: - `left`: Ignored. (This determines whether to treat the path as left-continuous or right-continuous at any jump points, but Brownian motion has no jump points.) + - `use_levy`: If True, the return type will be a `LevyVal`, which contains + PyTrees of Brownian increments and their Levy areas. **Returns:** diff --git a/diffrax/brownian/path.py b/diffrax/brownian/path.py index a844450a..215f18e3 100644 --- a/diffrax/brownian/path.py +++ b/diffrax/brownian/path.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union +from typing import Literal, Tuple, Union import equinox as eqx import equinox.internal as eqxi @@ -7,7 +7,7 @@ import jax.random as jrandom import jax.tree_util as jtu -from ..custom_types import Array, PyTree, Scalar +from ..custom_types import Array, levy_tree_transpose, LevyVal, PyTree, Scalar from ..misc import force_bitcast_convert_type, is_tuple_of_ints, split_by_tree from .base import AbstractBrownianPath @@ -30,9 +30,14 @@ class UnsafeBrownianPath(AbstractBrownianPath): interval, ignoring the correlation between samples exhibited in true Brownian motion. Hence the restrictions above. (They describe the general case for which the correlation structure isn't needed.) + + Depending on the `levy_area` argument, this can also be used to generate Levy area. + `levy_area` can be "" or "space-time". + """ shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) + levy_area: Literal["", "space-time"] = eqx.field(static=True) # Handled as a string because PRNGKey is actually a function, not a class, which # makes it appearly badly in autogenerated documentation. key: "jax.random.PRNGKey" # noqa: F821 @@ -41,6 +46,7 @@ def __init__( self, shape: Union[Tuple[int, ...], PyTree[jax.ShapeDtypeStruct]], key: "jax.random.PRNGKey", + levy_area: Literal["", "space-time"] = "", ): self.shape = ( jax.ShapeDtypeStruct(shape, jax.dtypes.canonicalize_dtype(None)) @@ -48,6 +54,12 @@ def __init__( else shape ) self.key = key + if levy_area not in ["", "space-time"]: + raise ValueError( + f"levy_area must be one of '', 'space-time', but got {levy_area}." + ) + self.levy_area = levy_area + if any( not jnp.issubdtype(x.dtype, jnp.inexact) for x in jtu.tree_leaves(self.shape) @@ -63,7 +75,13 @@ def t1(self): return None @eqx.filter_jit - def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]: + def evaluate( + self, + t0: Scalar, + t1: Scalar, + left: bool = True, + use_levy: bool = False, + ) -> PyTree[Array]: del left t0 = eqxi.nondifferentiable(t0, name="t0") t1 = eqxi.nondifferentiable(t1, name="t1") @@ -72,14 +90,42 @@ def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]: key = jrandom.fold_in(self.key, t0_) key = jrandom.fold_in(key, t1_) key = split_by_tree(key, self.shape) - return jtu.tree_map( - lambda key, shape: self._evaluate_leaf(t0, t1, key, shape), key, self.shape + out = jtu.tree_map( + lambda key, shape: self._evaluate_leaf( + t0, t1, key, shape, self.levy_area, use_levy + ), + key, + self.shape, ) + if use_levy: + out = levy_tree_transpose(self.shape, self.levy_area, out) + assert isinstance(out, LevyVal) + return out + + @staticmethod + def _evaluate_leaf( + t0: Scalar, + t1: Scalar, + key, + shape: jax.ShapeDtypeStruct, + levy_area: str, + use_levy: bool, + ): + w_std = jnp.sqrt(t1 - t0).astype(shape.dtype) - def _evaluate_leaf(self, t0: Scalar, t1: Scalar, key, shape: jax.ShapeDtypeStruct): - return jrandom.normal(key, shape.shape, shape.dtype) * jnp.sqrt(t1 - t0).astype( - shape.dtype - ) + if levy_area == "space-time": + key_w, key_hh = jrandom.split(key, 2) + w = jrandom.normal(key_w, shape.shape, shape.dtype) * w_std + hh_std = jnp.sqrt((t1 - t0) / 12).astype(shape.dtype) + hh = jrandom.normal(key_hh, shape.shape, shape.dtype) * hh_std + else: + hh = None + w = jrandom.normal(key, shape.shape, shape.dtype) * w_std + + if use_levy: + return LevyVal(dt=t1 - t0, W=w, H=hh, bar_H=None, K=None, bar_K=None) + else: + return w UnsafeBrownianPath.__init__.__doc__ = """ @@ -89,5 +135,6 @@ def _evaluate_leaf(self, t0: Scalar, t1: Scalar, key, shape: jax.ShapeDtypeStruc dtype, and PyTree structure of the output. For simplicity, `shape` can also just be a tuple of integers, describing the shape of a single JAX array. In that case the dtype is chosen to be the default floating-point dtype. + - `key`: A random key. """ diff --git a/diffrax/brownian/tree.py b/diffrax/brownian/tree.py index 577092ce..46b1fbce 100644 --- a/diffrax/brownian/tree.py +++ b/diffrax/brownian/tree.py @@ -1,5 +1,5 @@ from dataclasses import field -from typing import Optional, Tuple, Union +from typing import Literal, Optional, Tuple, Union import equinox as eqx import equinox.internal as eqxi @@ -9,8 +9,8 @@ import jax.random as jrandom import jax.tree_util as jtu -from ..custom_types import Array, PyTree, Scalar -from ..misc import is_tuple_of_ints, split_by_tree +from ..custom_types import levy_tree_transpose, LevyVal, PyTree, Scalar +from ..misc import is_tuple_of_ints, linear_rescale, split_by_tree from .base import AbstractBrownianPath @@ -25,21 +25,75 @@ # } # +# We define +# H_{s,t} = 1/(t-s) ( \int_s^t ( W_u - (u-s)/(t-s) W_{s,t} ) du ). +# bhh_t = t * H_{0,t} +# For more details see Definition 4.2.1 and Theorem 6.1.4 of +# +# Foster, J. M. (2020). Numerical approximations for stochastic +# differential equations [PhD thesis]. University of Oxford. + class _State(eqx.Module): - s: Scalar - t: Scalar - u: Scalar - w_s: Scalar - w_t: Scalar - w_u: Scalar + level: int + stu: tuple[Scalar, Scalar, Scalar] # s, t, u + w_stu: tuple[Scalar, Scalar, Scalar] # W at times s, t, u + w_st_tu: tuple[Scalar, Scalar] # W_{s,t} and W_{t,u} key: "jax.random.PRNGKey" + bhh_stu: Optional[Tuple[Scalar, Scalar, Scalar]] # \bar{H} at times s, t, u + bhh_st_tu: Optional[Tuple[Scalar, Scalar]] # \bar{H}_{s,t} and \bar{H}_{t,u} + bkk_stu: Optional[Tuple[Scalar, Scalar, Scalar]] # \bar{K} at times s, t, u + bkk_st_tu: Optional[Tuple[Scalar, Scalar]] # \bar{K}_{s,t} and \bar{K}_{t,u} + + +def _levy_diff(x0: LevyVal, x1: LevyVal) -> LevyVal: + r"""Computes $(W_{s,u}, H_{s,u})$ from $(W_s, \bar{H}_{s,u})$ and + $(W_u, \bar{H}_u)$, where $\bar{H}_u = u * H_u$. + + **Arguments:** + + - `x0`: `LevyVal` at time `s` + + - `x1`: `LevyVal` at time `u` + + **Returns:** + + `LevyVal(W_su, H_su)` + """ + + su = (x1.dt - x0.dt).astype(x0.W.dtype) + w_su = x1.W - x0.W + if x0.H is None or x1.H is None: # BM only case + return LevyVal(dt=su, W=w_su, H=None, bar_H=None, K=None, bar_K=None) + + # levy_area == "space-time" + _su = jnp.where(jnp.abs(su) < jnp.finfo(su).eps, jnp.inf, su) + inverse_su = 1 / _su + u_bb_s = x1.dt * x0.W - x0.dt * x1.W + bhh_su = x1.bar_H - x0.bar_H - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s) + hh_su = inverse_su * bhh_su + + return LevyVal(dt=su, W=w_su, H=hh_su, bar_H=None, K=None, bar_K=None) + + +def split_interval(_cond, x_stu, x_st_tu): + x_s, x_t, x_u = x_stu + x_st, x_tu = x_st_tu + x_s = jnp.where(_cond, x_t, x_s) + x_u = jnp.where(_cond, x_u, x_t) + x_su = jnp.where(_cond, x_tu, x_st) + return x_s, x_u, x_su class VirtualBrownianTree(AbstractBrownianPath): """Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`, and is piecewise quadratic at that discretisation. + Can be initialised with `levy_area` set to `""`, or `"space-time"`. + If `levy_area=="space_time"`, then it also computes space-time Lévy area `H`. + This will impact the Brownian path, so even with the same key, the trajectory will + be different depending on the value of `levy_area`. + ??? cite "Reference" ```bibtex @@ -52,15 +106,15 @@ class VirtualBrownianTree(AbstractBrownianPath): } ``` - (The implementation here is a slight improvement on the reference - implementation, by being piecwise quadratic rather than piecewise linear. This - corrects a small bias in the generated samples.) + (The implementation here is a slight improvement on the reference implementation + by using an interpolation method which ensures all the 2nd moments are correct.) """ t0: Scalar = field(init=True) t1: Scalar = field(init=True) # override init=False in AbstractPath tol: Scalar shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) + levy_area: Literal["", "space-time"] = eqx.field(static=True) key: "jax.random.PRNGKey" # noqa: F821 def __init__( @@ -70,10 +124,20 @@ def __init__( tol: Scalar, shape: Union[Tuple[int, ...], PyTree[jax.ShapeDtypeStruct]], key: "jax.random.PRNGKey", + levy_area: Literal["", "space-time"] = "", ): + (t0, t1) = eqx.error_if((t0, t1), t0 >= t1, "t0 must be strictly less than t1") self.t0 = t0 self.t1 = t1 - self.tol = tol + # Since we rescale the interval to [0,1], + # we need to rescale the tolerance too. + self.tol = tol / (self.t1 - self.t0) + + if levy_area not in ["", "space-time"]: + raise ValueError( + f"levy_area must be one of '', 'space-time', but got {levy_area}." + ) + self.levy_area = levy_area self.shape = ( jax.ShapeDtypeStruct(shape, jax.dtypes.canonicalize_dtype(None)) if is_tuple_of_ints(shape) @@ -88,70 +152,224 @@ def __init__( ) self.key = split_by_tree(key, self.shape) + def _denormalise_bm_inc(self, x: LevyVal) -> LevyVal: + # TODO: demonstrate rescaling actually helps + + # Rescaling back from [0, 1] to the original interval [t0, t1]. + interval_len = self.t1 - self.t0 # can be any dtype + sqrt_len = jnp.sqrt(interval_len) + + def sqrt_mult(z): + # need to cast to dtype of each leaf in PyTree + dtype = jnp.dtype(z) + return (z * sqrt_len).astype(dtype) + + def mult(z): + dtype = jnp.dtype(z) + return (interval_len * z).astype(dtype) + + return LevyVal( + dt=jtu.tree_map(mult, x.dt), + W=jtu.tree_map(sqrt_mult, x.W), + H=jtu.tree_map(sqrt_mult, x.H), + bar_H=None, + K=jtu.tree_map(sqrt_mult, x.K), + bar_K=None, + ) + @eqx.filter_jit def evaluate( - self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True - ) -> PyTree[Array]: - del left + self, + t0: Scalar, + t1: Optional[Scalar] = None, + left: bool = True, + use_levy: bool = False, + ) -> LevyVal: + def _is_levy_val(obj): + return isinstance(obj, LevyVal) + t0 = eqxi.nondifferentiable(t0, name="t0") + # map the interval [self.t0, self.t1] onto [0,1] + t0 = linear_rescale(self.t0, t0, self.t1) + levy_0 = self._evaluate(t0) if t1 is None: - return self._evaluate(t0) + levy_out = levy_0 + else: t1 = eqxi.nondifferentiable(t1, name="t1") - return jtu.tree_map( - lambda x, y: x - y, - self._evaluate(t1), - self._evaluate(t0), - ) + # map the interval [self.t0, self.t1] onto [0,1] + t1 = linear_rescale(self.t0, t1, self.t1) + levy_1 = self._evaluate(t1) + + levy_out = jtu.tree_map(_levy_diff, levy_0, levy_1, is_leaf=_is_levy_val) - def _evaluate(self, τ: Scalar) -> PyTree[Array]: - map_func = lambda key, shape: self._evaluate_leaf(key, τ, shape) + levy_out = levy_tree_transpose(self.shape, self.levy_area, levy_out) + # now map [0,1] back onto [self.t0, self.t1] + levy_out = self._denormalise_bm_inc(levy_out) + assert isinstance(levy_out, LevyVal) + return levy_out if use_levy else levy_out.W + + def _evaluate(self, r: Scalar) -> PyTree[LevyVal]: + """Maps the _evaluate_leaf function at time τ using self.key onto self.shape""" + r = eqxi.error_if( + r, + r < 0, + "Cannot evaluate VirtualBrownianTree outside of its range [t0, t1].", + ) + r = eqxi.error_if( + r, + r > 1, + "Cannot evaluate VirtualBrownianTree outside of its range [t0, t1].", + ) + # Clip because otherwise the while loop below won't terminate, and the above + # errors are only raised after everything has finished executing. + r = jnp.clip(r, 0, 1) + map_func = lambda key, shape: self._evaluate_leaf(key, r, shape) return jtu.tree_map(map_func, self.key, self.shape) - def _brownian_bridge(self, s, t, u, w_s, w_u, key, shape, dtype): - mean = w_s + (w_u - w_s) * ((t - s) / (u - s)) - var = (u - t) * (t - s) / (u - s) - std = jnp.sqrt(var) - return mean + std * jrandom.normal(key, shape, dtype) + def _brownian_arch( + self, + level: int, + s: Scalar, + u: Scalar, + w: Tuple[Scalar, Scalar, Scalar], + key, + shape, + dtype, + bhh: Optional[Tuple[Scalar, Scalar, Scalar]], + bkk: Optional[Tuple[Scalar, Scalar, Scalar]], + ): + r"""For `t = (s+u)/2` evaluates `w_t` and (optionally) `bhh_t` + conditioned on `w_s`, `w_u`, `bhh_s`, `bhh_u`, where + `bhh_st` represents $\bar{H}_{s,t} \coloneqq (t-s) H_{s,t}$. + To avoid cancellation errors, requires an input of `w_su`, `bhh_su` + and also returns `w_st` and `w_tu` in addition to just `w_t`. Same for `bhh` + if it is not None. + Note that the inputs and outputs already contain `bkk`. These values are + there for the sake of a future extension with "space-time-time" Levy area + and should be None for now. + + **Arguments:** + - `s`: start time + - `u`: end time + - `w_s`: value of BM at s + - `w_u`: value of BM at u + - `w_su`: $W_{s,u}$ + - `key`: + - `shape`: + - `dtype`: + - `bhh`: (optional) $(\bar{H}_s, \bar{H}_u, \bar{H}_{s,u})$ + - `bkk`: (optional) $(\bar{K}_s, \bar{K}_u, \bar{K}_{s,u})$ + + **Returns:** + - `t`: midpoint time + - `w_t`: value of BM at t + - `w_st_tu`: $(W_{s,t}, W_{t,u})$ + - `bhh`: (optional) $(\bar{H}_s, \bar{H}_t, \bar{H}_u)$ + - `bhh_st_tu`: (optional) $(\bar{H}_{s,t}, \bar{H}_{t,u})$ + - `bkk_t`: (optional) $(\bar{K}_s, \bar{K}_t, \bar{K}_u)$ + - `bkk_st_tu`: (optional) $(\bar{K}_{s,t}, \bar{K}_{t,u})$ + + """ + + su = jnp.power(jnp.asarray(2.0, dtype), -level) + st = su / 2 + t = s + st + u_minus_s = u - s + su = eqxi.error_if( + su, + jnp.abs(u_minus_s - su) > 1e-17, + "VirtualBrownianTree: u-s is not 2^(-tree_level)", + ) + root_su = jnp.sqrt(su) + + w_s, w_u, w_su = w + + if self.levy_area == "space-time": + assert bhh is not None + assert bkk is None + bhh_s, bhh_u, bhh_su = bhh + + z1_key, z2_key = jrandom.split(key, 2) + z1 = jrandom.normal(z1_key, shape, dtype) + z2 = jrandom.normal(z2_key, shape, dtype) + z = z1 * (root_su / 4) + n = z2 * jnp.sqrt(su / 12) + + w_term1 = w_su / 2 + w_term2 = 3 / (2 * su) * bhh_su + z + w_st = w_term1 + w_term2 + w_tu = w_term1 - w_term2 + w_st_tu = (w_st, w_tu) + + bhh_term1 = bhh_su / 8 - su / 4 * z + bhh_term2 = su / 4 * n + bhh_st = bhh_term1 + bhh_term2 + bhh_tu = bhh_term1 - bhh_term2 + bhh_st_tu = (bhh_st, bhh_tu) + + w_t = w_s + w_st + bhh_t = bhh_s + bhh_st + 0.5 * (t * w_s - s * w_t) + bhh = (bhh_s, bhh_t, bhh_u) + bkk = None + bkk_st_tu = None + + else: + assert bhh is None + assert bkk is None + mean = 0.5 * w_su + w_term2 = root_su / 2 * jrandom.normal(key, shape, dtype) + w_st = mean + w_term2 + w_tu = mean - w_term2 + w_st_tu = (w_st, w_tu) + w_t = w_s + w_st + bhh, bhh_st_tu, bkk, bkk_st_tu = None, None, None, None + return t, w_t, w_st_tu, bhh, bhh_st_tu, bkk, bkk_st_tu def _evaluate_leaf( self, key, - τ: Scalar, + r: Scalar, shape: jax.ShapeDtypeStruct, - ) -> Array: + ) -> LevyVal: shape, dtype = shape.shape, shape.dtype - cond = self.t0 < self.t1 - t0 = jnp.where(cond, self.t0, self.t1).astype(dtype) - t1 = jnp.where(cond, self.t1, self.t0).astype(dtype) + t0 = jnp.zeros((), dtype) + t1 = jnp.ones((), dtype) + r = jnp.asarray(r, dtype) - t0 = eqxi.error_if( - t0, - τ < t0, - "Cannot evaluate VirtualBrownianTree outside of its range [t0, t1].", - ) - t1 = eqxi.error_if( - t1, - τ > t1, - "Cannot evaluate VirtualBrownianTree outside of its range [t0, t1].", - ) - # Clip because otherwise the while loop below won't terminate, and the above - # errors are only raised after everything has finished executing. - τ = jnp.clip(τ, t0, t1).astype(dtype) + w_0 = jnp.zeros(shape, dtype) + + if self.levy_area == "space-time": + key, init_key_w, init_key_la, midpoint_key = jrandom.split(key, 4) + w_1 = jrandom.normal(init_key_w, shape, dtype) - key, init_key = jrandom.split(key, 2) - thalf = t0 + 0.5 * (t1 - t0) - w_t1 = jrandom.normal(init_key, shape, dtype) * jnp.sqrt(t1 - t0) - w_thalf = self._brownian_bridge(t0, thalf, t1, 0, w_t1, key, shape, dtype) + bhh_1 = jnp.sqrt(1 / 12) * jrandom.normal(init_key_la, shape, dtype) + bhh_0 = jnp.zeros_like(bhh_1) + bhh = (bhh_0, bhh_1, bhh_1) + bkk = None + + else: + key, init_key_w, midpoint_key = jrandom.split(key, 3) + w_1 = jrandom.normal(init_key_w, shape, dtype) + bhh = None + bkk = None + + w = (w_0, w_1, w_1) + + half, w_half, w_inc, bhh, bhh_inc, bkk, bkk_inc = self._brownian_arch( + 0, t0, t1, w, key, shape, dtype, bhh, bkk + ) init_state = _State( - s=t0, - t=thalf, - u=t1, - w_s=jnp.zeros_like(w_t1), - w_t=w_thalf, - w_u=w_t1, + level=0, + stu=(t0, half, t1), + w_stu=(w_0, w_half, w_1), + w_st_tu=w_inc, key=key, + bhh_stu=bhh, + bhh_st_tu=bhh_inc, + bkk_stu=bkk, + bkk_st_tu=bkk_inc, ) def _cond_fun(_state): @@ -161,67 +379,105 @@ def _cond_fun(_state): # jnp.abs(τ - state.s) > self.tol # Here, because we use quadratic splines to get better samples, we always # iterate down to the level of the spline. - return (_state.u - _state.s) > self.tol + _s, _t, _u = _state.stu + return (_u - _s) > 2 * self.tol - def _body_fun(_state): + def _body_fun(_state: _State): + """Single-step of binary search for τ.""" + _level = _state.level + 1 _key1, _key2 = jrandom.split(_state.key, 2) - _cond = τ > _state.t - _s = jnp.where(_cond, _state.t, _state.s) - _u = jnp.where(_cond, _state.u, _state.t) - _w_s = jnp.where(_cond, _state.w_t, _state.w_s) - _w_u = jnp.where(_cond, _state.w_u, _state.w_t) + _s, _t, _u = _state.stu + _cond = r > _t + _s = jnp.where(_cond, _t, _s) + _u = jnp.where(_cond, _u, _t) + + _w = split_interval(_cond, _state.w_stu, _state.w_st_tu) + if self.levy_area in ["space-time", "space-time-time"]: + _bhh = split_interval(_cond, _state.bhh_stu, _state.bhh_st_tu) + _bkk = None + else: + _bhh = None + _bkk = None + _key = jnp.where(_cond, _key1, _key2) - _t = _s + 0.5 * (_u - _s) - _w_t = self._brownian_bridge(_s, _t, _u, _w_s, _w_u, _key, shape, dtype) - return _State(s=_s, t=_t, u=_u, w_s=_w_s, w_t=_w_t, w_u=_w_u, key=_key) + _key, _midpoint_key = jrandom.split(_key, 2) + + _t, _w_t, _w_inc, _bhh, _bhh_inc, _bkk, _bkk_inc = self._brownian_arch( + _level, _s, _u, _w, _midpoint_key, shape, dtype, _bhh, _bkk + ) + + return _State( + level=_level, + stu=(_s, _t, _u), + w_stu=(_w[0], _w_t, _w[2]), + w_st_tu=_w_inc, + key=_key, + bhh_stu=_bhh, + bhh_st_tu=_bhh_inc, + bkk_stu=_bkk, + bkk_st_tu=_bkk_inc, + ) final_state = lax.while_loop(_cond_fun, _body_fun, init_state) - # Quadratic interpolation. - # We have w_s, w_t, w_u available to us, and interpolate with the unique - # parabola passing through them. - # Why quadratic and not just "linear from w_s to w_t to w_u"? Because linear - # only gets the conditional mean correct at every point, but not the - # conditional variance. This means that the Virtual Brownian Tree will pass - # statistical tests comparing w(t)|(w(s),w(u)) against the true Brownian - # bridge. (Provided s, t, u are greater than the discretisation level `tol`.) - # (If you just do linear then you find that the variance is *ever so slightly* - # too small.) - s = final_state.s - u = final_state.u - w_s = final_state.w_s - w_t = final_state.w_t - w_u = final_state.w_u - rescaled_τ = (τ - s) / (u - s) - # Fit polynomial as usual. - # The polynomial ax^2 + bx + c is found by solving - # [s^2 s 1][a] [w_s] - # [t^2 t 1][b] = [w_t] - # [u^2 u 1][c] [w_u] - # - # `A` is the inverse of the above matrix, with s=0, t=0.5, u=1. - A = jnp.array([[2, -4, 2], [-3, 4, -1], [1, 0, 0]]) - coeffs = jnp.tensordot(A, jnp.stack([w_s, w_t, w_u]), axes=1) - return jnp.polyval(coeffs, rescaled_τ) - - -VirtualBrownianTree.__init__.__doc__ = """ -**Arguments:** - -- `t0`: The start of the interval the Brownian motion is defined over. -- `t1`: The start of the interval the Brownian motion is defined over. -- `tol`: The discretisation that `[t0, t1]` is discretised to. -- `shape`: Should be a PyTree of `jax.ShapeDtypeStruct`s, representing the shape, - dtype, and PyTree structure of the output. For simplicity, `shape` can also just - be a tuple of integers, describing the shape of a single JAX array. In that case - the dtype is chosen to be the default floating-point dtype. -- `key`: A random key. - -!!! info - - If using this as part of an SDE solver, and you know (or have an estimate of) the - step sizes made in the solver, then you can optimise the computational efficiency - of the Virtual Brownian Tree by setting `tol` to be just slightly smaller than the - step size of the solver. - -The Brownian motion is defined to equal 0 at `t0`. -""" + + level = final_state.level + 1 + s, t, u = final_state.stu + + # Split the interval in half one last time depending on whether r < t or r > t + # but this time complete the step with the general interpolation, rather + # than the midpoint rule (as given by _brownian_arch). + + cond = r > t + s = jnp.where(cond, t, s) + u = jnp.where(cond, u, t) + su = jnp.power(jnp.asarray(2.0, dtype), -level) + su = eqxi.error_if( + su, + jnp.abs(u - s - su) > 1e-17, + "VirtualBrownianTree: u-s is not 2^(-tree_level) in final step.", + ) + + sr = r - s + ru = u - r # make sure su = sr + ru regardless of cancellation error + + w_s, w_u, w_su = split_interval(cond, final_state.w_stu, final_state.w_st_tu) + key1, key2 = jrandom.split(final_state.key, 2) + key = jnp.where(cond, key1, key2) + + # BM only case + if self.levy_area == "": + z = jrandom.normal(key, shape, dtype) + w_sr = sr / su * w_su + jnp.sqrt(sr * ru / su) * z + w_r = w_s + w_sr + return LevyVal(dt=r, W=w_r, H=None, bar_H=None, K=None, bar_K=None) + + elif self.levy_area == "space-time": + bhh_s, bhh_u, bhh_su = split_interval( + cond, final_state.bhh_stu, final_state.bhh_st_tu + ) + sr3 = jnp.power(sr, 3) + ru3 = jnp.power(ru, 3) + su3 = jnp.power(su, 3) + key1, key2 = jrandom.split(key, 2) + x1 = jrandom.normal(key1, shape, dtype) + x2 = jrandom.normal(key2, shape, dtype) + + sr_ru_half = jnp.sqrt(sr * ru) + d = jnp.sqrt(sr3 + ru3) + d_prime = 1 / (2 * su * d) + a = d_prime * sr3 * sr_ru_half + b = d_prime * ru3 * sr_ru_half + + w_sr = sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b) / su * x1 + w_r = w_s + w_sr + c = jnp.sqrt(3 * sr3 * ru3) / (6 * d) + bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2 + bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r) + + inverse_r = 1 / jnp.where(jnp.abs(r) < jnp.finfo(r).eps, jnp.inf, r) + hh_r = inverse_r * bhh_r + + else: + assert False + + return LevyVal(dt=r, W=w_r, H=hh_r, bar_H=bhh_r, K=None, bar_K=None) diff --git a/diffrax/custom_types.py b/diffrax/custom_types.py index 93e818b5..ff1a1fe0 100644 --- a/diffrax/custom_types.py +++ b/diffrax/custom_types.py @@ -1,8 +1,10 @@ import inspect import typing -from typing import Any, Dict, Generic, Tuple, TypeVar, Union +from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Union +import equinox as eqx import equinox.internal as eqxi +import jax import jax.tree_util as jtu @@ -131,3 +133,52 @@ def __class_getitem__(cls, item): DenseInfo = Dict[str, PyTree[Array]] DenseInfos = Dict[str, PyTree[Array["times", ...]]] # noqa: F821 sentinel: Any = eqxi.doc_repr(object(), "sentinel") + + +class LevyVal(eqx.Module): + dt: Scalar + W: PyTree[Array] + H: Optional[PyTree[Array]] + bar_H: Optional[PyTree[Array]] + K: Optional[PyTree[Array]] + bar_K: Optional[PyTree[Array]] + + +def levy_tree_transpose(tree_shape, levy_area, tree): + """Helper that takes a PyTree of LevyVals and transposes + into a LevyVal of PyTrees. + + **Arguments:** + - `tree_shape`: Corresponds to `outer_treedef` in `jax.tree_transpose`. + + - `levy_area`: can be "", "space-time" or "space-time-time", which indicates + which fields of the LevyVal will have values. + + - `tree`: the PyTree of LevyVals to transpose. + + **Returns:** + A `LevyVal` of PyTrees. + """ + if levy_area in ["space-time", "space-time-time"]: + hh_default_val = 0.0 + if levy_area == "space-time-time": + kk_default_val = 0.0 + else: + kk_default_val = None + else: + hh_default_val = None + kk_default_val = None + return jtu.tree_transpose( + outer_treedef=jax.tree_structure(tree_shape), + inner_treedef=jax.tree_structure( + LevyVal( + dt=0.0, + W=0.0, + H=hh_default_val, + bar_H=None, + K=kk_default_val, + bar_K=None, + ) + ), + pytree_to_transpose=tree, + ) diff --git a/test/test_brownian.py b/test/test_brownian.py index 4e6b8389..2af20239 100644 --- a/test/test_brownian.py +++ b/test/test_brownian.py @@ -7,8 +7,11 @@ import jax.tree_util as jtu import pytest import scipy.stats as stats +from jax import config +config.update("jax_enable_x64", True) + _vals = { int: [0, 2], float: [0.0, 2.0], @@ -120,7 +123,7 @@ def _eval(key): def test_conditional_statistics(): - key = jrandom.PRNGKey(5678) + key = jrandom.PRNGKey(5679) bm_key, sample_key, permute_key = jrandom.split(key, 3) # Get >80 randomly selected points; not too close to avoid discretisation error. @@ -128,8 +131,8 @@ def test_conditional_statistics(): t1 = 8.7 ts = jrandom.uniform(sample_key, shape=(100,), minval=t0, maxval=t1) sorted_ts = jnp.sort(ts) - ts = [] prev_ti = sorted_ts[0] + ts = [prev_ti] for ti in sorted_ts[1:]: if ti < prev_ti + 2**-10: continue @@ -143,7 +146,7 @@ def test_conditional_statistics(): bm_keys = jrandom.split(bm_key, 100000) path = jax.vmap( lambda k: diffrax.VirtualBrownianTree( - t0=t0, t1=t1, shape=(), tol=2**-12, key=k + t0=t0, t1=t1, shape=_make_struct((), jnp.float64), tol=2**-12, key=k ) )(bm_keys) @@ -155,7 +158,7 @@ def test_conditional_statistics(): out = sorted(out, key=lambda x: x[0]) # Test their conditional statistics - for i in range(1, 98): + for i in range(1, len(ts) - 2): prev_t, prev_vals = out[i - 1] this_t, this_vals = out[i] next_t, next_vals = out[i + 1] diff --git a/test/test_stla.py b/test/test_stla.py new file mode 100644 index 00000000..25ca960b --- /dev/null +++ b/test/test_stla.py @@ -0,0 +1,231 @@ +import math + +import diffrax +import jax +import jax.numpy as jnp +import jax.random as jrandom +import jax.tree_util as jtu +import pytest +import scipy.stats as stats + + +_vals = { + int: [0, 2], + float: [0.0, 2.0], + jnp.int32: [jnp.array(0, dtype=jnp.int32), jnp.array(2, dtype=jnp.int32)], + jnp.float32: [jnp.array(0.0, dtype=jnp.float32), jnp.array(2.0, dtype=jnp.float32)], +} + + +def _make_struct(shape, dtype): + dtype = jax.dtypes.canonicalize_dtype(dtype) + return jax.ShapeDtypeStruct(shape, dtype) + + +@pytest.mark.parametrize( + "ctr", [diffrax.UnsafeBrownianPath, diffrax.VirtualBrownianTree] +) +def test_shape_and_dtype(ctr, getkey): + t0 = 0 + t1 = 2 + + shapes = ( + (), + (0,), + (1, 0), + (2,), + (3, 4), + (1, 2, 3, 4), + { + "a": (1,), + "b": (2, 3), + }, + ( + (1, 2), + ( + (3, 4), + (5, 6), + ), + ), + ) + + dtypes = ( + None, + None, + None, + jnp.float16, + jnp.float32, + jnp.float64, + {"a": None, "b": jnp.float64}, + (jnp.float16, (jnp.float32, jnp.float64)), + ) + + def is_tuple_of_ints(obj): + return isinstance(obj, tuple) and all(isinstance(x, int) for x in obj) + + for shape, dtype in zip(shapes, dtypes): + # Shape to pass as input + if dtype is not None: + shape = jtu.tree_map(_make_struct, shape, dtype, is_leaf=is_tuple_of_ints) + + if ctr is diffrax.UnsafeBrownianPath: + path = ctr(shape, getkey(), levy_area="space-time") + assert path.t0 is None + assert path.t1 is None + elif ctr is diffrax.VirtualBrownianTree: + tol = 2**-5 + path = ctr(t0, t1, tol, shape, getkey(), levy_area="space-time") + assert path.t0 == 0 + assert path.t1 == 2 + else: + assert False + + # Expected output shape + if dtype is None: + shape = jtu.tree_map(_make_struct, shape, dtype, is_leaf=is_tuple_of_ints) + + for _t0 in _vals.values(): + for _t1 in _vals.values(): + t0, _ = _t0 + _, t1 = _t1 + bm = path.evaluate(t0, t1, use_levy=True) + out_w = bm.W + out_hh = bm.H + out_w_shape = jtu.tree_map( + lambda leaf: jax.ShapeDtypeStruct(leaf.shape, leaf.dtype), out_w + ) + out_hh_shape = jtu.tree_map( + lambda leaf: jax.ShapeDtypeStruct(leaf.shape, leaf.dtype), out_hh + ) + assert out_hh_shape == shape + assert out_w_shape == shape + + +@pytest.mark.parametrize( + "ctr", [diffrax.UnsafeBrownianPath, diffrax.VirtualBrownianTree] +) +def test_statistics(ctr): + # Deterministic key for this test; not using getkey() + key = jrandom.PRNGKey(5678) + keys = jrandom.split(key, 10000) + + def _eval(key): + if ctr is diffrax.UnsafeBrownianPath: + path = ctr(shape=(), key=key, levy_area="space-time") + elif ctr is diffrax.VirtualBrownianTree: + path = ctr( + t0=0, t1=5, tol=2**-5, shape=(), key=key, levy_area="space-time" + ) + else: + assert False + return path.evaluate(0, 5, use_levy=True) + + bm_inc = jax.vmap(_eval)(keys) + values_w = bm_inc.W + values_h = bm_inc.H + assert values_w.shape == (10000,) and values_h.shape == (10000,) + ref_dist_w = stats.norm(loc=0, scale=math.sqrt(5)) + _, pval_w = stats.kstest(values_w, ref_dist_w.cdf) + ref_dist_h = stats.norm(loc=0, scale=math.sqrt(5 / 12)) + _, pval_h = stats.kstest(values_h, ref_dist_h.cdf) + assert pval_w > 0.1 + assert pval_h > 0.1 + + +def test_conditional_statistics(): + key = jrandom.PRNGKey(5678) + bm_key, sample_key, permute_key = jrandom.split(key, 3) + + # Get >80 randomly selected points; not too close to avoid discretisation error. + t0 = 0.3 + t1 = 8.7 + ts = jrandom.uniform(sample_key, shape=(100,), minval=t0, maxval=t1) + sorted_ts = jnp.sort(ts) + ts = [] + prev_ti = sorted_ts[0] + for ti in sorted_ts[1:]: + if ti < prev_ti + 2**-9: + continue + prev_ti = ti + ts.append(ti) + ts = jnp.stack(ts) + assert len(ts) > 80 + ts = jrandom.permutation(permute_key, ts) + + # Get some random paths + bm_keys = jrandom.split(bm_key, 100000) + path = jax.vmap( + lambda k: diffrax.VirtualBrownianTree( + t0=t0, t1=t1, shape=(), tol=2**-12, key=k, levy_area="space-time" + ) + )(bm_keys) + + # Sample some points + out = [] + for ti in ts: + vals = jax.vmap(lambda p: p.evaluate(t0, ti, use_levy=True))(path) + out.append((ti, vals)) + out = sorted(out, key=lambda x: x[0]) + + # Test their conditional statistics + for i in range(1, len(ts) - 1): + s, bm_s = out[i - 1] + r, bm_r = out[i] + u, bm_u = out[i + 1] + + w_s, hh_s = bm_s.W, bm_s.H + w_r, hh_r = bm_r.W, bm_r.H + w_u, hh_u = bm_u.W, bm_u.H + + s = s - t0 + r = r - t0 + u = u - t0 + su = u - s + sr = r - s + ru = u - r + d = jnp.sqrt(jnp.power(sr, 3) + jnp.power(ru, 3)) + a = (1 / (2 * su * d)) * jnp.power(sr, 7 / 2) * jnp.sqrt(ru) + b = (1 / (2 * su * d)) * jnp.power(ru, 7 / 2) * jnp.sqrt(sr) + c = (1.0 / (jnp.sqrt(12) * d)) * jnp.power(sr, 3 / 2) * jnp.power(ru, 3 / 2) + + hh_su = (1.0 / su) * (u * hh_u - s * hh_s - u / 2 * w_s + s / 2 * w_u) + + w_mean = w_s + (sr / su) * (w_u - w_s) + (6 * sr * ru / jnp.square(su)) * hh_su + w_std = 2 * (a + b) / su + normalised_w = (w_r - w_mean) / w_std + hh_mean = ( + (s / r) * hh_s + + (jnp.power(sr, 3) / (r * jnp.square(su))) * hh_su + + 0.5 * w_s + - s / (2 * r) * w_mean + ) + hh_var = jnp.square(c / r) + jnp.square((a * u + s * b) / (r * su)) + hh_std = jnp.sqrt(hh_var) + normalised_hh = (hh_r - hh_mean) / hh_std + + _, pval_w = stats.kstest(normalised_w, stats.norm.cdf) + _, pval_hh = stats.kstest(normalised_hh, stats.norm.cdf) + + # Raise if the failure is statistically significant at 10%, subject to + # multiple-testing correction. + assert pval_w > 0.001 + assert pval_hh > 0.001 + + +def test_reverse_time(): + key = jrandom.PRNGKey(5678) + bm_key, sample_key = jrandom.split(key, 2) + bm = diffrax.VirtualBrownianTree( + t0=0, t1=5, tol=2**-5, shape=(), key=bm_key, levy_area="space-time" + ) + + ts = jrandom.uniform(sample_key, shape=(100,), minval=0, maxval=5) + + vec_eval = jax.vmap(lambda t_prev, t: bm.evaluate(t_prev, t)) + + fwd_increments = vec_eval(ts[:-1], ts[1:]) + back_increments = vec_eval(ts[1:], ts[:-1]) + + assert jtu.tree_map( + lambda fwd, bck: jnp.allclose(fwd, -bck), fwd_increments, back_increments + )