Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made post_step_hook more flexible. #586

Merged
merged 1 commit into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions docs/source/manual/advanced_usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,10 @@ field classes.
While such an implementation is helpful for testing initial ideas, actual
computations should be performed with compiled PDEs as described below.


Another feature of custom PDE classes are a special function that is called after every
time step. This function allows direct manipulation of the state data and also abortion
of the simulation by raising :class:`StopIteration`.
Another feature of custom PDE classes is a special function that is called after every
time step. This function is defined by :meth:`~pde.pdes.PDEBase.make_post_step_hook` and
allows direct manipulation of the state data and also abortion of the simulation by
raising :class:`StopIteration`.

.. code-block:: python

Expand All @@ -255,14 +255,14 @@ of the simulation by raising :class:`StopIteration`.
def make_post_step_hook(self, state):
"""Create a hook function that is called after every time step."""

def post_step_hook(state_data, t):
def post_step_hook(state_data, t, post_step_data):
"""Limit state to [-1, 1] & abort when standard deviation exceeds 1."""
np.clip(state_data, -1, 1, out=state_data) # limit state
if state_data.std() > 1:
raise StopIteration # abort simulation
return 1 # count the number of times the hook was called
post_step_data += 1 # increment number of times hook was called

return post_step_hook
return post_step_hook, 0 # hook function and initial value for data

def evolution_rate(self, state, t=0):
"""Evaluate the right hand side of the evolution equation."""
Expand All @@ -272,7 +272,9 @@ We here use a simple constant evolution equation. The hook defined by the first
does two things: First, it limits the state to the interval `[-1, 1]` using
:func:`numpy.clip`. Second, it evaluates the standard deviation across the entire data,
aborting the simulation when the value exceeds one. Note that the hook always receives
the data always as a :class:`~numpy.ndarray` and not as a full field class.
the data always as a :class:`~numpy.ndarray` and not as a full field class. The hook can
also keep track of additional data via :code:`post_step_data`, which is a
:class:`~numpy.ndarray` that can be updated in place.


Low-level operators
Expand Down
41 changes: 41 additions & 0 deletions examples/post_step_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
Post-step hook function
=======================

The hook function created by :meth:`~pde.pdes.PDEBase.make_post_step_hook` is called
after each time step. The function can modify the state, keep track of additional
information and abort the simulation.
"""

from pde import PDEBase, ScalarField, UnitGrid


class CustomPDE(PDEBase):

def make_post_step_hook(self, state):
"""Create a hook function that is called after every time step."""

def post_step_hook(state_data, t, post_step_data):
"""Limit state 1 and abort when standard deviation exceeds 1."""
i = state_data > 1 # get violating entries
overshoot = (state_data[i] - 1).sum() # get total correction
state_data[i] = 1 # limit data entries
post_step_data += overshoot # accumulate total correction
if post_step_data > 400:
# Abort simulation when correction exceeds 400
# Note that the `post_step_data` of the previous step will be returned.
raise StopIteration

return post_step_hook, 0.0 # hook function and initial value for data

def evolution_rate(self, state, t=0):
"""Evaluate the right hand side of the evolution equation."""
return state.__class__(state.grid, data=1) # constant growth


grid = UnitGrid([64, 64]) # generate grid
state = ScalarField.random_uniform(grid, 0.0, 0.5) # generate initial condition

eq = CustomPDE()
result = eq.solve(state, dt=0.1, t_range=1e4)
result.plot(title=f"Total correction={eq.diagnostics['solver']['post_step_data']}")
50 changes: 32 additions & 18 deletions pde/pdes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,36 +110,50 @@ def is_sde(self) -> bool:
# check for self.noise, in case __init__ is not called in a subclass
return hasattr(self, "noise") and np.any(self.noise != 0) # type: ignore

def make_post_step_hook(self, state: FieldBase) -> StepperHook:
def make_post_step_hook(self, state: FieldBase) -> tuple[StepperHook, Any]:
"""Returns a function that is called after each step.

This function receives the current state as a numpy array together with the
current time point. The function can modify the state data in place. The
function must return a float value, which normally indicates how much the state
was modified. This value does not affect the simulation, but is accumulated over
time and provided in the diagnostic information for debugging.
This function receives three arugments: the current state as a numpy array, the
current time point, and a numpy array that can store data for the hook function.
The function can modify the state data in place. If the function makes use of
the data feature, it must replace the data in place.

The hook can also be used to abort the simulation when a user-defined condition
is met by raising `StopIteration`. Note that this interrupts the inner-most loop
and some of the final information (including the stop time and the total
modifications made by the hook) might be incorrect. These fields will usually
still reflect the values they assumed at the last tracker interrupt.
is met by raising `StopIteration`. Note that this interrupts the inner-most
loop, so that some final information might be still reflect the values they
assumed at the last tracker interrupt. Additional information (beside the
current state) should be returned by the 1post_step_data1.

Example:
The following code provides an example that creates a hook function that
limits the state to a maximal value of 1 and keeps track of the total
correction that is applied. This is achieved using `post_step_data`, which
is initialized with the second value (0) returned by the method and
incremented each time the hook is called.

.. code-block:: python

def make_post_step_hook(self, state):

def post_step_hook(state_data, t, post_step_data):
i = state_data > 1 # get violating entries
overshoot = (state_data[i] - 1).sum() # get total correction
state_data[i] = 1 # limit data entries
post_step_data += overshoot # accumulate total correction

return post_step_hook, 0. # hook function and initial value

Args:
state (:class:`~pde.fields.FieldBase`):
An example for the state from which the grid and other information can
be extracted

Returns:
Function that can be applied to a state to modify it and which returns a
measure for the corrections applied to the state
tuple: The first entry is the function that implements the hook. The second
entry gives the initial data that is used as auxiallary data in the hook.
This can be `None` if no data is used.
"""

def post_step_hook(state_data: np.ndarray, t: float) -> float:
"""No-op function."""
return 0

return post_step_hook
raise NotImplementedError

@abstractmethod
def evolution_rate(self, state: TState, t: float = 0) -> TState:
Expand Down
13 changes: 6 additions & 7 deletions pde/solvers/adams_bashforth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from __future__ import annotations

from typing import Callable
from typing import Any, Callable

import numba as nb
import numpy as np
Expand All @@ -22,7 +22,7 @@ class AdamsBashforthSolver(SolverBase):

def _make_fixed_stepper(
self, state: FieldBase, dt: float
) -> Callable[[np.ndarray, float, int], tuple[float, float]]:
) -> Callable[[np.ndarray, float, int, Any], float]:
"""Return a stepper function using an explicit scheme with fixed time steps.

Args:
Expand Down Expand Up @@ -56,8 +56,8 @@ def single_step(
single_step = jit(sig_single_step)(single_step)

def fixed_stepper(
state_data: np.ndarray, t_start: float, steps: int
) -> tuple[float, float]:
state_data: np.ndarray, t_start: float, steps: int, post_step_data
) -> float:
"""Perform `steps` steps with fixed time steps."""
nonlocal state_prev, init_state_prev

Expand All @@ -66,14 +66,13 @@ def fixed_stepper(
state_prev[:] = state_data - dt * rhs_pde(state_data, t_start)
init_state_prev = False

modifications = 0.0
for i in range(steps):
# calculate the right hand side
t = t_start + i * dt
single_step(state_data, t, state_prev)
modifications += post_step_hook(state_data, t)
post_step_hook(state_data, t, post_step_data=post_step_data)

return t + dt, modifications
return t + dt

self._logger.info("Init explicit Adams-Bashforth stepper with dt=%g", dt)

Expand Down
Loading
Loading