Skip to content

Commit

Permalink
Made the error_synchronizer available to all solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
david-zwicker committed Aug 9, 2024
1 parent 04198ee commit 331e569
Showing 1 changed file with 30 additions and 19 deletions.
49 changes: 30 additions & 19 deletions pde/solvers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,36 @@ def _compiled(self) -> bool:
"""bool: indicates whether functions need to be compiled"""
return self.backend == "numba" and not nb.config.DISABLE_JIT

def _make_error_synchronizer(
self, operator: int | str = "MAX"
) -> Callable[[float], float]:
"""Return function that synchronizes errors between multiple processes.
Args:
operator (str or int):
Flag determining how the value from multiple nodes is combined.
Possible values include "MAX", "MIN", and "SUM".
Returns:
Function that can be used to synchronize errors across nodes
"""
if self._mpi_synchronization: # mpi.parallel_run:
# in a parallel run, we need to synchronize values
from ..tools.mpi import mpi_allreduce

@register_jitable
def synchronize_errors(error: float) -> float:
"""Return error synchronized accross all cores."""
return mpi_allreduce(error, operator=operator) # type: ignore

else:

@register_jitable
def synchronize_errors(value: float) -> float:
return value

return synchronize_errors # type: ignore

def _make_post_step_hook(self, state: FieldBase) -> StepperHook:
"""Create a function that calls the post-step hook of the PDE.
Expand Down Expand Up @@ -415,25 +445,6 @@ def __init__(
self.adaptive = adaptive
self.tolerance = tolerance

def _make_error_synchronizer(self) -> Callable[[float], float]:
"""Return function that synchronizes errors between multiple processes."""
if self._mpi_synchronization: # mpi.parallel_run:
# in a parallel run, we need to return the maximal error
from ..tools.mpi import mpi_allreduce

@register_jitable
def synchronize_errors(error: float) -> float:
"""Return maximal error accross all cores."""
return mpi_allreduce(error, operator="MAX") # type: ignore

else:

@register_jitable
def synchronize_errors(value: float) -> float:
return value

return synchronize_errors # type: ignore

def _make_dt_adjuster(self) -> Callable[[float, float], float]:
"""Return a function that can be used to adjust time steps."""
dt_min = self.dt_min
Expand Down

0 comments on commit 331e569

Please sign in to comment.