Skip to content

Commit

Permalink
Changed post_step_hook such that it also receives the current time (#585
Browse files Browse the repository at this point in the history
)

* Changed post_step_hook such that it also receives the current time
* Improved documentation
  • Loading branch information
david-zwicker authored Aug 2, 2024
1 parent 1756d65 commit dbbc5be
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 27 deletions.
38 changes: 34 additions & 4 deletions docs/source/manual/advanced_usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ A simple implementation for the Kuramoto–Sivashinsky equation could read
class KuramotoSivashinskyPDE(PDEBase):
def evolution_rate(self, state, t=0):
""" numpy implementation of the evolution equation """
"""Evaluate the right hand side of the evolution equation."""
state_lapacian = state.laplace(bc="auto_periodic_neumann")
state_gradient = state.gradient(bc="auto_periodic_neumann")
return (- state_lapacian.laplace(bc="auto_periodic_neumann")
Expand All @@ -222,14 +222,13 @@ instance define the boundary conditions and the diffusivity:
class KuramotoSivashinskyPDE(PDEBase):
def __init__(self, diffusivity=1, bc="auto_periodic_neumann", bc_laplace="auto_periodic_neumann"):
""" initialize the class with a diffusivity and boundary conditions
for the actual field and its second derivative """
"""Initialize the class with a diffusivity and boundary conditions."""
self.diffusivity = diffusivity
self.bc = bc
self.bc_laplace = bc_laplace
def evolution_rate(self, state, t=0):
""" numpy implementation of the evolution equation """
"""Evaluate the right hand side of the evolution equation."""
state_lapacian = state.laplace(bc=self.bc)
state_gradient = state.gradient(bc=self.bc)
return (- state_lapacian.laplace(bc=self.bc_laplace)
Expand All @@ -245,6 +244,37 @@ While such an implementation is helpful for testing initial ideas, actual
computations should be performed with compiled PDEs as described below.


Another feature of custom PDE classes are a special function that is called after every
time step. This function allows direct manipulation of the state data and also abortion
of the simulation by raising :class:`StopIteration`.

.. code-block:: python
class AbortEarlyPDE(PDEBase):
def make_post_step_hook(self, state):
"""Create a hook function that is called after every time step."""
def post_step_hook(state_data, t):
"""Limit state to [-1, 1] & abort when standard deviation exceeds 1."""
np.clip(state_data, -1, 1, out=state_data) # limit state
if state_data.std() > 1:
raise StopIteration # abort simulation
return 1 # count the number of times the hook was called
return post_step_hook
def evolution_rate(self, state, t=0):
"""Evaluate the right hand side of the evolution equation."""
return state
We here use a simple constant evolution equation. The hook defined by the first method
does two things: First, it limits the state to the interval `[-1, 1]` using
:func:`numpy.clip`. Second, it evaluates the standard deviation across the entire data,
aborting the simulation when the value exceeds one. Note that the hook always receives
the data always as a :class:`~numpy.ndarray` and not as a full field class.


Low-level operators
"""""""""""""""""""
This section explains how to use the low-level version of the field operators.
Expand Down
16 changes: 8 additions & 8 deletions pde/pdes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..fields.base import FieldBase
from ..fields.datafield_base import DataFieldBase
from ..tools.numba import jit
from ..tools.typing import ArrayLike
from ..tools.typing import ArrayLike, StepperHook
from ..trackers.base import TrackerCollectionDataType

if TYPE_CHECKING:
Expand Down Expand Up @@ -110,14 +110,14 @@ def is_sde(self) -> bool:
# check for self.noise, in case __init__ is not called in a subclass
return hasattr(self, "noise") and np.any(self.noise != 0) # type: ignore

def make_post_step_hook(self, state: FieldBase) -> Callable[[np.ndarray], float]:
def make_post_step_hook(self, state: FieldBase) -> StepperHook:
"""Returns a function that is called after each step.
This function receives the current state as a numpy array and can modify the
data in place. The function must return a float value, which normally indicates
how much the state was modified. This value does not affect the simulation, but
is accumulated over time and provided in the diagnostic information for
debugging.
This function receives the current state as a numpy array together with the
current time point. The function can modify the state data in place. The
function must return a float value, which normally indicates how much the state
was modified. This value does not affect the simulation, but is accumulated over
time and provided in the diagnostic information for debugging.
The hook can also be used to abort the simulation when a user-defined condition
is met by raising `StopIteration`. Note that this interrupts the inner-most loop
Expand All @@ -135,7 +135,7 @@ def make_post_step_hook(self, state: FieldBase) -> Callable[[np.ndarray], float]
measure for the corrections applied to the state
"""

def post_step_hook(state_data: np.ndarray) -> float:
def post_step_hook(state_data: np.ndarray, t: float) -> float:
"""No-op function."""
return 0

Expand Down
2 changes: 1 addition & 1 deletion pde/solvers/adams_bashforth.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def fixed_stepper(
# calculate the right hand side
t = t_start + i * dt
single_step(state_data, t, state_prev)
modifications += post_step_hook(state_data)
modifications += post_step_hook(state_data, t)

return t + dt, modifications

Expand Down
32 changes: 20 additions & 12 deletions pde/solvers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..tools.math import OnlineStatistics
from ..tools.misc import classproperty
from ..tools.numba import is_jitted, jit
from ..tools.typing import BackendType
from ..tools.typing import BackendType, StepperHook


class ConvergenceError(RuntimeError):
Expand Down Expand Up @@ -119,7 +119,7 @@ def _compiled(self) -> bool:
self.backend == "numba" and not nb.config.DISABLE_JIT
) # @UndefinedVariable

def _make_post_step_hook(self, state: FieldBase) -> Callable[[np.ndarray], float]:
def _make_post_step_hook(self, state: FieldBase) -> StepperHook:
"""Create a function that modifies a state after each step.
A noop function will be returned if `_post_step_hook` is `False`,
Expand All @@ -133,23 +133,31 @@ def _make_post_step_hook(self, state: FieldBase) -> Callable[[np.ndarray], float
if hasattr(self.pde, "make_modify_after_step"):
# Deprecated on 2024-08-02
warnings.warn(
"`make_modify_after_step` has been renamed to `make_post_step_hook`",
"`make_modify_after_step` has been replaced by `make_post_step_hook`",
DeprecationWarning,
)
post_step_hook = jit(self.pde.make_modify_after_step(state))
modify_after_step = self.pde.make_modify_after_step(state)
if self._compiled:
sig_modify = (nb.typeof(state.data),)
modify_after_step = jit(sig_modify)(modify_after_step)

def post_step_hook(state_data: np.ndarray, t: float) -> float:
"""Wrap function to adjust signature."""
return modify_after_step(state_data) # type: ignore

else:
post_step_hook = jit(self.pde.make_post_step_hook(state))
post_step_hook = self.pde.make_post_step_hook(state)

else:

def post_step_hook(state_data: np.ndarray) -> float:
return 0
def post_step_hook(state_data: np.ndarray, t: float) -> float:
return 0.0

if self._compiled:
sig_modify = (nb.typeof(state.data),)
post_step_hook = jit(sig_modify)(post_step_hook)
sig_hook = (nb.typeof(state.data), nb.float64)
post_step_hook = jit(sig_hook)(post_step_hook)

return post_step_hook # type: ignore
return post_step_hook

def _make_pde_rhs(
self, state: FieldBase, backend: BackendType = "auto"
Expand Down Expand Up @@ -266,7 +274,7 @@ def fixed_stepper(
# calculate the right hand side
t = t_start + i * dt
single_step(state_data, t)
modifications += post_step_hook(state_data)
modifications += post_step_hook(state_data, t)

return t + dt, modifications

Expand Down Expand Up @@ -538,7 +546,7 @@ def adaptive_stepper(
steps += 1
t += dt_step
state_data[...] = new_state
modifications += post_step_hook(state_data)
modifications += post_step_hook(state_data, t)
if dt_stats is not None:
dt_stats.add(dt_step)

Expand Down
2 changes: 1 addition & 1 deletion pde/solvers/explicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def adaptive_stepper(
steps += 1
t += dt_step
state_data[...] = step_small
modifications += post_step_hook(state_data)
modifications += post_step_hook(state_data, t)
if dt_stats is not None:
dt_stats.add(dt_step)

Expand Down
5 changes: 5 additions & 0 deletions pde/tools/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,8 @@ def __call__(
class GhostCellSetter(Protocol):
def __call__(self, data_full: np.ndarray, args=None) -> None:
"""Set the ghost cells."""


class StepperHook(Protocol):
def __call__(self, state_data: np.ndarray, t: float) -> float:
"""Function analyzing and potentially modifying the current state."""
2 changes: 1 addition & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_stop_iteration_hook(backend):

class TestPDE(PDEBase):
def make_post_step_hook(self, state):
def post_step_hook(state_data):
def post_step_hook(state_data, t):
if state_data.sum() > 1:
raise StopIteration
return 1 # count the number of times the hook was called
Expand Down

0 comments on commit dbbc5be

Please sign in to comment.