diff --git a/docs/source/manual/advanced_usage.rst b/docs/source/manual/advanced_usage.rst index b3e0c151..c2569d3c 100644 --- a/docs/source/manual/advanced_usage.rst +++ b/docs/source/manual/advanced_usage.rst @@ -207,7 +207,7 @@ A simple implementation for the Kuramoto–Sivashinsky equation could read class KuramotoSivashinskyPDE(PDEBase): def evolution_rate(self, state, t=0): - """ numpy implementation of the evolution equation """ + """Evaluate the right hand side of the evolution equation.""" state_lapacian = state.laplace(bc="auto_periodic_neumann") state_gradient = state.gradient(bc="auto_periodic_neumann") return (- state_lapacian.laplace(bc="auto_periodic_neumann") @@ -222,14 +222,13 @@ instance define the boundary conditions and the diffusivity: class KuramotoSivashinskyPDE(PDEBase): def __init__(self, diffusivity=1, bc="auto_periodic_neumann", bc_laplace="auto_periodic_neumann"): - """ initialize the class with a diffusivity and boundary conditions - for the actual field and its second derivative """ + """Initialize the class with a diffusivity and boundary conditions.""" self.diffusivity = diffusivity self.bc = bc self.bc_laplace = bc_laplace def evolution_rate(self, state, t=0): - """ numpy implementation of the evolution equation """ + """Evaluate the right hand side of the evolution equation.""" state_lapacian = state.laplace(bc=self.bc) state_gradient = state.gradient(bc=self.bc) return (- state_lapacian.laplace(bc=self.bc_laplace) @@ -245,6 +244,37 @@ While such an implementation is helpful for testing initial ideas, actual computations should be performed with compiled PDEs as described below. +Another feature of custom PDE classes are a special function that is called after every +time step. This function allows direct manipulation of the state data and also abortion +of the simulation by raising :class:`StopIteration`. + +.. code-block:: python + + class AbortEarlyPDE(PDEBase): + + def make_post_step_hook(self, state): + """Create a hook function that is called after every time step.""" + + def post_step_hook(state_data, t): + """Limit state to [-1, 1] & abort when standard deviation exceeds 1.""" + np.clip(state_data, -1, 1, out=state_data) # limit state + if state_data.std() > 1: + raise StopIteration # abort simulation + return 1 # count the number of times the hook was called + + return post_step_hook + + def evolution_rate(self, state, t=0): + """Evaluate the right hand side of the evolution equation.""" + return state + +We here use a simple constant evolution equation. The hook defined by the first method +does two things: First, it limits the state to the interval `[-1, 1]` using +:func:`numpy.clip`. Second, it evaluates the standard deviation across the entire data, +aborting the simulation when the value exceeds one. Note that the hook always receives +the data always as a :class:`~numpy.ndarray` and not as a full field class. + + Low-level operators """"""""""""""""""" This section explains how to use the low-level version of the field operators. diff --git a/pde/pdes/base.py b/pde/pdes/base.py index e26ec59e..dd7de0ee 100644 --- a/pde/pdes/base.py +++ b/pde/pdes/base.py @@ -16,7 +16,7 @@ from ..fields.base import FieldBase from ..fields.datafield_base import DataFieldBase from ..tools.numba import jit -from ..tools.typing import ArrayLike +from ..tools.typing import ArrayLike, StepperHook from ..trackers.base import TrackerCollectionDataType if TYPE_CHECKING: @@ -110,14 +110,14 @@ def is_sde(self) -> bool: # check for self.noise, in case __init__ is not called in a subclass return hasattr(self, "noise") and np.any(self.noise != 0) # type: ignore - def make_post_step_hook(self, state: FieldBase) -> Callable[[np.ndarray], float]: + def make_post_step_hook(self, state: FieldBase) -> StepperHook: """Returns a function that is called after each step. - This function receives the current state as a numpy array and can modify the - data in place. The function must return a float value, which normally indicates - how much the state was modified. This value does not affect the simulation, but - is accumulated over time and provided in the diagnostic information for - debugging. + This function receives the current state as a numpy array together with the + current time point. The function can modify the state data in place. The + function must return a float value, which normally indicates how much the state + was modified. This value does not affect the simulation, but is accumulated over + time and provided in the diagnostic information for debugging. The hook can also be used to abort the simulation when a user-defined condition is met by raising `StopIteration`. Note that this interrupts the inner-most loop @@ -135,7 +135,7 @@ def make_post_step_hook(self, state: FieldBase) -> Callable[[np.ndarray], float] measure for the corrections applied to the state """ - def post_step_hook(state_data: np.ndarray) -> float: + def post_step_hook(state_data: np.ndarray, t: float) -> float: """No-op function.""" return 0 diff --git a/pde/solvers/adams_bashforth.py b/pde/solvers/adams_bashforth.py index 072f11bc..ec9badcd 100644 --- a/pde/solvers/adams_bashforth.py +++ b/pde/solvers/adams_bashforth.py @@ -71,7 +71,7 @@ def fixed_stepper( # calculate the right hand side t = t_start + i * dt single_step(state_data, t, state_prev) - modifications += post_step_hook(state_data) + modifications += post_step_hook(state_data, t) return t + dt, modifications diff --git a/pde/solvers/base.py b/pde/solvers/base.py index 2145aee3..9e224c90 100644 --- a/pde/solvers/base.py +++ b/pde/solvers/base.py @@ -23,7 +23,7 @@ from ..tools.math import OnlineStatistics from ..tools.misc import classproperty from ..tools.numba import is_jitted, jit -from ..tools.typing import BackendType +from ..tools.typing import BackendType, StepperHook class ConvergenceError(RuntimeError): @@ -119,7 +119,7 @@ def _compiled(self) -> bool: self.backend == "numba" and not nb.config.DISABLE_JIT ) # @UndefinedVariable - def _make_post_step_hook(self, state: FieldBase) -> Callable[[np.ndarray], float]: + def _make_post_step_hook(self, state: FieldBase) -> StepperHook: """Create a function that modifies a state after each step. A noop function will be returned if `_post_step_hook` is `False`, @@ -133,23 +133,31 @@ def _make_post_step_hook(self, state: FieldBase) -> Callable[[np.ndarray], float if hasattr(self.pde, "make_modify_after_step"): # Deprecated on 2024-08-02 warnings.warn( - "`make_modify_after_step` has been renamed to `make_post_step_hook`", + "`make_modify_after_step` has been replaced by `make_post_step_hook`", DeprecationWarning, ) - post_step_hook = jit(self.pde.make_modify_after_step(state)) + modify_after_step = self.pde.make_modify_after_step(state) + if self._compiled: + sig_modify = (nb.typeof(state.data),) + modify_after_step = jit(sig_modify)(modify_after_step) + + def post_step_hook(state_data: np.ndarray, t: float) -> float: + """Wrap function to adjust signature.""" + return modify_after_step(state_data) # type: ignore + else: - post_step_hook = jit(self.pde.make_post_step_hook(state)) + post_step_hook = self.pde.make_post_step_hook(state) else: - def post_step_hook(state_data: np.ndarray) -> float: - return 0 + def post_step_hook(state_data: np.ndarray, t: float) -> float: + return 0.0 if self._compiled: - sig_modify = (nb.typeof(state.data),) - post_step_hook = jit(sig_modify)(post_step_hook) + sig_hook = (nb.typeof(state.data), nb.float64) + post_step_hook = jit(sig_hook)(post_step_hook) - return post_step_hook # type: ignore + return post_step_hook def _make_pde_rhs( self, state: FieldBase, backend: BackendType = "auto" @@ -266,7 +274,7 @@ def fixed_stepper( # calculate the right hand side t = t_start + i * dt single_step(state_data, t) - modifications += post_step_hook(state_data) + modifications += post_step_hook(state_data, t) return t + dt, modifications @@ -538,7 +546,7 @@ def adaptive_stepper( steps += 1 t += dt_step state_data[...] = new_state - modifications += post_step_hook(state_data) + modifications += post_step_hook(state_data, t) if dt_stats is not None: dt_stats.add(dt_step) diff --git a/pde/solvers/explicit.py b/pde/solvers/explicit.py index d575a5ba..567522f7 100644 --- a/pde/solvers/explicit.py +++ b/pde/solvers/explicit.py @@ -240,7 +240,7 @@ def adaptive_stepper( steps += 1 t += dt_step state_data[...] = step_small - modifications += post_step_hook(state_data) + modifications += post_step_hook(state_data, t) if dt_stats is not None: dt_stats.add(dt_step) diff --git a/pde/tools/typing.py b/pde/tools/typing.py index 7ea328ed..253dde8a 100644 --- a/pde/tools/typing.py +++ b/pde/tools/typing.py @@ -54,3 +54,8 @@ def __call__( class GhostCellSetter(Protocol): def __call__(self, data_full: np.ndarray, args=None) -> None: """Set the ghost cells.""" + + +class StepperHook(Protocol): + def __call__(self, state_data: np.ndarray, t: float) -> float: + """Function analyzing and potentially modifying the current state.""" diff --git a/tests/test_integration.py b/tests/test_integration.py index 3c7792ca..654b5245 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -134,7 +134,7 @@ def test_stop_iteration_hook(backend): class TestPDE(PDEBase): def make_post_step_hook(self, state): - def post_step_hook(state_data): + def post_step_hook(state_data, t): if state_data.sum() > 1: raise StopIteration return 1 # count the number of times the hook was called