From 9743eb83056e697aa0117179ca1c6a723a8a2a20 Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Mon, 13 May 2024 21:29:01 +0300 Subject: [PATCH] Enable implicit solvers for complex inputs (#411) * Enable implicit solvers for complex inputs * change version * make pyright happy --- diffrax/_integrate.py | 5 ----- diffrax/_progress_meter.py | 2 +- pyproject.toml | 2 +- test/test_integrate.py | 11 +++-------- test/test_interpolation.py | 5 +---- 5 files changed, 6 insertions(+), 19 deletions(-) diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index f052dcf0..91f327b0 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -709,11 +709,6 @@ def diffeqsolve( eqx.is_array(xi) and jnp.iscomplexobj(xi) for xi in jtu.tree_leaves((terms, y0, args)) ): - if isinstance(solver, AbstractImplicitSolver): - raise ValueError( - "Implicit solvers in conjunction with complex dtypes is currently not " - "supported." - ) warnings.warn( "Complex dtype support is work in progress, please read " "https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully.", diff --git a/diffrax/_progress_meter.py b/diffrax/_progress_meter.py index 9b8d0a1b..8a813be6 100644 --- a/diffrax/_progress_meter.py +++ b/diffrax/_progress_meter.py @@ -294,7 +294,7 @@ def _step(_progress, _idx): # Return the idx to thread the callbacks in the correct order. return _idx - return jax.pure_callback(_step, idx, progress, idx, vectorized=True) # pyright: ignore + return jax.pure_callback(_step, idx, progress, idx, vectorized=True) def close(self, close_bar: Callable[[Any], None], idx: IntScalarLike): def _close(_idx): diff --git a/pyproject.toml b/pyproject.toml index 6761ff74..f36b1ea1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Mathematics", ] urls = {repository = "https://github.com/patrick-kidger/diffrax" } -dependencies = ["jax>=0.4.23", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.2", "lineax>=0.0.5", "optimistix>=0.0.6"] +dependencies = ["jax>=0.4.23", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.2", "lineax>=0.0.5", "optimistix>=0.0.7"] [build-system] requires = ["hatchling"] diff --git a/test/test_integrate.py b/test/test_integrate.py index cdbd3fab..984044d7 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -77,13 +77,10 @@ def test_basic(solver, t_dtype, y_dtype, treedef, stepsize_controller, getkey): return if jnp.iscomplexobj(y_dtype) and treedef != jtu.tree_structure(None): - if isinstance(solver, diffrax.AbstractImplicitSolver): - return - else: - complex_warn = pytest.warns(match="Complex dtype") + complex_warn = pytest.warns(match="Complex dtype") - def f(t, y, args): - return jtu.tree_map(lambda yi: -1j * yi, y) + def f(t, y, args): + return jtu.tree_map(lambda yi: -1j * yi, y) else: complex_warn = contextlib.nullcontext() @@ -152,8 +149,6 @@ def test_ode_order(solver, dtype): A = jr.normal(akey, (10, 10), dtype=dtype) * 0.5 - if jnp.iscomplexobj(A) and isinstance(solver, diffrax.AbstractImplicitSolver): - return if ( solver.term_structure == diffrax.MultiTerm[tuple[diffrax.AbstractTerm, diffrax.AbstractTerm]] diff --git a/test/test_interpolation.py b/test/test_interpolation.py index 9a085a9f..d299b090 100644 --- a/test/test_interpolation.py +++ b/test/test_interpolation.py @@ -60,8 +60,6 @@ def test_derivative(dtype, getkey): solver = implicit_tol(solver) y0 = jr.normal(getkey(), (3,), dtype=dtype) - if jnp.iscomplexobj(y0) and isinstance(solver, diffrax.AbstractImplicitSolver): - continue solution = diffrax.diffeqsolve( diffrax.ODETerm(lambda t, y, p: -y), solver, @@ -77,8 +75,7 @@ def test_derivative(dtype, getkey): for solver in all_split_solvers: solver = implicit_tol(solver) y0 = jr.normal(getkey(), (3,), dtype=dtype) - if jnp.iscomplexobj(y0) and isinstance(solver, diffrax.AbstractImplicitSolver): - continue + solution = diffrax.diffeqsolve( diffrax.MultiTerm( diffrax.ODETerm(lambda t, y, p: -0.7 * y),