From f91a7f28adbd5cc3ade478c33358bad322195bcd Mon Sep 17 00:00:00 2001 From: David Zwicker Date: Wed, 3 Apr 2024 17:45:10 +0200 Subject: [PATCH] Added Adams-Bashforth stepper --- pde/solvers/__init__.py | 3 + pde/solvers/adams_bashforth.py | 83 ++++++++++++++++++++ tests/solvers/test_adams_bashforth_solver.py | 25 ++++++ tests/solvers/test_generic_solvers.py | 9 ++- 4 files changed, 119 insertions(+), 1 deletion(-) create mode 100644 pde/solvers/adams_bashforth.py create mode 100644 tests/solvers/test_adams_bashforth_solver.py diff --git a/pde/solvers/__init__.py b/pde/solvers/__init__.py index be12029a..cf699ad1 100644 --- a/pde/solvers/__init__.py +++ b/pde/solvers/__init__.py @@ -9,6 +9,7 @@ ~explicit_mpi.ExplicitMPISolver ~implicit.ImplicitSolver ~crank_nicolson.CrankNicolsonSolver + ~adams_bashforth.AdamsBashforthSolver ~scipy.ScipySolver ~registered_solvers @@ -17,6 +18,7 @@ from typing import List +from .adams_bashforth import AdamsBashforthSolver from .controller import Controller from .crank_nicolson import CrankNicolsonSolver from .explicit import ExplicitSolver @@ -46,6 +48,7 @@ def registered_solvers() -> list[str]: "ExplicitSolver", "ImplicitSolver", "CrankNicolsonSolver", + "AdamsBashforthSolver", "ScipySolver", "registered_solvers", ] diff --git a/pde/solvers/adams_bashforth.py b/pde/solvers/adams_bashforth.py new file mode 100644 index 00000000..83ddc879 --- /dev/null +++ b/pde/solvers/adams_bashforth.py @@ -0,0 +1,83 @@ +""" +Defines an explicit solver supporting various methods + +.. codeauthor:: David Zwicker +""" + +from __future__ import annotations + +from typing import Callable + +import numba as nb +import numpy as np + +from ..fields.base import FieldBase +from ..tools.numba import jit +from .base import SolverBase + + +class AdamsBashforthSolver(SolverBase): + """solving partial differential equations using an Adams-Bashforth scheme""" + + name = "adams–bashforth" + + def _make_fixed_stepper( + self, state: FieldBase, dt: float + ) -> Callable[[np.ndarray, float, int], tuple[float, float]]: + """return a stepper function using an explicit scheme with fixed time steps + + Args: + state (:class:`~pde.fields.base.FieldBase`): + An example for the state from which the grid and other information can + be extracted + dt (float): + Time step of the explicit stepping + """ + if self.pde.is_sde: + raise NotImplementedError + + rhs_pde = self._make_pde_rhs(state, backend=self.backend) + modify_state_after_step = self._modify_state_after_step + modify_after_step = self._make_modify_after_step(state) + + def single_step( + state_data: np.ndarray, t: float, state_prev: np.ndarray + ) -> None: + """perform a single Adams-Bashforth step""" + rhs_prev = rhs_pde(state_prev, t - dt).copy() + rhs_cur = rhs_pde(state_data, t) + state_prev[:] = state_data # save the previous state + state_data += dt * (1.5 * rhs_cur - 0.5 * rhs_prev) + + # allocate memory to store the state of the previous time step + state_prev = np.empty_like(state.data) + init_state_prev = True + + if self._compiled: + sig_single_step = (nb.typeof(state.data), nb.double, nb.typeof(state_prev)) + single_step = jit(sig_single_step)(single_step) + + def fixed_stepper( + state_data: np.ndarray, t_start: float, steps: int + ) -> tuple[float, float]: + """perform `steps` steps with fixed time steps""" + nonlocal state_prev, init_state_prev + + if init_state_prev: + # initialize the state_prev with an estimate of the previous step + state_prev[:] = state_data - dt * rhs_pde(state_data, t_start) + init_state_prev = False + + modifications = 0.0 + for i in range(steps): + # calculate the right hand side + t = t_start + i * dt + single_step(state_data, t, state_prev) + if modify_state_after_step: + modifications += modify_after_step(state_data) + + return t + dt, modifications + + self._logger.info("Init explicit Adams-Bashforth stepper with dt=%g", dt) + + return fixed_stepper diff --git a/tests/solvers/test_adams_bashforth_solver.py b/tests/solvers/test_adams_bashforth_solver.py new file mode 100644 index 00000000..e1d7c8c6 --- /dev/null +++ b/tests/solvers/test_adams_bashforth_solver.py @@ -0,0 +1,25 @@ +""" +.. codeauthor:: David Zwicker +""" + +import numpy as np + +import pde + + +def test_adams_bashforth(): + """test the adams_bashforth method""" + eq = pde.PDE({"y": "y"}) + state = pde.ScalarField(pde.UnitGrid([1]), 1) + storage = pde.MemoryStorage() + eq.solve( + state, + t_range=2.1, + dt=0.5, + solver="adams–bashforth", + tracker=storage.tracker(0.5), + ) + np.testing.assert_allclose( + np.ravel([f.data for f in storage]), + [1, 13 / 8, 83 / 32, 529 / 128, 3371 / 512, 21481 / 2048], + ) diff --git a/tests/solvers/test_generic_solvers.py b/tests/solvers/test_generic_solvers.py index 91e5120b..d5d1303d 100644 --- a/tests/solvers/test_generic_solvers.py +++ b/tests/solvers/test_generic_solvers.py @@ -7,6 +7,7 @@ from pde import PDE, DiffusionPDE, FieldCollection, MemoryStorage, ScalarField, UnitGrid from pde.solvers import ( + AdamsBashforthSolver, Controller, CrankNicolsonSolver, ExplicitSolver, @@ -16,7 +17,13 @@ ) from pde.solvers.base import AdaptiveSolverBase -SOLVER_CLASSES = [ExplicitSolver, ImplicitSolver, CrankNicolsonSolver, ScipySolver] +SOLVER_CLASSES = [ + ExplicitSolver, + ImplicitSolver, + CrankNicolsonSolver, + AdamsBashforthSolver, + ScipySolver, +] def test_solver_registration():