Skip to content

Commit

Permalink
Added Adams-Bashforth stepper
Browse files Browse the repository at this point in the history
  • Loading branch information
david-zwicker committed Apr 3, 2024
1 parent b401b7d commit f91a7f2
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 1 deletion.
3 changes: 3 additions & 0 deletions pde/solvers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
~explicit_mpi.ExplicitMPISolver
~implicit.ImplicitSolver
~crank_nicolson.CrankNicolsonSolver
~adams_bashforth.AdamsBashforthSolver
~scipy.ScipySolver
~registered_solvers
Expand All @@ -17,6 +18,7 @@

from typing import List

from .adams_bashforth import AdamsBashforthSolver
from .controller import Controller
from .crank_nicolson import CrankNicolsonSolver
from .explicit import ExplicitSolver
Expand Down Expand Up @@ -46,6 +48,7 @@ def registered_solvers() -> list[str]:
"ExplicitSolver",
"ImplicitSolver",
"CrankNicolsonSolver",
"AdamsBashforthSolver",
"ScipySolver",
"registered_solvers",
]
83 changes: 83 additions & 0 deletions pde/solvers/adams_bashforth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
Defines an explicit solver supporting various methods
.. codeauthor:: David Zwicker <[email protected]>
"""

from __future__ import annotations

from typing import Callable

import numba as nb
import numpy as np

from ..fields.base import FieldBase
from ..tools.numba import jit
from .base import SolverBase


class AdamsBashforthSolver(SolverBase):
"""solving partial differential equations using an Adams-Bashforth scheme"""

name = "adams–bashforth"

def _make_fixed_stepper(
self, state: FieldBase, dt: float
) -> Callable[[np.ndarray, float, int], tuple[float, float]]:
"""return a stepper function using an explicit scheme with fixed time steps
Args:
state (:class:`~pde.fields.base.FieldBase`):
An example for the state from which the grid and other information can
be extracted
dt (float):
Time step of the explicit stepping
"""
if self.pde.is_sde:
raise NotImplementedError

Check warning on line 37 in pde/solvers/adams_bashforth.py

View check run for this annotation

Codecov / codecov/patch

pde/solvers/adams_bashforth.py#L37

Added line #L37 was not covered by tests

rhs_pde = self._make_pde_rhs(state, backend=self.backend)
modify_state_after_step = self._modify_state_after_step
modify_after_step = self._make_modify_after_step(state)

def single_step(
state_data: np.ndarray, t: float, state_prev: np.ndarray
) -> None:
"""perform a single Adams-Bashforth step"""
rhs_prev = rhs_pde(state_prev, t - dt).copy()
rhs_cur = rhs_pde(state_data, t)
state_prev[:] = state_data # save the previous state
state_data += dt * (1.5 * rhs_cur - 0.5 * rhs_prev)

# allocate memory to store the state of the previous time step
state_prev = np.empty_like(state.data)
init_state_prev = True

if self._compiled:
sig_single_step = (nb.typeof(state.data), nb.double, nb.typeof(state_prev))
single_step = jit(sig_single_step)(single_step)

Check warning on line 58 in pde/solvers/adams_bashforth.py

View check run for this annotation

Codecov / codecov/patch

pde/solvers/adams_bashforth.py#L57-L58

Added lines #L57 - L58 were not covered by tests

def fixed_stepper(
state_data: np.ndarray, t_start: float, steps: int
) -> tuple[float, float]:
"""perform `steps` steps with fixed time steps"""
nonlocal state_prev, init_state_prev

if init_state_prev:
# initialize the state_prev with an estimate of the previous step
state_prev[:] = state_data - dt * rhs_pde(state_data, t_start)
init_state_prev = False

modifications = 0.0
for i in range(steps):
# calculate the right hand side
t = t_start + i * dt
single_step(state_data, t, state_prev)
if modify_state_after_step:
modifications += modify_after_step(state_data)

return t + dt, modifications

self._logger.info("Init explicit Adams-Bashforth stepper with dt=%g", dt)

return fixed_stepper
25 changes: 25 additions & 0 deletions tests/solvers/test_adams_bashforth_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
.. codeauthor:: David Zwicker <[email protected]>
"""

import numpy as np

import pde


def test_adams_bashforth():
"""test the adams_bashforth method"""
eq = pde.PDE({"y": "y"})
state = pde.ScalarField(pde.UnitGrid([1]), 1)
storage = pde.MemoryStorage()
eq.solve(
state,
t_range=2.1,
dt=0.5,
solver="adams–bashforth",
tracker=storage.tracker(0.5),
)
np.testing.assert_allclose(
np.ravel([f.data for f in storage]),
[1, 13 / 8, 83 / 32, 529 / 128, 3371 / 512, 21481 / 2048],
)
9 changes: 8 additions & 1 deletion tests/solvers/test_generic_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from pde import PDE, DiffusionPDE, FieldCollection, MemoryStorage, ScalarField, UnitGrid
from pde.solvers import (
AdamsBashforthSolver,
Controller,
CrankNicolsonSolver,
ExplicitSolver,
Expand All @@ -16,7 +17,13 @@
)
from pde.solvers.base import AdaptiveSolverBase

SOLVER_CLASSES = [ExplicitSolver, ImplicitSolver, CrankNicolsonSolver, ScipySolver]
SOLVER_CLASSES = [
ExplicitSolver,
ImplicitSolver,
CrankNicolsonSolver,
AdamsBashforthSolver,
ScipySolver,
]


def test_solver_registration():
Expand Down

0 comments on commit f91a7f2

Please sign in to comment.