Skip to content

Commit

Permalink
add supports_parallel_solve property on BaseSolver
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Sep 18, 2024
1 parent 24c22aa commit 5087673
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def supports_interp(self):
def root_method(self):
return self._root_method

@property
def supports_parallel_solve(self):
return False

@root_method.setter
def root_method(self, method):
if method == "casadi":
Expand Down Expand Up @@ -896,7 +900,7 @@ def solve(
pybamm.logger.verbose(
f"Calling solver for {t_eval[start_index]} < t < {t_eval[end_index - 1]}"
)
if isinstance(self, (pybamm.JaxSolver, pybamm.IDAKLUSolver)):
if self.supports_parallel_solve:
# Jax and IDAKLU solver can accept a list of inputs
new_solutions = self._integrate(
model,
Expand Down Expand Up @@ -1353,7 +1357,7 @@ def step(
timer.reset()

# API for _integrate is different for JaxSolver and IDAKLUSolver
if isinstance(self, (pybamm.JaxSolver, pybamm.IDAKLUSolver)):
if self.supports_parallel_solve:
solutions = self._integrate(model, t_eval, [model_inputs], t_interp)
solution = solutions[0]
else:
Expand Down
4 changes: 4 additions & 0 deletions src/pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,10 @@ def _check_mlir_conversion(self, name, mlir: str):
def _demote_64_to_32(self, x: pybamm.EvaluatorJax):
return pybamm.EvaluatorJax._demote_64_to_32(x)

@property
def supports_parallel_solve(self):
return True

def _integrate(self, model, t_eval, inputs_list=None, t_interp=None):
"""
Solve a DAE model defined by residuals with initial conditions y0.
Expand Down
4 changes: 4 additions & 0 deletions src/pybamm/solvers/jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ def solve_model_bdf(inputs):
else:
return jax.jit(solve_model_bdf)

@property
def supports_parallel_solve(self):
return True

def _integrate(self, model, t_eval, inputs=None, t_interp=None):
"""
Solve a model defined by dydt with initial conditions y0.
Expand Down

0 comments on commit 5087673

Please sign in to comment.