Skip to content

Commit

Permalink
Added test for PDEs with constant fields in MPI run (#602)
Browse files Browse the repository at this point in the history
* Added test for PDEs with constant fields in MPI run
* Removed deprecated support for testing jupyter notebooks version 6
* Improved error message when numba-backend is not implemented.
  • Loading branch information
david-zwicker authored Aug 20, 2024
1 parent 9983abb commit a58b012
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 34 deletions.
7 changes: 6 additions & 1 deletion pde/pdes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,12 @@ def _make_pde_rhs_numba(
self, state: FieldBase, **kwargs
) -> Callable[[np.ndarray, float], np.ndarray]:
"""Create a compiled function for evaluating the right hand side."""
raise NotImplementedError("No backend `numba`")
raise NotImplementedError(
"The right-hand side of the PDE is not implemented using the `numba` "
"backend. To add the implementation, provide the method "
"`_make_pde_rhs_numba`, which should return a numba-compiled function "
"calculating the right-hand side using numpy arrays as input and output."
)

def check_rhs_consistency(
self,
Expand Down
1 change: 1 addition & 0 deletions tests/pdes/test_pde_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import logging
import os

import numba as nb
import numpy as np
import pytest
Expand Down
48 changes: 42 additions & 6 deletions tests/pdes/test_pdes_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

from pde import PDE, DiffusionPDE, grids
from pde.fields import ScalarField, VectorField
from pde.pdes.base import PDEBase
from pde.tools import mpi
from pde.tools.numba import jit


@pytest.mark.multiprocessing
Expand All @@ -31,6 +33,8 @@ def test_pde_complex_bcs_mpi(dim, backend, rng):
res_exp = eq.solve(backend="numpy", solver="explicit", **args)
res_exp.assert_field_compatible(res)
np.testing.assert_allclose(res_exp.data, res.data)
else:
assert res is None


@pytest.mark.multiprocessing
Expand Down Expand Up @@ -87,6 +91,9 @@ def test_pde_complex_mpi(rng):
assert res2.is_complex
np.testing.assert_allclose(res2.data, expect.data)
assert info2["solver"]["steps"] == 11
else:
assert res1 is None
assert res2 is None


@pytest.mark.multiprocessing
Expand All @@ -96,15 +103,44 @@ def test_pde_const_mpi(backend):
grid = grids.UnitGrid([8])
eq = PDE({"u": "k"}, consts={"k": ScalarField.from_expression(grid, "x")})

args = {
"state": ScalarField(grid),
"t_range": 1,
"dt": 0.01,
"tracker": None,
}
args = {"state": ScalarField(grid), "t_range": 1, "dt": 0.01, "tracker": None}
res_a = eq.solve(backend="numpy", solver="explicit", **args)
res_b = eq.solve(backend=backend, solver="explicit_mpi", **args)

if mpi.is_main:
res_a.assert_field_compatible(res_b)
np.testing.assert_allclose(res_a.data, res_b.data)
else:
assert res_b is None


@pytest.mark.multiprocessing
@pytest.mark.parametrize("backend", ["numpy", "numba"])
def test_pde_const_mpi_class(backend):
"""Test PDE with a field constant using multiprocessing."""
grid = grids.UnitGrid([8])

class ExplicitFieldPDE(PDEBase):
def evolution_rate(self, state, t):
return ScalarField(state.grid, state.grid.axes_coords[0])

def _make_pde_rhs_numba(self, state):
x = state.grid.axes_coords[0]

@jit
def pde_rhs(state_data, t):
return x

return pde_rhs

eq = ExplicitFieldPDE()

args = {"state": ScalarField(grid), "t_range": 1, "dt": 0.01, "tracker": None}
res_a = eq.solve(backend="numpy", solver="explicit", **args)
res_b = eq.solve(backend=backend, solver="explicit_mpi", **args)

if mpi.is_main:
res_a.assert_field_compatible(res_b)
np.testing.assert_allclose(res_a.data, res_b.data)
else:
assert res_b is None
4 changes: 2 additions & 2 deletions tests/solvers/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
.. codeauthor:: David Zwicker <[email protected]>
"""

import pytest
import numpy as np
import pytest

from pde import PDEBase, ScalarField, UnitGrid
from pde.solvers import Controller
Expand Down Expand Up @@ -32,7 +32,7 @@ def evolution_rate(self, state, t):


def test_controller_foreign_solver():
"""Test whether the Controller can deal with a minimal foreign solver"""
"""Test whether the Controller can deal with a minimal foreign solver."""

class MySolver:
def make_stepper(self, state, dt):
Expand Down
31 changes: 6 additions & 25 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,34 +79,15 @@ def test_jupyter_notebooks(path, tmp_path):
"""Run the jupyter notebooks."""
import notebook as jupyter_notebook

if int(jupyter_notebook.__version__.split(".")[0]) < 7:
raise RuntimeError("Jupyter notebooks must be at least version 7")

if path.name.startswith("_"):
pytest.skip("skip examples starting with an underscore")
pytest.skip("Skip examples starting with an underscore")

# adjust python environment
my_env = os.environ.copy()
my_env["PYTHONPATH"] = str(PACKAGE_PATH) + ":" + my_env.get("PYTHONPATH", "")

outfile = tmp_path / path.name
if jupyter_notebook.__version__.startswith("6"):
# older version of running jypyter notebook
# deprecated on 2023-07-31
# in the future, the `notebook` package should be at least version 7
sp.check_call(
[
sys.executable,
"-m",
"jupyter",
"nbconvert",
"--ExecutePreprocessor.timeout=600",
"--to",
"notebook",
"--output",
outfile,
"--execute",
path,
],
env=my_env,
)
else:
# run the notebook
sp.check_call([sys.executable, "-m", "jupyter", "execute", path], env=my_env)
# run the notebook
sp.check_call([sys.executable, "-m", "jupyter", "execute", path], env=my_env)

0 comments on commit a58b012

Please sign in to comment.