Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tweaks to #337 #339

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ dist/
site/
.all_objects.cache
.pymon
.idea/
1 change: 1 addition & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
UnsafeBrownianPath as UnsafeBrownianPath,
VirtualBrownianTree as VirtualBrownianTree,
)
from ._custom_types import LevyVal as LevyVal
from ._event import (
AbstractDiscreteTerminatingEvent as AbstractDiscreteTerminatingEvent,
DiscreteTerminatingEvent as DiscreteTerminatingEvent,
Expand Down
29 changes: 19 additions & 10 deletions diffrax/_brownian/base.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,44 @@
import abc
from typing import Optional
from typing import Optional, Union

from jaxtyping import Array, PyTree

from .._custom_types import RealScalarLike
from .._custom_types import LevyVal, RealScalarLike
from .._path import AbstractPath


class AbstractBrownianPath(AbstractPath):
"Abstract base class for all Brownian paths."
"""Abstract base class for all Brownian paths."""

@abc.abstractmethod
def evaluate(
self, t0: RealScalarLike, t1: Optional[RealScalarLike] = None, left: bool = True
) -> PyTree[Array]:
self,
t0: RealScalarLike,
t1: Optional[RealScalarLike] = None,
left: bool = True,
use_levy: bool = False,
) -> Union[PyTree[Array], LevyVal]:
r"""Samples a Brownian increment $w(t_1) - w(t_0)$.

Each increment has distribution $\mathcal{N}(0, t_1 - t_0)$.

**Arguments:**

- `t0`: Start of interval.
- `t1`: End of interval.
- `t0`: Any point in $[t_0, t_1]$ to evaluate the path at.
- `t1`: If passed, then the increment from `t1` to `t0` is evaluated instead.
- `left`: Ignored. (This determines whether to treat the path as
left-continuous or right-continuous at any jump points, but Brownian
motion has no jump points.)
- `use_levy`: If True, the return type will be a `LevyVal`, which contains
PyTrees of Brownian increments and their Levy areas.

**Returns:**

A pytree of JAX arrays corresponding to the increment $w(t_1) - w(t_0)$.
If `t1` is not passed:

Some subclasses may allow `t1=None`, in which case just the value $w(t_0)$ is
returned.
The value of the Brownian motion at `t0`.

If `t1` is passed:

The increment of the Brownian motion between `t0` and `t1`.
"""
66 changes: 54 additions & 12 deletions diffrax/_brownian/path.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import cast, Optional, Union
import math
from typing import cast, Literal, Optional, Union

import equinox as eqx
import equinox.internal as eqxi
Expand All @@ -8,7 +9,7 @@
import jax.tree_util as jtu
from jaxtyping import Array, PRNGKeyArray, PyTree

from .._custom_types import RealScalarLike
from .._custom_types import levy_tree_transpose, LevyVal, RealScalarLike
from .._misc import (
default_floating_dtype,
force_bitcast_convert_type,
Expand Down Expand Up @@ -36,24 +37,32 @@ class UnsafeBrownianPath(AbstractBrownianPath):
interval, ignoring the correlation between samples exhibited in true Brownian
motion. Hence the restrictions above. (They describe the general case for which the
correlation structure isn't needed.)

Depending on the `levy_area` argument, this can also be used to generate Levy area.
"""

shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
# Handled as a string because PRNGKey is actually a function, not a class, which
# makes it appearly badly in autogenerated documentation.
levy_area: Literal["", "space-time"] = eqx.field(static=True)
key: PRNGKeyArray

def __init__(
self,
shape: Union[tuple[int, ...], PyTree[jax.ShapeDtypeStruct]],
key: PRNGKeyArray,
levy_area: Literal["", "space-time"] = "",
):
self.shape = (
jax.ShapeDtypeStruct(shape, default_floating_dtype())
if is_tuple_of_ints(shape)
else shape
)
self.key = key
if levy_area not in ["", "space-time"]:
raise ValueError(
f"levy_area must be one of '', 'space-time', but got {levy_area}."
)
self.levy_area = levy_area

if any(
not jnp.issubdtype(x.dtype, jnp.inexact)
for x in jtu.tree_leaves(self.shape)
Expand All @@ -70,8 +79,12 @@ def t1(self):

@eqx.filter_jit
def evaluate(
self, t0: RealScalarLike, t1: Optional[RealScalarLike] = None, left: bool = True
) -> PyTree[Array]:
self,
t0: RealScalarLike,
t1: Optional[RealScalarLike] = None,
left: bool = True,
use_levy: bool = False,
) -> Union[PyTree[Array], LevyVal]:
del left
if t1 is None:
t1 = t0
Expand All @@ -84,16 +97,43 @@ def evaluate(
key = jr.fold_in(self.key, t0_)
key = jr.fold_in(key, t1_)
key = split_by_tree(key, self.shape)
return jtu.tree_map(
lambda key, shape: self._evaluate_leaf(t0, t1, key, shape), key, self.shape
out = jtu.tree_map(
lambda key, shape: self._evaluate_leaf(
t0, t1, key, shape, self.levy_area, use_levy
),
key,
self.shape,
)
if use_levy:
out = levy_tree_transpose(self.shape, self.levy_area, out)
assert isinstance(out, LevyVal)
return out

@staticmethod
def _evaluate_leaf(
self, t0: RealScalarLike, t1: RealScalarLike, key, shape: jax.ShapeDtypeStruct
t0: RealScalarLike,
t1: RealScalarLike,
key,
shape: jax.ShapeDtypeStruct,
levy_area: str,
use_levy: bool,
):
return jr.normal(key, shape.shape, shape.dtype) * jnp.sqrt(t1 - t0).astype(
shape.dtype
)
w_std = jnp.sqrt(t1 - t0).astype(shape.dtype)

if levy_area == "space-time":
key, key_hh = jr.split(key, 2)
hh_std = w_std / math.sqrt(12)
hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std
elif levy_area == "":
hh = None
else:
assert False
w = jr.normal(key, shape.shape, shape.dtype) * w_std

if use_levy:
return LevyVal(dt=t1 - t0, W=w, H=hh, bar_H=None, K=None, bar_K=None)
else:
return w


UnsafeBrownianPath.__init__.__doc__ = """
Expand All @@ -104,4 +144,6 @@ def _evaluate_leaf(
be a tuple of integers, describing the shape of a single JAX array. In that case
the dtype is chosen to be the default floating-point dtype.
- `key`: A random key.
- `levy_area`: Whether to additionally generate Levy area. This is required by some SDE
solvers.
"""
Loading
Loading