Skip to content

Commit

Permalink
added space-time levy area (for PR)
Browse files Browse the repository at this point in the history
  • Loading branch information
andyElking committed Dec 7, 2023
1 parent 6192f62 commit 872f476
Show file tree
Hide file tree
Showing 12 changed files with 780 additions and 133 deletions.
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions .idea/diffrax_STLA.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 12 additions & 3 deletions diffrax/brownian/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import abc
from typing import Optional, Union

from ..custom_types import Array, PyTree, Scalar
from ..custom_types import Array, LevyVal, PyTree, Scalar
from ..path import AbstractPath


class AbstractBrownianPath(AbstractPath):
"Abstract base class for all Brownian paths."
"""Abstract base class for all Brownian paths."""

@abc.abstractmethod
def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]:
def evaluate(
self,
t0: Scalar,
t1: Optional[Scalar] = None,
left: bool = True,
use_levy: bool = False,
) -> Union[PyTree[Array], LevyVal]:
r"""Samples a Brownian increment $w(t_1) - w(t_0)$.
Each increment has distribution $\mathcal{N}(0, t_1 - t_0)$.
Expand All @@ -20,6 +27,8 @@ def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]:
- `left`: Ignored. (This determines whether to treat the path as
left-continuous or right-continuous at any jump points, but Brownian
motion has no jump points.)
- `use_levy`: If True, the return type will be a `LevyVal`, which contains
PyTrees of Brownian increments and their Levy areas.
**Returns:**
Expand Down
65 changes: 56 additions & 9 deletions diffrax/brownian/path.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, Union
from typing import Literal, Tuple, Union

import equinox as eqx
import equinox.internal as eqxi
Expand All @@ -7,7 +7,7 @@
import jax.random as jrandom
import jax.tree_util as jtu

from ..custom_types import Array, PyTree, Scalar
from ..custom_types import Array, levy_tree_transpose, LevyVal, PyTree, Scalar
from ..misc import force_bitcast_convert_type, is_tuple_of_ints, split_by_tree
from .base import AbstractBrownianPath

Expand All @@ -30,9 +30,14 @@ class UnsafeBrownianPath(AbstractBrownianPath):
interval, ignoring the correlation between samples exhibited in true Brownian
motion. Hence the restrictions above. (They describe the general case for which the
correlation structure isn't needed.)
Depending on the `levy_area` argument, this can also be used to generate Levy area.
`levy_area` can be "" or "space-time".
"""

shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
levy_area: Literal["", "space-time"] = eqx.field(static=True)
# Handled as a string because PRNGKey is actually a function, not a class, which
# makes it appearly badly in autogenerated documentation.
key: "jax.random.PRNGKey" # noqa: F821
Expand All @@ -41,13 +46,20 @@ def __init__(
self,
shape: Union[Tuple[int, ...], PyTree[jax.ShapeDtypeStruct]],
key: "jax.random.PRNGKey",
levy_area: Literal["", "space-time"] = "",
):
self.shape = (
jax.ShapeDtypeStruct(shape, jax.dtypes.canonicalize_dtype(None))
if is_tuple_of_ints(shape)
else shape
)
self.key = key
if levy_area not in ["", "space-time"]:
raise ValueError(
f"levy_area must be one of '', 'space-time', but got {levy_area}."
)
self.levy_area = levy_area

if any(
not jnp.issubdtype(x.dtype, jnp.inexact)
for x in jtu.tree_leaves(self.shape)
Expand All @@ -63,7 +75,13 @@ def t1(self):
return None

@eqx.filter_jit
def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]:
def evaluate(
self,
t0: Scalar,
t1: Scalar,
left: bool = True,
use_levy: bool = False,
) -> PyTree[Array]:
del left
t0 = eqxi.nondifferentiable(t0, name="t0")
t1 = eqxi.nondifferentiable(t1, name="t1")
Expand All @@ -72,14 +90,42 @@ def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]:
key = jrandom.fold_in(self.key, t0_)
key = jrandom.fold_in(key, t1_)
key = split_by_tree(key, self.shape)
return jtu.tree_map(
lambda key, shape: self._evaluate_leaf(t0, t1, key, shape), key, self.shape
out = jtu.tree_map(
lambda key, shape: self._evaluate_leaf(
t0, t1, key, shape, self.levy_area, use_levy
),
key,
self.shape,
)
if use_levy:
out = levy_tree_transpose(self.shape, self.levy_area, out)
assert isinstance(out, LevyVal)
return out

@staticmethod
def _evaluate_leaf(
t0: Scalar,
t1: Scalar,
key,
shape: jax.ShapeDtypeStruct,
levy_area: str,
use_levy: bool,
):
w_std = jnp.sqrt(t1 - t0).astype(shape.dtype)

def _evaluate_leaf(self, t0: Scalar, t1: Scalar, key, shape: jax.ShapeDtypeStruct):
return jrandom.normal(key, shape.shape, shape.dtype) * jnp.sqrt(t1 - t0).astype(
shape.dtype
)
if levy_area == "space-time":
key_w, key_hh = jrandom.split(key, 2)
w = jrandom.normal(key_w, shape.shape, shape.dtype) * w_std
hh_std = jnp.sqrt((t1 - t0) / 12).astype(shape.dtype)
hh = jrandom.normal(key_hh, shape.shape, shape.dtype) * hh_std
else:
hh = None
w = jrandom.normal(key, shape.shape, shape.dtype) * w_std

if use_levy:
return LevyVal(dt=t1 - t0, W=w, H=hh, bar_H=None, K=None, bar_K=None)
else:
return w


UnsafeBrownianPath.__init__.__doc__ = """
Expand All @@ -89,5 +135,6 @@ def _evaluate_leaf(self, t0: Scalar, t1: Scalar, key, shape: jax.ShapeDtypeStruc
dtype, and PyTree structure of the output. For simplicity, `shape` can also just
be a tuple of integers, describing the shape of a single JAX array. In that case
the dtype is chosen to be the default floating-point dtype.
- `key`: A random key.
"""
Loading

0 comments on commit 872f476

Please sign in to comment.