Skip to content

Commit

Permalink
Merge pull request #257 from patrick-kidger/better-rk
Browse files Browse the repository at this point in the history
Better rk
  • Loading branch information
patrick-kidger authored May 22, 2023
2 parents f101e75 + 77b1a60 commit 27702dd
Show file tree
Hide file tree
Showing 51 changed files with 2,186 additions and 849 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[flake8]
max-line-length = 88
ignore = W291,W503,W504,E121,E123,E126,E203,E402,E701,E702,E731
ignore = W291,W503,W504,E121,E123,E126,E203,E402,E701,E702,E731,F722
per-file-ignores = __init__.py: F401
2 changes: 1 addition & 1 deletion .github/workflows/build_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
build:
strategy:
matrix:
python-version: [ 3.8 ]
python-version: [ 3.11 ]
os: [ ubuntu-latest ]
runs-on: ${{ matrix.os }}
steps:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
- name: Release
uses: patrick-kidger/action_update_python_project@v1
with:
python-version: "3.8"
python-version: "3.11"
test-script: |
python -m pip install pytest psutil jax jaxlib equinox scipy optax
cp -r ${{ github.workspace }}/test ./test
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
run-tests:
strategy:
matrix:
python-version: [ 3.8, 3.9 ]
python-version: [ 3.9, 3.11 ]
os: [ ubuntu-latest ]
fail-fast: false
runs-on: ${{ matrix.os }}
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ _From a technical point of view, the internal structure of the library is pretty
pip install diffrax
```

Requires Python 3.8+, JAX 0.4.3+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.0+.
Requires Python 3.9+, JAX 0.4.4+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.4+.

## Documentation

Expand Down
69 changes: 0 additions & 69 deletions benchmarks/scan_stages.py

This file was deleted.

96 changes: 0 additions & 96 deletions benchmarks/scan_stages_cnf.py

This file was deleted.

6 changes: 6 additions & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .misc import adjoint_rms_seminorm
from .nonlinear_solver import (
AbstractNonlinearSolver,
AffineNonlinearSolver,
NewtonNonlinearSolver,
NonlinearSolution,
)
Expand Down Expand Up @@ -60,14 +61,19 @@
Heun,
ImplicitEuler,
ItoMilstein,
KenCarp3,
KenCarp4,
KenCarp5,
Kvaerno3,
Kvaerno4,
Kvaerno5,
LeapfrogMidpoint,
Midpoint,
MultiButcherTableau,
Ralston,
ReversibleHeun,
SemiImplicitEuler,
Sil3,
StratonovichMilstein,
Tsit5,
)
Expand Down
15 changes: 14 additions & 1 deletion diffrax/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .ad import implicit_jvp
from .heuristics import is_sde, is_unsafe_sde
from .saveat import save_y, SaveAt, SubSaveAt
from .solver import AbstractItoSolver, AbstractStratonovichSolver
from .solver import AbstractItoSolver, AbstractRungeKutta, AbstractStratonovichSolver
from .term import AbstractTerm, AdjointTerm


Expand Down Expand Up @@ -332,6 +332,7 @@ class DirectAdjoint(AbstractAdjoint):
def loop(
self,
*,
solver,
max_steps,
terms,
throw,
Expand Down Expand Up @@ -362,10 +363,15 @@ def loop(
else:
kind = "bounded"
msg = None
# Support forward-mode autodiff.
# TODO: remove this hack once we can JVP through custom_vjps.
if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None:
solver = eqx.tree_at(lambda s: s.scan_kind, solver, "bounded")
inner_while_loop = ft.partial(_inner_loop, kind=kind)
outer_while_loop = ft.partial(_outer_loop, kind=kind)
final_state = self._loop(
**kwargs,
solver=solver,
max_steps=max_steps,
terms=terms,
inner_while_loop=inner_while_loop,
Expand Down Expand Up @@ -535,6 +541,8 @@ def _loop_backsolve_bwd(
zeros_like_diff_args = jtu.tree_map(jnp.zeros_like, diff_args)
zeros_like_diff_terms = jtu.tree_map(jnp.zeros_like, diff_terms)
del diff_args, diff_terms
# TODO: have this look inside MultiTerms? Need to think about the math. i.e.:
# is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, MultiTerm)
adjoint_terms = jtu.tree_map(
AdjointTerm, terms, is_leaf=lambda x: isinstance(x, AbstractTerm)
)
Expand Down Expand Up @@ -762,6 +770,11 @@ def loop(
"`BacksolveAdjoint` will only produce the correct solution for "
"Stratonovich SDEs."
)
if jtu.tree_structure(solver.term_structure) != jtu.tree_structure(0):
raise NotImplementedError(
"`diffrax.BacksolveAdjoint` is only compatible with solvers that take "
"a single term."
)

y = init_state.y
init_state = eqx.tree_at(lambda s: s.y, init_state, object())
Expand Down
4 changes: 3 additions & 1 deletion diffrax/custom_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import inspect
import typing
from typing import Dict, Generic, Tuple, TypeVar, Union
from typing import Any, Dict, Generic, Tuple, TypeVar, Union

import equinox.internal as eqxi
import jax.tree_util as jtu


Expand Down Expand Up @@ -129,3 +130,4 @@ def __class_getitem__(cls, item):

DenseInfo = Dict[str, PyTree[Array]]
DenseInfos = Dict[str, PyTree[Array["times", ...]]] # noqa: F821
sentinel: Any = eqxi.doc_repr(object(), "sentinel")
Loading

0 comments on commit 27702dd

Please sign in to comment.