Skip to content

Commit

Permalink
Made Controller more resilient
Browse files Browse the repository at this point in the history
  • Loading branch information
david-zwicker committed Aug 17, 2024
1 parent 90eb937 commit 4cc3e84
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
9 changes: 5 additions & 4 deletions pde/solvers/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _run_main_process(self, state: TState, dt: float | None = None) -> None:
# initialize solver information
self.info["t_start"] = t_start
self.info["t_end"] = t_end
self.diagnostics["solver"] = self.solver.info
self.diagnostics["solver"] = getattr(self.solver, "info", {})

# initialize profilers
jit_count_base = int(JIT_COUNT)
Expand All @@ -187,7 +187,8 @@ def _run_main_process(self, state: TState, dt: float | None = None) -> None:
self.info["solver_start"] = str(solver_start)

if dt is None:
dt = self.solver.info.get("dt")
# use self.solver.info['dt'] if it is present
dt = self.diagnostics["solver"].get("dt")
# add some tolerance to account for inaccurate float point math
if dt is None: # self.solver.info['dt'] might be None
atol = 1e-12
Expand Down Expand Up @@ -256,7 +257,7 @@ def _run_main_process(self, state: TState, dt: float | None = None) -> None:
self.info["t_final"] = t
self.info["jit_count"]["simulation"] = int(JIT_COUNT) - jit_count_after_init
self.trackers.finalize(info=self.diagnostics)
if "dt_statistics" in self.solver.info:
if "dt_statistics" in getattr(self.solver, "info", {}):
dt_statistics = dict(self.solver.info["dt_statistics"].to_dict())
self.solver.info["dt_statistics"] = dt_statistics

Expand Down Expand Up @@ -396,7 +397,7 @@ def run(self, initial_state: TState, dt: float | None = None) -> TState | None:
from ..tools import mpi

# copy the initial state to not modify the supplied one
if hasattr(self.solver, "pde") and self.solver.pde.complex_valued:
if getattr(self.solver, "pde", None) and self.solver.pde.complex_valued:
self._logger.info("Convert state to complex numbers")
state: TState = initial_state.copy(dtype=complex)
else:
Expand Down
17 changes: 17 additions & 0 deletions tests/solvers/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
"""

import pytest
import numpy as np

from pde import PDEBase, ScalarField, UnitGrid
from pde.solvers import Controller


def test_controller_abort():
Expand All @@ -27,3 +29,18 @@ def evolution_rate(self, state, t):

assert eq.diagnostics["last_tracker_time"] >= 0
assert eq.diagnostics["last_state"] == field


def test_controller_foreign_solver():
"""Test whether the Controller can deal with a minimal foreign solver"""

class MySolver:
def make_stepper(self, state, dt):
def stepper(state, t, t_break):
return t_break

return stepper

c = Controller(MySolver(), t_range=1)
res = c.run(np.arange(3))
np.testing.assert_allclose(res, np.arange(3))

0 comments on commit 4cc3e84

Please sign in to comment.