Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
andyElking committed Dec 4, 2023
1 parent 07b0675 commit af38f9b
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions diffrax/solver/additive_srk.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional

import equinox as eqx
import jax
Expand All @@ -13,7 +13,7 @@
from ..custom_types import Array, Bool, DenseInfo, LevyVal, PyTree, Scalar
from ..local_interpolation import LocalLinearInterpolation
from ..solution import RESULTS
from ..term import _ControlTerm, AbstractTerm, MultiTerm, ODETerm
from ..term import AbstractTerm, MultiTerm, ODETerm
from .base import AbstractStratonovichSolver


Expand Down Expand Up @@ -53,8 +53,6 @@ def __post_init__(self):
assert np.allclose(sum(a_i), c_i)
assert np.allclose(sum(self.b_sol), 1.0)

# TODO: add checks for whether the method is FSAL


StochasticButcherTableau.__init__.__doc__ = """**Arguments:**
Expand Down Expand Up @@ -111,13 +109,13 @@ class AbstractAdditiveSRK(AbstractStratonovichSolver):
in the `StochasticButcherTableau`.
"""

term_structure = MultiTerm[Tuple[ODETerm, _ControlTerm]]
term_structure = MultiTerm[tuple[ODETerm, AbstractTerm]]
interpolation_cls = LocalLinearInterpolation
tableau: StochasticButcherTableau

def init(
self,
terms: MultiTerm[Tuple[ODETerm, _ControlTerm]],
terms: term_structure,
t0: Scalar,
t1: Scalar,
y0: PyTree,
Expand All @@ -132,7 +130,7 @@ def init(
if not path.levy_area == "space-time":
raise ValueError(
"The Brownian path controlling the diffusion "
"should be initialised with `compute_stla=True`"
"should be initialised with `levy_area='space-time'`"
)

# check that the vector field of the diffusion term is constant
Expand All @@ -157,14 +155,14 @@ def _embed_a_lower(self, dtype):

def step(
self,
terms: MultiTerm[Tuple[ODETerm, _ControlTerm]],
terms: term_structure,
t0: Scalar,
t1: Scalar,
y0: PyTree,
args: PyTree,
solver_state: _SolverState,
made_jump: Bool,
) -> Tuple[PyTree, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]:
) -> tuple[PyTree, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]:
del solver_state, made_jump

h = t1 - t0
Expand Down

0 comments on commit af38f9b

Please sign in to comment.