Skip to content

Commit

Permalink
Enable implicit solvers for complex inputs (#411)
Browse files Browse the repository at this point in the history
* Enable implicit solvers for complex inputs

* change version

* make pyright happy
  • Loading branch information
Randl authored and patrick-kidger committed May 19, 2024
1 parent 068b4b9 commit 9743eb8
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 19 deletions.
5 changes: 0 additions & 5 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,11 +709,6 @@ def diffeqsolve(
eqx.is_array(xi) and jnp.iscomplexobj(xi)
for xi in jtu.tree_leaves((terms, y0, args))
):
if isinstance(solver, AbstractImplicitSolver):
raise ValueError(
"Implicit solvers in conjunction with complex dtypes is currently not "
"supported."
)
warnings.warn(
"Complex dtype support is work in progress, please read "
"https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully.",
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_progress_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def _step(_progress, _idx):
# Return the idx to thread the callbacks in the correct order.
return _idx

return jax.pure_callback(_step, idx, progress, idx, vectorized=True) # pyright: ignore
return jax.pure_callback(_step, idx, progress, idx, vectorized=True)

def close(self, close_bar: Callable[[Any], None], idx: IntScalarLike):
def _close(_idx):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Mathematics",
]
urls = {repository = "https://github.com/patrick-kidger/diffrax" }
dependencies = ["jax>=0.4.23", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.2", "lineax>=0.0.5", "optimistix>=0.0.6"]
dependencies = ["jax>=0.4.23", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.2", "lineax>=0.0.5", "optimistix>=0.0.7"]

[build-system]
requires = ["hatchling"]
Expand Down
11 changes: 3 additions & 8 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,10 @@ def test_basic(solver, t_dtype, y_dtype, treedef, stepsize_controller, getkey):
return

if jnp.iscomplexobj(y_dtype) and treedef != jtu.tree_structure(None):
if isinstance(solver, diffrax.AbstractImplicitSolver):
return
else:
complex_warn = pytest.warns(match="Complex dtype")
complex_warn = pytest.warns(match="Complex dtype")

def f(t, y, args):
return jtu.tree_map(lambda yi: -1j * yi, y)
def f(t, y, args):
return jtu.tree_map(lambda yi: -1j * yi, y)
else:
complex_warn = contextlib.nullcontext()

Expand Down Expand Up @@ -152,8 +149,6 @@ def test_ode_order(solver, dtype):

A = jr.normal(akey, (10, 10), dtype=dtype) * 0.5

if jnp.iscomplexobj(A) and isinstance(solver, diffrax.AbstractImplicitSolver):
return
if (
solver.term_structure
== diffrax.MultiTerm[tuple[diffrax.AbstractTerm, diffrax.AbstractTerm]]
Expand Down
5 changes: 1 addition & 4 deletions test/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ def test_derivative(dtype, getkey):
solver = implicit_tol(solver)
y0 = jr.normal(getkey(), (3,), dtype=dtype)

if jnp.iscomplexobj(y0) and isinstance(solver, diffrax.AbstractImplicitSolver):
continue
solution = diffrax.diffeqsolve(
diffrax.ODETerm(lambda t, y, p: -y),
solver,
Expand All @@ -77,8 +75,7 @@ def test_derivative(dtype, getkey):
for solver in all_split_solvers:
solver = implicit_tol(solver)
y0 = jr.normal(getkey(), (3,), dtype=dtype)
if jnp.iscomplexobj(y0) and isinstance(solver, diffrax.AbstractImplicitSolver):
continue

solution = diffrax.diffeqsolve(
diffrax.MultiTerm(
diffrax.ODETerm(lambda t, y, p: -0.7 * y),
Expand Down

0 comments on commit 9743eb8

Please sign in to comment.