diff --git a/pde/solvers/controller.py b/pde/solvers/controller.py index 972f9127..82843c93 100644 --- a/pde/solvers/controller.py +++ b/pde/solvers/controller.py @@ -163,7 +163,7 @@ def _run_main_process(self, state: TState, dt: float | None = None) -> None: # initialize solver information self.info["t_start"] = t_start self.info["t_end"] = t_end - self.diagnostics["solver"] = self.solver.info + self.diagnostics["solver"] = getattr(self.solver, "info", {}) # initialize profilers jit_count_base = int(JIT_COUNT) @@ -187,7 +187,8 @@ def _run_main_process(self, state: TState, dt: float | None = None) -> None: self.info["solver_start"] = str(solver_start) if dt is None: - dt = self.solver.info.get("dt") + # use self.solver.info['dt'] if it is present + dt = self.diagnostics["solver"].get("dt") # add some tolerance to account for inaccurate float point math if dt is None: # self.solver.info['dt'] might be None atol = 1e-12 @@ -256,7 +257,7 @@ def _run_main_process(self, state: TState, dt: float | None = None) -> None: self.info["t_final"] = t self.info["jit_count"]["simulation"] = int(JIT_COUNT) - jit_count_after_init self.trackers.finalize(info=self.diagnostics) - if "dt_statistics" in self.solver.info: + if "dt_statistics" in getattr(self.solver, "info", {}): dt_statistics = dict(self.solver.info["dt_statistics"].to_dict()) self.solver.info["dt_statistics"] = dt_statistics @@ -396,7 +397,7 @@ def run(self, initial_state: TState, dt: float | None = None) -> TState | None: from ..tools import mpi # copy the initial state to not modify the supplied one - if hasattr(self.solver, "pde") and self.solver.pde.complex_valued: + if getattr(self.solver, "pde", None) and self.solver.pde.complex_valued: self._logger.info("Convert state to complex numbers") state: TState = initial_state.copy(dtype=complex) else: diff --git a/tests/solvers/test_controller.py b/tests/solvers/test_controller.py index 60c2073c..b79cd604 100644 --- a/tests/solvers/test_controller.py +++ b/tests/solvers/test_controller.py @@ -3,8 +3,10 @@ """ import pytest +import numpy as np from pde import PDEBase, ScalarField, UnitGrid +from pde.solvers import Controller def test_controller_abort(): @@ -27,3 +29,18 @@ def evolution_rate(self, state, t): assert eq.diagnostics["last_tracker_time"] >= 0 assert eq.diagnostics["last_state"] == field + + +def test_controller_foreign_solver(): + """Test whether the Controller can deal with a minimal foreign solver""" + + class MySolver: + def make_stepper(self, state, dt): + def stepper(state, t, t_break): + return t_break + + return stepper + + c = Controller(MySolver(), t_range=1) + res = c.run(np.arange(3)) + np.testing.assert_allclose(res, np.arange(3))