From a58b01209994ba73037540d498959a0d027e82d5 Mon Sep 17 00:00:00 2001 From: David Zwicker Date: Tue, 20 Aug 2024 14:28:21 +0200 Subject: [PATCH] Added test for PDEs with constant fields in MPI run (#602) * Added test for PDEs with constant fields in MPI run * Removed deprecated support for testing jupyter notebooks version 6 * Improved error message when numba-backend is not implemented. --- pde/pdes/base.py | 7 ++++- tests/pdes/test_pde_class.py | 1 + tests/pdes/test_pdes_mpi.py | 48 ++++++++++++++++++++++++++++---- tests/solvers/test_controller.py | 4 +-- tests/test_examples.py | 31 ++++----------------- 5 files changed, 57 insertions(+), 34 deletions(-) diff --git a/pde/pdes/base.py b/pde/pdes/base.py index a8cccf30..a91bdb2d 100644 --- a/pde/pdes/base.py +++ b/pde/pdes/base.py @@ -180,7 +180,12 @@ def _make_pde_rhs_numba( self, state: FieldBase, **kwargs ) -> Callable[[np.ndarray, float], np.ndarray]: """Create a compiled function for evaluating the right hand side.""" - raise NotImplementedError("No backend `numba`") + raise NotImplementedError( + "The right-hand side of the PDE is not implemented using the `numba` " + "backend. To add the implementation, provide the method " + "`_make_pde_rhs_numba`, which should return a numba-compiled function " + "calculating the right-hand side using numpy arrays as input and output." + ) def check_rhs_consistency( self, diff --git a/tests/pdes/test_pde_class.py b/tests/pdes/test_pde_class.py index 4796d26a..f064fd2b 100644 --- a/tests/pdes/test_pde_class.py +++ b/tests/pdes/test_pde_class.py @@ -4,6 +4,7 @@ import logging import os + import numba as nb import numpy as np import pytest diff --git a/tests/pdes/test_pdes_mpi.py b/tests/pdes/test_pdes_mpi.py index c65a4fd9..d500a18f 100644 --- a/tests/pdes/test_pdes_mpi.py +++ b/tests/pdes/test_pdes_mpi.py @@ -7,7 +7,9 @@ from pde import PDE, DiffusionPDE, grids from pde.fields import ScalarField, VectorField +from pde.pdes.base import PDEBase from pde.tools import mpi +from pde.tools.numba import jit @pytest.mark.multiprocessing @@ -31,6 +33,8 @@ def test_pde_complex_bcs_mpi(dim, backend, rng): res_exp = eq.solve(backend="numpy", solver="explicit", **args) res_exp.assert_field_compatible(res) np.testing.assert_allclose(res_exp.data, res.data) + else: + assert res is None @pytest.mark.multiprocessing @@ -87,6 +91,9 @@ def test_pde_complex_mpi(rng): assert res2.is_complex np.testing.assert_allclose(res2.data, expect.data) assert info2["solver"]["steps"] == 11 + else: + assert res1 is None + assert res2 is None @pytest.mark.multiprocessing @@ -96,15 +103,44 @@ def test_pde_const_mpi(backend): grid = grids.UnitGrid([8]) eq = PDE({"u": "k"}, consts={"k": ScalarField.from_expression(grid, "x")}) - args = { - "state": ScalarField(grid), - "t_range": 1, - "dt": 0.01, - "tracker": None, - } + args = {"state": ScalarField(grid), "t_range": 1, "dt": 0.01, "tracker": None} + res_a = eq.solve(backend="numpy", solver="explicit", **args) + res_b = eq.solve(backend=backend, solver="explicit_mpi", **args) + + if mpi.is_main: + res_a.assert_field_compatible(res_b) + np.testing.assert_allclose(res_a.data, res_b.data) + else: + assert res_b is None + + +@pytest.mark.multiprocessing +@pytest.mark.parametrize("backend", ["numpy", "numba"]) +def test_pde_const_mpi_class(backend): + """Test PDE with a field constant using multiprocessing.""" + grid = grids.UnitGrid([8]) + + class ExplicitFieldPDE(PDEBase): + def evolution_rate(self, state, t): + return ScalarField(state.grid, state.grid.axes_coords[0]) + + def _make_pde_rhs_numba(self, state): + x = state.grid.axes_coords[0] + + @jit + def pde_rhs(state_data, t): + return x + + return pde_rhs + + eq = ExplicitFieldPDE() + + args = {"state": ScalarField(grid), "t_range": 1, "dt": 0.01, "tracker": None} res_a = eq.solve(backend="numpy", solver="explicit", **args) res_b = eq.solve(backend=backend, solver="explicit_mpi", **args) if mpi.is_main: res_a.assert_field_compatible(res_b) np.testing.assert_allclose(res_a.data, res_b.data) + else: + assert res_b is None diff --git a/tests/solvers/test_controller.py b/tests/solvers/test_controller.py index b79cd604..bda37f20 100644 --- a/tests/solvers/test_controller.py +++ b/tests/solvers/test_controller.py @@ -2,8 +2,8 @@ .. codeauthor:: David Zwicker """ -import pytest import numpy as np +import pytest from pde import PDEBase, ScalarField, UnitGrid from pde.solvers import Controller @@ -32,7 +32,7 @@ def evolution_rate(self, state, t): def test_controller_foreign_solver(): - """Test whether the Controller can deal with a minimal foreign solver""" + """Test whether the Controller can deal with a minimal foreign solver.""" class MySolver: def make_stepper(self, state, dt): diff --git a/tests/test_examples.py b/tests/test_examples.py index 8d9cbab0..9e530121 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -79,34 +79,15 @@ def test_jupyter_notebooks(path, tmp_path): """Run the jupyter notebooks.""" import notebook as jupyter_notebook + if int(jupyter_notebook.__version__.split(".")[0]) < 7: + raise RuntimeError("Jupyter notebooks must be at least version 7") + if path.name.startswith("_"): - pytest.skip("skip examples starting with an underscore") + pytest.skip("Skip examples starting with an underscore") # adjust python environment my_env = os.environ.copy() my_env["PYTHONPATH"] = str(PACKAGE_PATH) + ":" + my_env.get("PYTHONPATH", "") - outfile = tmp_path / path.name - if jupyter_notebook.__version__.startswith("6"): - # older version of running jypyter notebook - # deprecated on 2023-07-31 - # in the future, the `notebook` package should be at least version 7 - sp.check_call( - [ - sys.executable, - "-m", - "jupyter", - "nbconvert", - "--ExecutePreprocessor.timeout=600", - "--to", - "notebook", - "--output", - outfile, - "--execute", - path, - ], - env=my_env, - ) - else: - # run the notebook - sp.check_call([sys.executable, "-m", "jupyter", "execute", path], env=my_env) + # run the notebook + sp.check_call([sys.executable, "-m", "jupyter", "execute", path], env=my_env)