diff --git a/docs/source/user_guide/installation/gnu-linux-mac.rst b/docs/source/user_guide/installation/gnu-linux-mac.rst index e8d941a299..6340631fd3 100644 --- a/docs/source/user_guide/installation/gnu-linux-mac.rst +++ b/docs/source/user_guide/installation/gnu-linux-mac.rst @@ -102,20 +102,6 @@ Users can install ``jax`` and ``jaxlib`` to use the Jax solver. The ``pip install "pybamm[jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``jaxlib`` on your system. PyBaMM's full `conda-forge distribution `_ (``pybamm``) includes ``jax`` and ``jaxlib`` by default. - -.. _optional-iree-mlir-support: - -Optional - IREE / MLIR support -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Users can install ``iree`` (for MLIR just-in-time compilation) to use for main expression evaluation in the IDAKLU solver. Requires ``jax``. - -.. code:: bash - - pip install "pybamm[iree,jax]" - -The ``pip install "pybamm[iree,jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``iree`` onto your system. - Uninstall PyBaMM ---------------- diff --git a/docs/source/user_guide/installation/index.rst b/docs/source/user_guide/installation/index.rst index f3b9ca8ccb..fc3a1bcdcc 100644 --- a/docs/source/user_guide/installation/index.rst +++ b/docs/source/user_guide/installation/index.rst @@ -36,7 +36,6 @@ Optional solvers The following solvers are optionally available: * `jax `_ -based solver, see :ref:`optional-jaxsolver` . -* `IREE `_ (`MLIR `_) support, see :ref:`optional-iree-mlir-support`. Dependencies ------------ @@ -220,17 +219,6 @@ Dependency Minimu `jaxlib `__ 0.4.20 jax Support library for JAX ========================================================================= ================== ================== ======================= -IREE dependencies -^^^^^^^^^^^^^^^^^ - -Installable with ``pip install "pybamm[iree]"`` (requires ``jax`` dependencies to be installed). - -========================================================================= ================== ================== ======================= -Dependency Minimum Version pip extra Notes -========================================================================= ================== ================== ======================= -`iree-compiler `__ 20240507.886 iree IREE compiler -========================================================================= ================== ================== ======================= - Full installation guide ----------------------- diff --git a/noxfile.py b/noxfile.py index 7892ad86ac..a84922add7 100644 --- a/noxfile.py +++ b/noxfile.py @@ -117,8 +117,6 @@ def set_dev(session): session.install("virtualenv", "cmake") session.run("virtualenv", os.fsdecode(VENV_DIR), silent=True) python = os.fsdecode(VENV_DIR.joinpath("bin/python")) - components = ["all", "dev", "jax"] - args = [] # Temporary fix for Python 3.12 CI. TODO: remove after # https://bitbucket.org/pybtex-devs/pybtex/issues/169/replace-pkg_resources-with # is fixed @@ -129,8 +127,7 @@ def set_dev(session): "pip", "install", "-e", - ".[{}]".format(",".join(components)), - *args, + ".[all,dev,jax]", external=True, ) diff --git a/pyproject.toml b/pyproject.toml index 745ee5d8d8..d7aef0b837 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,13 +114,12 @@ dev = [ "importlib-metadata; python_version < '3.10'", ] # For the Jax solver. -# Note: These must be kept in sync with the versions defined in pybamm/util.py, and -# must remain compatible with IREE (see noxfile.py for IREE compatibility). +# Note: These must be kept in sync with the versions defined in pybamm/util.py jax = [ "jax==0.4.27", "jaxlib==0.4.27", ] -# Contains all optional dependencies, except for jax, iree, and dev dependencies +# Contains all optional dependencies, except for jax and dev dependencies all = [ "scikit-fem>=8.1.0", "pybamm[examples,plot,cite,bpx,tqdm]", diff --git a/src/pybamm/__init__.py b/src/pybamm/__init__.py index 5ff23cf85f..8cb57ac66a 100644 --- a/src/pybamm/__init__.py +++ b/src/pybamm/__init__.py @@ -1,8 +1,5 @@ from pybamm.version import __version__ -# Demote expressions to 32-bit floats/ints - option used for IDAKLU-MLIR compilation -demote_expressions_to_32bit = False - # Utility classes and methods from .util import root_dir from .util import Timer, TimerTime, FuzzyDict @@ -175,7 +172,7 @@ from .solvers.jax_bdf_solver import jax_bdf_integrate from .solvers.idaklu_jax import IDAKLUJax -from .solvers.idaklu_solver import IDAKLUSolver, has_iree +from .solvers.idaklu_solver import IDAKLUSolver # Experiments from .experiment.experiment import Experiment diff --git a/src/pybamm/expression_tree/operations/evaluate_python.py b/src/pybamm/expression_tree/operations/evaluate_python.py index a8a37ea7b2..eb4a0f39b9 100644 --- a/src/pybamm/expression_tree/operations/evaluate_python.py +++ b/src/pybamm/expression_tree/operations/evaluate_python.py @@ -596,54 +596,9 @@ def __init__(self, symbol: pybamm.Symbol): static_argnums=self._static_argnums, ) - def _demote_constants(self): - """Demote 64-bit constants (f64, i64) to 32-bit (f32, i32)""" - if not pybamm.demote_expressions_to_32bit: - return # pragma: no cover - self._constants = EvaluatorJax._demote_64_to_32(self._constants) - - @classmethod - def _demote_64_to_32(cls, c): - """Demote 64-bit operations (f64, i64) to 32-bit (f32, i32)""" - - if not pybamm.demote_expressions_to_32bit: - return c - if isinstance(c, float): - c = jax.numpy.float32(c) - if isinstance(c, int): - c = jax.numpy.int32(c) - if isinstance(c, np.int64): - c = c.astype(jax.numpy.int32) - if isinstance(c, np.ndarray): - if c.dtype == np.float64: - c = c.astype(jax.numpy.float32) - if c.dtype == np.int64: - c = c.astype(jax.numpy.int32) - if isinstance(c, jax.numpy.ndarray): - if c.dtype == jax.numpy.float64: - c = c.astype(jax.numpy.float32) - if c.dtype == jax.numpy.int64: - c = c.astype(jax.numpy.int32) - if isinstance( - c, pybamm.expression_tree.operations.evaluate_python.JaxCooMatrix - ): - if c.data.dtype == np.float64: - c.data = c.data.astype(jax.numpy.float32) - if c.row.dtype == np.int64: - c.row = c.row.astype(jax.numpy.int32) - if c.col.dtype == np.int64: - c.col = c.col.astype(jax.numpy.int32) - if isinstance(c, dict): - c = {key: EvaluatorJax._demote_64_to_32(value) for key, value in c.items()} - if isinstance(c, tuple): - c = tuple(EvaluatorJax._demote_64_to_32(value) for value in c) - if isinstance(c, list): - c = [EvaluatorJax._demote_64_to_32(value) for value in c] - return c - @property def _constants(self): - return tuple(map(EvaluatorJax._demote_64_to_32, self.__constants)) + return self.__constants @_constants.setter def _constants(self, value): diff --git a/src/pybamm/solvers/idaklu_solver.py b/src/pybamm/solvers/idaklu_solver.py index 445b4d586a..94e69cf7b8 100644 --- a/src/pybamm/solvers/idaklu_solver.py +++ b/src/pybamm/solvers/idaklu_solver.py @@ -1,35 +1,13 @@ # mypy: ignore-errors -import os import casadi import pybamm import numpy as np import numbers -import scipy.sparse as sparse -from scipy.linalg import bandwidth import pybammsolvers.idaklu as idaklu import warnings -if pybamm.has_jax(): - import jax - from jax import numpy as jnp - - try: - import iree.compiler - except ImportError: # pragma: no cover - pass - - -def has_iree(): - try: - import iree.compiler # noqa: F401 - - return True - except ImportError: # pragma: no cover - return False - - class IDAKLUSolver(pybamm.BaseSolver): """ Solve a discretised model, using sundials with the KLU sparse linear solver. @@ -66,8 +44,6 @@ class IDAKLUSolver(pybamm.BaseSolver): "num_threads": 1, # Number of solvers to use in parallel (for solving multiple sets of input parameters in parallel) "num_solvers": num_threads, - # Evaluation engine to use for jax, can be 'jax'(native) or 'iree' - "jax_evaluator": "jax", ## Linear solver interface # name of sundials linear solver to use options are: "SUNLinSol_KLU", # "SUNLinSol_Dense", "SUNLinSol_Band", "SUNLinSol_SPBCGS", @@ -176,7 +152,6 @@ def __init__( "precon_half_bandwidth_keep": 5, "num_threads": 1, "num_solvers": 1, - "jax_evaluator": "jax", "linear_solver": "SUNLinSol_KLU", "linsol_max_iterations": 5, "epsilon_linear_tolerance": 0.05, @@ -210,10 +185,6 @@ def __init__( for key, value in default_options.items(): if key not in options: options[key] = value - if options["jax_evaluator"] not in ["jax", "iree"]: - raise pybamm.SolverError( - "Evaluation engine must be 'jax' or 'iree' for IDAKLU solver" - ) self._options = options self.output_variables = [] if output_variables is None else output_variables @@ -263,19 +234,10 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False): # stack inputs if inputs_dict: arrays_to_stack = [np.array(x).reshape(-1, 1) for x in inputs_dict.values()] - inputs_sizes = [len(array) for array in arrays_to_stack] inputs = np.vstack(arrays_to_stack) else: - inputs_sizes = [] inputs = np.array([[]]) - def inputs_to_dict(inputs): - index = 0 - for n, key in zip(inputs_sizes, inputs_dict.keys()): - inputs_dict[key] = inputs[index : (index + n)] - index += n - return inputs_dict - y0 = model.y0 if isinstance(y0, casadi.DM): y0 = y0.full() @@ -286,21 +248,12 @@ def inputs_to_dict(inputs): if model.convert_to_format not in ["casadi", "jax"]: msg = ( - "The python-idaklu solver has been deprecated. " - "To use the IDAKLU solver set `convert_to_format = 'casadi'`, or `jax`" - " if using IREE." + "The python-idaklu and IREE solvers have been deprecated. " + "To use the IDAKLU solver set `convert_to_format = 'casadi'` or `jax`" ) warnings.warn(msg, DeprecationWarning, stacklevel=2) - if model.convert_to_format == "jax": - if self._options["jax_evaluator"] != "iree": - raise pybamm.SolverError( - "Unsupported evaluation engine for convert_to_format=" - f"{model.convert_to_format} " - f"(jax_evaluator={self._options['jax_evaluator']})" - ) - mass_matrix = model.mass_matrix.entries.toarray() - elif model.convert_to_format == "casadi": + if model.convert_to_format == "casadi": if self._options["jacobian"] == "dense": mass_matrix = casadi.DM(model.mass_matrix.entries.toarray()) else: @@ -447,171 +400,6 @@ def inputs_to_dict(inputs): rootfn = idaklu.generate_function(rootfn.serialize()) mass_action = idaklu.generate_function(mass_action.serialize()) sensfn = idaklu.generate_function(sensfn.serialize()) - elif ( - model.convert_to_format == "jax" - and self._options["jax_evaluator"] == "iree" - ): - # Convert Jax functions to MLIR (also, demote to single precision) - idaklu_solver_fcn = idaklu.create_iree_solver_group - pybamm.demote_expressions_to_32bit = True - if pybamm.demote_expressions_to_32bit: - warnings.warn( - "Demoting expressions to 32-bit for MLIR conversion", - stacklevel=2, - ) - jnpfloat = jnp.float32 - else: # pragma: no cover - jnpfloat = jnp.float64 - raise pybamm.SolverError( - "Demoting expressions to 32-bit is required for MLIR conversion" - " at this time" - ) - - # input arguments (used for lowering) - t_eval = self._demote_64_to_32(jnp.array([0.0], dtype=jnpfloat)) - y0 = self._demote_64_to_32(model.y0) - inputs0 = self._demote_64_to_32(inputs_to_dict(inputs)) - cj = self._demote_64_to_32(jnp.array([1.0], dtype=jnpfloat)) # array - v0 = jnp.zeros(model.len_rhs_and_alg, jnpfloat) - mass_matrix = model.mass_matrix.entries.toarray() - mass_matrix_demoted = self._demote_64_to_32(mass_matrix) - - # rhs_algebraic - rhs_algebraic_demoted = model.rhs_algebraic_eval - rhs_algebraic_demoted._demote_constants() - - def fcn_rhs_algebraic(t, y, inputs): - # function wraps an expression tree (and names MLIR module) - return rhs_algebraic_demoted(t, y, inputs) - - rhs_algebraic = self._make_iree_function( - fcn_rhs_algebraic, t_eval, y0, inputs0 - ) - - # jac_times_cjmass - jac_rhs_algebraic_demoted = rhs_algebraic_demoted.get_jacobian() - - def fcn_jac_times_cjmass(t, y, p, cj): - return jac_rhs_algebraic_demoted(t, y, p) - cj * mass_matrix_demoted - - sparse_eval = sparse.csc_matrix( - fcn_jac_times_cjmass(t_eval, y0, inputs0, cj) - ) - jac_times_cjmass_nnz = sparse_eval.nnz - jac_times_cjmass_colptrs = sparse_eval.indptr - jac_times_cjmass_rowvals = sparse_eval.indices - jac_bw_lower, jac_bw_upper = bandwidth( - sparse_eval.todense() - ) # potentially slow - if jac_bw_upper <= 1: - jac_bw_upper = jac_bw_lower - 1 - if jac_bw_lower <= 1: - jac_bw_lower = jac_bw_upper + 1 - coo = sparse_eval.tocoo() # convert to COOrdinate format for indexing - - def fcn_jac_times_cjmass_sparse(t, y, p, cj): - return fcn_jac_times_cjmass(t, y, p, cj)[coo.row, coo.col] - - jac_times_cjmass = self._make_iree_function( - fcn_jac_times_cjmass_sparse, t_eval, y0, inputs0, cj - ) - - # Mass action - def fcn_mass_action(v): - return mass_matrix_demoted @ v - - mass_action_demoted = self._demote_64_to_32(fcn_mass_action) - mass_action = self._make_iree_function(mass_action_demoted, v0) - - # rootfn - for ix, _ in enumerate(model.terminate_events_eval): - model.terminate_events_eval[ix]._demote_constants() - - def fcn_rootfn(t, y, inputs): - return jnp.array( - [event(t, y, inputs) for event in model.terminate_events_eval], - dtype=jnpfloat, - ).reshape(-1) - - def fcn_rootfn_demoted(t, y, inputs): - return self._demote_64_to_32(fcn_rootfn)(t, y, inputs) - - rootfn = self._make_iree_function(fcn_rootfn_demoted, t_eval, y0, inputs0) - - # jac_rhs_algebraic_action - jac_rhs_algebraic_action_demoted = ( - rhs_algebraic_demoted.get_jacobian_action() - ) - - def fcn_jac_rhs_algebraic_action( - t, y, p, v - ): # sundials calls (t, y, inputs, v) - return jac_rhs_algebraic_action_demoted( - t, y, v, p - ) # jvp calls (t, y, v, inputs) - - jac_rhs_algebraic_action = self._make_iree_function( - fcn_jac_rhs_algebraic_action, t_eval, y0, inputs0, v0 - ) - - # sensfn - if model.jacp_rhs_algebraic_eval is None: - sensfn = idaklu.IREEBaseFunctionType() # empty equation - else: - sensfn_demoted = rhs_algebraic_demoted.get_sensitivities() - - def fcn_sensfn(t, y, p): - return sensfn_demoted(t, y, p) - - sensfn = self._make_iree_function( - fcn_sensfn, t_eval, jnp.zeros_like(y0), inputs0 - ) - - # output_variables - self.var_idaklu_fcns = [] - self.dvar_dy_idaklu_fcns = [] - self.dvar_dp_idaklu_fcns = [] - for key in self.output_variables: - fcn = self.computed_var_fcns[key] - fcn._demote_constants() - self.var_idaklu_fcns.append( - self._make_iree_function( - lambda t, y, p: fcn(t, y, p), # noqa: B023 - t_eval, - y0, - inputs0, - ) - ) - # Convert derivative functions for sensitivities - if (len(inputs) > 0) and (model.calculate_sensitivities): - dvar_dy = fcn.get_jacobian() - self.dvar_dy_idaklu_fcns.append( - self._make_iree_function( - lambda t, y, p: dvar_dy(t, y, p), # noqa: B023 - t_eval, - y0, - inputs0, - sparse_index=True, - ) - ) - dvar_dp = fcn.get_sensitivities() - self.dvar_dp_idaklu_fcns.append( - self._make_iree_function( - lambda t, y, p: dvar_dp(t, y, p), # noqa: B023 - t_eval, - y0, - inputs0, - ) - ) - - # Identify IREE library - iree_lib_path = os.path.join(iree.compiler.__path__[0], "_mlir_libs") - os.environ["IREE_COMPILER_LIB"] = os.path.join( - iree_lib_path, - next(f for f in os.listdir(iree_lib_path) if "IREECompiler" in f), - ) - - pybamm.demote_expressions_to_32bit = False else: # pragma: no cover raise pybamm.SolverError( "Unsupported evaluation engine for convert_to_format='jax'" @@ -670,57 +458,6 @@ def fcn_sensfn(t, y, p): return base_set_up_return - def _make_iree_function(self, fcn, *args, sparse_index=False): - # Initialise IREE function object - iree_fcn = idaklu.IREEBaseFunctionType() - # Get sparsity pattern index outputs as needed - try: - fcn_eval = fcn(*args) - if not isinstance(fcn_eval, np.ndarray): - fcn_eval = jax.flatten_util.ravel_pytree(fcn_eval)[0] - coo = sparse.coo_matrix(fcn_eval) - iree_fcn.nnz = coo.nnz - iree_fcn.numel = np.prod(coo.shape) - iree_fcn.col = coo.col - iree_fcn.row = coo.row - if sparse_index: - # Isolate NNZ elements while recording original sparsity structure - fcn_inner = fcn - - def fcn(*args): - return fcn_inner(*args)[coo.row, coo.col] - - elif coo.nnz != iree_fcn.numel: - iree_fcn.nnz = iree_fcn.numel - iree_fcn.col = list(range(iree_fcn.numel)) - iree_fcn.row = [0] * iree_fcn.numel - except (TypeError, AttributeError) as error: # pragma: no cover - raise pybamm.SolverError( - "Could not get sparsity pattern for function {fcn.__name__}" - ) from error - # Lower to MLIR - lowered = jax.jit(fcn).lower(*args) - iree_fcn.mlir = lowered.as_text() - self._check_mlir_conversion(fcn.__name__, iree_fcn.mlir) - iree_fcn.kept_var_idx = list(lowered._lowering.compile_args["kept_var_idx"]) - # Record number of variables in each argument (these will flatten in the mlir) - iree_fcn.pytree_shape = [ - len(jax.tree_util.tree_flatten(arg)[0]) for arg in args - ] - # Record array length of each mlir variable - iree_fcn.pytree_sizes = [ - len(arg) for arg in jax.tree_util.tree_flatten(args)[0] - ] - iree_fcn.n_args = len(args) - return iree_fcn - - def _check_mlir_conversion(self, name, mlir: str): - if mlir.count("f64") > 0: # pragma: no cover - warnings.warn(f"f64 found in {name} (x{mlir.count('f64')})", stacklevel=2) - - def _demote_64_to_32(self, x: pybamm.EvaluatorJax): - return pybamm.EvaluatorJax._demote_64_to_32(x) - @property def supports_parallel_solve(self): return True @@ -745,13 +482,7 @@ def _integrate(self, model, t_eval, inputs_list=None, t_interp=None): The times (in seconds) at which to interpolate the solution. Defaults to `None`, which returns the adaptive time-stepping times. """ - if not ( - model.convert_to_format == "casadi" - or ( - model.convert_to_format == "jax" - and self._options["jax_evaluator"] == "iree" - ) - ): # pragma: no cover + if not (model.convert_to_format == "casadi"): # pragma: no cover # Shouldn't ever reach this point raise pybamm.SolverError("Unsupported IDAKLU solver configuration.") @@ -867,18 +598,10 @@ def _post_process_solution(self, sol, model, integration_time, inputs_dict): self._setup["var_fcns"][var](0.0, 0.0, 0.0).sparsity().nnz() ) base_variables = [self._setup["var_fcns"][var]] - elif ( - model.convert_to_format == "jax" - and self._options["jax_evaluator"] == "iree" - ): - idx = self.output_variables.index(var) - len_of_var = self._setup["var_idaklu_fcns"][idx].nnz - base_variables = [self._setup["var_idaklu_fcns"][idx]] else: # pragma: no cover raise pybamm.SolverError( "Unsupported evaluation engine for convert_to_format=" - + f"{model.convert_to_format} " - + f"(jax_evaluator={self._options['jax_evaluator']})" + + f"{model.convert_to_format}" ) newsol._variables[var] = pybamm.ProcessedVariableComputed( [model.variables_and_events[var]], @@ -916,10 +639,6 @@ def _set_consistent_initialization(self, model, time, inputs_dict): super()._set_consistent_initialization(model, time, inputs_dict) casadi_format = model.convert_to_format == "casadi" - jax_iree_format = ( - model.convert_to_format == "jax" - and self._options["jax_evaluator"] == "iree" - ) y0 = model.y0 if isinstance(y0, casadi.DM): @@ -935,7 +654,7 @@ def _set_consistent_initialization(self, model, time, inputs_dict): else: ydot0 = np.zeros_like(y0) - sensitivity = (model.y0S is not None) and (jax_iree_format or casadi_format) + sensitivity = (model.y0S is not None) and casadi_format if sensitivity: y0full, ydot0full = self._sensitivity_consistent_initialization( y0, ydot0, model, time, inputs_dict @@ -944,12 +663,6 @@ def _set_consistent_initialization(self, model, time, inputs_dict): y0full = y0 ydot0full = ydot0 - if jax_iree_format: - pybamm.demote_expressions_to_32bit = True - y0full = self._demote_64_to_32(y0full) - ydot0full = self._demote_64_to_32(ydot0full) - pybamm.demote_expressions_to_32bit = False - model.y0full = y0full model.ydot0full = ydot0full @@ -1017,19 +730,9 @@ def _sensitivity_consistent_initialization( Any input parameters to pass to the model when solving. """ - - jax_iree_format = ( - model.convert_to_format == "jax" - and self._options["jax_evaluator"] == "iree" - ) - y0S = model.y0S - if jax_iree_format: - inputs_dict = inputs_dict or {} - inputs_dict_keys = list(inputs_dict.keys()) - y0S = np.concatenate([y0S[k] for k in inputs_dict_keys]) - elif isinstance(y0S, casadi.DM): + if isinstance(y0S, casadi.DM): y0S = (y0S,) if isinstance(y0S[0], casadi.DM): diff --git a/src/pybamm/solvers/processed_variable_computed.py b/src/pybamm/solvers/processed_variable_computed.py index befe6314b6..4602de4017 100644 --- a/src/pybamm/solvers/processed_variable_computed.py +++ b/src/pybamm/solvers/processed_variable_computed.py @@ -126,11 +126,6 @@ def _unroll_nnz(self, realdata=None): nnz = sp.nnz() numel = sp.numel() row = sp.row() - elif "nnz" in dir(self.base_variables_casadi[0]): # IREE fcn - sp = self.base_variables_casadi[0] - nnz = sp.nnz - numel = sp.numel - row = sp.row if nnz != numel: data = [None] * len(realdata) for datak in range(len(realdata)): diff --git a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py index 14b980b358..2aa7bcaf30 100644 --- a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py +++ b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py @@ -671,68 +671,6 @@ def test_evaluator_jax_inputs(self): result = evaluator(inputs={"a": 2}) assert result == 4 - @pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed") - def test_evaluator_jax_demotion(self): - for demote in [True, False]: - pybamm.demote_expressions_to_32bit = demote # global flag - target_dtype = "32" if demote else "64" - if demote: - # Test only works after conversion to jax.numpy - for c in [ - 1.0, - 1, - ]: - assert ( - str(pybamm.EvaluatorJax._demote_64_to_32(c).dtype)[-2:] - == target_dtype - ) - for c in [ - np.float64(1.0), - np.int64(1), - np.array([1.0], dtype=np.float64), - np.array([1], dtype=np.int64), - jax.numpy.array([1.0], dtype=np.float64), - jax.numpy.array([1], dtype=np.int64), - ]: - assert ( - str(pybamm.EvaluatorJax._demote_64_to_32(c).dtype)[-2:] - == target_dtype - ) - for c in [ - {key: np.float64(1.0) for key in ["a", "b"]}, - ]: - expr_demoted = pybamm.EvaluatorJax._demote_64_to_32(c) - assert all( - str(c_v.dtype)[-2:] == target_dtype - for c_k, c_v in expr_demoted.items() - ) - for c in [ - (np.float64(1.0), np.float64(2.0)), - [np.float64(1.0), np.float64(2.0)], - ]: - expr_demoted = pybamm.EvaluatorJax._demote_64_to_32(c) - assert all(str(c_i.dtype)[-2:] == target_dtype for c_i in expr_demoted) - for dtype in [ - np.float64, - jax.numpy.float64, - ]: - c = pybamm.JaxCooMatrix([0, 1], [0, 1], dtype([1.0, 2.0]), (2, 2)) - c_demoted = pybamm.EvaluatorJax._demote_64_to_32(c) - assert all( - str(c_i.dtype)[-2:] == target_dtype for c_i in c_demoted.data - ) - for dtype in [ - np.int64, - jax.numpy.int64, - ]: - c = pybamm.JaxCooMatrix( - dtype([0, 1]), dtype([0, 1]), [1.0, 2.0], (2, 2) - ) - c_demoted = pybamm.EvaluatorJax._demote_64_to_32(c) - assert all(str(c_i.dtype)[-2:] == target_dtype for c_i in c_demoted.row) - assert all(str(c_i.dtype)[-2:] == target_dtype for c_i in c_demoted.col) - pybamm.demote_expressions_to_32bit = False - @pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed") def test_jax_coo_matrix(self): A = pybamm.JaxCooMatrix([0, 1], [0, 1], [1.0, 2.0], (2, 2)) diff --git a/tests/unit/test_solvers/test_idaklu_solver.py b/tests/unit/test_solvers/test_idaklu_solver.py index 28cc34ded9..81a4f9ae9f 100644 --- a/tests/unit/test_solvers/test_idaklu_solver.py +++ b/tests/unit/test_solvers/test_idaklu_solver.py @@ -13,49 +13,43 @@ def test_ida_roberts_klu(self): # this test implements a python version of the ida Roberts # example provided in sundials # see sundials ida examples pdf - for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): - continue - if form == "casadi": - root_method = "casadi" - else: - root_method = "lm" - model = pybamm.BaseModel() - model.convert_to_format = "jax" if form == "iree" else form - u = pybamm.Variable("u") - v = pybamm.Variable("v") - model.rhs = {u: 0.1 * v} - model.algebraic = {v: 1 - v} - model.initial_conditions = {u: 0, v: 1} - model.events = [pybamm.Event("1", 0.2 - u), pybamm.Event("2", v)] + form = "casadi" + root_method = "casadi" + model = pybamm.BaseModel() + model.convert_to_format = form + u = pybamm.Variable("u") + v = pybamm.Variable("v") + model.rhs = {u: 0.1 * v} + model.algebraic = {v: 1 - v} + model.initial_conditions = {u: 0, v: 1} + model.events = [pybamm.Event("1", 0.2 - u), pybamm.Event("2", v)] - disc = pybamm.Discretisation() - disc.process_model(model) + disc = pybamm.Discretisation() + disc.process_model(model) - solver = pybamm.IDAKLUSolver( - root_method=root_method, - options={"jax_evaluator": "iree"} if form == "iree" else {}, - ) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + ) - # Test - t_eval = [0, 3] - solution = solver.solve(model, t_eval) + # Test + t_eval = [0, 3] + solution = solver.solve(model, t_eval) - # test that final time is time of event - # y = 0.1 t + y0 so y=0.2 when t=2 - np.testing.assert_array_almost_equal(solution.t[-1], 2.0) + # test that final time is time of event + # y = 0.1 t + y0 so y=0.2 when t=2 + np.testing.assert_array_almost_equal(solution.t[-1], 2.0) - # test that final value is the event value - np.testing.assert_array_almost_equal(solution.y[0, -1], 0.2) + # test that final value is the event value + np.testing.assert_array_almost_equal(solution.y[0, -1], 0.2) - # test that y[1] remains constant - np.testing.assert_array_almost_equal( - solution.y[1, :], np.ones(solution.t.shape) - ) + # test that y[1] remains constant + np.testing.assert_array_almost_equal( + solution.y[1, :], np.ones(solution.t.shape) + ) - # test that y[0] = to true solution - true_solution = 0.1 * solution.t - np.testing.assert_array_almost_equal(solution.y[0, :], true_solution) + # test that y[0] = to true solution + true_solution = 0.1 * solution.t + np.testing.assert_array_almost_equal(solution.y[0, :], true_solution) def test_multiple_inputs(self): model = pybamm.BaseModel() @@ -99,488 +93,445 @@ def test_multiple_inputs(self): ) def test_model_events(self): - for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): - continue - if form == "casadi": - root_method = "casadi" - else: - root_method = "lm" - # Create model - model = pybamm.BaseModel() - model.convert_to_format = "jax" if form == "iree" else form - var = pybamm.Variable("var") - model.rhs = {var: 0.1 * var} - model.initial_conditions = {var: 1} + form = "casadi" + root_method = "casadi" + # Create model + model = pybamm.BaseModel() + model.convert_to_format = form + var = pybamm.Variable("var") + model.rhs = {var: 0.1 * var} + model.initial_conditions = {var: 1} - # create discretisation - disc = pybamm.Discretisation() - model_disc = disc.process_model(model, inplace=False) - # Solve - solver = pybamm.IDAKLUSolver( - rtol=1e-8, - atol=1e-8, - root_method=root_method, - options={"jax_evaluator": "iree"} if form == "iree" else {}, - ) + # create discretisation + disc = pybamm.Discretisation() + model_disc = disc.process_model(model, inplace=False) + # Solve + solver = pybamm.IDAKLUSolver( + rtol=1e-8, + atol=1e-8, + root_method=root_method, + ) - t_interp = np.linspace(0, 1, 100) - t_eval = [t_interp[0], t_interp[-1]] + t_interp = np.linspace(0, 1, 100) + t_eval = [t_interp[0], t_interp[-1]] - solution = solver.solve(model_disc, t_eval, t_interp=t_interp) - np.testing.assert_array_equal( - solution.t, t_interp, err_msg=f"Failed for form {form}" - ) - np.testing.assert_array_almost_equal( - solution.y[0], - np.exp(0.1 * solution.t), - decimal=5, - err_msg=f"Failed for form {form}", - ) + solution = solver.solve(model_disc, t_eval, t_interp=t_interp) + np.testing.assert_array_equal( + solution.t, t_interp, err_msg=f"Failed for form {form}" + ) + np.testing.assert_array_almost_equal( + solution.y[0], + np.exp(0.1 * solution.t), + decimal=5, + err_msg=f"Failed for form {form}", + ) - # Check invalid atol type raises an error - with pytest.raises(pybamm.SolverError): - solver._check_atol_type({"key": "value"}, []) + # Check invalid atol type raises an error + with pytest.raises(pybamm.SolverError): + solver._check_atol_type({"key": "value"}, []) - # enforce events that won't be triggered - model.events = [pybamm.Event("an event", var + 1)] - model_disc = disc.process_model(model, inplace=False) - solver = pybamm.IDAKLUSolver( - rtol=1e-8, - atol=1e-8, - root_method=root_method, - options={"jax_evaluator": "iree"} if form == "iree" else {}, - ) - solution = solver.solve(model_disc, t_eval, t_interp=t_interp) - np.testing.assert_array_equal(solution.t, t_interp) - np.testing.assert_array_almost_equal( - solution.y[0], - np.exp(0.1 * solution.t), - decimal=5, - err_msg=f"Failed for form {form}", - ) + # enforce events that won't be triggered + model.events = [pybamm.Event("an event", var + 1)] + model_disc = disc.process_model(model, inplace=False) + solver = pybamm.IDAKLUSolver( + rtol=1e-8, + atol=1e-8, + root_method=root_method, + ) + solution = solver.solve(model_disc, t_eval, t_interp=t_interp) + np.testing.assert_array_equal(solution.t, t_interp) + np.testing.assert_array_almost_equal( + solution.y[0], + np.exp(0.1 * solution.t), + decimal=5, + err_msg=f"Failed for form {form}", + ) - # enforce events that will be triggered - model.events = [pybamm.Event("an event", 1.01 - var)] - model_disc = disc.process_model(model, inplace=False) - solver = pybamm.IDAKLUSolver( - rtol=1e-8, - atol=1e-8, - root_method=root_method, - options={"jax_evaluator": "iree"} if form == "iree" else {}, - ) - solution = solver.solve(model_disc, t_eval, t_interp=t_interp) - assert len(solution.t) < len(t_interp) - np.testing.assert_array_almost_equal( - solution.y[0], - np.exp(0.1 * solution.t), - decimal=5, - err_msg=f"Failed for form {form}", - ) + # enforce events that will be triggered + model.events = [pybamm.Event("an event", 1.01 - var)] + model_disc = disc.process_model(model, inplace=False) + solver = pybamm.IDAKLUSolver( + rtol=1e-8, + atol=1e-8, + root_method=root_method, + ) + solution = solver.solve(model_disc, t_eval, t_interp=t_interp) + assert len(solution.t) < len(t_interp) + np.testing.assert_array_almost_equal( + solution.y[0], + np.exp(0.1 * solution.t), + decimal=5, + err_msg=f"Failed for form {form}", + ) - # bigger dae model with multiple events - model = pybamm.BaseModel() - whole_cell = ["negative electrode", "separator", "positive electrode"] - var1 = pybamm.Variable("var1", domain=whole_cell) - var2 = pybamm.Variable("var2", domain=whole_cell) - model.rhs = {var1: 0.1 * var1} - model.algebraic = {var2: 2 * var1 - var2} - model.initial_conditions = {var1: 1, var2: 2} - model.events = [ - pybamm.Event("var1 = 1.5", pybamm.min(1.5 - var1)), - pybamm.Event("var2 = 2.5", pybamm.min(2.5 - var2)), - ] - disc = get_discretisation_for_testing() - disc.process_model(model) + # bigger dae model with multiple events + model = pybamm.BaseModel() + whole_cell = ["negative electrode", "separator", "positive electrode"] + var1 = pybamm.Variable("var1", domain=whole_cell) + var2 = pybamm.Variable("var2", domain=whole_cell) + model.rhs = {var1: 0.1 * var1} + model.algebraic = {var2: 2 * var1 - var2} + model.initial_conditions = {var1: 1, var2: 2} + model.events = [ + pybamm.Event("var1 = 1.5", pybamm.min(1.5 - var1)), + pybamm.Event("var2 = 2.5", pybamm.min(2.5 - var2)), + ] + disc = get_discretisation_for_testing() + disc.process_model(model) - solver = pybamm.IDAKLUSolver( - rtol=1e-8, - atol=1e-8, - root_method=root_method, - options={"jax_evaluator": "iree"} if form == "iree" else {}, - ) - t_eval = np.array([0, 5]) - solution = solver.solve(model, t_eval) - np.testing.assert_array_less(solution.y[0, :-1], 1.5) - np.testing.assert_array_less(solution.y[-1, :-1], 2.5) - np.testing.assert_equal(solution.t_event[0], solution.t[-1]) - np.testing.assert_array_equal(solution.y_event[:, 0], solution.y[:, -1]) - np.testing.assert_array_almost_equal( - solution.y[0], - np.exp(0.1 * solution.t), - decimal=5, - err_msg=f"Failed for form {form}", - ) - np.testing.assert_array_almost_equal( - solution.y[-1], - 2 * np.exp(0.1 * solution.t), - decimal=5, - err_msg=f"Failed for form {form}", - ) + solver = pybamm.IDAKLUSolver( + rtol=1e-8, + atol=1e-8, + root_method=root_method, + ) + t_eval = np.array([0, 5]) + solution = solver.solve(model, t_eval) + np.testing.assert_array_less(solution.y[0, :-1], 1.5) + np.testing.assert_array_less(solution.y[-1, :-1], 2.5) + np.testing.assert_equal(solution.t_event[0], solution.t[-1]) + np.testing.assert_array_equal(solution.y_event[:, 0], solution.y[:, -1]) + np.testing.assert_array_almost_equal( + solution.y[0], + np.exp(0.1 * solution.t), + decimal=5, + err_msg=f"Failed for form {form}", + ) + np.testing.assert_array_almost_equal( + solution.y[-1], + 2 * np.exp(0.1 * solution.t), + decimal=5, + err_msg=f"Failed for form {form}", + ) def test_input_params(self): # test a mix of scalar and vector input params - for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): - continue - if form == "casadi": - root_method = "casadi" - else: - root_method = "lm" - model = pybamm.BaseModel() - model.convert_to_format = "jax" if form == "iree" else form - u1 = pybamm.Variable("u1") - u2 = pybamm.Variable("u2") - u3 = pybamm.Variable("u3") - v = pybamm.Variable("v") - a = pybamm.InputParameter("a") - b = pybamm.InputParameter("b", expected_size=2) - model.rhs = {u1: a * v, u2: pybamm.Index(b, 0), u3: pybamm.Index(b, 1)} - model.algebraic = {v: 1 - v} - model.initial_conditions = {u1: 0, u2: 0, u3: 0, v: 1} + form = "casadi" + root_method = "casadi" + model = pybamm.BaseModel() + model.convert_to_format = form + u1 = pybamm.Variable("u1") + u2 = pybamm.Variable("u2") + u3 = pybamm.Variable("u3") + v = pybamm.Variable("v") + a = pybamm.InputParameter("a") + b = pybamm.InputParameter("b", expected_size=2) + model.rhs = {u1: a * v, u2: pybamm.Index(b, 0), u3: pybamm.Index(b, 1)} + model.algebraic = {v: 1 - v} + model.initial_conditions = {u1: 0, u2: 0, u3: 0, v: 1} - disc = pybamm.Discretisation() - disc.process_model(model) + disc = pybamm.Discretisation() + disc.process_model(model) - solver = pybamm.IDAKLUSolver( - root_method=root_method, - options={"jax_evaluator": "iree"} if form == "iree" else {}, - ) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + ) - t_interp = np.linspace(0, 3, 100) - t_eval = [t_interp[0], t_interp[-1]] - a_value = 0.1 - b_value = np.array([[0.2], [0.3]]) + t_interp = np.linspace(0, 3, 100) + t_eval = [t_interp[0], t_interp[-1]] + a_value = 0.1 + b_value = np.array([[0.2], [0.3]]) - sol = solver.solve( - model, - t_eval, - inputs={"a": a_value, "b": b_value}, - t_interp=t_interp, - ) + sol = solver.solve( + model, + t_eval, + inputs={"a": a_value, "b": b_value}, + t_interp=t_interp, + ) - # test that y[3] remains constant - np.testing.assert_array_almost_equal( - sol.y[3], np.ones(sol.t.shape), err_msg=f"Failed for form {form}" - ) + # test that y[3] remains constant + np.testing.assert_array_almost_equal( + sol.y[3], np.ones(sol.t.shape), err_msg=f"Failed for form {form}" + ) - # test that y[0] = to true solution - true_solution = a_value * sol.t - np.testing.assert_array_almost_equal( - sol.y[0], true_solution, err_msg=f"Failed for form {form}" - ) + # test that y[0] = to true solution + true_solution = a_value * sol.t + np.testing.assert_array_almost_equal( + sol.y[0], true_solution, err_msg=f"Failed for form {form}" + ) - # test that y[1:3] = to true solution - true_solution = b_value * sol.t - np.testing.assert_array_almost_equal( - sol.y[1:3], true_solution, err_msg=f"Failed for form {form}" - ) + # test that y[1:3] = to true solution + true_solution = b_value * sol.t + np.testing.assert_array_almost_equal( + sol.y[1:3], true_solution, err_msg=f"Failed for form {form}" + ) def test_sensitivities_initial_condition(self): - for form in ["casadi", "iree"]: - for output_variables in [[], ["2v"]]: - if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): - continue - if form == "casadi": - root_method = "casadi" - else: - root_method = "lm" - model = pybamm.BaseModel() - model.convert_to_format = "jax" if form == "iree" else form - u = pybamm.Variable("u") - v = pybamm.Variable("v") - a = pybamm.InputParameter("a") - model.rhs = {u: -u} - model.algebraic = {v: a * u - v} - model.initial_conditions = {u: 1, v: 1} - model.variables = {"2v": 2 * v} - - disc = pybamm.Discretisation() - disc.process_model(model) - solver = pybamm.IDAKLUSolver( - rtol=1e-6, - atol=1e-6, - root_method=root_method, - output_variables=output_variables, - options={"jax_evaluator": "iree"} if form == "iree" else {}, - ) - - t_eval = [0, 3] - - a_value = 0.1 - - sol = solver.solve( - model, - t_eval, - inputs={"a": a_value}, - calculate_sensitivities=True, - ) - - np.testing.assert_array_almost_equal( - sol["2v"].sensitivities["a"].full().flatten(), - np.exp(-sol.t) * 2, - decimal=4, - err_msg=f"Failed for form {form}", - ) - - def test_ida_roberts_klu_sensitivities(self): - # this test implements a python version of the ida Roberts - # example provided in sundials - # see sundials ida examples pdf - for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): - continue - if form == "casadi": - root_method = "casadi" - else: - root_method = "lm" + form = "casadi" + root_method = "casadi" + for output_variables in [[], ["2v"]]: model = pybamm.BaseModel() - model.convert_to_format = "jax" if form == "iree" else form + model.convert_to_format = form u = pybamm.Variable("u") v = pybamm.Variable("v") a = pybamm.InputParameter("a") - model.rhs = {u: a * v} - model.algebraic = {v: 1 - v} - model.initial_conditions = {u: 0, v: 1} - model.variables = {"2u": 2 * u} + model.rhs = {u: -u} + model.algebraic = {v: a * u - v} + model.initial_conditions = {u: 1, v: 1} + model.variables = {"2v": 2 * v} disc = pybamm.Discretisation() disc.process_model(model) - solver = pybamm.IDAKLUSolver( + rtol=1e-6, + atol=1e-6, root_method=root_method, - options={"jax_evaluator": "iree"} if form == "iree" else {}, + output_variables=output_variables, ) - t_interp = np.linspace(0, 3, 100) - t_eval = [t_interp[0], t_interp[-1]] + t_eval = [0, 3] + a_value = 0.1 - # solve first without sensitivities sol = solver.solve( model, t_eval, inputs={"a": a_value}, - t_interp=t_interp, + calculate_sensitivities=True, ) - # test that y[1] remains constant np.testing.assert_array_almost_equal( - sol.y[1, :], np.ones(sol.t.shape), err_msg=f"Failed for form {form}" + sol["2v"].sensitivities["a"].full().flatten(), + np.exp(-sol.t) * 2, + decimal=4, + err_msg=f"Failed for form {form}", ) - # test that y[0] = to true solution - true_solution = a_value * sol.t - np.testing.assert_array_almost_equal( - sol.y[0, :], true_solution, err_msg=f"Failed for form {form}" - ) + def test_ida_roberts_klu_sensitivities(self): + # this test implements a python version of the ida Roberts + # example provided in sundials + # see sundials ida examples pdf + form = "casadi" + root_method = "casadi" + model = pybamm.BaseModel() + model.convert_to_format = form + u = pybamm.Variable("u") + v = pybamm.Variable("v") + a = pybamm.InputParameter("a") + model.rhs = {u: a * v} + model.algebraic = {v: 1 - v} + model.initial_conditions = {u: 0, v: 1} + model.variables = {"2u": 2 * u} - # should be no sensitivities calculated - with pytest.raises(KeyError): - print(sol.sensitivities["a"]) + disc = pybamm.Discretisation() + disc.process_model(model) - # now solve with sensitivities (this should cause set_up to be run again) - sol = solver.solve( - model, - t_eval, - inputs={"a": a_value}, - calculate_sensitivities=True, - t_interp=t_interp, - ) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + ) - # test that y[1] remains constant - np.testing.assert_array_almost_equal( - sol.y[1, :], np.ones(sol.t.shape), err_msg=f"Failed for form {form}" - ) + t_interp = np.linspace(0, 3, 100) + t_eval = [t_interp[0], t_interp[-1]] + a_value = 0.1 - # test that y[0] = to true solution - true_solution = a_value * sol.t - np.testing.assert_array_almost_equal( - sol.y[0, :], true_solution, err_msg=f"Failed for form {form}" - ) + # solve first without sensitivities + sol = solver.solve( + model, + t_eval, + inputs={"a": a_value}, + t_interp=t_interp, + ) - # evaluate the sensitivities using idas - dyda_ida = sol.sensitivities["a"] + # test that y[1] remains constant + np.testing.assert_array_almost_equal( + sol.y[1, :], np.ones(sol.t.shape), err_msg=f"Failed for form {form}" + ) - # evaluate the sensitivities using finite difference - h = 1e-6 - sol_plus = solver.solve( - model, t_eval, inputs={"a": a_value + 0.5 * h}, t_interp=t_interp - ) - sol_neg = solver.solve( - model, t_eval, inputs={"a": a_value - 0.5 * h}, t_interp=t_interp - ) - dyda_fd = (sol_plus.y - sol_neg.y) / h - dyda_fd = dyda_fd.transpose().reshape(-1, 1) + # test that y[0] = to true solution + true_solution = a_value * sol.t + np.testing.assert_array_almost_equal( + sol.y[0, :], true_solution, err_msg=f"Failed for form {form}" + ) - decimal = ( - 2 if form == "iree" else 6 - ) # iree currently operates with single precision - np.testing.assert_array_almost_equal( - dyda_ida, dyda_fd, decimal=decimal, err_msg=f"Failed for form {form}" - ) + # should be no sensitivities calculated + with pytest.raises(KeyError): + print(sol.sensitivities["a"]) - # get the sensitivities for the variable - d2uda = sol["2u"].sensitivities["a"] - np.testing.assert_array_almost_equal( - 2 * dyda_ida[0:200:2], - d2uda, - decimal=decimal, - err_msg=f"Failed for form {form}", - ) + # now solve with sensitivities (this should cause set_up to be run again) + sol = solver.solve( + model, + t_eval, + inputs={"a": a_value}, + calculate_sensitivities=True, + t_interp=t_interp, + ) + + # test that y[1] remains constant + np.testing.assert_array_almost_equal( + sol.y[1, :], np.ones(sol.t.shape), err_msg=f"Failed for form {form}" + ) + + # test that y[0] = to true solution + true_solution = a_value * sol.t + np.testing.assert_array_almost_equal( + sol.y[0, :], true_solution, err_msg=f"Failed for form {form}" + ) + + # evaluate the sensitivities using idas + dyda_ida = sol.sensitivities["a"] + + # evaluate the sensitivities using finite difference + h = 1e-6 + sol_plus = solver.solve( + model, t_eval, inputs={"a": a_value + 0.5 * h}, t_interp=t_interp + ) + sol_neg = solver.solve( + model, t_eval, inputs={"a": a_value - 0.5 * h}, t_interp=t_interp + ) + dyda_fd = (sol_plus.y - sol_neg.y) / h + dyda_fd = dyda_fd.transpose().reshape(-1, 1) + + decimal = 6 + np.testing.assert_array_almost_equal( + dyda_ida, dyda_fd, decimal=decimal, err_msg=f"Failed for form {form}" + ) + + # get the sensitivities for the variable + d2uda = sol["2u"].sensitivities["a"] + np.testing.assert_array_almost_equal( + 2 * dyda_ida[0:200:2], + d2uda, + decimal=decimal, + err_msg=f"Failed for form {form}", + ) def test_ida_roberts_consistent_initialization(self): # this test implements a python version of the ida Roberts # example provided in sundials # see sundials ida examples pdf - for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): - continue - if form == "casadi": - root_method = "casadi" - else: - root_method = "lm" - model = pybamm.BaseModel() - model.convert_to_format = "jax" if form == "iree" else form - u = pybamm.Variable("u") - v = pybamm.Variable("v") - model.rhs = {u: 0.1 * v} - model.algebraic = {v: 1 - v} - model.initial_conditions = {u: 0, v: 2} + form = "casadi" + root_method = "casadi" + model = pybamm.BaseModel() + model.convert_to_format = form + u = pybamm.Variable("u") + v = pybamm.Variable("v") + model.rhs = {u: 0.1 * v} + model.algebraic = {v: 1 - v} + model.initial_conditions = {u: 0, v: 2} - disc = pybamm.Discretisation() - disc.process_model(model) + disc = pybamm.Discretisation() + disc.process_model(model) - solver = pybamm.IDAKLUSolver( - root_method=root_method, - options={"jax_evaluator": "iree"} if form == "iree" else {}, - ) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + ) - # Set up and model consistently initializate the model - solver.set_up(model) - t0 = 0.0 - solver._set_consistent_initialization(model, t0, inputs_dict={}) + # Set up and model consistently initializate the model + solver.set_up(model) + t0 = 0.0 + solver._set_consistent_initialization(model, t0, inputs_dict={}) - # u(t0) = 0, v(t0) = 1 - np.testing.assert_array_almost_equal( - model.y0full, [0, 1], err_msg=f"Failed for form {form}" - ) - # u'(t0) = 0.1 * v(t0) = 0.1 - # Since v is algebraic, the initial derivative is set to 0 - np.testing.assert_array_almost_equal( - model.ydot0full, [0.1, 0], err_msg=f"Failed for form {form}" - ) + # u(t0) = 0, v(t0) = 1 + np.testing.assert_array_almost_equal( + model.y0full, [0, 1], err_msg=f"Failed for form {form}" + ) + # u'(t0) = 0.1 * v(t0) = 0.1 + # Since v is algebraic, the initial derivative is set to 0 + np.testing.assert_array_almost_equal( + model.ydot0full, [0.1, 0], err_msg=f"Failed for form {form}" + ) def test_sensitivities_with_events(self): # this test implements a python version of the ida Roberts # example provided in sundials # see sundials ida examples pdf - for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): - continue - if form == "casadi": - root_method = "casadi" - else: - root_method = "lm" - model = pybamm.BaseModel() - model.convert_to_format = "jax" if form == "iree" else form - u = pybamm.Variable("u") - v = pybamm.Variable("v") - a = pybamm.InputParameter("a") - b = pybamm.InputParameter("b") - model.rhs = {u: a * v + b} - model.algebraic = {v: 1 - v} - model.initial_conditions = {u: 0, v: 1} - model.events = [pybamm.Event("1", 0.2 - u)] - - disc = pybamm.Discretisation() - disc.process_model(model) - - solver = pybamm.IDAKLUSolver( - root_method=root_method, - options={"jax_evaluator": "iree"} if form == "iree" else {}, - ) + form = "casadi" + root_method = "casadi" + model = pybamm.BaseModel() + model.convert_to_format = form + u = pybamm.Variable("u") + v = pybamm.Variable("v") + a = pybamm.InputParameter("a") + b = pybamm.InputParameter("b") + model.rhs = {u: a * v + b} + model.algebraic = {v: 1 - v} + model.initial_conditions = {u: 0, v: 1} + model.events = [pybamm.Event("1", 0.2 - u)] - t_interp = np.linspace(0, 3, 100) - t_eval = [t_interp[0], t_interp[-1]] + disc = pybamm.Discretisation() + disc.process_model(model) - a_value = 0.1 - b_value = 0.0 + solver = pybamm.IDAKLUSolver( + root_method=root_method, + ) - # solve first without sensitivities - sol = solver.solve( - model, - t_eval, - inputs={"a": a_value, "b": b_value}, - calculate_sensitivities=True, - t_interp=t_interp, - ) + t_interp = np.linspace(0, 3, 100) + t_eval = [t_interp[0], t_interp[-1]] - # test that y[1] remains constant - np.testing.assert_array_almost_equal( - sol.y[1, :], np.ones(sol.t.shape), err_msg=f"Failed for form {form}" - ) + a_value = 0.1 + b_value = 0.0 - # test that y[0] = to true solution - true_solution = a_value * sol.t - np.testing.assert_array_almost_equal( - sol.y[0, :], true_solution, err_msg=f"Failed for form {form}" - ) + # solve first without sensitivities + sol = solver.solve( + model, + t_eval, + inputs={"a": a_value, "b": b_value}, + calculate_sensitivities=True, + t_interp=t_interp, + ) - # evaluate the sensitivities using idas - dyda_ida = sol.sensitivities["a"] - dydb_ida = sol.sensitivities["b"] + # test that y[1] remains constant + np.testing.assert_array_almost_equal( + sol.y[1, :], np.ones(sol.t.shape), err_msg=f"Failed for form {form}" + ) - # evaluate the sensitivities using finite difference - h = 1e-6 - sol_plus = solver.solve( - model, - t_eval, - inputs={"a": a_value + 0.5 * h, "b": b_value}, - t_interp=t_interp, - ) - sol_neg = solver.solve( - model, - t_eval, - inputs={"a": a_value - 0.5 * h, "b": b_value}, - t_interp=t_interp, - ) - max_index = min(sol_plus.y.shape[1], sol_neg.y.shape[1]) - 1 - dyda_fd = (sol_plus.y[:, :max_index] - sol_neg.y[:, :max_index]) / h - dyda_fd = dyda_fd.transpose().reshape(-1, 1) + # test that y[0] = to true solution + true_solution = a_value * sol.t + np.testing.assert_array_almost_equal( + sol.y[0, :], true_solution, err_msg=f"Failed for form {form}" + ) - decimal = ( - 2 if form == "iree" else 6 - ) # iree currently operates with single precision - np.testing.assert_array_almost_equal( - dyda_ida[: (2 * max_index), :], - dyda_fd, - decimal=decimal, - err_msg=f"Failed for form {form}", - ) + # evaluate the sensitivities using idas + dyda_ida = sol.sensitivities["a"] + dydb_ida = sol.sensitivities["b"] - sol_plus = solver.solve( - model, - t_eval, - inputs={"a": a_value, "b": b_value + 0.5 * h}, - t_interp=t_interp, - ) - sol_neg = solver.solve( - model, - t_eval, - inputs={"a": a_value, "b": b_value - 0.5 * h}, - t_interp=t_interp, - ) - max_index = min(sol_plus.y.shape[1], sol_neg.y.shape[1]) - 1 - dydb_fd = (sol_plus.y[:, :max_index] - sol_neg.y[:, :max_index]) / h - dydb_fd = dydb_fd.transpose().reshape(-1, 1) + # evaluate the sensitivities using finite difference + h = 1e-6 + sol_plus = solver.solve( + model, + t_eval, + inputs={"a": a_value + 0.5 * h, "b": b_value}, + t_interp=t_interp, + ) + sol_neg = solver.solve( + model, + t_eval, + inputs={"a": a_value - 0.5 * h, "b": b_value}, + t_interp=t_interp, + ) + max_index = min(sol_plus.y.shape[1], sol_neg.y.shape[1]) - 1 + dyda_fd = (sol_plus.y[:, :max_index] - sol_neg.y[:, :max_index]) / h + dyda_fd = dyda_fd.transpose().reshape(-1, 1) + + decimal = 6 + np.testing.assert_array_almost_equal( + dyda_ida[: (2 * max_index), :], + dyda_fd, + decimal=decimal, + err_msg=f"Failed for form {form}", + ) - np.testing.assert_array_almost_equal( - dydb_ida[: (2 * max_index), :], - dydb_fd, - decimal=decimal, - err_msg=f"Failed for form {form}", - ) + sol_plus = solver.solve( + model, + t_eval, + inputs={"a": a_value, "b": b_value + 0.5 * h}, + t_interp=t_interp, + ) + sol_neg = solver.solve( + model, + t_eval, + inputs={"a": a_value, "b": b_value - 0.5 * h}, + t_interp=t_interp, + ) + max_index = min(sol_plus.y.shape[1], sol_neg.y.shape[1]) - 1 + dydb_fd = (sol_plus.y[:, :max_index] - sol_neg.y[:, :max_index]) / h + dydb_fd = dydb_fd.transpose().reshape(-1, 1) + + np.testing.assert_array_almost_equal( + dydb_ida[: (2 * max_index), :], + dydb_fd, + decimal=decimal, + err_msg=f"Failed for form {form}", + ) def test_failures(self): # this test implements a python version of the ida Roberts @@ -634,34 +585,28 @@ def test_failures(self): solver.solve(model, t_eval) def test_dae_solver_algebraic_model(self): - for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): - continue - if form == "casadi": - root_method = "casadi" - else: - root_method = "lm" - model = pybamm.BaseModel() - model.convert_to_format = "jax" if form == "iree" else form - var = pybamm.Variable("var") - model.algebraic = {var: var + 1} - model.initial_conditions = {var: 0} + form = "casadi" + root_method = "casadi" + model = pybamm.BaseModel() + model.convert_to_format = form + var = pybamm.Variable("var") + model.algebraic = {var: var + 1} + model.initial_conditions = {var: 0} - disc = pybamm.Discretisation() - disc.process_model(model) + disc = pybamm.Discretisation() + disc.process_model(model) - solver = pybamm.IDAKLUSolver( - root_method=root_method, - options={"jax_evaluator": "iree"} if form == "iree" else {}, - ) - t_eval = [0, 1] - solution = solver.solve(model, t_eval) - np.testing.assert_array_equal(solution.y, -1) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + ) + t_eval = [0, 1] + solution = solver.solve(model, t_eval) + np.testing.assert_array_equal(solution.y, -1) - # change initial_conditions and re-solve (to test if ics_only works) - model.concatenated_initial_conditions = pybamm.Vector(np.array([[1]])) - solution = solver.solve(model, t_eval) - np.testing.assert_array_equal(solution.y, -1) + # change initial_conditions and re-solve (to test if ics_only works) + model.concatenated_initial_conditions = pybamm.Vector(np.array([[1]])) + solution = solver.solve(model, t_eval) + np.testing.assert_array_equal(solution.y, -1) def test_banded(self): model = pybamm.lithium_ion.SPM() @@ -950,113 +895,90 @@ def test_with_output_variables_and_sensitivities(self): # Construct a model and solve for all variables, then test # the 'output_variables' option for each variable in turn, confirming # equivalence + form = "casadi" + root_method = "casadi" + input_parameters = { # Sensitivities dictionary + "Current function [A]": 0.222, + "Separator porosity": 0.3, + } - for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): - continue - if form == "casadi": - root_method = "casadi" - else: - root_method = "lm" - input_parameters = { # Sensitivities dictionary - "Current function [A]": 0.222, - "Separator porosity": 0.3, - } - - # construct model - model = pybamm.lithium_ion.DFN() - model.convert_to_format = "jax" if form == "iree" else form - geometry = model.default_geometry - param = model.default_parameter_values - param.update({key: "[input]" for key in input_parameters}) - param.process_model(model) - param.process_geometry(geometry) - var_pts = {"x_n": 50, "x_s": 50, "x_p": 50, "r_n": 5, "r_p": 5} - mesh = pybamm.Mesh(geometry, model.default_submesh_types, var_pts) - disc = pybamm.Discretisation(mesh, model.default_spatial_methods) - disc.process_model(model) - - t_interp = np.linspace(0, 100, 5) - t_eval = [t_interp[0], t_interp[-1]] + # construct model + model = pybamm.lithium_ion.DFN() + model.convert_to_format = form + geometry = model.default_geometry + param = model.default_parameter_values + param.update({key: "[input]" for key in input_parameters}) + param.process_model(model) + param.process_geometry(geometry) + var_pts = {"x_n": 50, "x_s": 50, "x_p": 50, "r_n": 5, "r_p": 5} + mesh = pybamm.Mesh(geometry, model.default_submesh_types, var_pts) + disc = pybamm.Discretisation(mesh, model.default_spatial_methods) + disc.process_model(model) - options = { - "linear_solver": "SUNLinSol_KLU", - "jacobian": "sparse", - "num_threads": 4, - "max_num_steps": 1000, - } - if form == "iree": - options["jax_evaluator"] = "iree" - - # Use a selection of variables of different types - output_variables = [ - "Voltage [V]", - "Time [min]", - "x [m]", - "Negative particle flux [mol.m-2.s-1]", - "Throughput capacity [A.h]", # ExplicitTimeIntegral - ] - - # Use the full model as comparison (tested separately) - solver_all = pybamm.IDAKLUSolver( - root_method=root_method, - atol=1e-8 if form != "iree" else 1e-0, # iree has reduced precision - rtol=1e-8 if form != "iree" else 1e-0, # iree has reduced precision - options=options, - ) - sol_all = solver_all.solve( - model, - t_eval, - inputs=input_parameters, - calculate_sensitivities=True, - t_interp=t_interp, - ) + t_interp = np.linspace(0, 100, 5) + t_eval = [t_interp[0], t_interp[-1]] - # Solve for a subset of variables and compare results - solver = pybamm.IDAKLUSolver( - root_method=root_method, - atol=1e-8 if form != "iree" else 1e-0, # iree has reduced precision - rtol=1e-8 if form != "iree" else 1e-0, # iree has reduced precision - options=options, - output_variables=output_variables, - ) - sol = solver.solve( - model, - t_eval, - inputs=input_parameters, - calculate_sensitivities=True, - t_interp=t_interp, - ) + options = { + "linear_solver": "SUNLinSol_KLU", + "jacobian": "sparse", + "num_threads": 4, + "max_num_steps": 1000, + } - # Compare output to sol_all - tol = 1e-5 if form != "iree" else 1e-2 # iree has reduced precision - for varname in output_variables: - np.testing.assert_array_almost_equal( - sol[varname](t_interp), - sol_all[varname](t_interp), - tol, - err_msg=f"Failed for {varname} with form {form}", - ) + # Use a selection of variables of different types + output_variables = [ + "Voltage [V]", + "Time [min]", + "x [m]", + "Negative particle flux [mol.m-2.s-1]", + "Throughput capacity [A.h]", # ExplicitTimeIntegral + ] - # Mock a 1D current collector and initialise (none in the model) - sol["x_s [m]"].domain = ["current collector"] - sol["x_s [m]"].entries + # Use the full model as comparison (tested separately) + solver_all = pybamm.IDAKLUSolver( + root_method=root_method, + atol=1e-8, + rtol=1e-8, + options=options, + ) + sol_all = solver_all.solve( + model, + t_eval, + inputs=input_parameters, + calculate_sensitivities=True, + t_interp=t_interp, + ) - def test_bad_jax_evaluator(self): - model = pybamm.lithium_ion.DFN() - model.convert_to_format = "jax" - with pytest.raises(pybamm.SolverError): - pybamm.IDAKLUSolver(options={"jax_evaluator": "bad_evaluator"}) + # Solve for a subset of variables and compare results + solver = pybamm.IDAKLUSolver( + root_method=root_method, + atol=1e-8, + rtol=1e-8, + options=options, + output_variables=output_variables, + ) + sol = solver.solve( + model, + t_eval, + inputs=input_parameters, + calculate_sensitivities=True, + t_interp=t_interp, + ) - def test_bad_jax_evaluator_output_variables(self): - model = pybamm.lithium_ion.DFN() - model.convert_to_format = "jax" - with pytest.raises(pybamm.SolverError): - pybamm.IDAKLUSolver( - options={"jax_evaluator": "bad_evaluator"}, - output_variables=["Terminal voltage [V]"], + # Compare output to sol_all + tol = 1e-5 + for varname in output_variables: + np.testing.assert_array_almost_equal( + sol[varname](t_interp), + sol_all[varname](t_interp), + tol, + err_msg=f"Failed for {varname} with form {form}", ) + # Mock a 1D current collector and initialise (none in the model) + sol["x_s [m]"].domain = ["current collector"] + sol["x_s [m]"].entries + def test_with_output_variables_and_event_termination(self): model = pybamm.lithium_ion.DFN() parameter_values = pybamm.ParameterValues("Chen2020") @@ -1145,7 +1067,7 @@ def experiment_setup(period=None): ) def test_python_idaklu_deprecation_errors(self): - for form in ["python", "", "jax"]: + for form in ["python", "jax"]: if form == "jax" and not pybamm.has_jax(): continue @@ -1174,13 +1096,13 @@ def test_python_idaklu_deprecation_errors(self): ): with pytest.raises( DeprecationWarning, - match="The python-idaklu solver has been deprecated.", + match="The python-idaklu and IREE solvers have been deprecated.", ): _ = solver.solve(model, t_eval) elif form == "jax": with pytest.raises( pybamm.SolverError, - match="Unsupported evaluation engine for convert_to_format=jax", + match="Unsupported option for convert_to_format=jax", ): _ = solver.solve(model, t_eval)