Skip to content

Commit

Permalink
fixed an issue where jnp.sqrt(very small number) gives a nan due to f…
Browse files Browse the repository at this point in the history
…loat errors
  • Loading branch information
andyElking committed Dec 9, 2023
1 parent 695face commit e12a018
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions diffrax/brownian/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# Foster, J. M. (2020). Numerical approximations for stochastic
# differential equations [PhD thesis]. University of Oxford.

# @dataclasses.dataclass

class _State(eqx.Module):
level: int
s: Scalar # starting time of the interval
Expand Down Expand Up @@ -159,7 +159,7 @@ def _denormalise_bm_inc(self, x: LevyVal) -> LevyVal:
def sqrt_mult(z):
# need to cast to dtype of each leaf in PyTree
dtype = jnp.dtype(z)
return (z * sqrt_len).astype(dtype)
return z * jnp.asarray(sqrt_len, dtype)

def mult(z):
dtype = jnp.dtype(z)
Expand Down Expand Up @@ -392,7 +392,7 @@ def _body_fun(_state: _State):
# BM only case
if self.levy_area == "":
z = jrandom.normal(key, shape, dtype)
w_sr = sr / su * w_su + jnp.sqrt(sr * ru / su) * z
w_sr = sr / su * w_su + jnp.sqrt(jnp.abs(sr * ru / su)) * z
w_r = w_s + w_sr
return LevyVal(dt=r, W=w_r, H=None, bar_H=None, K=None, bar_K=None)

Expand All @@ -405,15 +405,15 @@ def _body_fun(_state: _State):
x1 = jrandom.normal(key1, shape, dtype)
x2 = jrandom.normal(key2, shape, dtype)

sr_ru_half = jnp.sqrt(sr * ru)
d = jnp.sqrt(sr3 + ru3)
sr_ru_half = jnp.sqrt(jnp.abs(sr * ru))
d = jnp.sqrt(jnp.abs(sr3 + ru3))
d_prime = 1 / (2 * su * d)
a = d_prime * sr3 * sr_ru_half
b = d_prime * ru3 * sr_ru_half

w_sr = sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b) / su * x1
w_r = w_s + w_sr
c = jnp.sqrt(3 * sr3 * ru3) / (6 * d)
c = jnp.sqrt(jnp.abs(3 * sr3 * ru3)) / (6 * d)
bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2
bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r)

Expand Down

0 comments on commit e12a018

Please sign in to comment.