diff --git a/.github/workflows/publish_pypi.yml b/.github/workflows/publish_pypi.yml index 402bf9e859..8d328118a1 100644 --- a/.github/workflows/publish_pypi.yml +++ b/.github/workflows/publish_pypi.yml @@ -317,7 +317,7 @@ jobs: if: github.event.inputs.target == 'testpypi' uses: pypa/gh-action-pypi-publish@release/v1 with: - packages-dir: files/ + packages-dir: artifacts/ repository-url: https://test.pypi.org/legacy/ open_failure_issue: diff --git a/.github/workflows/run_periodic_tests.yml b/.github/workflows/run_periodic_tests.yml index 6f79df76b2..2dd9ef8a89 100644 --- a/.github/workflows/run_periodic_tests.yml +++ b/.github/workflows/run_periodic_tests.yml @@ -14,6 +14,8 @@ on: env: FORCE_COLOR: 3 + PYBAMM_IDAKLU_EXPR_CASADI: ON + PYBAMM_IDAKLU_EXPR_IREE: ON concurrency: # github.workflow: name of the workflow, so that we don't cancel other workflows diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index 0c81f71bde..9a41c49b69 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -68,6 +68,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard (optional). # Commenting out will disable upload of results to your repo's Code Scanning dashboard - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@b611370bb5703a7efb587f9d136a52ea24c5c38c # v3.25.11 + uses: github/codeql-action/upload-sarif@4fa2a7953630fd2f3fb380f21be14ede0169dd4f # v3.25.12 with: sarif_file: results.sarif diff --git a/.github/workflows/test_on_push.yml b/.github/workflows/test_on_push.yml index 97c37e8c28..adfb698a69 100644 --- a/.github/workflows/test_on_push.yml +++ b/.github/workflows/test_on_push.yml @@ -6,6 +6,8 @@ on: env: FORCE_COLOR: 3 + PYBAMM_IDAKLU_EXPR_CASADI: ON + PYBAMM_IDAKLU_EXPR_IREE: ON concurrency: # github.workflow: name of the workflow, so that we don't cancel other workflows diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8effca2b07..ecd3cd9199 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.5.1" + rev: "v0.5.2" hooks: - id: ruff args: [--fix, --show-fixes] diff --git a/CHANGELOG.md b/CHANGELOG.md index 9addd13346..c89062026f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,10 @@ # [Unreleased](https://github.com/pybamm-team/PyBaMM/) -# [v24.5rc0](https://github.com/pybamm-team/PyBaMM/tree/v24.5rc0) - 2024-05-01 +# [v24.5rc2](https://github.com/pybamm-team/PyBaMM/tree/v24.5rc2) - 2024-07-12 ## Features +- Added additional user-configurable options to the (`IDAKLUSolver`). ([#4249](https://github.com/pybamm-team/PyBaMM/pull/4249)) - Added functionality to pass in arbitrary functions of time as the argument for a (`pybamm.step`). ([#4222](https://github.com/pybamm-team/PyBaMM/pull/4222)) - Added new parameters `"f{pref]Initial inner SEI on cracks thickness [m]"` and `"f{pref]Initial outer SEI on cracks thickness [m]"`, instead of hardcoding these to `L_inner_0 / 10000` and `L_outer_0 / 10000`. ([#4168](https://github.com/pybamm-team/PyBaMM/pull/4168)) - Added `pybamm.DataLoader` class to fetch data files from [pybamm-data](https://github.com/pybamm-team/pybamm-data/releases/tag/v1.0.0) and store it under local cache. ([#4098](https://github.com/pybamm-team/PyBaMM/pull/4098)) diff --git a/CITATION.cff b/CITATION.cff index 43fa574cdd..7e28662bac 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -24,6 +24,6 @@ keywords: - "expression tree" - "python" - "symbolic differentiation" -version: "24.5rc0" +version: "24.5rc2" repository-code: "https://github.com/pybamm-team/PyBaMM" title: "Python Battery Mathematical Modelling (PyBaMM)" diff --git a/CMakeLists.txt b/CMakeLists.txt index b9fe37c331..661f63457e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,32 +35,76 @@ add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) if(NOT PYBIND11_DIR) set(PYBIND11_DIR pybind11) endif() - add_subdirectory(${PYBIND11_DIR}) -# The sources list should mirror the list in setup.py + +# Check Casadi build flag +if(NOT DEFINED PYBAMM_IDAKLU_EXPR_CASADI) + set(PYBAMM_IDAKLU_EXPR_CASADI ON) +endif() +message("PYBAMM_IDAKLU_EXPR_CASADI: ${PYBAMM_IDAKLU_EXPR_CASADI}") + +# Casadi PyBaMM source files +set(IDAKLU_EXPR_CASADI_SOURCE_FILES "") +if(${PYBAMM_IDAKLU_EXPR_CASADI} STREQUAL "ON" ) + add_compile_definitions(CASADI_ENABLE) + set(IDAKLU_EXPR_CASADI_SOURCE_FILES + pybamm/solvers/c_solvers/idaklu/Expressions/Casadi/CasadiFunctions.cpp + pybamm/solvers/c_solvers/idaklu/Expressions/Casadi/CasadiFunctions.hpp + ) +endif() + +# Check IREE build flag +if(NOT DEFINED PYBAMM_IDAKLU_EXPR_IREE) + set(PYBAMM_IDAKLU_EXPR_IREE OFF) +endif() +message("PYBAMM_IDAKLU_EXPR_IREE: ${PYBAMM_IDAKLU_EXPR_IREE}") + +# IREE (MLIR expression evaluation) PyBaMM source files +set(IDAKLU_EXPR_IREE_SOURCE_FILES "") +if(${PYBAMM_IDAKLU_EXPR_IREE} STREQUAL "ON" ) + add_compile_definitions(IREE_ENABLE) + # Source file list + set(IDAKLU_EXPR_IREE_SOURCE_FILES + pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.cpp + pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.hpp + pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.cpp + pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.hpp + pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.cpp + pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.hpp + ) +endif() + +# The complete (all dependencies) sources list should be mirrored in setup.py pybind11_add_module(idaklu - pybamm/solvers/c_solvers/idaklu/casadi_functions.cpp - pybamm/solvers/c_solvers/idaklu/casadi_functions.hpp - pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp - pybamm/solvers/c_solvers/idaklu/casadi_solver.hpp - pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp - pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp - pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp - pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.hpp - pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp - pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.hpp - pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.cpp - pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.hpp - pybamm/solvers/c_solvers/idaklu/idaklu_jax.cpp - pybamm/solvers/c_solvers/idaklu/idaklu_jax.hpp + # pybind11 interface + pybamm/solvers/c_solvers/idaklu.cpp + # IDAKLU solver (SUNDIALS) + pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp + pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.cpp + pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp + pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl + pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp + pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.cpp + pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.hpp + pybamm/solvers/c_solvers/idaklu/sundials_functions.inl + pybamm/solvers/c_solvers/idaklu/sundials_functions.hpp + pybamm/solvers/c_solvers/idaklu/IdakluJax.cpp + pybamm/solvers/c_solvers/idaklu/IdakluJax.hpp pybamm/solvers/c_solvers/idaklu/common.hpp pybamm/solvers/c_solvers/idaklu/python.hpp pybamm/solvers/c_solvers/idaklu/python.cpp - pybamm/solvers/c_solvers/idaklu/solution.cpp - pybamm/solvers/c_solvers/idaklu/solution.hpp - pybamm/solvers/c_solvers/idaklu/options.hpp - pybamm/solvers/c_solvers/idaklu/options.cpp - pybamm/solvers/c_solvers/idaklu.cpp + pybamm/solvers/c_solvers/idaklu/Solution.cpp + pybamm/solvers/c_solvers/idaklu/Solution.hpp + pybamm/solvers/c_solvers/idaklu/Options.hpp + pybamm/solvers/c_solvers/idaklu/Options.cpp + # IDAKLU expressions / function evaluation [abstract] + pybamm/solvers/c_solvers/idaklu/Expressions/Expressions.hpp + pybamm/solvers/c_solvers/idaklu/Expressions/Base/Expression.hpp + pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionSet.hpp + pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionTypes.hpp + # IDAKLU expressions - concrete implementations + ${IDAKLU_EXPR_CASADI_SOURCE_FILES} + ${IDAKLU_EXPR_IREE_SOURCE_FILES} ) if (NOT DEFINED USE_PYTHON_CASADI) @@ -113,3 +157,16 @@ else() endif() include_directories(${SuiteSparse_INCLUDE_DIRS}) target_link_libraries(idaklu PRIVATE ${SuiteSparse_LIBRARIES}) + +# IREE (MLIR compiler and runtime library) build settings +if(${PYBAMM_IDAKLU_EXPR_IREE} STREQUAL "ON" ) + set(IREE_BUILD_COMPILER ON) + set(IREE_BUILD_TESTS OFF) + set(IREE_BUILD_SAMPLES OFF) + add_subdirectory(iree EXCLUDE_FROM_ALL) + set(IREE_COMPILER_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/iree/compiler") + target_include_directories(idaklu SYSTEM PRIVATE "${IREE_COMPILER_ROOT}/bindings/c/iree/compiler") + target_compile_options(idaklu PRIVATE ${IREE_DEFAULT_COPTS}) + target_link_libraries(idaklu PRIVATE iree_compiler_bindings_c_loader) + target_link_libraries(idaklu PRIVATE iree_runtime_runtime) +endif() diff --git a/bandit.yml b/bandit.yml new file mode 100644 index 0000000000..87da61e530 --- /dev/null +++ b/bandit.yml @@ -0,0 +1,2 @@ +# To ignore false flagging of assert statements in tests by Codacy. +skips: ['B101'] diff --git a/docs/source/examples/notebooks/models/spm1.png b/docs/source/examples/notebooks/models/spm1.png index 7e0e9ea9cc..a8509b442a 100644 Binary files a/docs/source/examples/notebooks/models/spm1.png and b/docs/source/examples/notebooks/models/spm1.png differ diff --git a/docs/source/user_guide/installation/gnu-linux-mac.rst b/docs/source/user_guide/installation/gnu-linux-mac.rst index 0be4b98e4c..121c6df437 100644 --- a/docs/source/user_guide/installation/gnu-linux-mac.rst +++ b/docs/source/user_guide/installation/gnu-linux-mac.rst @@ -101,6 +101,19 @@ Users can install ``jax`` and ``jaxlib`` to use the Jax solver. The ``pip install "pybamm[jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``jaxlib`` on your system. (``pybamm_install_jax`` is deprecated.) +.. _optional-iree-mlir-support: + +Optional - IREE / MLIR support +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Users can install ``iree`` (for MLIR just-in-time compilation) to use for main expression evaluation in the IDAKLU solver. Requires ``jax``. + +.. code:: bash + + pip install "pybamm[iree,jax]" + +The ``pip install "pybamm[iree,jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``iree`` onto your system. + Uninstall PyBaMM ---------------- diff --git a/docs/source/user_guide/installation/index.rst b/docs/source/user_guide/installation/index.rst index 778f64c3f9..d6411348c5 100644 --- a/docs/source/user_guide/installation/index.rst +++ b/docs/source/user_guide/installation/index.rst @@ -47,7 +47,8 @@ Optional solvers The following solvers are optionally available: -* `jax `_ -based solver, see `Optional - JaxSolver `_. +* `jax `_ -based solver, see `Optional - JaxSolver `_. +* `IREE `_ (`MLIR `_) support, see `Optional - IREE / MLIR Support `_. Dependencies ------------ @@ -205,6 +206,17 @@ Dependency Minimu `jaxlib `__ 0.4.20 jax Support library for JAX ========================================================================= ================== ================== ======================= +IREE dependencies +^^^^^^^^^^^^^^^^^^ + +Installable with ``pip install "pybamm[iree]"`` (requires ``jax`` dependencies to be installed). + +========================================================================= ================== ================== ======================= +Dependency Minimum Version pip extra Notes +========================================================================= ================== ================== ======================= +`iree-compiler `__ 20240507.886 iree IREE compiler +========================================================================= ================== ================== ======================= + Full installation guide ----------------------- diff --git a/noxfile.py b/noxfile.py index 7237786ef6..373b77f71f 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,6 +1,8 @@ import nox import os import sys +import warnings +import platform from pathlib import Path @@ -13,11 +15,54 @@ nox.options.sessions = ["pre-commit", "unit"] +def set_iree_state(): + """ + Check if IREE is enabled and set the environment variable accordingly. + + Returns + ------- + str + "ON" if IREE is enabled, "OFF" otherwise. + + """ + state = "ON" if os.getenv("PYBAMM_IDAKLU_EXPR_IREE", "OFF") == "ON" else "OFF" + if state == "ON": + if sys.platform == "win32": + warnings.warn( + ( + "IREE is not enabled on Windows yet. " + "Setting PYBAMM_IDAKLU_EXPR_IREE=OFF." + ), + stacklevel=2, + ) + return "OFF" + if sys.platform == "darwin": + # iree-compiler is currently only available as a wheel on macOS 13 (or + # higher) and Python version 3.11 + mac_ver = int(platform.mac_ver()[0].split(".")[0]) + if (not sys.version_info[:2] == (3, 11)) or mac_ver < 13: + warnings.warn( + ( + "IREE is only supported on MacOS 13 (or higher) and Python" + "version 3.11. Setting PYBAMM_IDAKLU_EXPR_IREE=OFF." + ), + stacklevel=2, + ) + return "OFF" + return state + + homedir = os.getenv("HOME") PYBAMM_ENV = { "SUNDIALS_INST": f"{homedir}/.local", "LD_LIBRARY_PATH": f"{homedir}/.local/lib", "PYTHONIOENCODING": "utf-8", + # Expression evaluators (...EXPR_CASADI cannot be fully disabled at this time) + "PYBAMM_IDAKLU_EXPR_CASADI": os.getenv("PYBAMM_IDAKLU_EXPR_CASADI", "ON"), + "PYBAMM_IDAKLU_EXPR_IREE": set_iree_state(), + "IREE_INDEX_URL": os.getenv( + "IREE_INDEX_URL", "https://iree.dev/pip-release-links.html" + ), } VENV_DIR = Path("./venv").resolve() @@ -59,6 +104,29 @@ def run_pybamm_requires(session): "advice.detachedHead=false", external=True, ) + if PYBAMM_ENV.get("PYBAMM_IDAKLU_EXPR_IREE") == "ON" and not os.path.exists( + "./iree" + ): + session.run( + "git", + "clone", + "--depth=1", + "--recurse-submodules", + "--shallow-submodules", + "--branch=candidate-20240507.886", + "https://github.com/openxla/iree", + "iree/", + external=True, + ) + with session.chdir("iree"): + session.run( + "git", + "submodule", + "update", + "--init", + "--recursive", + external=True, + ) else: session.error("nox -s pybamm-requires is only available on Linux & macOS.") @@ -70,6 +138,15 @@ def run_coverage(session): session.install("setuptools", silent=False) session.install("coverage", silent=False) session.install("-e", ".[all,dev,jax]", silent=False) + if PYBAMM_ENV.get("PYBAMM_IDAKLU_EXPR_IREE") == "ON": + # See comments in 'dev' session + session.install( + "-e", + ".[iree]", + "--find-links", + PYBAMM_ENV.get("IREE_INDEX_URL"), + silent=False, + ) session.run("pytest", "--cov=pybamm", "--cov-report=xml", "tests/unit") @@ -98,6 +175,15 @@ def run_unit(session): set_environment_variables(PYBAMM_ENV, session=session) session.install("setuptools", silent=False) session.install("-e", ".[all,dev,jax]", silent=False) + if PYBAMM_ENV.get("PYBAMM_IDAKLU_EXPR_IREE") == "ON": + # See comments in 'dev' session + session.install( + "-e", + ".[iree]", + "--find-links", + PYBAMM_ENV.get("IREE_INDEX_URL"), + silent=False, + ) session.run("python", "run-tests.py", "--unit") @@ -130,6 +216,17 @@ def set_dev(session): session.install("virtualenv", "cmake") session.run("virtualenv", os.fsdecode(VENV_DIR), silent=True) python = os.fsdecode(VENV_DIR.joinpath("bin/python")) + components = ["all", "dev", "jax"] + args = [] + if PYBAMM_ENV.get("PYBAMM_IDAKLU_EXPR_IREE") == "ON": + # Install IREE libraries for Jax-MLIR expression evaluation in the IDAKLU solver + # (optional). IREE is currently pre-release and relies on nightly jaxlib builds. + # When upgrading Jax/IREE ensure that the following are compatible with each other: + # - Jax and Jaxlib version [pyproject.toml] + # - IREE repository clone (use the matching nightly candidate) [noxfile.py] + # - IREE compiler matches Jaxlib (use the matching nightly build) [pyproject.toml] + components.append("iree") + args = ["--find-links", PYBAMM_ENV.get("IREE_INDEX_URL")] # Temporary fix for Python 3.12 CI. TODO: remove after # https://bitbucket.org/pybtex-devs/pybtex/issues/169/replace-pkg_resources-with # is fixed @@ -140,7 +237,8 @@ def set_dev(session): "pip", "install", "-e", - ".[all,dev,jax]", + ".[{}]".format(",".join(components)), + *args, external=True, ) diff --git a/pybamm/__init__.py b/pybamm/__init__.py index b3b7fafd3f..a371fdbc03 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -2,6 +2,9 @@ from pybamm.version import __version__ +# Demote expressions to 32-bit floats/ints - option used for IDAKLU-MLIR compilation +demote_expressions_to_32bit = False + # Utility classes and methods from .util import root_dir from .util import Timer, TimerTime, FuzzyDict @@ -168,7 +171,7 @@ from .solvers.jax_bdf_solver import jax_bdf_integrate from .solvers.idaklu_jax import IDAKLUJax -from .solvers.idaklu_solver import IDAKLUSolver, have_idaklu +from .solvers.idaklu_solver import IDAKLUSolver, have_idaklu, have_iree # Experiments from .experiment.experiment import Experiment diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index 6d13761756..20a6d4b4a2 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -281,7 +281,7 @@ def find_symbols( if isinstance(symbol, pybamm.Index): symbol_str = f"{children_vars[0]}[{symbol.slice.start}:{symbol.slice.stop}]" else: - symbol_str = symbol.name + children_vars[0] + symbol_str = symbol.name + "(" + children_vars[0] + ")" elif isinstance(symbol, pybamm.Function): children_str = "" @@ -596,6 +596,59 @@ def __init__(self, symbol: pybamm.Symbol): static_argnums=self._static_argnums, ) + def _demote_constants(self): + """Demote 64-bit constants (f64, i64) to 32-bit (f32, i32)""" + if not pybamm.demote_expressions_to_32bit: + return # pragma: no cover + self._constants = EvaluatorJax._demote_64_to_32(self._constants) + + @classmethod + def _demote_64_to_32(cls, c): + """Demote 64-bit operations (f64, i64) to 32-bit (f32, i32)""" + + if not pybamm.demote_expressions_to_32bit: + return c + if isinstance(c, float): + c = jax.numpy.float32(c) + if isinstance(c, int): + c = jax.numpy.int32(c) + if isinstance(c, np.int64): + c = c.astype(jax.numpy.int32) + if isinstance(c, np.ndarray): + if c.dtype == np.float64: + c = c.astype(jax.numpy.float32) + if c.dtype == np.int64: + c = c.astype(jax.numpy.int32) + if isinstance(c, jax.numpy.ndarray): + if c.dtype == jax.numpy.float64: + c = c.astype(jax.numpy.float32) + if c.dtype == jax.numpy.int64: + c = c.astype(jax.numpy.int32) + if isinstance( + c, pybamm.expression_tree.operations.evaluate_python.JaxCooMatrix + ): + if c.data.dtype == np.float64: + c.data = c.data.astype(jax.numpy.float32) + if c.row.dtype == np.int64: + c.row = c.row.astype(jax.numpy.int32) + if c.col.dtype == np.int64: + c.col = c.col.astype(jax.numpy.int32) + if isinstance(c, dict): + c = {key: EvaluatorJax._demote_64_to_32(value) for key, value in c.items()} + if isinstance(c, tuple): + c = tuple(EvaluatorJax._demote_64_to_32(value) for value in c) + if isinstance(c, list): + c = [EvaluatorJax._demote_64_to_32(value) for value in c] + return c + + @property + def _constants(self): + return tuple(map(EvaluatorJax._demote_64_to_32, self.__constants)) + + @_constants.setter + def _constants(self, value): + self.__constants = value + def get_jacobian(self): n = len(self._arg_list) diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index df549747c9..aa9ebe66db 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -282,7 +282,8 @@ def name(self): @name.setter def name(self, value: str): - assert isinstance(value, str) + if not isinstance(value, str): + raise TypeError(f"{value} must be of type str") self._name = value @property diff --git a/pybamm/models/full_battery_models/lithium_ion/base_lithium_ion_model.py b/pybamm/models/full_battery_models/lithium_ion/base_lithium_ion_model.py index 479e8203ed..6db56b74c4 100644 --- a/pybamm/models/full_battery_models/lithium_ion/base_lithium_ion_model.py +++ b/pybamm/models/full_battery_models/lithium_ion/base_lithium_ion_model.py @@ -265,9 +265,9 @@ def set_sei_submodel(self): reaction_loc = "x-average" else: reaction_loc = "full electrode" - sei_option = getattr(self.options, domain)["SEI"] phases = self.options.phases[domain] for phase in phases: + sei_option = getattr(getattr(self.options, domain), phase)["SEI"] if sei_option == "none": submodel = pybamm.sei.NoSEI(self.param, domain, self.options, phase) elif sei_option == "constant": @@ -333,9 +333,11 @@ def set_lithium_plating_submodel(self): for domain in self.options.whole_cell_domains: if domain != "separator": domain = domain.split()[0].lower() - lithium_plating_opt = getattr(self.options, domain)["lithium plating"] phases = self.options.phases[domain] for phase in phases: + lithium_plating_opt = getattr(getattr(self.options, domain), phase)[ + "lithium plating" + ] if lithium_plating_opt == "none": submodel = pybamm.lithium_plating.NoPlating( self.param, domain, self.options, phase diff --git a/pybamm/models/submodels/active_material/loss_active_material.py b/pybamm/models/submodels/active_material/loss_active_material.py index 7816122e07..6f027d89e6 100644 --- a/pybamm/models/submodels/active_material/loss_active_material.py +++ b/pybamm/models/submodels/active_material/loss_active_material.py @@ -60,7 +60,9 @@ def get_coupled_variables(self, variables): domain, Domain = self.domain_Domain deps_solid_dt = 0 - lam_option = getattr(self.options, self.domain)["loss of active material"] + lam_option = getattr(getattr(self.options, domain), self.phase)[ + "loss of active material" + ] if "stress" in lam_option: # obtain the rate of loss of active materials (LAM) by stress # This is loss of active material model by mechanical effects diff --git a/pybamm/settings.py b/pybamm/settings.py index 2ccd9bcd13..d190eaf47e 100644 --- a/pybamm/settings.py +++ b/pybamm/settings.py @@ -29,8 +29,9 @@ def debug_mode(self): return self._debug_mode @debug_mode.setter - def debug_mode(self, value): - assert isinstance(value, bool) + def debug_mode(self, value: bool): + if not isinstance(value, bool): + raise TypeError(f"{value} must be of type bool") self._debug_mode = value @property @@ -38,8 +39,9 @@ def simplify(self): return self._simplify @simplify.setter - def simplify(self, value): - assert isinstance(value, bool) + def simplify(self, value: bool): + if not isinstance(value, bool): + raise TypeError(f"{value} must be of type bool") self._simplify = value def set_smoothing_parameters(self, k): diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 1425bf0845..0eb573e87a 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -256,32 +256,30 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False): model.casadi_sensitivities_rhs = jacp_rhs model.casadi_sensitivities_algebraic = jacp_algebraic - # if output_variables specified then convert functions to casadi - # expressions for evaluation within the respective solver - self.computed_var_fcns = {} - self.computed_dvar_dy_fcns = {} - self.computed_dvar_dp_fcns = {} - for key in self.output_variables: - # ExplicitTimeIntegral's are not computed as part of the solver and - # do not need to be converted - if isinstance( - model.variables_and_events[key], pybamm.ExplicitTimeIntegral - ): - continue - # Generate Casadi function to calculate variable and derivates - # to enable sensitivites to be computed within the solver - ( - self.computed_var_fcns[key], - self.computed_dvar_dy_fcns[key], - self.computed_dvar_dp_fcns[key], - _, - ) = process( - model.variables_and_events[key], - BaseSolver._wrangle_name(key), - vars_for_processing, - use_jacobian=True, - return_jacp_stacked=True, - ) + # if output_variables specified then convert functions to casadi + # expressions for evaluation within the respective solver + self.computed_var_fcns = {} + self.computed_dvar_dy_fcns = {} + self.computed_dvar_dp_fcns = {} + for key in self.output_variables: + # ExplicitTimeIntegral's are not computed as part of the solver and + # do not need to be converted + if isinstance(model.variables_and_events[key], pybamm.ExplicitTimeIntegral): + continue + # Generate Casadi function to calculate variable and derivates + # to enable sensitivites to be computed within the solver + ( + self.computed_var_fcns[key], + self.computed_dvar_dy_fcns[key], + self.computed_dvar_dp_fcns[key], + _, + ) = process( + model.variables_and_events[key], + BaseSolver._wrangle_name(key), + vars_for_processing, + use_jacobian=True, + return_jacp_stacked=True, + ) pybamm.logger.info("Finish solver set-up") diff --git a/pybamm/solvers/c_solvers/idaklu.cpp b/pybamm/solvers/c_solvers/idaklu.cpp index 9f99d4d3f4..3afed5faa8 100644 --- a/pybamm/solvers/c_solvers/idaklu.cpp +++ b/pybamm/solvers/c_solvers/idaklu.cpp @@ -8,14 +8,20 @@ #include #include -#include "idaklu/casadi_solver.hpp" -#include "idaklu/idaklu_jax.hpp" +#include "idaklu/idaklu_solver.hpp" +#include "idaklu/IdakluJax.hpp" #include "idaklu/common.hpp" #include "idaklu/python.hpp" +#include "idaklu/Expressions/Casadi/CasadiFunctions.hpp" -Function generate_function(const std::string &data) +#ifdef IREE_ENABLE +#include "idaklu/Expressions/IREE/IREEFunctions.hpp" +#endif + + +casadi::Function generate_casadi_function(const std::string &data) { - return Function::deserialize(data); + return casadi::Function::deserialize(data); } namespace py = pybind11; @@ -50,8 +56,8 @@ PYBIND11_MODULE(idaklu, m) py::arg("number_of_sensitivity_parameters"), py::return_value_policy::take_ownership); - py::class_(m, "CasadiSolver") - .def("solve", &CasadiSolver::solve, + py::class_(m, "IDAKLUSolver") + .def("solve", &IDAKLUSolver::solve, "perform a solve", py::arg("t"), py::arg("y0"), @@ -59,7 +65,7 @@ PYBIND11_MODULE(idaklu, m) py::arg("inputs"), py::return_value_policy::take_ownership); - m.def("create_casadi_solver", &create_casadi_solver, + m.def("create_casadi_solver", &create_idaklu_solver, "Create a casadi idaklu solver object", py::arg("number_of_states"), py::arg("number_of_parameters"), @@ -79,13 +85,41 @@ PYBIND11_MODULE(idaklu, m) py::arg("atol"), py::arg("rtol"), py::arg("inputs"), - py::arg("var_casadi_fcns"), + py::arg("var_fcns"), + py::arg("dvar_dy_fcns"), + py::arg("dvar_dp_fcns"), + py::arg("options"), + py::return_value_policy::take_ownership); + +#ifdef IREE_ENABLE + m.def("create_iree_solver", &create_idaklu_solver, + "Create a iree idaklu solver object", + py::arg("number_of_states"), + py::arg("number_of_parameters"), + py::arg("rhs_alg"), + py::arg("jac_times_cjmass"), + py::arg("jac_times_cjmass_colptrs"), + py::arg("jac_times_cjmass_rowvals"), + py::arg("jac_times_cjmass_nnz"), + py::arg("jac_bandwidth_lower"), + py::arg("jac_bandwidth_upper"), + py::arg("jac_action"), + py::arg("mass_action"), + py::arg("sens"), + py::arg("events"), + py::arg("number_of_events"), + py::arg("rhs_alg_id"), + py::arg("atol"), + py::arg("rtol"), + py::arg("inputs"), + py::arg("var_fcns"), py::arg("dvar_dy_fcns"), py::arg("dvar_dp_fcns"), py::arg("options"), py::return_value_policy::take_ownership); +#endif - m.def("generate_function", &generate_function, + m.def("generate_function", &generate_casadi_function, "Generate a casadi function", py::arg("string"), py::return_value_policy::take_ownership); @@ -133,11 +167,25 @@ PYBIND11_MODULE(idaklu, m) &Registrations ); - py::class_(m, "Function"); + py::class_(m, "Function"); + +#ifdef IREE_ENABLE + py::class_(m, "IREEBaseFunctionType") + .def(py::init<>()) + .def_readwrite("mlir", &IREEBaseFunctionType::mlir) + .def_readwrite("kept_var_idx", &IREEBaseFunctionType::kept_var_idx) + .def_readwrite("nnz", &IREEBaseFunctionType::nnz) + .def_readwrite("numel", &IREEBaseFunctionType::numel) + .def_readwrite("col", &IREEBaseFunctionType::col) + .def_readwrite("row", &IREEBaseFunctionType::row) + .def_readwrite("pytree_shape", &IREEBaseFunctionType::pytree_shape) + .def_readwrite("pytree_sizes", &IREEBaseFunctionType::pytree_sizes) + .def_readwrite("n_args", &IREEBaseFunctionType::n_args); +#endif py::class_(m, "solution") - .def_readwrite("t", &Solution::t) - .def_readwrite("y", &Solution::y) - .def_readwrite("yS", &Solution::yS) - .def_readwrite("flag", &Solution::flag); + .def_readwrite("t", &Solution::t) + .def_readwrite("y", &Solution::y) + .def_readwrite("yS", &Solution::yS) + .def_readwrite("flag", &Solution::flag); } diff --git a/pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp b/pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp deleted file mode 100644 index 16a04f8eb9..0000000000 --- a/pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp +++ /dev/null @@ -1 +0,0 @@ -#include "CasadiSolver.hpp" diff --git a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp b/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp deleted file mode 100644 index ad51eda4e1..0000000000 --- a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp +++ /dev/null @@ -1,518 +0,0 @@ -#include "CasadiSolverOpenMP.hpp" -#include "casadi_sundials_functions.hpp" -#include -#include -#include - -CasadiSolverOpenMP::CasadiSolverOpenMP( - np_array atol_np, - double rel_tol, - np_array rhs_alg_id, - int number_of_parameters, - int number_of_events, - int jac_times_cjmass_nnz, - int jac_bandwidth_lower, - int jac_bandwidth_upper, - std::unique_ptr functions_arg, - const Options &options -) : - atol_np(atol_np), - rhs_alg_id(rhs_alg_id), - number_of_states(atol_np.request().size), - number_of_parameters(number_of_parameters), - number_of_events(number_of_events), - jac_times_cjmass_nnz(jac_times_cjmass_nnz), - jac_bandwidth_lower(jac_bandwidth_lower), - jac_bandwidth_upper(jac_bandwidth_upper), - functions(std::move(functions_arg)), - options(options) -{ - // Construction code moved to Initialize() which is called from the - // (child) CasadiSolver_XXX class constructors. - DEBUG("CasadiSolverOpenMP::CasadiSolverOpenMP"); - auto atol = atol_np.unchecked<1>(); - - // create SUNDIALS context object - SUNContext_Create(NULL, &sunctx); // calls null-wrapper if Sundials Ver<6 - - // allocate memory for solver - ida_mem = IDACreate(sunctx); - - // create the vector of initial values - AllocateVectors(); - if (number_of_parameters > 0) - { - yyS = N_VCloneVectorArray(number_of_parameters, yy); - ypS = N_VCloneVectorArray(number_of_parameters, yp); - } - // set initial values - realtype *atval = N_VGetArrayPointer(avtol); - for (int i = 0; i < number_of_states; i++) - atval[i] = atol[i]; - for (int is = 0; is < number_of_parameters; is++) - { - N_VConst(RCONST(0.0), yyS[is]); - N_VConst(RCONST(0.0), ypS[is]); - } - - // create Matrix objects - SetMatrix(); - - // initialise solver - IDAInit(ida_mem, residual_casadi, 0, yy, yp); - - // set tolerances - rtol = RCONST(rel_tol); - IDASVtolerances(ida_mem, rtol, avtol); - - // set events - IDARootInit(ida_mem, number_of_events, events_casadi); - void *user_data = functions.get(); - IDASetUserData(ida_mem, user_data); - - // specify preconditioner type - precon_type = SUN_PREC_NONE; - if (options.preconditioner != "none") { - precon_type = SUN_PREC_LEFT; - } -} - -void CasadiSolverOpenMP::AllocateVectors() { - // Create vectors - yy = N_VNew_OpenMP(number_of_states, options.num_threads, sunctx); - yp = N_VNew_OpenMP(number_of_states, options.num_threads, sunctx); - avtol = N_VNew_OpenMP(number_of_states, options.num_threads, sunctx); - id = N_VNew_OpenMP(number_of_states, options.num_threads, sunctx); -} - -void CasadiSolverOpenMP::SetMatrix() { - // Create Matrix object - if (options.jacobian == "sparse") - { - DEBUG("\tsetting sparse matrix"); - J = SUNSparseMatrix( - number_of_states, - number_of_states, - jac_times_cjmass_nnz, - CSC_MAT, // CSC is used by casadi; CSR requires a conversion step - sunctx - ); - } - else if (options.jacobian == "banded") { - DEBUG("\tsetting banded matrix"); - J = SUNBandMatrix( - number_of_states, - jac_bandwidth_upper, - jac_bandwidth_lower, - sunctx - ); - } else if (options.jacobian == "dense" || options.jacobian == "none") - { - DEBUG("\tsetting dense matrix"); - J = SUNDenseMatrix( - number_of_states, - number_of_states, - sunctx - ); - } - else if (options.jacobian == "matrix-free") - { - DEBUG("\tsetting matrix-free"); - J = NULL; - } - else - throw std::invalid_argument("Unsupported matrix requested"); -} - -void CasadiSolverOpenMP::Initialize() { - // Call after setting the solver - - // attach the linear solver - if (LS == nullptr) - throw std::invalid_argument("Linear solver not set"); - IDASetLinearSolver(ida_mem, LS, J); - - if (options.preconditioner != "none") - { - DEBUG("\tsetting IDADDB preconditioner"); - // setup preconditioner - IDABBDPrecInit( - ida_mem, number_of_states, options.precon_half_bandwidth, - options.precon_half_bandwidth, options.precon_half_bandwidth_keep, - options.precon_half_bandwidth_keep, 0.0, residual_casadi_approx, NULL); - } - - if (options.jacobian == "matrix-free") - IDASetJacTimes(ida_mem, NULL, jtimes_casadi); - else if (options.jacobian != "none") - IDASetJacFn(ida_mem, jacobian_casadi); - - if (number_of_parameters > 0) - { - IDASensInit(ida_mem, number_of_parameters, IDA_SIMULTANEOUS, - sensitivities_casadi, yyS, ypS); - IDASensEEtolerances(ida_mem); - } - - SUNLinSolInitialize(LS); - - auto id_np_val = rhs_alg_id.unchecked<1>(); - realtype *id_val; - id_val = N_VGetArrayPointer(id); - - int ii; - for (ii = 0; ii < number_of_states; ii++) - id_val[ii] = id_np_val[ii]; - - IDASetId(ida_mem, id); -} - -CasadiSolverOpenMP::~CasadiSolverOpenMP() -{ - // Free memory - if (number_of_parameters > 0) - IDASensFree(ida_mem); - - SUNLinSolFree(LS); - SUNMatDestroy(J); - N_VDestroy(avtol); - N_VDestroy(yy); - N_VDestroy(yp); - N_VDestroy(id); - - if (number_of_parameters > 0) - { - N_VDestroyVectorArray(yyS, number_of_parameters); - N_VDestroyVectorArray(ypS, number_of_parameters); - } - - IDAFree(&ida_mem); - SUNContext_Free(&sunctx); -} - -void CasadiSolverOpenMP::CalcVars( - realtype *y_return, - size_t length_of_return_vector, - size_t t_i, - realtype *tret, - realtype *yval, - const std::vector& ySval, - realtype *yS_return, - size_t *ySk -) { - // Evaluate casadi functions for each requested variable and store - size_t j = 0; - for (auto& var_fcn : functions->var_casadi_fcns) { - var_fcn({tret, yval, functions->inputs.data()}, {res}); - // store in return vector - for (size_t jj=0; jj& ySval, - realtype *yS_return, - size_t *ySk -) { - // Calculate sensitivities - - // Loop over variables - realtype* dens_dvar_dp = new realtype[number_of_parameters]; - for (size_t dvar_k=0; dvar_kdvar_dy_fcns.size(); dvar_k++) { - // Isolate functions - CasadiFunction dvar_dy = functions->dvar_dy_fcns[dvar_k]; - CasadiFunction dvar_dp = functions->dvar_dp_fcns[dvar_k]; - // Calculate dvar/dy - dvar_dy({tret, yval, functions->inputs.data()}, {res_dvar_dy}); - casadi::Sparsity spdy = dvar_dy.sparsity_out(0); - // Calculate dvar/dp and convert to dense array for indexing - dvar_dp({tret, yval, functions->inputs.data()}, {res_dvar_dp}); - casadi::Sparsity spdp = dvar_dp.sparsity_out(0); - for(int k=0; k(); - realtype t0 = RCONST(t(0)); - auto y0 = y0_np.unchecked<1>(); - auto yp0 = yp0_np.unchecked<1>(); - auto n_coeffs = number_of_states + number_of_parameters * number_of_states; - - if (y0.size() != n_coeffs) - throw std::domain_error( - "y0 has wrong size. Expected " + std::to_string(n_coeffs) + - " but got " + std::to_string(y0.size())); - - if (yp0.size() != n_coeffs) - throw std::domain_error( - "yp0 has wrong size. Expected " + std::to_string(n_coeffs) + - " but got " + std::to_string(yp0.size())); - - // set inputs - auto p_inputs = inputs.unchecked<2>(); - for (int i = 0; i < functions->inputs.size(); i++) - functions->inputs[i] = p_inputs(i, 0); - - // set initial conditions - realtype *yval = N_VGetArrayPointer(yy); - realtype *ypval = N_VGetArrayPointer(yp); - std::vector ySval(number_of_parameters); - std::vector ypSval(number_of_parameters); - for (int p = 0 ; p < number_of_parameters; p++) { - ySval[p] = N_VGetArrayPointer(yyS[p]); - ypSval[p] = N_VGetArrayPointer(ypS[p]); - for (int i = 0; i < number_of_states; i++) { - ySval[p][i] = y0[i + (p + 1) * number_of_states]; - ypSval[p][i] = yp0[i + (p + 1) * number_of_states]; - } - } - - for (int i = 0; i < number_of_states; i++) - { - yval[i] = y0[i]; - ypval[i] = yp0[i]; - } - - IDAReInit(ida_mem, t0, yy, yp); - if (number_of_parameters > 0) - IDASensReInit(ida_mem, IDA_SIMULTANEOUS, yyS, ypS); - - // correct initial values - DEBUG("IDACalcIC"); - IDACalcIC(ida_mem, IDA_YA_YDP_INIT, t(1)); - if (number_of_parameters > 0) - IDAGetSens(ida_mem, &t0, yyS); - - realtype tret; - realtype t_final = t(number_of_timesteps - 1); - - // set return vectors - int length_of_return_vector = 0; - size_t max_res_size = 0; // maximum result size (for common result buffer) - size_t max_res_dvar_dy = 0, max_res_dvar_dp = 0; - if (functions->var_casadi_fcns.size() > 0) { - // return only the requested variables list after computation - for (auto& var_fcn : functions->var_casadi_fcns) { - max_res_size = std::max(max_res_size, size_t(var_fcn.nnz_out())); - length_of_return_vector += var_fcn.nnz_out(); - for (auto& dvar_fcn : functions->dvar_dy_fcns) - max_res_dvar_dy = std::max(max_res_dvar_dy, size_t(dvar_fcn.nnz_out())); - for (auto& dvar_fcn : functions->dvar_dp_fcns) - max_res_dvar_dp = std::max(max_res_dvar_dp, size_t(dvar_fcn.nnz_out())); - } - } else { - // Return full y state-vector - length_of_return_vector = number_of_states; - } - realtype *t_return = new realtype[number_of_timesteps]; - realtype *y_return = new realtype[number_of_timesteps * - length_of_return_vector]; - realtype *yS_return = new realtype[number_of_parameters * - number_of_timesteps * - length_of_return_vector]; - - res = new realtype[max_res_size]; - res_dvar_dy = new realtype[max_res_dvar_dy]; - res_dvar_dp = new realtype[max_res_dvar_dp]; - - py::capsule free_t_when_done( - t_return, - [](void *f) { - realtype *vect = reinterpret_cast(f); - delete[] vect; - } - ); - py::capsule free_y_when_done( - y_return, - [](void *f) { - realtype *vect = reinterpret_cast(f); - delete[] vect; - } - ); - py::capsule free_yS_when_done( - yS_return, - [](void *f) { - realtype *vect = reinterpret_cast(f); - delete[] vect; - } - ); - - // Initial state (t_i=0) - int t_i = 0; - size_t ySk = 0; - t_return[t_i] = t(t_i); - if (functions->var_casadi_fcns.size() > 0) { - // Evaluate casadi functions for each requested variable and store - CalcVars(y_return, length_of_return_vector, t_i, - &tret, yval, ySval, yS_return, &ySk); - } else { - // Retain complete copy of the state vector - for (int j = 0; j < number_of_states; j++) - y_return[j] = yval[j]; - for (int j = 0; j < number_of_parameters; j++) - { - const int base_index = j * number_of_timesteps * number_of_states; - for (int k = 0; k < number_of_states; k++) - yS_return[base_index + k] = ySval[j][k]; - } - } - - // Subsequent states (t_i>0) - int retval; - t_i = 1; - while (true) - { - realtype t_next = t(t_i); - IDASetStopTime(ida_mem, t_next); - DEBUG("IDASolve"); - retval = IDASolve(ida_mem, t_final, &tret, yy, yp, IDA_NORMAL); - - if (retval == IDA_TSTOP_RETURN || - retval == IDA_SUCCESS || - retval == IDA_ROOT_RETURN) - { - if (number_of_parameters > 0) - IDAGetSens(ida_mem, &tret, yyS); - - // Evaluate and store results for the time step - t_return[t_i] = tret; - if (functions->var_casadi_fcns.size() > 0) { - // Evaluate casadi functions for each requested variable and store - // NOTE: Indexing of yS_return is (time:var:param) - CalcVars(y_return, length_of_return_vector, t_i, - &tret, yval, ySval, yS_return, &ySk); - } else { - // Retain complete copy of the state vector - for (int j = 0; j < number_of_states; j++) - y_return[t_i * number_of_states + j] = yval[j]; - for (int j = 0; j < number_of_parameters; j++) - { - const int base_index = - j * number_of_timesteps * number_of_states + - t_i * number_of_states; - for (int k = 0; k < number_of_states; k++) - // NOTE: Indexing of yS_return is (time:param:yvec) - yS_return[base_index + k] = ySval[j][k]; - } - } - t_i += 1; - - if (retval == IDA_SUCCESS || - retval == IDA_ROOT_RETURN) - break; - } - else - { - // failed - break; - } - } - - np_array t_ret = np_array( - t_i, - &t_return[0], - free_t_when_done - ); - np_array y_ret = np_array( - t_i * length_of_return_vector, - &y_return[0], - free_y_when_done - ); - // Note: Ordering of vector is differnet if computing variables vs returning - // the complete state vector - np_array yS_ret; - if (functions->var_casadi_fcns.size() > 0) { - yS_ret = np_array( - std::vector { - number_of_timesteps, - length_of_return_vector, - number_of_parameters - }, - &yS_return[0], - free_yS_when_done - ); - } else { - yS_ret = np_array( - std::vector { - number_of_parameters, - number_of_timesteps, - length_of_return_vector - }, - &yS_return[0], - free_yS_when_done - ); - } - - Solution sol(retval, t_ret, y_ret, yS_ret); - - if (options.print_stats) - { - long nsteps, nrevals, nlinsetups, netfails; - int klast, kcur; - realtype hinused, hlast, hcur, tcur; - - IDAGetIntegratorStats( - ida_mem, - &nsteps, - &nrevals, - &nlinsetups, - &netfails, - &klast, - &kcur, - &hinused, - &hlast, - &hcur, - &tcur - ); - - long nniters, nncfails; - IDAGetNonlinSolvStats(ida_mem, &nniters, &nncfails); - - long int ngevalsBBDP = 0; - if (options.using_iterative_solver) - IDABBDPrecGetNumGfnEvals(ida_mem, &ngevalsBBDP); - - py::print("Solver Stats:"); - py::print("\tNumber of steps =", nsteps); - py::print("\tNumber of calls to residual function =", nrevals); - py::print("\tNumber of calls to residual function in preconditioner =", - ngevalsBBDP); - py::print("\tNumber of linear solver setup calls =", nlinsetups); - py::print("\tNumber of error test failures =", netfails); - py::print("\tMethod order used on last step =", klast); - py::print("\tMethod order used on next step =", kcur); - py::print("\tInitial step size =", hinused); - py::print("\tStep size on last step =", hlast); - py::print("\tStep size on next step =", hcur); - py::print("\tCurrent internal time reached =", tcur); - py::print("\tNumber of nonlinear iterations performed =", nniters); - py::print("\tNumber of nonlinear convergence failures =", nncfails); - } - - return sol; -} diff --git a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp b/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp deleted file mode 100644 index 868d2b2138..0000000000 --- a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp +++ /dev/null @@ -1 +0,0 @@ -#include "CasadiSolverOpenMP_solvers.hpp" diff --git a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.hpp b/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.hpp deleted file mode 100644 index 3e39e5a303..0000000000 --- a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.hpp +++ /dev/null @@ -1,125 +0,0 @@ -#ifndef PYBAMM_IDAKLU_CASADI_SOLVER_OPENMP_HPP -#define PYBAMM_IDAKLU_CASADI_SOLVER_OPENMP_HPP - -#include "CasadiSolverOpenMP.hpp" -#include "casadi_solver.hpp" - -/** - * @brief CasadiSolver Dense implementation with OpenMP class - */ -class CasadiSolverOpenMP_Dense : public CasadiSolverOpenMP { -public: - template - CasadiSolverOpenMP_Dense(Args&& ... args) - : CasadiSolverOpenMP(std::forward(args) ...) - { - LS = SUNLinSol_Dense(yy, J, sunctx); - Initialize(); - } -}; - -/** - * @brief CasadiSolver KLU implementation with OpenMP class - */ -class CasadiSolverOpenMP_KLU : public CasadiSolverOpenMP { -public: - template - CasadiSolverOpenMP_KLU(Args&& ... args) - : CasadiSolverOpenMP(std::forward(args) ...) - { - LS = SUNLinSol_KLU(yy, J, sunctx); - Initialize(); - } -}; - -/** - * @brief CasadiSolver Banded implementation with OpenMP class - */ -class CasadiSolverOpenMP_Band : public CasadiSolverOpenMP { -public: - template - CasadiSolverOpenMP_Band(Args&& ... args) - : CasadiSolverOpenMP(std::forward(args) ...) - { - LS = SUNLinSol_Band(yy, J, sunctx); - Initialize(); - } -}; - -/** - * @brief CasadiSolver SPBCGS implementation with OpenMP class - */ -class CasadiSolverOpenMP_SPBCGS : public CasadiSolverOpenMP { -public: - template - CasadiSolverOpenMP_SPBCGS(Args&& ... args) - : CasadiSolverOpenMP(std::forward(args) ...) - { - LS = SUNLinSol_SPBCGS( - yy, - precon_type, - options.linsol_max_iterations, - sunctx - ); - Initialize(); - } -}; - -/** - * @brief CasadiSolver SPFGMR implementation with OpenMP class - */ -class CasadiSolverOpenMP_SPFGMR : public CasadiSolverOpenMP { -public: - template - CasadiSolverOpenMP_SPFGMR(Args&& ... args) - : CasadiSolverOpenMP(std::forward(args) ...) - { - LS = SUNLinSol_SPFGMR( - yy, - precon_type, - options.linsol_max_iterations, - sunctx - ); - Initialize(); - } -}; - -/** - * @brief CasadiSolver SPGMR implementation with OpenMP class - */ -class CasadiSolverOpenMP_SPGMR : public CasadiSolverOpenMP { -public: - template - CasadiSolverOpenMP_SPGMR(Args&& ... args) - : CasadiSolverOpenMP(std::forward(args) ...) - { - LS = SUNLinSol_SPGMR( - yy, - precon_type, - options.linsol_max_iterations, - sunctx - ); - Initialize(); - } -}; - -/** - * @brief CasadiSolver SPTFQMR implementation with OpenMP class - */ -class CasadiSolverOpenMP_SPTFQMR : public CasadiSolverOpenMP { -public: - template - CasadiSolverOpenMP_SPTFQMR(Args&& ... args) - : CasadiSolverOpenMP(std::forward(args) ...) - { - LS = SUNLinSol_SPTFQMR( - yy, - precon_type, - options.linsol_max_iterations, - sunctx - ); - Initialize(); - } -}; - -#endif // PYBAMM_IDAKLU_CASADI_SOLVER_OPENMP_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/Base/Expression.hpp b/pybamm/solvers/c_solvers/idaklu/Expressions/Base/Expression.hpp new file mode 100644 index 0000000000..bbf60b4568 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/Base/Expression.hpp @@ -0,0 +1,69 @@ +#ifndef PYBAMM_EXPRESSION_HPP +#define PYBAMM_EXPRESSION_HPP + +#include "ExpressionTypes.hpp" +#include "../../common.hpp" +#include "../../Options.hpp" +#include +#include + +class Expression { +public: // method declarations + /** + * @brief Constructor + */ + Expression() = default; + + /** + * @brief Evaluation operator (for use after setting input and output data references) + */ + virtual void operator()() = 0; + + /** + * @brief Evaluation operator (supplying data references) + */ + virtual void operator()( + const std::vector& inputs, + const std::vector& results) = 0; + + /** + * @brief The maximum number of elements returned by the k'th output + * + * This is used to allocate memory for the output of the function and usually (but + * not always) corresponds to the number of non-zero elements (NNZ). + */ + virtual expr_int out_shape(int k) = 0; + + /** + * @brief Return the number of non-zero elements for the function output + */ + virtual expr_int nnz() = 0; + + /** + * @brief Return the number of non-zero elements for the function output + */ + virtual expr_int nnz_out() = 0; + + /** + * @brief Returns row indices in COO format (where the output data represents sparse matrix elements) + */ + virtual std::vector get_row() = 0; + + /** + * @brief Returns column indices in COO format (where the output data represents sparse matrix elements) + */ + virtual std::vector get_col() = 0; + +public: // data members + /** + * @brief Vector of pointers to the input data + */ + std::vector m_arg; // cppcheck-suppress unusedStructMember + + /** + * @brief Vector of pointers to the output data + */ + std::vector m_res; // cppcheck-suppress unusedStructMember +}; + +#endif // PYBAMM_EXPRESSION_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionSet.hpp b/pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionSet.hpp new file mode 100644 index 0000000000..13c746a37d --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionSet.hpp @@ -0,0 +1,86 @@ +#ifndef PYBAMM_IDAKLU_EXPRESSION_SET_HPP +#define PYBAMM_IDAKLU_EXPRESSION_SET_HPP + +#include "ExpressionTypes.hpp" +#include "Expression.hpp" +#include "../../common.hpp" +#include "../../Options.hpp" +#include + +template +class ExpressionSet +{ +public: + + /** + * @brief Constructor + */ + ExpressionSet( + Expression* rhs_alg, + Expression* jac_times_cjmass, + const int jac_times_cjmass_nnz, + const int jac_bandwidth_lower, + const int jac_bandwidth_upper, + const np_array_int &jac_times_cjmass_rowvals_arg, // cppcheck-suppress unusedStructMember + const np_array_int &jac_times_cjmass_colptrs_arg, // cppcheck-suppress unusedStructMember + const int inputs_length, + Expression* jac_action, + Expression* mass_action, + Expression* sens, + Expression* events, + const int n_s, + const int n_e, + const int n_p, + const SetupOptions& options) + : number_of_states(n_s), + number_of_events(n_e), + number_of_parameters(n_p), + number_of_nnz(jac_times_cjmass_nnz), + jac_bandwidth_lower(jac_bandwidth_lower), + jac_bandwidth_upper(jac_bandwidth_upper), + rhs_alg(rhs_alg), + jac_times_cjmass(jac_times_cjmass), + jac_action(jac_action), + mass_action(mass_action), + sens(sens), + events(events), + tmp_state_vector(number_of_states), + tmp_sparse_jacobian_data(jac_times_cjmass_nnz), + setup_opts(options) + {}; + + int number_of_states; + int number_of_parameters; + int number_of_events; + int number_of_nnz; + int jac_bandwidth_lower; + int jac_bandwidth_upper; + + Expression *rhs_alg = nullptr; + Expression *jac_times_cjmass = nullptr; + Expression *jac_action = nullptr; + Expression *mass_action = nullptr; + Expression *sens = nullptr; + Expression *events = nullptr; + + // `cppcheck-suppress unusedStructMember` is used because codacy reports + // these members as unused, but they are inherited through variadics + std::vector var_fcns; // cppcheck-suppress unusedStructMember + std::vector dvar_dy_fcns; // cppcheck-suppress unusedStructMember + std::vector dvar_dp_fcns; // cppcheck-suppress unusedStructMember + + std::vector jac_times_cjmass_rowvals; // cppcheck-suppress unusedStructMember + std::vector jac_times_cjmass_colptrs; // cppcheck-suppress unusedStructMember + std::vector inputs; // cppcheck-suppress unusedStructMember + + SetupOptions setup_opts; + + virtual realtype *get_tmp_state_vector() = 0; + virtual realtype *get_tmp_sparse_jacobian_data() = 0; + +protected: + std::vector tmp_state_vector; + std::vector tmp_sparse_jacobian_data; +}; + +#endif // PYBAMM_IDAKLU_EXPRESSION_SET_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionTypes.hpp b/pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionTypes.hpp new file mode 100644 index 0000000000..c8d690c125 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionTypes.hpp @@ -0,0 +1,6 @@ +#ifndef PYBAMM_EXPRESSION_TYPES_HPP +#define PYBAMM_EXPRESSION_TYPES_HPP + +using expr_int = long long int; + +#endif // PYBAMM_EXPRESSION_TYPES_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/Casadi/CasadiFunctions.cpp b/pybamm/solvers/c_solvers/idaklu/Expressions/Casadi/CasadiFunctions.cpp new file mode 100644 index 0000000000..b0c8ab1142 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/Casadi/CasadiFunctions.cpp @@ -0,0 +1,80 @@ +#include "CasadiFunctions.hpp" +#include + +CasadiFunction::CasadiFunction(const BaseFunctionType &f) : Expression(), m_func(f) +{ + DEBUG("CasadiFunction constructor: " << m_func.name()); + + size_t sz_arg; + size_t sz_res; + size_t sz_iw; + size_t sz_w; + m_func.sz_work(sz_arg, sz_res, sz_iw, sz_w); + + int nnz = (sz_res>0) ? m_func.nnz_out() : 0; // cppcheck-suppress unreadVariable + DEBUG("name = "<< m_func.name() << " arg = " << sz_arg << " res = " + << sz_res << " iw = " << sz_iw << " w = " << sz_w << " nnz = " << nnz); + + m_arg.resize(sz_arg, nullptr); + m_res.resize(sz_res, nullptr); + m_iw.resize(sz_iw, 0); + m_w.resize(sz_w, 0); +} + +// only call this once m_arg and m_res have been set appropriately +void CasadiFunction::operator()() +{ + DEBUG("CasadiFunction operator(): " << m_func.name()); + int mem = m_func.checkout(); + m_func(m_arg.data(), m_res.data(), m_iw.data(), m_w.data(), mem); + m_func.release(mem); +} + +expr_int CasadiFunction::out_shape(int k) { + DEBUG("CasadiFunctions out_shape(): " << m_func.name() << " " << m_func.nnz_out()); + return static_cast(m_func.nnz_out()); +} + +expr_int CasadiFunction::nnz() { + DEBUG("CasadiFunction nnz(): " << m_func.name() << " " << static_cast(m_func.nnz_out())); + return static_cast(m_func.nnz_out()); +} + +expr_int CasadiFunction::nnz_out() { + DEBUG("CasadiFunction nnz_out(): " << m_func.name() << " " << static_cast(m_func.nnz_out())); + return static_cast(m_func.nnz_out()); +} + +std::vector CasadiFunction::get_row() { + return get_row(0); +} + +std::vector CasadiFunction::get_row(expr_int ind) { + DEBUG("CasadiFunction get_row(): " << m_func.name()); + casadi::Sparsity casadi_sparsity = m_func.sparsity_out(ind); + return casadi_sparsity.get_row(); +} + +std::vector CasadiFunction::get_col() { + return get_col(0); +} + +std::vector CasadiFunction::get_col(expr_int ind) { + DEBUG("CasadiFunction get_col(): " << m_func.name()); + casadi::Sparsity casadi_sparsity = m_func.sparsity_out(ind); + return casadi_sparsity.get_col(); +} + +void CasadiFunction::operator()(const std::vector& inputs, + const std::vector& results) +{ + DEBUG("CasadiFunction operator() with inputs and results: " << m_func.name()); + + // Set-up input arguments, provide result vector, then execute function + // Example call: fcn({in1, in2, in3}, {out1}) + for(size_t k=0; k +#include +#include +#include + +/** + * @brief Class for handling individual casadi functions + */ +class CasadiFunction : public Expression +{ +public: + + typedef casadi::Function BaseFunctionType; + + /** + * @brief Constructor + */ + explicit CasadiFunction(const BaseFunctionType &f); + + // Method overrides + void operator()() override; + void operator()(const std::vector& inputs, + const std::vector& results) override; + expr_int out_shape(int k) override; + expr_int nnz() override; + expr_int nnz_out() override; + std::vector get_row() override; + std::vector get_row(expr_int ind); + std::vector get_col() override; + std::vector get_col(expr_int ind); + +public: + /* + * @brief Casadi function + */ + BaseFunctionType m_func; + +private: + std::vector m_iw; // cppcheck-suppress unusedStructMember + std::vector m_w; // cppcheck-suppress unusedStructMember +}; + +/** + * @brief Class for handling casadi functions + */ +class CasadiFunctions : public ExpressionSet +{ +public: + + typedef CasadiFunction::BaseFunctionType BaseFunctionType; // expose typedef in class + + /** + * @brief Create a new CasadiFunctions object + */ + CasadiFunctions( + const BaseFunctionType &rhs_alg, + const BaseFunctionType &jac_times_cjmass, + const int jac_times_cjmass_nnz, + const int jac_bandwidth_lower, + const int jac_bandwidth_upper, + const np_array_int &jac_times_cjmass_rowvals_arg, + const np_array_int &jac_times_cjmass_colptrs_arg, + const int inputs_length, + const BaseFunctionType &jac_action, + const BaseFunctionType &mass_action, + const BaseFunctionType &sens, + const BaseFunctionType &events, + const int n_s, + const int n_e, + const int n_p, + const std::vector& var_fcns, + const std::vector& dvar_dy_fcns, + const std::vector& dvar_dp_fcns, + const SetupOptions& setup_opts + ) : + rhs_alg_casadi(rhs_alg), + jac_times_cjmass_casadi(jac_times_cjmass), + jac_action_casadi(jac_action), + mass_action_casadi(mass_action), + sens_casadi(sens), + events_casadi(events), + ExpressionSet( + static_cast(&rhs_alg_casadi), + static_cast(&jac_times_cjmass_casadi), + jac_times_cjmass_nnz, + jac_bandwidth_lower, + jac_bandwidth_upper, + jac_times_cjmass_rowvals_arg, + jac_times_cjmass_colptrs_arg, + inputs_length, + static_cast(&jac_action_casadi), + static_cast(&mass_action_casadi), + static_cast(&sens_casadi), + static_cast(&events_casadi), + n_s, n_e, n_p, + setup_opts) + { + // convert BaseFunctionType list to CasadiFunction list + // NOTE: You must allocate ALL std::vector elements before taking references + for (auto& var : var_fcns) + var_fcns_casadi.push_back(CasadiFunction(*var)); + for (int k = 0; k < var_fcns_casadi.size(); k++) + ExpressionSet::var_fcns.push_back(&this->var_fcns_casadi[k]); + + for (auto& var : dvar_dy_fcns) + dvar_dy_fcns_casadi.push_back(CasadiFunction(*var)); + for (int k = 0; k < dvar_dy_fcns_casadi.size(); k++) + this->dvar_dy_fcns.push_back(&this->dvar_dy_fcns_casadi[k]); + + for (auto& var : dvar_dp_fcns) + dvar_dp_fcns_casadi.push_back(CasadiFunction(*var)); + for (int k = 0; k < dvar_dp_fcns_casadi.size(); k++) + this->dvar_dp_fcns.push_back(&this->dvar_dp_fcns_casadi[k]); + + // copy across numpy array values + const int n_row_vals = jac_times_cjmass_rowvals_arg.request().size; + auto p_jac_times_cjmass_rowvals = jac_times_cjmass_rowvals_arg.unchecked<1>(); + jac_times_cjmass_rowvals.resize(n_row_vals); + for (int i = 0; i < n_row_vals; i++) { + jac_times_cjmass_rowvals[i] = p_jac_times_cjmass_rowvals[i]; + } + + const int n_col_ptrs = jac_times_cjmass_colptrs_arg.request().size; + auto p_jac_times_cjmass_colptrs = jac_times_cjmass_colptrs_arg.unchecked<1>(); + jac_times_cjmass_colptrs.resize(n_col_ptrs); + for (int i = 0; i < n_col_ptrs; i++) { + jac_times_cjmass_colptrs[i] = p_jac_times_cjmass_colptrs[i]; + } + + inputs.resize(inputs_length); + } + + CasadiFunction rhs_alg_casadi; + CasadiFunction jac_times_cjmass_casadi; + CasadiFunction jac_action_casadi; + CasadiFunction mass_action_casadi; + CasadiFunction sens_casadi; + CasadiFunction events_casadi; + + std::vector var_fcns_casadi; + std::vector dvar_dy_fcns_casadi; + std::vector dvar_dp_fcns_casadi; + + realtype* get_tmp_state_vector() override { + return tmp_state_vector.data(); + } + realtype* get_tmp_sparse_jacobian_data() override { + return tmp_sparse_jacobian_data.data(); + } +}; + +#endif // PYBAMM_IDAKLU_CASADI_FUNCTIONS_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/Expressions.hpp b/pybamm/solvers/c_solvers/idaklu/Expressions/Expressions.hpp new file mode 100644 index 0000000000..70380eaba7 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/Expressions.hpp @@ -0,0 +1,6 @@ +#ifndef PYBAMM_IDAKLU_EXPRESSIONS_HPP +#define PYBAMM_IDAKLU_EXPRESSIONS_HPP + +#include "Base/ExpressionSet.hpp" + +#endif // PYBAMM_IDAKLU_EXPRESSIONS_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEBaseFunction.hpp b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEBaseFunction.hpp new file mode 100644 index 0000000000..d2ba7e4de0 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEBaseFunction.hpp @@ -0,0 +1,27 @@ +#ifndef PYBAMM_IDAKLU_IREE_BASE_FUNCTION_HPP +#define PYBAMM_IDAKLU_IREE_BASE_FUNCTION_HPP + +#include +#include + +/* + * @brief Function definition passed from PyBaMM + */ +class IREEBaseFunctionType +{ +public: // methods + const std::string& get_mlir() const { return mlir; } + +public: // data members + std::string mlir; // cppcheck-suppress unusedStructMember + std::vector kept_var_idx; // cppcheck-suppress unusedStructMember + expr_int nnz; // cppcheck-suppress unusedStructMember + expr_int numel; // cppcheck-suppress unusedStructMember + std::vector col; // cppcheck-suppress unusedStructMember + std::vector row; // cppcheck-suppress unusedStructMember + std::vector pytree_shape; // cppcheck-suppress unusedStructMember + std::vector pytree_sizes; // cppcheck-suppress unusedStructMember + expr_int n_args; // cppcheck-suppress unusedStructMember +}; + +#endif // PYBAMM_IDAKLU_IREE_BASE_FUNCTION_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunction.hpp b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunction.hpp new file mode 100644 index 0000000000..26f81c8f98 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunction.hpp @@ -0,0 +1,59 @@ +#ifndef PYBAMM_IDAKLU_IREE_FUNCTION_HPP +#define PYBAMM_IDAKLU_IREE_FUNCTION_HPP + +#include "../../Options.hpp" +#include "../Expressions.hpp" +#include +#include "iree_jit.hpp" +#include "IREEBaseFunction.hpp" + +/** + * @brief Class for handling individual iree functions + */ +class IREEFunction : public Expression +{ +public: + typedef IREEBaseFunctionType BaseFunctionType; + + /* + * @brief Constructor + */ + explicit IREEFunction(const BaseFunctionType &f); + + // Method overrides + void operator()() override; + void operator()(const std::vector& inputs, + const std::vector& results) override; + expr_int out_shape(int k) override; + expr_int nnz() override; + expr_int nnz_out() override; + std::vector get_col() override; + std::vector get_row() override; + + /* + * @brief Evaluate the MLIR function + */ + void evaluate(); + + /* + * @brief Evaluate the MLIR function + * @param n_outputs The number of outputs to return + */ + void evaluate(int n_outputs); + +public: + std::unique_ptr session; + std::vector> result; // cppcheck-suppress unusedStructMember + std::vector> input_shape; // cppcheck-suppress unusedStructMember + std::vector> output_shape; // cppcheck-suppress unusedStructMember + std::vector> input_data; // cppcheck-suppress unusedStructMember + + BaseFunctionType m_func; // cppcheck-suppress unusedStructMember + std::string module_name; // cppcheck-suppress unusedStructMember + std::string function_name; // cppcheck-suppress unusedStructMember + std::vector m_arg_argno; // cppcheck-suppress unusedStructMember + std::vector m_arg_argix; // cppcheck-suppress unusedStructMember + std::vector numel; // cppcheck-suppress unusedStructMember +}; + +#endif // PYBAMM_IDAKLU_IREE_FUNCTION_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.cpp b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.cpp new file mode 100644 index 0000000000..6837d21198 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.cpp @@ -0,0 +1,230 @@ +#include +#include +#include +#include +#include + +#include "IREEFunctions.hpp" +#include "iree_jit.hpp" +#include "ModuleParser.hpp" + +IREEFunction::IREEFunction(const BaseFunctionType &f) : Expression(), m_func(f) +{ + DEBUG("IreeFunction constructor"); + const std::string& mlir = f.get_mlir(); + + // Parse IREE (MLIR) function string + if (mlir.size() == 0) { + DEBUG("Empty function --- skipping..."); + return; + } + + // Parse MLIR for module name, input and output shapes + ModuleParser parser(mlir); + module_name = parser.getModuleName(); + function_name = parser.getFunctionName(); + input_shape = parser.getInputShape(); + output_shape = parser.getOutputShape(); + + DEBUG("Compiling module: '" << module_name << "'"); + const char* device_uri = "local-sync"; + session = std::make_unique(device_uri, mlir); + DEBUG("compile complete."); + // Create index vectors into m_arg + // This is required since Jax expands input arguments through PyTrees, which need to + // be remapped to the corresponding expression call. For example: + // fcn(t, y, inputs, cj) with inputs = [[in1], [in2], [in3]] + // will produce a function with six inputs; we therefore need to be able to map + // arguments to their 1) corresponding input argument, and 2) the correct position + // within that argument. + m_arg_argno.clear(); + m_arg_argix.clear(); + int current_element = 0; + for (int i=0; i 2) || + ((input_shape[j].size() == 2) && (input_shape[j][1] > 1)) + ) { + std::cerr << "Unsupported input shape: " << input_shape[j].size() << " ["; + for (int k=0; k {res0} signature (i.e. x and z are reduced out) + // with kept_var_idx = [1] + // + // *********************************************************************************** + + DEBUG("Copying inputs, shape " << input_shape.size() << " - " << m_func.kept_var_idx.size()); + for (int j=0; j 1) { + // Index into argument using appropriate shape + for(int k=0; k(m_arg[m_arg_from][m_arg_argix[mlir_arg]+k]); + } + } else { + // Copy the entire vector + for(int k=0; k(m_arg[m_arg_from][k]); + } + } + } + + // Call the 'main' function of the module + const std::string mlir = m_func.get_mlir(); + DEBUG("Calling function '" << function_name << "'"); + auto status = session->iree_runtime_exec(function_name, input_shape, input_data, result); + if (!iree_status_is_ok(status)) { + iree_status_fprint(stderr, status); + std::cerr << "MLIR: " << mlir.substr(0,1000) << std::endl; + throw std::runtime_error("Execution failed"); + } + + // Copy results to output array + for(size_t k=0; k(result[k][j]); + } + } + + DEBUG("IreeFunction operator() complete"); +} + +expr_int IREEFunction::out_shape(int k) { + DEBUG("IreeFunction nnz(" << k << "): " << m_func.nnz); + auto elements = 1; + for (auto i : output_shape[k]) { + elements *= i; + } + return elements; +} + +expr_int IREEFunction::nnz() { + DEBUG("IreeFunction nnz: " << m_func.nnz); + return nnz_out(); +} + +expr_int IREEFunction::nnz_out() { + DEBUG("IreeFunction nnz_out" << m_func.nnz); + return m_func.nnz; +} + +std::vector IREEFunction::get_row() { + DEBUG("IreeFunction get_row" << m_func.row.size()); + return m_func.row; +} + +std::vector IREEFunction::get_col() { + DEBUG("IreeFunction get_col" << m_func.col.size()); + return m_func.col; +} + +void IREEFunction::operator()(const std::vector& inputs, + const std::vector& results) +{ + DEBUG("IreeFunction operator() with inputs and results"); + // Set-up input arguments, provide result vector, then execute function + // Example call: fcn({in1, in2, in3}, {out1}) + ASSERT(inputs.size() == m_func.n_args); + for(size_t k=0; k +#include "iree_jit.hpp" +#include "IREEFunction.hpp" + +/** + * @brief Class for handling iree functions + */ +class IREEFunctions : public ExpressionSet +{ +public: + std::unique_ptr iree_compiler; + + typedef IREEFunction::BaseFunctionType BaseFunctionType; // expose typedef in class + + int iree_init_status; + + int iree_init(const std::string& device_uri, const std::string& target_backends) { + // Initialise IREE + DEBUG("IREEFunctions: Initialising IREECompiler"); + iree_compiler = std::make_unique(device_uri.c_str()); + + int iree_argc = 2; + std::string target_backends_str = "--iree-hal-target-backends=" + target_backends; + const char* iree_argv[2] = {"iree", target_backends_str.c_str()}; + iree_compiler->init(iree_argc, iree_argv); + DEBUG("IREEFunctions: Initialised IREECompiler"); + return 0; + } + + int iree_init() { + return iree_init("local-sync", "llvm-cpu"); + } + + + /** + * @brief Create a new IREEFunctions object + */ + IREEFunctions( + const BaseFunctionType &rhs_alg, + const BaseFunctionType &jac_times_cjmass, + const int jac_times_cjmass_nnz, + const int jac_bandwidth_lower, + const int jac_bandwidth_upper, + const np_array_int &jac_times_cjmass_rowvals_arg, + const np_array_int &jac_times_cjmass_colptrs_arg, + const int inputs_length, + const BaseFunctionType &jac_action, + const BaseFunctionType &mass_action, + const BaseFunctionType &sens, + const BaseFunctionType &events, + const int n_s, + const int n_e, + const int n_p, + const std::vector& var_fcns, + const std::vector& dvar_dy_fcns, + const std::vector& dvar_dp_fcns, + const SetupOptions& setup_opts + ) : + iree_init_status(iree_init()), + rhs_alg_iree(rhs_alg), + jac_times_cjmass_iree(jac_times_cjmass), + jac_action_iree(jac_action), + mass_action_iree(mass_action), + sens_iree(sens), + events_iree(events), + ExpressionSet( + static_cast(&rhs_alg_iree), + static_cast(&jac_times_cjmass_iree), + jac_times_cjmass_nnz, + jac_bandwidth_lower, + jac_bandwidth_upper, + jac_times_cjmass_rowvals_arg, + jac_times_cjmass_colptrs_arg, + inputs_length, + static_cast(&jac_action_iree), + static_cast(&mass_action_iree), + static_cast(&sens_iree), + static_cast(&events_iree), + n_s, n_e, n_p, + setup_opts) + { + // convert BaseFunctionType list to IREEFunction list + // NOTE: You must allocate ALL std::vector elements before taking references + for (auto& var : var_fcns) + var_fcns_iree.push_back(IREEFunction(*var)); + for (int k = 0; k < var_fcns_iree.size(); k++) + ExpressionSet::var_fcns.push_back(&this->var_fcns_iree[k]); + + for (auto& var : dvar_dy_fcns) + dvar_dy_fcns_iree.push_back(IREEFunction(*var)); + for (int k = 0; k < dvar_dy_fcns_iree.size(); k++) + this->dvar_dy_fcns.push_back(&this->dvar_dy_fcns_iree[k]); + + for (auto& var : dvar_dp_fcns) + dvar_dp_fcns_iree.push_back(IREEFunction(*var)); + for (int k = 0; k < dvar_dp_fcns_iree.size(); k++) + this->dvar_dp_fcns.push_back(&this->dvar_dp_fcns_iree[k]); + + // copy across numpy array values + const int n_row_vals = jac_times_cjmass_rowvals_arg.request().size; + auto p_jac_times_cjmass_rowvals = jac_times_cjmass_rowvals_arg.unchecked<1>(); + jac_times_cjmass_rowvals.resize(n_row_vals); + for (int i = 0; i < n_row_vals; i++) { + jac_times_cjmass_rowvals[i] = p_jac_times_cjmass_rowvals[i]; + } + + const int n_col_ptrs = jac_times_cjmass_colptrs_arg.request().size; + auto p_jac_times_cjmass_colptrs = jac_times_cjmass_colptrs_arg.unchecked<1>(); + jac_times_cjmass_colptrs.resize(n_col_ptrs); + for (int i = 0; i < n_col_ptrs; i++) { + jac_times_cjmass_colptrs[i] = p_jac_times_cjmass_colptrs[i]; + } + + inputs.resize(inputs_length); + } + + IREEFunction rhs_alg_iree; + IREEFunction jac_times_cjmass_iree; + IREEFunction jac_action_iree; + IREEFunction mass_action_iree; + IREEFunction sens_iree; + IREEFunction events_iree; + + std::vector var_fcns_iree; + std::vector dvar_dy_fcns_iree; + std::vector dvar_dp_fcns_iree; + + realtype* get_tmp_state_vector() override { + return tmp_state_vector.data(); + } + realtype* get_tmp_sparse_jacobian_data() override { + return tmp_sparse_jacobian_data.data(); + } + + ~IREEFunctions() { + // cleanup IREE + iree_compiler->cleanup(); + } +}; + +#endif // PYBAMM_IDAKLU_IREE_FUNCTIONS_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.cpp b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.cpp new file mode 100644 index 0000000000..d1c5575ee2 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.cpp @@ -0,0 +1,91 @@ +#include "ModuleParser.hpp" + +ModuleParser::ModuleParser(const std::string& mlir) : mlir(mlir) +{ + parse(); +} + +void ModuleParser::parse() +{ + // Parse module name + std::regex module_name_regex("module @([^\\s]+)"); // Match until first whitespace + std::smatch module_name_match; + std::regex_search(this->mlir, module_name_match, module_name_regex); + if (module_name_match.size() == 0) { + std::cerr << "Could not find module name in module" << std::endl; + std::cerr << "Module snippet: " << this->mlir.substr(0, 1000) << std::endl; + throw std::runtime_error("Could not find module name in module"); + } + module_name = module_name_match[1].str(); + DEBUG("Module name: " << module_name); + + // Assign function name + function_name = module_name + ".main"; + + // Isolate 'main' function call signature + std::regex main_func("public @main\\((.*?)\\) -> \\((.*?)\\)"); + std::smatch match; + std::regex_search(this->mlir, match, main_func); + if (match.size() == 0) { + std::cerr << "Could not find 'main' function in module" << std::endl; + std::cerr << "Module snippet: " << this->mlir.substr(0, 1000) << std::endl; + throw std::runtime_error("Could not find 'main' function in module"); + } + std::string main_sig_inputs = match[1].str(); + std::string main_sig_outputs = match[2].str(); + DEBUG( + "Main function signature: " << main_sig_inputs << " -> " << main_sig_outputs << '\n' + ); + + // Parse input sizes + input_shape.clear(); + std::regex input_size("tensor<(.*?)>"); + for(std::sregex_iterator i = std::sregex_iterator(main_sig_inputs.begin(), main_sig_inputs.end(), input_size); + i != std::sregex_iterator(); + ++i) + { + std::smatch matchi = *i; + std::string match_str = matchi.str(); + std::string shape_str = match_str.substr(7, match_str.size() - 8); // Remove 'tensor<>' from string + std::vector shape; + std::string dim_str; + for (char c : shape_str) { + if (c == 'x') { + shape.push_back(std::stoi(dim_str)); + dim_str = ""; + } else { + dim_str += c; + } + } + input_shape.push_back(shape); + } + + // Parse output sizes + output_shape.clear(); + std::regex output_size("tensor<(.*?)>"); + for( + std::sregex_iterator i = std::sregex_iterator(main_sig_outputs.begin(), main_sig_outputs.end(), output_size); + i != std::sregex_iterator(); + ++i + ) { + std::smatch matchi = *i; + std::string match_str = matchi.str(); + std::string shape_str = match_str.substr(7, match_str.size() - 8); // Remove 'tensor<>' from string + std::vector shape; + std::string dim_str; + for (char c : shape_str) { + if (c == 'x') { + shape.push_back(std::stoi(dim_str)); + dim_str = ""; + } else { + dim_str += c; + } + } + // If shape is empty, assume scalar (i.e. "tensor" or some singleton variant) + if (shape.size() == 0) { + shape.push_back(1); + } + // Add output to list + output_shape.push_back(shape); + } +} diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.hpp b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.hpp new file mode 100644 index 0000000000..2fbfdc086c --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.hpp @@ -0,0 +1,55 @@ +#ifndef PYBAMM_IDAKLU_IREE_MODULE_PARSER_HPP +#define PYBAMM_IDAKLU_IREE_MODULE_PARSER_HPP + +#include +#include +#include +#include +#include + +#include "../../common.hpp" + +class ModuleParser { +private: + std::string mlir; // cppcheck-suppress unusedStructMember + // codacy fix: member is referenced as this->mlir in parse() + std::string module_name; + std::string function_name; + std::vector> input_shape; + std::vector> output_shape; +public: + /** + * @brief Constructor + * @param mlir: string representation of MLIR code for the module + */ + explicit ModuleParser(const std::string& mlir); + + /** + * @brief Get the module name + * @return module name + */ + const std::string& getModuleName() const { return module_name; } + + /** + * @brief Get the function name + * @return function name + */ + const std::string& getFunctionName() const { return function_name; } + + /** + * @brief Get the input shape + * @return input shape + */ + const std::vector>& getInputShape() const { return input_shape; } + + /** + * @brief Get the output shape + * @return output shape + */ + const std::vector>& getOutputShape() const { return output_shape; } + +private: + void parse(); +}; + +#endif // PYBAMM_IDAKLU_IREE_MODULE_PARSER_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.cpp b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.cpp new file mode 100644 index 0000000000..c84c3928bd --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.cpp @@ -0,0 +1,408 @@ +#include "iree_jit.hpp" +#include "iree/hal/buffer_view.h" +#include "iree/hal/buffer_view_util.h" +#include "../../common.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +// Used to suppress stderr output (see initIREE below) +#ifdef _WIN32 +#include +#define close _close +#define dup _dup +#define fileno _fileno +#define open _open +#define dup2 _dup2 +#define NULL_DEVICE "NUL" +#else +#define NULL_DEVICE "/dev/null" +#endif + +void IREESession::handle_compiler_error(iree_compiler_error_t *error) { + const char *msg = ireeCompilerErrorGetMessage(error); + fprintf(stderr, "Error from compiler API:\n%s\n", msg); + ireeCompilerErrorDestroy(error); +} + +void IREESession::cleanup_compiler_state(compiler_state_t s) { + if (s.inv) + ireeCompilerInvocationDestroy(s.inv); + if (s.output) + ireeCompilerOutputDestroy(s.output); + if (s.source) + ireeCompilerSourceDestroy(s.source); + if (s.session) + ireeCompilerSessionDestroy(s.session); +} + +IREECompiler::IREECompiler() { + this->device_uri = "local-sync"; +}; + +IREECompiler::~IREECompiler() { + ireeCompilerGlobalShutdown(); +}; + +int IREECompiler::init(int argc, const char **argv) { + return initIREE(argc, argv); // Initialisation and version checking +}; + +int IREECompiler::cleanup() { + return 0; +}; + +IREESession::IREESession() { + s.session = NULL; + s.source = NULL; + s.output = NULL; + s.inv = NULL; +}; + +IREESession::IREESession(const char *device_uri, const std::string& mlir_code) : IREESession() { + this->device_uri=device_uri; + this->mlir_code=mlir_code; + init(); +} + +int IREESession::init() { + if (initCompiler() != 0) // Prepare compiler inputs and outputs + return 1; + if (initCompileToByteCode() != 0) // Compile to bytecode + return 1; + if (initRuntime() != 0) // Initialise runtime environment + return 1; + return 0; +}; + +int IREECompiler::initIREE(int argc, const char **argv) { + + if (device_uri == NULL) { + DEBUG("No device URI provided, using local-sync\n"); + this->device_uri = "local-sync"; + } + + int cl_argc = argc; + const char *iree_compiler_lib = std::getenv("IREE_COMPILER_LIB"); + + // Load the compiler library and initialize it + // NOTE: On second and subsequent calls, the function will return false and display + // a message on stderr, but it is safe to ignore this message. For an improved user + // experience we actively suppress stderr during the call to this function but since + // this also suppresses any other error message, we actively check for the presence + // of the library file prior to the call. + + // Check if the library file exists + if (iree_compiler_lib == NULL) { + fprintf(stderr, "Error: IREE_COMPILER_LIB environment variable not set\n"); + return 1; + } + if (access(iree_compiler_lib, F_OK) == -1) { + fprintf(stderr, "Error: IREE_COMPILER_LIB file not found\n"); + return 1; + } + // Suppress stderr + int saved_stderr = dup(fileno(stderr)); + if (!freopen(NULL_DEVICE, "w", stderr)) + DEBUG("Error: failed redirecting stderr"); + // Load library + bool result = ireeCompilerLoadLibrary(iree_compiler_lib); + // Restore stderr + fflush(stderr); + dup2(saved_stderr, fileno(stderr)); + close(saved_stderr); + // Process result + if (!result) { + // Library may have already been loaded (can be safely ignored), + // or may not be found (critical error), we cannot tell which from the return value. + return 1; + } + // Must be balanced with a call to ireeCompilerGlobalShutdown() + ireeCompilerGlobalInitialize(); + + // To set global options (see `iree-compile --help` for possibilities), use + // |ireeCompilerGetProcessCLArgs| and |ireeCompilerSetupGlobalCL| + ireeCompilerGetProcessCLArgs(&cl_argc, &argv); + ireeCompilerSetupGlobalCL(cl_argc, argv, "iree-jit", false); + + // Check the API version before proceeding any further + uint32_t api_version = (uint32_t)ireeCompilerGetAPIVersion(); + uint16_t api_version_major = (uint16_t)((api_version >> 16) & 0xFFFFUL); + uint16_t api_version_minor = (uint16_t)(api_version & 0xFFFFUL); + DEBUG("Compiler API version: " << api_version_major << "." << api_version_minor); + if (api_version_major > IREE_COMPILER_EXPECTED_API_MAJOR || + api_version_minor < IREE_COMPILER_EXPECTED_API_MINOR) { + fprintf(stderr, + "Error: incompatible API version; built for version %" PRIu16 + ".%" PRIu16 " but loaded version %" PRIu16 ".%" PRIu16 "\n", + IREE_COMPILER_EXPECTED_API_MAJOR, IREE_COMPILER_EXPECTED_API_MINOR, + api_version_major, api_version_minor); + ireeCompilerGlobalShutdown(); + return 1; + } + + // Check for a build tag with release version information + const char *revision = ireeCompilerGetRevision(); // cppcheck-suppress unreadVariable + DEBUG("Compiler revision: '" << revision << "'"); + return 0; +}; + +int IREESession::initCompiler() { + + // A session provides a scope where one or more invocations can be executed + s.session = ireeCompilerSessionCreate(); + + // Read the MLIR from memory + error = ireeCompilerSourceWrapBuffer( + s.session, + "expr_buffer", // name of the buffer (does not need to match MLIR) + mlir_code.c_str(), + mlir_code.length() + 1, + true, + &s.source + ); + if (error) { + fprintf(stderr, "Error wrapping source buffer\n"); + handle_compiler_error(error); + cleanup_compiler_state(s); + return 1; + } + DEBUG("Wrapped buffer as a compiler source"); + + return 0; +}; + +int IREESession::initCompileToByteCode() { + // Use an invocation to compile from the input source to the output stream + iree_compiler_invocation_t *inv = ireeCompilerInvocationCreate(s.session); + ireeCompilerInvocationEnableConsoleDiagnostics(inv); + + if (!ireeCompilerInvocationParseSource(inv, s.source)) { + fprintf(stderr, "Error parsing input source into invocation\n"); + cleanup_compiler_state(s); + return 1; + } + + // Compile, specifying the target dialect phase + ireeCompilerInvocationSetCompileToPhase(inv, "end"); + + // Run the compiler invocation pipeline + if (!ireeCompilerInvocationPipeline(inv, IREE_COMPILER_PIPELINE_STD)) { + fprintf(stderr, "Error running compiler invocation\n"); + cleanup_compiler_state(s); + return 1; + } + DEBUG("Compilation successful"); + + // Create compiler 'output' to a memory buffer + error = ireeCompilerOutputOpenMembuffer(&s.output); + if (error) { + fprintf(stderr, "Error opening output membuffer\n"); + handle_compiler_error(error); + cleanup_compiler_state(s); + return 1; + } + + // Create bytecode in memory + error = ireeCompilerInvocationOutputVMBytecode(inv, s.output); + if (error) { + fprintf(stderr, "Error creating VM bytecode\n"); + handle_compiler_error(error); + cleanup_compiler_state(s); + return 1; + } + + // Once the bytecode has been written, retrieve the memory map + ireeCompilerOutputMapMemory(s.output, &contents, &size); + + return 0; +}; + +int IREESession::initRuntime() { + // Setup the shared runtime instance + iree_runtime_instance_options_t instance_options; + iree_runtime_instance_options_initialize(&instance_options); + iree_runtime_instance_options_use_all_available_drivers(&instance_options); + status = iree_runtime_instance_create( + &instance_options, iree_allocator_system(), &instance); + + // Create the HAL device used to run the workloads + if (iree_status_is_ok(status)) { + status = iree_hal_create_device( + iree_runtime_instance_driver_registry(instance), + iree_make_cstring_view(device_uri), + iree_runtime_instance_host_allocator(instance), &device); + } + + // Set up the session to run the module + if (iree_status_is_ok(status)) { + iree_runtime_session_options_t session_options; + iree_runtime_session_options_initialize(&session_options); + status = iree_runtime_session_create_with_device( + instance, &session_options, device, + iree_runtime_instance_host_allocator(instance), &session); + } + + // Load the compiled user module from a file + if (iree_status_is_ok(status)) { + /*status = iree_runtime_session_append_bytecode_module_from_file(session, module_path);*/ + status = iree_runtime_session_append_bytecode_module_from_memory( + session, + iree_make_const_byte_span(contents, size), + iree_allocator_null()); + } + + if (!iree_status_is_ok(status)) + return 1; + + return 0; +}; + +// Release the session and free all cached resources. +int IREESession::cleanup() { + iree_runtime_session_release(session); + iree_hal_device_release(device); + iree_runtime_instance_release(instance); + + int ret = (int)iree_status_code(status); + if (!iree_status_is_ok(status)) { + iree_status_fprint(stderr, status); + iree_status_ignore(status); + } + cleanup_compiler_state(s); + return ret; +} + +iree_status_t IREESession::iree_runtime_exec( + const std::string& function_name, + const std::vector>& inputs, + const std::vector>& data, + std::vector>& result +) { + + // Initialize the call to the function. + status = iree_runtime_call_initialize_by_name( + session, iree_make_cstring_view(function_name.c_str()), &call); + if (!iree_status_is_ok(status)) { + std::cerr << "Error: iree_runtime_call_initialize_by_name failed" << std::endl; + iree_status_fprint(stderr, status); + return status; + } + + // Append the function inputs with the HAL device allocator in use by the + // session. The buffers will be usable within the session and _may_ be usable + // in other sessions depending on whether they share a compatible device. + iree_hal_allocator_t* device_allocator = + iree_runtime_session_device_allocator(session); + host_allocator = iree_runtime_session_host_allocator(session); + status = iree_ok_status(); + if (iree_status_is_ok(status)) { + + for(int k=0; k arg_shape(input_shape.size()); + for (int i = 0; i < input_shape.size(); i++) { + arg_shape[i] = input_shape[i]; + } + int numel = 1; + for(int i = 0; i < input_shape.size(); i++) { + numel *= input_shape[i]; + } + std::vector arg_data(numel); + for(int i = 0; i < numel; i++) { + arg_data[i] = input_data[i]; + } + + status = iree_hal_buffer_view_allocate_buffer_copy( + device, device_allocator, + // Shape rank and dimensions: + arg_shape.size(), arg_shape.data(), + // Element type: + IREE_HAL_ELEMENT_TYPE_FLOAT_32, + // Encoding type: + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + (iree_hal_buffer_params_t){ + // Intended usage of the buffer (transfers, dispatches, etc): + .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, + // Access to allow to this memory: + .access = IREE_HAL_MEMORY_ACCESS_ALL, + // Where to allocate (host or device): + .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, + }, + // The actual heap buffer to wrap or clone and its allocator: + iree_make_const_byte_span(&arg_data[0], sizeof(float) * arg_data.size()), + // Buffer view + storage are returned and owned by the caller: + &arg); + } + if (iree_status_is_ok(status)) { + // Add to the call inputs list (which retains the buffer view). + status = iree_runtime_call_inputs_push_back_buffer_view(&call, arg); + if (!iree_status_is_ok(status)) { + std::cerr << "Error: iree_runtime_call_inputs_push_back_buffer_view failed" << std::endl; + iree_status_fprint(stderr, status); + } + } + // Since the call retains the buffer view we can release it here. + iree_hal_buffer_view_release(arg); + } + } + + // Synchronously perform the call. + if (iree_status_is_ok(status)) { + status = iree_runtime_call_invoke(&call, /*flags=*/0); + } + if (!iree_status_is_ok(status)) { + std::cerr << "Error: iree_runtime_call_invoke failed" << std::endl; + iree_status_fprint(stderr, status); + } + + for(int k=0; k +#include +#include +#include + +#include +#include +#include + +#define IREE_COMPILER_EXPECTED_API_MAJOR 1 // At most this major version +#define IREE_COMPILER_EXPECTED_API_MINOR 2 // At least this minor version + +// Forward declaration +class IREESession; + +/* + * @brief IREECompiler class + * @details This class is used to compile MLIR code to IREE bytecode and + * create IREE sessions. + */ +class IREECompiler { +private: + /* + * @brief Device Uniform Resource Identifier (URI) + * @details The device URI is used to specify the device to be used by the + * IREE runtime. E.g. "local-sync" for CPU, "vulkan" for GPU, etc. + */ + const char *device_uri = NULL; + +private: + /* + * @brief Initialize the IREE runtime + */ + int initIREE(int argc, const char **argv); + +public: + /* + * @brief Default constructor + */ + IREECompiler(); + + /* + * @brief Destructor + */ + ~IREECompiler(); + + /* + * @brief Constructor with device URI + * @param device_uri Device URI + */ + explicit IREECompiler(const char *device_uri) + : IREECompiler() { this->device_uri=device_uri; } + + /* + * @brief Initialize the compiler + */ + int init(int argc, const char **argv); + + /* + * @brief Cleanup the compiler + * @details This method cleans up the compiler and all the IREE sessions + * created by the compiler. Returns 0 on success. + */ + int cleanup(); +}; + +/* + * @brief Compiler state + */ +typedef struct compiler_state_t { + iree_compiler_session_t *session; // cppcheck-suppress unusedStructMember + iree_compiler_source_t *source; // cppcheck-suppress unusedStructMember + iree_compiler_output_t *output; // cppcheck-suppress unusedStructMember + iree_compiler_invocation_t *inv; // cppcheck-suppress unusedStructMember +} compiler_state_t; + +/* + * @brief IREE session class + */ +class IREESession { +private: // data members + const char *device_uri = NULL; + compiler_state_t s; + iree_compiler_error_t *error = NULL; + void *contents = NULL; + uint64_t size = 0; + iree_runtime_session_t* session = NULL; + iree_status_t status; + iree_hal_device_t* device = NULL; + iree_runtime_instance_t* instance = NULL; + std::string mlir_code; // cppcheck-suppress unusedStructMember + iree_runtime_call_t call; + iree_allocator_t host_allocator; + +private: // private methods + void handle_compiler_error(iree_compiler_error_t *error); + void cleanup_compiler_state(compiler_state_t s); + int init(); + int initCompiler(); + int initCompileToByteCode(); + int initRuntime(); + +public: // public methods + + /* + * @brief Default constructor + */ + IREESession(); + + /* + * @brief Constructor with device URI and MLIR code + * @param device_uri Device URI + * @param mlir_code MLIR code + */ + explicit IREESession(const char *device_uri, const std::string& mlir_code); + + /* + * @brief Cleanup the IREE session + */ + int cleanup(); + + /* + * @brief Execute the pre-compiled byte-code with the given inputs + * @param function_name Function name to execute + * @param inputs List of input shapes + * @param data List of input data + * @param result List of output data + */ + iree_status_t iree_runtime_exec( + const std::string& function_name, + const std::vector>& inputs, + const std::vector>& data, + std::vector>& result + ); +}; + +#endif // IREE_JIT_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.cpp b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.cpp new file mode 100644 index 0000000000..b769d4d1d4 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.cpp @@ -0,0 +1 @@ +#include "IDAKLUSolver.hpp" diff --git a/pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp similarity index 75% rename from pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp rename to pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp index dac94579f3..26e587e424 100644 --- a/pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp +++ b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp @@ -1,33 +1,27 @@ #ifndef PYBAMM_IDAKLU_CASADI_SOLVER_HPP #define PYBAMM_IDAKLU_CASADI_SOLVER_HPP -#include -using Function = casadi::Function; - -#include "casadi_functions.hpp" #include "common.hpp" -#include "options.hpp" -#include "solution.hpp" -#include "sundials_legacy_wrapper.hpp" +#include "Solution.hpp" /** * Abstract base class for solutions that can use different solvers and vector * implementations. * @brief An abstract base class for the Idaklu solver */ -class CasadiSolver +class IDAKLUSolver { public: /** * @brief Default constructor */ - CasadiSolver() = default; + IDAKLUSolver() = default; /** * @brief Default destructor */ - ~CasadiSolver() = default; + ~IDAKLUSolver() = default; /** * @brief Abstract solver method that returns a Solution class diff --git a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.hpp b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp similarity index 78% rename from pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.hpp rename to pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp index 2312f9cf8f..98148a3c9f 100644 --- a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.hpp +++ b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp @@ -1,14 +1,10 @@ -#ifndef PYBAMM_IDAKLU_CASADISOLVEROPENMP_HPP -#define PYBAMM_IDAKLU_CASADISOLVEROPENMP_HPP +#ifndef PYBAMM_IDAKLU_SOLVEROPENMP_HPP +#define PYBAMM_IDAKLU_SOLVEROPENMP_HPP -#include "CasadiSolver.hpp" -#include -using Function = casadi::Function; - -#include "casadi_functions.hpp" +#include "IDAKLUSolver.hpp" #include "common.hpp" -#include "options.hpp" -#include "solution.hpp" +#include "Options.hpp" +#include "Solution.hpp" #include "sundials_legacy_wrapper.hpp" /** @@ -40,7 +36,8 @@ using Function = casadi::Function; * 19. Destroy objects * 20. (N/A) Finalize MPI */ -class CasadiSolverOpenMP : public CasadiSolver +template +class IDAKLUSolverOpenMP : public IDAKLUSolver { // NB: cppcheck-suppress unusedStructMember is used because codacy reports // these members as unused even though they are important in child @@ -63,11 +60,12 @@ class CasadiSolverOpenMP : public CasadiSolver int jac_bandwidth_upper; // cppcheck-suppress unusedStructMember SUNMatrix J; SUNLinearSolver LS = nullptr; - std::unique_ptr functions; - realtype *res = nullptr; - realtype *res_dvar_dy = nullptr; - realtype *res_dvar_dp = nullptr; - Options options; + std::unique_ptr functions; + std::vector res; + std::vector res_dvar_dy; + std::vector res_dvar_dp; + SetupOptions setup_opts; + SolverOptions solver_opts; #if SUNDIALS_VERSION_MAJOR >= 6 SUNContext sunctx; @@ -77,7 +75,7 @@ class CasadiSolverOpenMP : public CasadiSolver /** * @brief Constructor */ - CasadiSolverOpenMP( + IDAKLUSolverOpenMP( np_array atol_np, double rel_tol, np_array rhs_alg_id, @@ -86,18 +84,20 @@ class CasadiSolverOpenMP : public CasadiSolver int jac_times_cjmass_nnz, int jac_bandwidth_lower, int jac_bandwidth_upper, - std::unique_ptr functions, - const Options& options); + std::unique_ptr functions, + const SetupOptions &setup_opts, + const SolverOptions &solver_opts + ); /** * @brief Destructor */ - ~CasadiSolverOpenMP(); + ~IDAKLUSolverOpenMP(); /** - * Evaluate casadi functions (including sensitivies) for each requested + * Evaluate functions (including sensitivies) for each requested * variable and store - * @brief Evaluate casadi functions + * @brief Evaluate functions */ void CalcVars( realtype *y_return, @@ -110,7 +110,7 @@ class CasadiSolverOpenMP : public CasadiSolver size_t *ySk); /** - * @brief Evaluate casadi functions for sensitivities + * @brief Evaluate functions for sensitivities */ void CalcVarsSensitivities( realtype *tret, @@ -142,6 +142,18 @@ class CasadiSolverOpenMP : public CasadiSolver * @brief Allocate memory for matrices (noting appropriate matrix format/types) */ void SetMatrix(); + + /** + * @brief Apply user-configurable IDA options + */ + void SetSolverOptions(); + + /** + * @brief Check the return flag for errors + */ + void CheckErrors(int const & flag); }; -#endif // PYBAMM_IDAKLU_CASADISOLVEROPENMP_HPP +#include "IDAKLUSolverOpenMP.inl" + +#endif // PYBAMM_IDAKLU_SOLVEROPENMP_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl new file mode 100644 index 0000000000..aaeaceb41d --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl @@ -0,0 +1,619 @@ +#include "Expressions/Expressions.hpp" +#include "sundials_functions.hpp" + +template +IDAKLUSolverOpenMP::IDAKLUSolverOpenMP( + np_array atol_np_input, + double rel_tol, + np_array rhs_alg_id_input, + int number_of_parameters_input, + int number_of_events_input, + int jac_times_cjmass_nnz_input, + int jac_bandwidth_lower_input, + int jac_bandwidth_upper_input, + std::unique_ptr functions_arg, + const SetupOptions &setup_input, + const SolverOptions &solver_input +) : + atol_np(atol_np_input), + rhs_alg_id(rhs_alg_id_input), + number_of_states(atol_np_input.request().size), + number_of_parameters(number_of_parameters_input), + number_of_events(number_of_events_input), + jac_times_cjmass_nnz(jac_times_cjmass_nnz_input), + jac_bandwidth_lower(jac_bandwidth_lower_input), + jac_bandwidth_upper(jac_bandwidth_upper_input), + functions(std::move(functions_arg)), + setup_opts(setup_input), + solver_opts(solver_input) +{ + // Construction code moved to Initialize() which is called from the + // (child) IDAKLUSolver_* class constructors. + DEBUG("IDAKLUSolverOpenMP:IDAKLUSolverOpenMP"); + auto atol = atol_np.unchecked<1>(); + + // create SUNDIALS context object + SUNContext_Create(NULL, &sunctx); // calls null-wrapper if Sundials Ver<6 + + // allocate memory for solver + ida_mem = IDACreate(sunctx); + + // create the vector of initial values + AllocateVectors(); + if (number_of_parameters > 0) + { + yyS = N_VCloneVectorArray(number_of_parameters, yy); + ypS = N_VCloneVectorArray(number_of_parameters, yp); + } + // set initial values + realtype *atval = N_VGetArrayPointer(avtol); + for (int i = 0; i < number_of_states; i++) + atval[i] = atol[i]; + for (int is = 0; is < number_of_parameters; is++) + { + N_VConst(RCONST(0.0), yyS[is]); + N_VConst(RCONST(0.0), ypS[is]); + } + + // create Matrix objects + SetMatrix(); + + // initialise solver + IDAInit(ida_mem, residual_eval, 0, yy, yp); + + // set tolerances + rtol = RCONST(rel_tol); + IDASVtolerances(ida_mem, rtol, avtol); + + // Set events + IDARootInit(ida_mem, number_of_events, events_eval); + + // Set user data + void *user_data = functions.get(); + IDASetUserData(ida_mem, user_data); + + // Specify preconditioner type + precon_type = SUN_PREC_NONE; + if (this->setup_opts.preconditioner != "none") { + precon_type = SUN_PREC_LEFT; + } +} + +template +void IDAKLUSolverOpenMP::AllocateVectors() { + // Create vectors + yy = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); + yp = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); + avtol = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); + id = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); +} + +template +void IDAKLUSolverOpenMP::SetSolverOptions() { + // Maximum order of the linear multistep method + CheckErrors(IDASetMaxOrd(ida_mem, solver_opts.max_order_bdf)); + + // Maximum number of steps to be taken by the solver in its attempt to reach + // the next output time + CheckErrors(IDASetMaxNumSteps(ida_mem, solver_opts.max_num_steps)); + + // Initial step size + CheckErrors(IDASetInitStep(ida_mem, solver_opts.dt_init)); + + // Maximum absolute step size + CheckErrors(IDASetMaxStep(ida_mem, solver_opts.dt_max)); + + // Maximum number of error test failures in attempting one step + CheckErrors(IDASetMaxErrTestFails(ida_mem, solver_opts.max_error_test_failures)); + + // Maximum number of nonlinear solver iterations at one step + CheckErrors(IDASetMaxNonlinIters(ida_mem, solver_opts.max_nonlinear_iterations)); + + // Maximum number of nonlinear solver convergence failures at one step + CheckErrors(IDASetMaxConvFails(ida_mem, solver_opts.max_convergence_failures)); + + // Safety factor in the nonlinear convergence test + CheckErrors(IDASetNonlinConvCoef(ida_mem, solver_opts.nonlinear_convergence_coefficient)); + + // Suppress algebraic variables from error test + CheckErrors(IDASetSuppressAlg(ida_mem, solver_opts.suppress_algebraic_error)); + + // Positive constant in the Newton iteration convergence test within the initial + // condition calculation + CheckErrors(IDASetNonlinConvCoefIC(ida_mem, solver_opts.nonlinear_convergence_coefficient_ic)); + + // Maximum number of steps allowed when icopt=IDA_YA_YDP_INIT in IDACalcIC + CheckErrors(IDASetMaxNumStepsIC(ida_mem, solver_opts.max_num_steps_ic)); + + // Maximum number of the approximate Jacobian or preconditioner evaluations + // allowed when the Newton iteration appears to be slowly converging + CheckErrors(IDASetMaxNumJacsIC(ida_mem, solver_opts.max_num_jacobians_ic)); + + // Maximum number of Newton iterations allowed in any one attempt to solve + // the initial conditions calculation problem + CheckErrors(IDASetMaxNumItersIC(ida_mem, solver_opts.max_num_iterations_ic)); + + // Maximum number of linesearch backtracks allowed in any Newton iteration, + // when solving the initial conditions calculation problem + CheckErrors(IDASetMaxBacksIC(ida_mem, solver_opts.max_linesearch_backtracks_ic)); + + // Turn off linesearch + CheckErrors(IDASetLineSearchOffIC(ida_mem, solver_opts.linesearch_off_ic)); + + // Ratio between linear and nonlinear tolerances + CheckErrors(IDASetEpsLin(ida_mem, solver_opts.epsilon_linear_tolerance)); + + // Increment factor used in DQ Jv approximation + CheckErrors(IDASetIncrementFactor(ida_mem, solver_opts.increment_factor)); + + int LS_type = SUNLinSolGetType(LS); + if (LS_type == SUNLINEARSOLVER_DIRECT || LS_type == SUNLINEARSOLVER_MATRIX_ITERATIVE) { + // Enable or disable linear solution scaling + CheckErrors(IDASetLinearSolutionScaling(ida_mem, solver_opts.linear_solution_scaling)); + } +} + + + +template +void IDAKLUSolverOpenMP::SetMatrix() { + // Create Matrix object + if (setup_opts.jacobian == "sparse") + { + DEBUG("\tsetting sparse matrix"); + J = SUNSparseMatrix( + number_of_states, + number_of_states, + jac_times_cjmass_nnz, + CSC_MAT, + sunctx + ); + } + else if (setup_opts.jacobian == "banded") { + DEBUG("\tsetting banded matrix"); + J = SUNBandMatrix( + number_of_states, + jac_bandwidth_upper, + jac_bandwidth_lower, + sunctx + ); + } else if (setup_opts.jacobian == "dense" || setup_opts.jacobian == "none") + { + DEBUG("\tsetting dense matrix"); + J = SUNDenseMatrix( + number_of_states, + number_of_states, + sunctx + ); + } + else if (setup_opts.jacobian == "matrix-free") + { + DEBUG("\tsetting matrix-free"); + J = NULL; + } + else + throw std::invalid_argument("Unsupported matrix requested"); +} + +template +void IDAKLUSolverOpenMP::Initialize() { + // Call after setting the solver + + // attach the linear solver + if (LS == nullptr) { + throw std::invalid_argument("Linear solver not set"); + } + CheckErrors(IDASetLinearSolver(ida_mem, LS, J)); + + if (setup_opts.preconditioner != "none") + { + DEBUG("\tsetting IDADDB preconditioner"); + // setup preconditioner + CheckErrors(IDABBDPrecInit( + ida_mem, number_of_states, setup_opts.precon_half_bandwidth, + setup_opts.precon_half_bandwidth, setup_opts.precon_half_bandwidth_keep, + setup_opts.precon_half_bandwidth_keep, 0.0, residual_eval_approx, NULL)); + } + + if (setup_opts.jacobian == "matrix-free") { + CheckErrors(IDASetJacTimes(ida_mem, NULL, jtimes_eval)); + } else if (setup_opts.jacobian != "none") { + CheckErrors(IDASetJacFn(ida_mem, jacobian_eval)); + } + + if (number_of_parameters > 0) + { + CheckErrors(IDASensInit(ida_mem, number_of_parameters, IDA_SIMULTANEOUS, + sensitivities_eval, yyS, ypS)); + CheckErrors(IDASensEEtolerances(ida_mem)); + } + + CheckErrors(SUNLinSolInitialize(LS)); + + auto id_np_val = rhs_alg_id.unchecked<1>(); + realtype *id_val; + id_val = N_VGetArrayPointer(id); + + int ii; + for (ii = 0; ii < number_of_states; ii++) + id_val[ii] = id_np_val[ii]; + + // Variable types: differential (1) and algebraic (0) + CheckErrors(IDASetId(ida_mem, id)); +} + +template +IDAKLUSolverOpenMP::~IDAKLUSolverOpenMP() +{ + // Free memory + if (number_of_parameters > 0) { + IDASensFree(ida_mem); + } + + CheckErrors(SUNLinSolFree(LS)); + + SUNMatDestroy(J); + N_VDestroy(avtol); + N_VDestroy(yy); + N_VDestroy(yp); + N_VDestroy(id); + + if (number_of_parameters > 0) + { + N_VDestroyVectorArray(yyS, number_of_parameters); + N_VDestroyVectorArray(ypS, number_of_parameters); + } + + IDAFree(&ida_mem); + SUNContext_Free(&sunctx); +} + +template +void IDAKLUSolverOpenMP::CalcVars( + realtype *y_return, + size_t length_of_return_vector, + size_t t_i, + realtype *tret, + realtype *yval, + const std::vector& ySval, + realtype *yS_return, + size_t *ySk +) { + DEBUG("IDAKLUSolver::CalcVars"); + // Evaluate functions for each requested variable and store + size_t j = 0; + for (auto& var_fcn : functions->var_fcns) { + (*var_fcn)({tret, yval, functions->inputs.data()}, {&res[0]}); + // store in return vector + for (size_t jj=0; jjnnz_out(); jj++) + y_return[t_i*length_of_return_vector + j++] = res[jj]; + } + // calculate sensitivities + CalcVarsSensitivities(tret, yval, ySval, yS_return, ySk); +} + +template +void IDAKLUSolverOpenMP::CalcVarsSensitivities( + realtype *tret, + realtype *yval, + const std::vector& ySval, + realtype *yS_return, + size_t *ySk +) { + DEBUG("IDAKLUSolver::CalcVarsSensitivities"); + // Calculate sensitivities + std::vector dens_dvar_dp = std::vector(number_of_parameters, 0); + for (size_t dvar_k=0; dvar_kdvar_dy_fcns.size(); dvar_k++) { + // Isolate functions + Expression* dvar_dy = functions->dvar_dy_fcns[dvar_k]; + Expression* dvar_dp = functions->dvar_dp_fcns[dvar_k]; + // Calculate dvar/dy + (*dvar_dy)({tret, yval, functions->inputs.data()}, {&res_dvar_dy[0]}); + // Calculate dvar/dp and convert to dense array for indexing + (*dvar_dp)({tret, yval, functions->inputs.data()}, {&res_dvar_dp[0]}); + for(int k=0; knnz_out(); k++) + dens_dvar_dp[dvar_dp->get_row()[k]] = res_dvar_dp[k]; + // Calculate sensitivities + for(int paramk=0; paramknnz_out(); spk++) + yS_return[*ySk] += res_dvar_dy[spk] * ySval[paramk][dvar_dy->get_col()[spk]]; + (*ySk)++; + } + } +} + +template +Solution IDAKLUSolverOpenMP::solve( + np_array t_np, + np_array y0_np, + np_array yp0_np, + np_array_dense inputs +) +{ + DEBUG("IDAKLUSolver::solve"); + + int number_of_timesteps = t_np.request().size; + auto t = t_np.unchecked<1>(); + realtype t0 = RCONST(t(0)); + auto y0 = y0_np.unchecked<1>(); + auto yp0 = yp0_np.unchecked<1>(); + auto n_coeffs = number_of_states + number_of_parameters * number_of_states; + bool const sensitivity = number_of_parameters > 0; + + if (y0.size() != n_coeffs) { + throw std::domain_error( + "y0 has wrong size. Expected " + std::to_string(n_coeffs) + + " but got " + std::to_string(y0.size())); + } + + if (yp0.size() != n_coeffs) { + throw std::domain_error( + "yp0 has wrong size. Expected " + std::to_string(n_coeffs) + + " but got " + std::to_string(yp0.size())); + } + + // set inputs + auto p_inputs = inputs.unchecked<2>(); + for (int i = 0; i < functions->inputs.size(); i++) + functions->inputs[i] = p_inputs(i, 0); + + // set initial conditions + realtype *yval = N_VGetArrayPointer(yy); + realtype *ypval = N_VGetArrayPointer(yp); + std::vector ySval(number_of_parameters); + std::vector ypSval(number_of_parameters); + for (int p = 0 ; p < number_of_parameters; p++) { + ySval[p] = N_VGetArrayPointer(yyS[p]); + ypSval[p] = N_VGetArrayPointer(ypS[p]); + for (int i = 0; i < number_of_states; i++) { + ySval[p][i] = y0[i + (p + 1) * number_of_states]; + ypSval[p][i] = yp0[i + (p + 1) * number_of_states]; + } + } + + for (int i = 0; i < number_of_states; i++) + { + yval[i] = y0[i]; + ypval[i] = yp0[i]; + } + + SetSolverOptions(); + + CheckErrors(IDAReInit(ida_mem, t0, yy, yp)); + if (sensitivity) { + CheckErrors(IDASensReInit(ida_mem, IDA_SIMULTANEOUS, yyS, ypS)); + } + + // correct initial values + int const init_type = solver_opts.init_all_y_ic ? IDA_Y_INIT : IDA_YA_YDP_INIT; + if (solver_opts.calc_ic) { + DEBUG("IDACalcIC"); + // Do not throw a warning if the initial conditions calculation fails + // as the solver will still run + IDACalcIC(ida_mem, init_type, t(1)); + } + + if (sensitivity) { + CheckErrors(IDAGetSens(ida_mem, &t0, yyS)); + } + + realtype tret; + realtype t_final = t(number_of_timesteps - 1); + + // set return vectors + int length_of_return_vector = 0; + size_t max_res_size = 0; // maximum result size (for common result buffer) + size_t max_res_dvar_dy = 0, max_res_dvar_dp = 0; + if (functions->var_fcns.size() > 0) { + // return only the requested variables list after computation + for (auto& var_fcn : functions->var_fcns) { + max_res_size = std::max(max_res_size, size_t(var_fcn->out_shape(0))); + length_of_return_vector += var_fcn->nnz_out(); + for (auto& dvar_fcn : functions->dvar_dy_fcns) + max_res_dvar_dy = std::max(max_res_dvar_dy, size_t(dvar_fcn->out_shape(0))); + for (auto& dvar_fcn : functions->dvar_dp_fcns) + max_res_dvar_dp = std::max(max_res_dvar_dp, size_t(dvar_fcn->out_shape(0))); + } + } else { + // Return full y state-vector + length_of_return_vector = number_of_states; + } + realtype *t_return = new realtype[number_of_timesteps]; + realtype *y_return = new realtype[number_of_timesteps * + length_of_return_vector]; + realtype *yS_return = new realtype[number_of_parameters * + number_of_timesteps * + length_of_return_vector]; + + res.resize(max_res_size); + res_dvar_dy.resize(max_res_dvar_dy); + res_dvar_dp.resize(max_res_dvar_dp); + + py::capsule free_t_when_done( + t_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + py::capsule free_y_when_done( + y_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + py::capsule free_yS_when_done( + yS_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + + // Initial state (t_i=0) + int t_i = 0; + size_t ySk = 0; + t_return[t_i] = t(t_i); + if (functions->var_fcns.size() > 0) { + // Evaluate functions for each requested variable and store + CalcVars(y_return, length_of_return_vector, t_i, + &tret, yval, ySval, yS_return, &ySk); + } else { + // Retain complete copy of the state vector + for (int j = 0; j < number_of_states; j++) + y_return[j] = yval[j]; + for (int j = 0; j < number_of_parameters; j++) + { + const int base_index = j * number_of_timesteps * number_of_states; + for (int k = 0; k < number_of_states; k++) + yS_return[base_index + k] = ySval[j][k]; + } + } + + // Subsequent states (t_i>0) + int retval; + t_i = 1; + while (true) + { + realtype t_next = t(t_i); + IDASetStopTime(ida_mem, t_next); + DEBUG("IDASolve"); + retval = IDASolve(ida_mem, t_final, &tret, yy, yp, IDA_NORMAL); + + if (!(retval == IDA_TSTOP_RETURN || + retval == IDA_SUCCESS || + retval == IDA_ROOT_RETURN)) { + // failed + break; + } + + if (number_of_parameters > 0) { + CheckErrors(IDAGetSens(ida_mem, &tret, yyS)); + } + + // Evaluate and store results for the time step + t_return[t_i] = tret; + if (functions->var_fcns.size() > 0) { + // Evaluate functions for each requested variable and store + // NOTE: Indexing of yS_return is (time:var:param) + CalcVars(y_return, length_of_return_vector, t_i, + &tret, yval, ySval, yS_return, &ySk); + } else { + // Retain complete copy of the state vector + for (int j = 0; j < number_of_states; j++) + y_return[t_i * number_of_states + j] = yval[j]; + for (int j = 0; j < number_of_parameters; j++) + { + const int base_index = + j * number_of_timesteps * number_of_states + + t_i * number_of_states; + for (int k = 0; k < number_of_states; k++) + // NOTE: Indexing of yS_return is (time:param:yvec) + yS_return[base_index + k] = ySval[j][k]; + } + } + t_i += 1; + + if (retval == IDA_SUCCESS || + retval == IDA_ROOT_RETURN) + break; + } + + np_array t_ret = np_array( + t_i, + &t_return[0], + free_t_when_done + ); + np_array y_ret = np_array( + t_i * length_of_return_vector, + &y_return[0], + free_y_when_done + ); + // Note: Ordering of vector is differnet if computing variables vs returning + // the complete state vector + np_array yS_ret; + if (functions->var_fcns.size() > 0) { + yS_ret = np_array( + std::vector { + number_of_timesteps, + length_of_return_vector, + number_of_parameters + }, + &yS_return[0], + free_yS_when_done + ); + } else { + yS_ret = np_array( + std::vector { + number_of_parameters, + number_of_timesteps, + length_of_return_vector + }, + &yS_return[0], + free_yS_when_done + ); + } + + Solution sol(retval, t_ret, y_ret, yS_ret); + + if (solver_opts.print_stats) + { + long nsteps, nrevals, nlinsetups, netfails; + int klast, kcur; + realtype hinused, hlast, hcur, tcur; + + CheckErrors(IDAGetIntegratorStats( + ida_mem, + &nsteps, + &nrevals, + &nlinsetups, + &netfails, + &klast, + &kcur, + &hinused, + &hlast, + &hcur, + &tcur + )); + + long nniters, nncfails; + CheckErrors(IDAGetNonlinSolvStats(ida_mem, &nniters, &nncfails)); + + long int ngevalsBBDP = 0; + if (setup_opts.using_iterative_solver) + { + CheckErrors(IDABBDPrecGetNumGfnEvals(ida_mem, &ngevalsBBDP)); + } + + py::print("Solver Stats:"); + py::print("\tNumber of steps =", nsteps); + py::print("\tNumber of calls to residual function =", nrevals); + py::print("\tNumber of calls to residual function in preconditioner =", + ngevalsBBDP); + py::print("\tNumber of linear solver setup calls =", nlinsetups); + py::print("\tNumber of error test failures =", netfails); + py::print("\tMethod order used on last step =", klast); + py::print("\tMethod order used on next step =", kcur); + py::print("\tInitial step size =", hinused); + py::print("\tStep size on last step =", hlast); + py::print("\tStep size on next step =", hcur); + py::print("\tCurrent internal time reached =", tcur); + py::print("\tNumber of nonlinear iterations performed =", nniters); + py::print("\tNumber of nonlinear convergence failures =", nncfails); + } + + return sol; +} + +template +void IDAKLUSolverOpenMP::CheckErrors(int const & flag) { + if (flag < 0) { + auto message = (std::string("IDA failed with flag ") + std::to_string(flag)).c_str(); + py::set_error(PyExc_ValueError, message); + throw py::error_already_set(); + } +} diff --git a/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.cpp b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.cpp new file mode 100644 index 0000000000..45ceed0ada --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.cpp @@ -0,0 +1 @@ +#include "IDAKLUSolverOpenMP_solvers.hpp" diff --git a/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.hpp b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.hpp new file mode 100644 index 0000000000..5f6f29b47b --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.hpp @@ -0,0 +1,131 @@ +#ifndef PYBAMM_IDAKLU_CASADI_SOLVER_OPENMP_HPP +#define PYBAMM_IDAKLU_CASADI_SOLVER_OPENMP_HPP + +#include "IDAKLUSolverOpenMP.hpp" + +/** + * @brief IDAKLUSolver Dense implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_Dense : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_Dense(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_Dense(Base::yy, Base::J, Base::sunctx); + Base::Initialize(); + } +}; + +/** + * @brief IDAKLUSolver KLU implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_KLU : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_KLU(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_KLU(Base::yy, Base::J, Base::sunctx); + Base::Initialize(); + } +}; + +/** + * @brief IDAKLUSolver Banded implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_Band : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_Band(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_Band(Base::yy, Base::J, Base::sunctx); + Base::Initialize(); + } +}; + +/** + * @brief IDAKLUSolver SPBCGS implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_SPBCGS : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_SPBCGS(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_SPBCGS( + Base::yy, + Base::precon_type, + Base::setup_opts.linsol_max_iterations, + Base::sunctx + ); + Base::Initialize(); + } +}; + +/** + * @brief IDAKLUSolver SPFGMR implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_SPFGMR : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_SPFGMR(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_SPFGMR( + Base::yy, + Base::precon_type, + Base::setup_opts.linsol_max_iterations, + Base::sunctx + ); + Base::Initialize(); + } +}; + +/** + * @brief IDAKLUSolver SPGMR implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_SPGMR : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_SPGMR(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_SPGMR( + Base::yy, + Base::precon_type, + Base::setup_opts.linsol_max_iterations, + Base::sunctx + ); + Base::Initialize(); + } +}; + +/** + * @brief IDAKLUSolver SPTFQMR implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_SPTFQMR : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_SPTFQMR(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_SPTFQMR( + Base::yy, + Base::precon_type, + Base::setup_opts.linsol_max_iterations, + Base::sunctx + ); + Base::Initialize(); + } +}; + +#endif // PYBAMM_IDAKLU_CASADI_SOLVER_OPENMP_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/idaklu_jax.cpp b/pybamm/solvers/c_solvers/idaklu/IdakluJax.cpp similarity index 99% rename from pybamm/solvers/c_solvers/idaklu/idaklu_jax.cpp rename to pybamm/solvers/c_solvers/idaklu/IdakluJax.cpp index b338560259..15c2b2d811 100644 --- a/pybamm/solvers/c_solvers/idaklu/idaklu_jax.cpp +++ b/pybamm/solvers/c_solvers/idaklu/IdakluJax.cpp @@ -1,4 +1,4 @@ -#include "idaklu_jax.hpp" +#include "IdakluJax.hpp" #include #include diff --git a/pybamm/solvers/c_solvers/idaklu/idaklu_jax.hpp b/pybamm/solvers/c_solvers/idaklu/IdakluJax.hpp similarity index 100% rename from pybamm/solvers/c_solvers/idaklu/idaklu_jax.hpp rename to pybamm/solvers/c_solvers/idaklu/IdakluJax.hpp diff --git a/pybamm/solvers/c_solvers/idaklu/options.cpp b/pybamm/solvers/c_solvers/idaklu/Options.cpp similarity index 60% rename from pybamm/solvers/c_solvers/idaklu/options.cpp rename to pybamm/solvers/c_solvers/idaklu/Options.cpp index efad4d5de0..6a7545b627 100644 --- a/pybamm/solvers/c_solvers/idaklu/options.cpp +++ b/pybamm/solvers/c_solvers/idaklu/Options.cpp @@ -1,19 +1,18 @@ -#include "options.hpp" +#include "Options.hpp" #include #include using namespace std::string_literals; -Options::Options(py::dict options) - : print_stats(options["print_stats"].cast()), - jacobian(options["jacobian"].cast()), - preconditioner(options["preconditioner"].cast()), - linsol_max_iterations(options["linsol_max_iterations"].cast()), - linear_solver(options["linear_solver"].cast()), - precon_half_bandwidth(options["precon_half_bandwidth"].cast()), - precon_half_bandwidth_keep(options["precon_half_bandwidth_keep"].cast()), - num_threads(options["num_threads"].cast()) +SetupOptions::SetupOptions(py::dict &py_opts) + : jacobian(py_opts["jacobian"].cast()), + preconditioner(py_opts["preconditioner"].cast()), + precon_half_bandwidth(py_opts["precon_half_bandwidth"].cast()), + precon_half_bandwidth_keep(py_opts["precon_half_bandwidth_keep"].cast()), + num_threads(py_opts["num_threads"].cast()), + linear_solver(py_opts["linear_solver"].cast()), + linsol_max_iterations(py_opts["linsol_max_iterations"].cast()) { using_sparse_matrix = true; @@ -119,3 +118,32 @@ Options::Options(py::dict options) preconditioner = "none"; } } + +SolverOptions::SolverOptions(py::dict &py_opts) + : print_stats(py_opts["print_stats"].cast()), + // IDA main solver + max_order_bdf(py_opts["max_order_bdf"].cast()), + max_num_steps(py_opts["max_num_steps"].cast()), + dt_init(RCONST(py_opts["dt_init"].cast())), + dt_max(RCONST(py_opts["dt_max"].cast())), + max_error_test_failures(py_opts["max_error_test_failures"].cast()), + max_nonlinear_iterations(py_opts["max_nonlinear_iterations"].cast()), + max_convergence_failures(py_opts["max_convergence_failures"].cast()), + nonlinear_convergence_coefficient(RCONST(py_opts["nonlinear_convergence_coefficient"].cast())), + nonlinear_convergence_coefficient_ic(RCONST(py_opts["nonlinear_convergence_coefficient_ic"].cast())), + suppress_algebraic_error(py_opts["suppress_algebraic_error"].cast()), + // IDA initial conditions calculation + calc_ic(py_opts["calc_ic"].cast()), + init_all_y_ic(py_opts["init_all_y_ic"].cast()), + max_num_steps_ic(py_opts["max_num_steps_ic"].cast()), + max_num_jacobians_ic(py_opts["max_num_jacobians_ic"].cast()), + max_num_iterations_ic(py_opts["max_num_iterations_ic"].cast()), + max_linesearch_backtracks_ic(py_opts["max_linesearch_backtracks_ic"].cast()), + linesearch_off_ic(py_opts["linesearch_off_ic"].cast()), + // IDALS linear solver interface + linear_solution_scaling(py_opts["linear_solution_scaling"].cast()), + epsilon_linear_tolerance(RCONST(py_opts["epsilon_linear_tolerance"].cast())), + increment_factor(RCONST(py_opts["increment_factor"].cast())) +{ + +} diff --git a/pybamm/solvers/c_solvers/idaklu/Options.hpp b/pybamm/solvers/c_solvers/idaklu/Options.hpp new file mode 100644 index 0000000000..66a175cfff --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Options.hpp @@ -0,0 +1,55 @@ +#ifndef PYBAMM_OPTIONS_HPP +#define PYBAMM_OPTIONS_HPP + +#include "common.hpp" + +/** + * @brief SetupOptions passed to the idaklu setup by pybamm + */ +struct SetupOptions { + bool using_sparse_matrix; + bool using_banded_matrix; + bool using_iterative_solver; + std::string jacobian; + std::string preconditioner; // spbcg + int precon_half_bandwidth; + int precon_half_bandwidth_keep; + int num_threads; + // IDALS linear solver interface + std::string linear_solver; // klu, lapack, spbcg + int linsol_max_iterations; + explicit SetupOptions(py::dict &py_opts); +}; + +/** + * @brief SolverOptions passed to the idaklu solver by pybamm + */ +struct SolverOptions { + bool print_stats; + // IDA main solver + int max_order_bdf; + int max_num_steps; + double dt_init; + double dt_max; + int max_error_test_failures; + int max_nonlinear_iterations; + int max_convergence_failures; + double nonlinear_convergence_coefficient; + double nonlinear_convergence_coefficient_ic; + sunbooleantype suppress_algebraic_error; + // IDA initial conditions calculation + bool calc_ic; + bool init_all_y_ic; + int max_num_steps_ic; + int max_num_jacobians_ic; + int max_num_iterations_ic; + int max_linesearch_backtracks_ic; + sunbooleantype linesearch_off_ic; + // IDALS linear solver interface + sunbooleantype linear_solution_scaling; + double epsilon_linear_tolerance; + double increment_factor; + explicit SolverOptions(py::dict &py_opts); +}; + +#endif // PYBAMM_OPTIONS_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Solution.cpp b/pybamm/solvers/c_solvers/idaklu/Solution.cpp new file mode 100644 index 0000000000..7b50364379 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Solution.cpp @@ -0,0 +1 @@ +#include "Solution.hpp" diff --git a/pybamm/solvers/c_solvers/idaklu/solution.hpp b/pybamm/solvers/c_solvers/idaklu/Solution.hpp similarity index 100% rename from pybamm/solvers/c_solvers/idaklu/solution.hpp rename to pybamm/solvers/c_solvers/idaklu/Solution.hpp diff --git a/pybamm/solvers/c_solvers/idaklu/casadi_functions.cpp b/pybamm/solvers/c_solvers/idaklu/casadi_functions.cpp deleted file mode 100644 index ddad4612c9..0000000000 --- a/pybamm/solvers/c_solvers/idaklu/casadi_functions.cpp +++ /dev/null @@ -1,105 +0,0 @@ -#include "casadi_functions.hpp" - -CasadiFunction::CasadiFunction(const Function &f) : m_func(f) -{ - size_t sz_arg; - size_t sz_res; - size_t sz_iw; - size_t sz_w; - m_func.sz_work(sz_arg, sz_res, sz_iw, sz_w); - //int nnz = (sz_res>0) ? m_func.nnz_out() : 0; - //std::cout << "name = "<< m_func.name() << " arg = " << sz_arg << " res = " - // << sz_res << " iw = " << sz_iw << " w = " << sz_w << " nnz = " << nnz << - // std::endl; - m_arg.resize(sz_arg, nullptr); - m_res.resize(sz_res, nullptr); - m_iw.resize(sz_iw, 0); - m_w.resize(sz_w, 0); -} - -// only call this once m_arg and m_res have been set appropriately -void CasadiFunction::operator()() -{ - int mem = m_func.checkout(); - m_func(m_arg.data(), m_res.data(), m_iw.data(), m_w.data(), mem); - m_func.release(mem); -} - -casadi_int CasadiFunction::nnz_out() { - return m_func.nnz_out(); -} - -casadi::Sparsity CasadiFunction::sparsity_out(casadi_int ind) { - return m_func.sparsity_out(ind); -} - -void CasadiFunction::operator()(const std::vector& inputs, - const std::vector& results) -{ - // Set-up input arguments, provide result vector, then execute function - // Example call: fcn({in1, in2, in3}, {out1}) - for(size_t k=0; k& var_casadi_fcns, - const std::vector& dvar_dy_fcns, - const std::vector& dvar_dp_fcns, - const Options& options) - : number_of_states(n_s), number_of_events(n_e), number_of_parameters(n_p), - number_of_nnz(jac_times_cjmass_nnz), - jac_bandwidth_lower(jac_bandwidth_lower), jac_bandwidth_upper(jac_bandwidth_upper), - rhs_alg(rhs_alg), - jac_times_cjmass(jac_times_cjmass), jac_action(jac_action), - mass_action(mass_action), sens(sens), events(events), - tmp_state_vector(number_of_states), - tmp_sparse_jacobian_data(jac_times_cjmass_nnz), - options(options) -{ - // convert casadi::Function list to CasadiFunction list - for (auto& var : var_casadi_fcns) { - this->var_casadi_fcns.push_back(CasadiFunction(*var)); - } - for (auto& var : dvar_dy_fcns) { - this->dvar_dy_fcns.push_back(CasadiFunction(*var)); - } - for (auto& var : dvar_dp_fcns) { - this->dvar_dp_fcns.push_back(CasadiFunction(*var)); - } - - // copy across numpy array values - const int n_row_vals = jac_times_cjmass_rowvals_arg.request().size; - auto p_jac_times_cjmass_rowvals = jac_times_cjmass_rowvals_arg.unchecked<1>(); - jac_times_cjmass_rowvals.resize(n_row_vals); - for (int i = 0; i < n_row_vals; i++) { - jac_times_cjmass_rowvals[i] = p_jac_times_cjmass_rowvals[i]; - } - - const int n_col_ptrs = jac_times_cjmass_colptrs_arg.request().size; - auto p_jac_times_cjmass_colptrs = jac_times_cjmass_colptrs_arg.unchecked<1>(); - jac_times_cjmass_colptrs.resize(n_col_ptrs); - for (int i = 0; i < n_col_ptrs; i++) { - jac_times_cjmass_colptrs[i] = p_jac_times_cjmass_colptrs[i]; - } - - inputs.resize(inputs_length); -} - -realtype *CasadiFunctions::get_tmp_state_vector() { - return tmp_state_vector.data(); -} -realtype *CasadiFunctions::get_tmp_sparse_jacobian_data() { - return tmp_sparse_jacobian_data.data(); -} diff --git a/pybamm/solvers/c_solvers/idaklu/casadi_functions.hpp b/pybamm/solvers/c_solvers/idaklu/casadi_functions.hpp deleted file mode 100644 index 1a33b957f8..0000000000 --- a/pybamm/solvers/c_solvers/idaklu/casadi_functions.hpp +++ /dev/null @@ -1,160 +0,0 @@ -#ifndef PYBAMM_IDAKLU_CASADI_FUNCTIONS_HPP -#define PYBAMM_IDAKLU_CASADI_FUNCTIONS_HPP - -#include "common.hpp" -#include "options.hpp" -#include -#include -#include - -/** - * Utility function to convert compressed-sparse-column (CSC) to/from - * compressed-sparse-row (CSR) matrix representation. Conversion is symmetric / - * invertible using this function. - * @brief Utility function to convert to/from CSC/CSR matrix representations. - * @param f Data vector containing the sparse matrix elements - * @param c Index pointer to column starts - * @param r Array of row indices - * @param nf New data vector that will contain the transformed sparse matrix - * @param nc New array of column indices - * @param nr New index pointer to row starts - */ -template -void csc_csr(const realtype f[], const T1 c[], const T1 r[], realtype nf[], T2 nc[], T2 nr[], int N, int cols) { - std::vector nn(cols+1); - std::vector rr(N); - for (int i=0; i& inputs, - const std::vector& results); - - /** - * @brief Return the number of non-zero elements for the function output - */ - casadi_int nnz_out(); - - /** - * @brief Return the number of non-zero elements for the function output - */ - casadi::Sparsity sparsity_out(casadi_int ind); - -public: - std::vector m_arg; - std::vector m_res; - -private: - const Function &m_func; - std::vector m_iw; - std::vector m_w; -}; - -/** - * @brief Class for handling casadi functions - */ -class CasadiFunctions -{ -public: - /** - * @brief Create a new CasadiFunctions object - */ - CasadiFunctions( - const Function &rhs_alg, - const Function &jac_times_cjmass, - const int jac_times_cjmass_nnz, - const int jac_bandwidth_lower, - const int jac_bandwidth_upper, - const np_array_int &jac_times_cjmass_rowvals, - const np_array_int &jac_times_cjmass_colptrs, - const int inputs_length, - const Function &jac_action, - const Function &mass_action, - const Function &sens, - const Function &events, - const int n_s, - const int n_e, - const int n_p, - const std::vector& var_casadi_fcns, - const std::vector& dvar_dy_fcns, - const std::vector& dvar_dp_fcns, - const Options& options - ); - -public: - int number_of_states; - int number_of_parameters; - int number_of_events; - int number_of_nnz; - int jac_bandwidth_lower; - int jac_bandwidth_upper; - - CasadiFunction rhs_alg; - CasadiFunction sens; - CasadiFunction jac_times_cjmass; - CasadiFunction jac_action; - CasadiFunction mass_action; - CasadiFunction events; - - // NB: cppcheck-suppress unusedStructMember is used because codacy reports - // these members as unused even though they are important - std::vector var_casadi_fcns; // cppcheck-suppress unusedStructMember - std::vector dvar_dy_fcns; // cppcheck-suppress unusedStructMember - std::vector dvar_dp_fcns; // cppcheck-suppress unusedStructMember - - std::vector jac_times_cjmass_rowvals; - std::vector jac_times_cjmass_colptrs; - std::vector inputs; - - Options options; - - realtype *get_tmp_state_vector(); - realtype *get_tmp_sparse_jacobian_data(); - -private: - std::vector tmp_state_vector; - std::vector tmp_sparse_jacobian_data; -}; - -#endif // PYBAMM_IDAKLU_CASADI_FUNCTIONS_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/casadi_solver.hpp b/pybamm/solvers/c_solvers/idaklu/casadi_solver.hpp deleted file mode 100644 index 335907a93a..0000000000 --- a/pybamm/solvers/c_solvers/idaklu/casadi_solver.hpp +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef PYBAMM_IDAKLU_CREATE_CASADI_SOLVER_HPP -#define PYBAMM_IDAKLU_CREATE_CASADI_SOLVER_HPP - -#include "CasadiSolver.hpp" - -/** - * Creates a concrete casadi solver given a linear solver, as specified in - * options_cpp.linear_solver. - * @brief Create a concrete casadi solver given a linear solver - */ -CasadiSolver *create_casadi_solver( - int number_of_states, - int number_of_parameters, - const Function &rhs_alg, - const Function &jac_times_cjmass, - const np_array_int &jac_times_cjmass_colptrs, - const np_array_int &jac_times_cjmass_rowvals, - const int jac_times_cjmass_nnz, - const int jac_bandwidth_lower, - const int jac_bandwidth_upper, - const Function &jac_action, - const Function &mass_action, - const Function &sens, - const Function &event, - const int number_of_events, - np_array rhs_alg_id, - np_array atol_np, - double rel_tol, - int inputs_length, - const std::vector& var_casadi_fcns, - const std::vector& dvar_dy_fcns, - const std::vector& dvar_dp_fcns, - py::dict options -); - -#endif // PYBAMM_IDAKLU_CREATE_CASADI_SOLVER_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.hpp b/pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.hpp deleted file mode 100644 index a2192030b4..0000000000 --- a/pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.hpp +++ /dev/null @@ -1,27 +0,0 @@ -#ifndef PYBAMM_IDAKLU_CASADI_SUNDIALS_FUNCTIONS_HPP -#define PYBAMM_IDAKLU_CASADI_SUNDIALS_FUNCTIONS_HPP - -#include "common.hpp" - -int residual_casadi(realtype tres, N_Vector yy, N_Vector yp, N_Vector rr, - void *user_data); - -int jtimes_casadi(realtype tt, N_Vector yy, N_Vector yp, N_Vector rr, - N_Vector v, N_Vector Jv, realtype cj, void *user_data, - N_Vector tmp1, N_Vector tmp2); - -int events_casadi(realtype t, N_Vector yy, N_Vector yp, realtype *events_ptr, - void *user_data); - -int sensitivities_casadi(int Ns, realtype t, N_Vector yy, N_Vector yp, - N_Vector resval, N_Vector *yS, N_Vector *ypS, - N_Vector *resvalS, void *user_data, N_Vector tmp1, - N_Vector tmp2, N_Vector tmp3); - -int jacobian_casadi(realtype tt, realtype cj, N_Vector yy, N_Vector yp, - N_Vector resvec, SUNMatrix JJ, void *user_data, - N_Vector tempv1, N_Vector tempv2, N_Vector tempv3); - -int residual_casadi_approx(sunindextype Nlocal, realtype tt, N_Vector yy, - N_Vector yp, N_Vector gval, void *user_data); -#endif // PYBAMM_IDAKLU_CASADI_SUNDIALS_FUNCTIONS_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/common.hpp b/pybamm/solvers/c_solvers/idaklu/common.hpp index e0abbb5a1d..0ef7ee60a0 100644 --- a/pybamm/solvers/c_solvers/idaklu/common.hpp +++ b/pybamm/solvers/c_solvers/idaklu/common.hpp @@ -1,6 +1,8 @@ #ifndef PYBAMM_IDAKLU_COMMON_HPP #define PYBAMM_IDAKLU_COMMON_HPP +#include + #include /* prototypes for IDAS fcts., consts. */ #include /* access to IDABBDPRE preconditioner */ @@ -33,16 +35,58 @@ using np_array = py::array_t; using np_array_dense = py::array_t; using np_array_int = py::array_t; +/** + * Utility function to convert compressed-sparse-column (CSC) to/from + * compressed-sparse-row (CSR) matrix representation. Conversion is symmetric / + * invertible using this function. + * @brief Utility function to convert to/from CSC/CSR matrix representations. + * @param f Data vector containing the sparse matrix elements + * @param c Index pointer to column starts + * @param r Array of row indices + * @param nf New data vector that will contain the transformed sparse matrix + * @param nc New array of column indices + * @param nr New index pointer to row starts + */ +template +void csc_csr(const realtype f[], const T1 c[], const T1 r[], realtype nf[], T2 nc[], T2 nr[], int N, int cols) { + std::vector nn(cols+1); + std::vector rr(N); + for (int i=0; i; } \ std::cout << "]" << std::endl; } -#define DEBUG_v(v, N) {\ +#define DEBUG_v(v, M) {\ + int N = 2; \ std::cout << #v << "[n=" << N << "] = ["; \ for (int i = 0; i < N; i++) { \ std::cout << v[i]; \ @@ -82,6 +127,13 @@ using np_array_int = py::array_t; std::cerr << __FILE__ << ":" << __LINE__ << "," << #x << " = " << x << std::endl; \ } +#define ASSERT(x) { \ + if (!(x)) { \ + std::cerr << __FILE__ << ":" << __LINE__ << " Assertion failed: " << #x << std::endl; \ + throw std::runtime_error("Assertion failed: " #x); \ + } \ + } + #endif #endif // PYBAMM_IDAKLU_COMMON_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp b/pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp similarity index 55% rename from pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp rename to pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp index 9fcfa06510..ce1765aa82 100644 --- a/pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp +++ b/pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp @@ -1,37 +1,43 @@ -#include "casadi_solver.hpp" -#include "CasadiSolver.hpp" -#include "CasadiSolverOpenMP_solvers.hpp" -#include "casadi_sundials_functions.hpp" -#include "common.hpp" +#ifndef PYBAMM_CREATE_IDAKLU_SOLVER_HPP +#define PYBAMM_CREATE_IDAKLU_SOLVER_HPP + +#include "IDAKLUSolverOpenMP_solvers.hpp" #include #include -CasadiSolver *create_casadi_solver( +/** + * Creates a concrete solver given a linear solver, as specified in + * options_cpp.linear_solver. + * @brief Create a concrete solver given a linear solver + */ +template +IDAKLUSolver *create_idaklu_solver( int number_of_states, int number_of_parameters, - const Function &rhs_alg, - const Function &jac_times_cjmass, + const typename ExprSet::BaseFunctionType &rhs_alg, + const typename ExprSet::BaseFunctionType &jac_times_cjmass, const np_array_int &jac_times_cjmass_colptrs, const np_array_int &jac_times_cjmass_rowvals, const int jac_times_cjmass_nnz, const int jac_bandwidth_lower, const int jac_bandwidth_upper, - const Function &jac_action, - const Function &mass_action, - const Function &sens, - const Function &events, + const typename ExprSet::BaseFunctionType &jac_action, + const typename ExprSet::BaseFunctionType &mass_action, + const typename ExprSet::BaseFunctionType &sens, + const typename ExprSet::BaseFunctionType &events, const int number_of_events, np_array rhs_alg_id, np_array atol_np, double rel_tol, int inputs_length, - const std::vector& var_casadi_fcns, - const std::vector& dvar_dy_fcns, - const std::vector& dvar_dp_fcns, - py::dict options + const std::vector& var_fcns, + const std::vector& dvar_dy_fcns, + const std::vector& dvar_dp_fcns, + py::dict py_opts ) { - auto options_cpp = Options(options); - auto functions = std::make_unique( + auto setup_opts = SetupOptions(py_opts); + auto solver_opts = SolverOptions(py_opts); + auto functions = std::make_unique( rhs_alg, jac_times_cjmass, jac_times_cjmass_nnz, @@ -47,19 +53,19 @@ CasadiSolver *create_casadi_solver( number_of_states, number_of_events, number_of_parameters, - var_casadi_fcns, + var_fcns, dvar_dy_fcns, dvar_dp_fcns, - options_cpp + setup_opts ); - CasadiSolver *casadiSolver = nullptr; + IDAKLUSolver *idakluSolver = nullptr; // Instantiate solver class - if (options_cpp.linear_solver == "SUNLinSol_Dense") + if (setup_opts.linear_solver == "SUNLinSol_Dense") { DEBUG("\tsetting SUNLinSol_Dense linear solver"); - casadiSolver = new CasadiSolverOpenMP_Dense( + idakluSolver = new IDAKLUSolverOpenMP_Dense( atol_np, rel_tol, rhs_alg_id, @@ -69,13 +75,14 @@ CasadiSolver *create_casadi_solver( jac_bandwidth_lower, jac_bandwidth_upper, std::move(functions), - options_cpp + setup_opts, + solver_opts ); } - else if (options_cpp.linear_solver == "SUNLinSol_KLU") + else if (setup_opts.linear_solver == "SUNLinSol_KLU") { DEBUG("\tsetting SUNLinSol_KLU linear solver"); - casadiSolver = new CasadiSolverOpenMP_KLU( + idakluSolver = new IDAKLUSolverOpenMP_KLU( atol_np, rel_tol, rhs_alg_id, @@ -85,13 +92,14 @@ CasadiSolver *create_casadi_solver( jac_bandwidth_lower, jac_bandwidth_upper, std::move(functions), - options_cpp + setup_opts, + solver_opts ); } - else if (options_cpp.linear_solver == "SUNLinSol_Band") + else if (setup_opts.linear_solver == "SUNLinSol_Band") { DEBUG("\tsetting SUNLinSol_Band linear solver"); - casadiSolver = new CasadiSolverOpenMP_Band( + idakluSolver = new IDAKLUSolverOpenMP_Band( atol_np, rel_tol, rhs_alg_id, @@ -101,13 +109,14 @@ CasadiSolver *create_casadi_solver( jac_bandwidth_lower, jac_bandwidth_upper, std::move(functions), - options_cpp + setup_opts, + solver_opts ); } - else if (options_cpp.linear_solver == "SUNLinSol_SPBCGS") + else if (setup_opts.linear_solver == "SUNLinSol_SPBCGS") { DEBUG("\tsetting SUNLinSol_SPBCGS_linear solver"); - casadiSolver = new CasadiSolverOpenMP_SPBCGS( + idakluSolver = new IDAKLUSolverOpenMP_SPBCGS( atol_np, rel_tol, rhs_alg_id, @@ -117,13 +126,14 @@ CasadiSolver *create_casadi_solver( jac_bandwidth_lower, jac_bandwidth_upper, std::move(functions), - options_cpp + setup_opts, + solver_opts ); } - else if (options_cpp.linear_solver == "SUNLinSol_SPFGMR") + else if (setup_opts.linear_solver == "SUNLinSol_SPFGMR") { DEBUG("\tsetting SUNLinSol_SPFGMR_linear solver"); - casadiSolver = new CasadiSolverOpenMP_SPFGMR( + idakluSolver = new IDAKLUSolverOpenMP_SPFGMR( atol_np, rel_tol, rhs_alg_id, @@ -133,13 +143,14 @@ CasadiSolver *create_casadi_solver( jac_bandwidth_lower, jac_bandwidth_upper, std::move(functions), - options_cpp + setup_opts, + solver_opts ); } - else if (options_cpp.linear_solver == "SUNLinSol_SPGMR") + else if (setup_opts.linear_solver == "SUNLinSol_SPGMR") { DEBUG("\tsetting SUNLinSol_SPGMR solver"); - casadiSolver = new CasadiSolverOpenMP_SPGMR( + idakluSolver = new IDAKLUSolverOpenMP_SPGMR( atol_np, rel_tol, rhs_alg_id, @@ -149,13 +160,14 @@ CasadiSolver *create_casadi_solver( jac_bandwidth_lower, jac_bandwidth_upper, std::move(functions), - options_cpp + setup_opts, + solver_opts ); } - else if (options_cpp.linear_solver == "SUNLinSol_SPTFQMR") + else if (setup_opts.linear_solver == "SUNLinSol_SPTFQMR") { DEBUG("\tsetting SUNLinSol_SPGMR solver"); - casadiSolver = new CasadiSolverOpenMP_SPTFQMR( + idakluSolver = new IDAKLUSolverOpenMP_SPTFQMR( atol_np, rel_tol, rhs_alg_id, @@ -165,13 +177,16 @@ CasadiSolver *create_casadi_solver( jac_bandwidth_lower, jac_bandwidth_upper, std::move(functions), - options_cpp + setup_opts, + solver_opts ); } - if (casadiSolver == nullptr) { + if (idakluSolver == nullptr) { throw std::invalid_argument("Unsupported solver requested"); } - return casadiSolver; + return idakluSolver; } + +#endif // PYBAMM_CREATE_IDAKLU_SOLVER_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/options.hpp b/pybamm/solvers/c_solvers/idaklu/options.hpp deleted file mode 100644 index b70d0f4a30..0000000000 --- a/pybamm/solvers/c_solvers/idaklu/options.hpp +++ /dev/null @@ -1,25 +0,0 @@ -#ifndef PYBAMM_OPTIONS_HPP -#define PYBAMM_OPTIONS_HPP - -#include "common.hpp" - -/** - * @brief Options passed to the idaklu solver by pybamm - */ -struct Options { - bool print_stats; - bool using_sparse_matrix; - bool using_banded_matrix; - bool using_iterative_solver; - std::string jacobian; - std::string linear_solver; // klu, lapack, spbcg - std::string preconditioner; // spbcg - int linsol_max_iterations; - int precon_half_bandwidth; - int precon_half_bandwidth_keep; - int num_threads; - explicit Options(py::dict options); - -}; - -#endif // PYBAMM_OPTIONS_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/python.hpp b/pybamm/solvers/c_solvers/idaklu/python.hpp index 0478d0946f..6231d13eb6 100644 --- a/pybamm/solvers/c_solvers/idaklu/python.hpp +++ b/pybamm/solvers/c_solvers/idaklu/python.hpp @@ -2,7 +2,7 @@ #define PYBAMM_IDAKLU_HPP #include "common.hpp" -#include "solution.hpp" +#include "Solution.hpp" #include using residual_type = std::function< diff --git a/pybamm/solvers/c_solvers/idaklu/solution.cpp b/pybamm/solvers/c_solvers/idaklu/solution.cpp deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/pybamm/solvers/c_solvers/idaklu/sundials_functions.hpp b/pybamm/solvers/c_solvers/idaklu/sundials_functions.hpp new file mode 100644 index 0000000000..c4024bc20a --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/sundials_functions.hpp @@ -0,0 +1,36 @@ +#ifndef PYBAMM_SUNDIALS_FUNCTIONS_HPP +#define PYBAMM_SUNDIALS_FUNCTIONS_HPP + +#include "common.hpp" + +template +void axpy(int n, T alpha, const T* x, T* y) { + if (!x || !y) return; + for (int i=0; i #define NV_DATA NV_DATA_OMP // Serial: NV_DATA_S -int residual_casadi(realtype tres, N_Vector yy, N_Vector yp, N_Vector rr, - void *user_data) +template +int residual_eval(realtype tres, N_Vector yy, N_Vector yp, N_Vector rr, void *user_data) { - DEBUG("residual_casadi"); - CasadiFunctions *p_python_functions = - static_cast(user_data); + DEBUG("residual_eval"); + ExpressionSet *p_python_functions = + static_cast *>(user_data); - p_python_functions->rhs_alg.m_arg[0] = &tres; - p_python_functions->rhs_alg.m_arg[1] = NV_DATA(yy); - p_python_functions->rhs_alg.m_arg[2] = p_python_functions->inputs.data(); - p_python_functions->rhs_alg.m_res[0] = NV_DATA(rr); - p_python_functions->rhs_alg(); + DEBUG_VECTORn(yy, 100); + DEBUG_VECTORn(yp, 100); + + p_python_functions->rhs_alg->m_arg[0] = &tres; + p_python_functions->rhs_alg->m_arg[1] = NV_DATA(yy); + p_python_functions->rhs_alg->m_arg[2] = p_python_functions->inputs.data(); + p_python_functions->rhs_alg->m_res[0] = NV_DATA(rr); + (*p_python_functions->rhs_alg)(); + + DEBUG_VECTORn(rr, 100); realtype *tmp = p_python_functions->get_tmp_state_vector(); - p_python_functions->mass_action.m_arg[0] = NV_DATA(yp); - p_python_functions->mass_action.m_res[0] = tmp; - p_python_functions->mass_action(); + p_python_functions->mass_action->m_arg[0] = NV_DATA(yp); + p_python_functions->mass_action->m_res[0] = tmp; + (*p_python_functions->mass_action)(); // AXPY: y <- a*x + y const int ns = p_python_functions->number_of_states; - casadi::casadi_axpy(ns, -1., tmp, NV_DATA(rr)); + axpy(ns, -1., tmp, NV_DATA(rr)); - //DEBUG_VECTOR(yy); - //DEBUG_VECTOR(yp); - //DEBUG_VECTOR(rr); + DEBUG("mass - rhs"); + DEBUG_VECTORn(rr, 100); // now rr has rhs_alg(t, y) - mass_matrix * yp return 0; @@ -64,13 +68,14 @@ int residual_casadi(realtype tres, N_Vector yy, N_Vector yp, N_Vector rr, // within user_data. // // The case where G is mathematically identical to F is allowed. -int residual_casadi_approx(sunindextype Nlocal, realtype tt, N_Vector yy, +template +int residual_eval_approx(sunindextype Nlocal, realtype tt, N_Vector yy, N_Vector yp, N_Vector gval, void *user_data) { - DEBUG("residual_casadi_approx"); + DEBUG("residual_eval_approx"); // Just use true residual for now - int result = residual_casadi(tt, yy, yp, gval, user_data); + int result = residual_eval(tt, yy, yp, gval, user_data); return result; } @@ -94,32 +99,35 @@ int residual_casadi_approx(sunindextype Nlocal, realtype tt, N_Vector yy, // tmp2 are pointers to memory allocated for variables of type N Vector // which can // be used by IDALsJacTimesVecFn as temporary storage or work space. -int jtimes_casadi(realtype tt, N_Vector yy, N_Vector yp, N_Vector rr, +template +int jtimes_eval(realtype tt, N_Vector yy, N_Vector yp, N_Vector rr, N_Vector v, N_Vector Jv, realtype cj, void *user_data, N_Vector tmp1, N_Vector tmp2) { - DEBUG("jtimes_casadi"); - CasadiFunctions *p_python_functions = - static_cast(user_data); + DEBUG("jtimes_eval"); + T *p_python_functions = + static_cast(user_data); // Jv has ∂F/∂y v - p_python_functions->jac_action.m_arg[0] = &tt; - p_python_functions->jac_action.m_arg[1] = NV_DATA(yy); - p_python_functions->jac_action.m_arg[2] = p_python_functions->inputs.data(); - p_python_functions->jac_action.m_arg[3] = NV_DATA(v); - p_python_functions->jac_action.m_res[0] = NV_DATA(Jv); - p_python_functions->jac_action(); + p_python_functions->jac_action->m_arg[0] = &tt; + p_python_functions->jac_action->m_arg[1] = NV_DATA(yy); + p_python_functions->jac_action->m_arg[2] = p_python_functions->inputs.data(); + p_python_functions->jac_action->m_arg[3] = NV_DATA(v); + p_python_functions->jac_action->m_res[0] = NV_DATA(Jv); + (*p_python_functions->jac_action)(); // tmp has -∂F/∂y˙ v realtype *tmp = p_python_functions->get_tmp_state_vector(); - p_python_functions->mass_action.m_arg[0] = NV_DATA(v); - p_python_functions->mass_action.m_res[0] = tmp; - p_python_functions->mass_action(); + p_python_functions->mass_action->m_arg[0] = NV_DATA(v); + p_python_functions->mass_action->m_res[0] = tmp; + (*p_python_functions->mass_action)(); // AXPY: y <- a*x + y // Jv has ∂F/∂y v + cj ∂F/∂y˙ v const int ns = p_python_functions->number_of_states; - casadi::casadi_axpy(ns, -cj, tmp, NV_DATA(Jv)); + axpy(ns, -cj, tmp, NV_DATA(Jv)); + + DEBUG_VECTORn(Jv, 10); return 0; } @@ -141,22 +149,23 @@ int jtimes_casadi(realtype tt, N_Vector yy, N_Vector yp, N_Vector rr, // tmp3 are pointers to memory allocated for variables of type N Vector which // can // be used by IDALsJacFn function as temporary storage or work space. -int jacobian_casadi(realtype tt, realtype cj, N_Vector yy, N_Vector yp, +template +int jacobian_eval(realtype tt, realtype cj, N_Vector yy, N_Vector yp, N_Vector resvec, SUNMatrix JJ, void *user_data, N_Vector tempv1, N_Vector tempv2, N_Vector tempv3) { - DEBUG("jacobian_casadi"); + DEBUG("jacobian_eval"); - CasadiFunctions *p_python_functions = - static_cast(user_data); + T *p_python_functions = + static_cast(user_data); // create pointer to jac data, column pointers, and row values realtype *jac_data; - if (p_python_functions->options.using_sparse_matrix) + if (p_python_functions->setup_opts.using_sparse_matrix) { jac_data = SUNSparseMatrix_Data(JJ); } - else if (p_python_functions->options.using_banded_matrix) { + else if (p_python_functions->setup_opts.using_banded_matrix) { jac_data = p_python_functions->get_tmp_sparse_jacobian_data(); } else @@ -164,18 +173,25 @@ int jacobian_casadi(realtype tt, realtype cj, N_Vector yy, N_Vector yp, jac_data = SUNDenseMatrix_Data(JJ); } + DEBUG_VECTORn(yy, 100); + // args are t, y, cj, put result in jacobian data matrix - p_python_functions->jac_times_cjmass.m_arg[0] = &tt; - p_python_functions->jac_times_cjmass.m_arg[1] = NV_DATA(yy); - p_python_functions->jac_times_cjmass.m_arg[2] = + p_python_functions->jac_times_cjmass->m_arg[0] = &tt; + p_python_functions->jac_times_cjmass->m_arg[1] = NV_DATA(yy); + p_python_functions->jac_times_cjmass->m_arg[2] = p_python_functions->inputs.data(); - p_python_functions->jac_times_cjmass.m_arg[3] = &cj; - p_python_functions->jac_times_cjmass.m_res[0] = jac_data; - - p_python_functions->jac_times_cjmass(); - - - if (p_python_functions->options.using_banded_matrix) + p_python_functions->jac_times_cjmass->m_arg[3] = &cj; + p_python_functions->jac_times_cjmass->m_res[0] = jac_data; + (*p_python_functions->jac_times_cjmass)(); + + DEBUG("jac_times_cjmass [" << sizeof(jac_data) << "]"); + DEBUG("t = " << tt); + DEBUG_VECTORn(yy, 100); + DEBUG("inputs = " << p_python_functions->inputs); + DEBUG("cj = " << cj); + DEBUG_v(jac_data, 100); + + if (p_python_functions->setup_opts.using_banded_matrix) { // copy data from temporary matrix to the banded matrix auto jac_colptrs = p_python_functions->jac_times_cjmass_colptrs.data(); @@ -191,7 +207,7 @@ int jacobian_casadi(realtype tt, realtype cj, N_Vector yy, N_Vector yp, } } } - else if (p_python_functions->options.using_sparse_matrix) + else if (p_python_functions->setup_opts.using_sparse_matrix) { if (SUNSparseMatrix_SparseType(JJ) == CSC_MAT) { @@ -219,20 +235,12 @@ int jacobian_casadi(realtype tt, realtype cj, N_Vector yy, N_Vector yp, jac_colptrs[i] = p_jac_times_cjmass_colptrs[i]; } } else if (SUNSparseMatrix_SparseType(JJ) == CSR_MAT) { - std::vector newjac(SUNSparseMatrix_NNZ(JJ)); + // make a copy so that we can overwrite jac_data as CSR + std::vector newjac(&jac_data[0], &jac_data[SUNSparseMatrix_NNZ(JJ)]); sunindextype *jac_ptrs = SUNSparseMatrix_IndexPointers(JJ); sunindextype *jac_vals = SUNSparseMatrix_IndexValues(JJ); - // args are t, y, cj, put result in jacobian data matrix - p_python_functions->jac_times_cjmass.m_arg[0] = &tt; - p_python_functions->jac_times_cjmass.m_arg[1] = NV_DATA(yy); - p_python_functions->jac_times_cjmass.m_arg[2] = - p_python_functions->inputs.data(); - p_python_functions->jac_times_cjmass.m_arg[3] = &cj; - p_python_functions->jac_times_cjmass.m_res[0] = newjac.data(); - p_python_functions->jac_times_cjmass(); - - // convert (casadi's) CSC format to CSR + // convert CSC format to CSR csc_csr< std::remove_pointer_tjac_times_cjmass_rowvals.data())>, std::remove_pointer_t @@ -253,18 +261,20 @@ int jacobian_casadi(realtype tt, realtype cj, N_Vector yy, N_Vector yp, return (0); } -int events_casadi(realtype t, N_Vector yy, N_Vector yp, realtype *events_ptr, +template +int events_eval(realtype t, N_Vector yy, N_Vector yp, realtype *events_ptr, void *user_data) { - CasadiFunctions *p_python_functions = - static_cast(user_data); + DEBUG("events_eval"); + T *p_python_functions = + static_cast(user_data); // args are t, y, put result in events_ptr - p_python_functions->events.m_arg[0] = &t; - p_python_functions->events.m_arg[1] = NV_DATA(yy); - p_python_functions->events.m_arg[2] = p_python_functions->inputs.data(); - p_python_functions->events.m_res[0] = events_ptr; - p_python_functions->events(); + p_python_functions->events->m_arg[0] = &t; + p_python_functions->events->m_arg[1] = NV_DATA(yy); + p_python_functions->events->m_arg[2] = p_python_functions->inputs.data(); + p_python_functions->events->m_res[0] = events_ptr; + (*p_python_functions->events)(); return (0); } @@ -290,52 +300,52 @@ int events_casadi(realtype t, N_Vector yy, N_Vector yp, realtype *events_ptr, // occurred (in which case idas will attempt to correct), // or a negative value if it failed unrecoverably (in which case the integration // is halted and IDA SRES FAIL is returned) -// -int sensitivities_casadi(int Ns, realtype t, N_Vector yy, N_Vector yp, +template +int sensitivities_eval(int Ns, realtype t, N_Vector yy, N_Vector yp, N_Vector resval, N_Vector *yS, N_Vector *ypS, N_Vector *resvalS, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) { - DEBUG("sensitivities_casadi"); - CasadiFunctions *p_python_functions = - static_cast(user_data); + DEBUG("sensitivities_eval"); + T *p_python_functions = + static_cast(user_data); const int np = p_python_functions->number_of_parameters; // args are t, y put result in rr - p_python_functions->sens.m_arg[0] = &t; - p_python_functions->sens.m_arg[1] = NV_DATA(yy); - p_python_functions->sens.m_arg[2] = p_python_functions->inputs.data(); + p_python_functions->sens->m_arg[0] = &t; + p_python_functions->sens->m_arg[1] = NV_DATA(yy); + p_python_functions->sens->m_arg[2] = p_python_functions->inputs.data(); for (int i = 0; i < np; i++) { - p_python_functions->sens.m_res[i] = NV_DATA(resvalS[i]); + p_python_functions->sens->m_res[i] = NV_DATA(resvalS[i]); } // resvalsS now has (∂F/∂p i ) - p_python_functions->sens(); + (*p_python_functions->sens)(); for (int i = 0; i < np; i++) { // put (∂F/∂y)s i (t) in tmp realtype *tmp = p_python_functions->get_tmp_state_vector(); - p_python_functions->jac_action.m_arg[0] = &t; - p_python_functions->jac_action.m_arg[1] = NV_DATA(yy); - p_python_functions->jac_action.m_arg[2] = p_python_functions->inputs.data(); - p_python_functions->jac_action.m_arg[3] = NV_DATA(yS[i]); - p_python_functions->jac_action.m_res[0] = tmp; - p_python_functions->jac_action(); + p_python_functions->jac_action->m_arg[0] = &t; + p_python_functions->jac_action->m_arg[1] = NV_DATA(yy); + p_python_functions->jac_action->m_arg[2] = p_python_functions->inputs.data(); + p_python_functions->jac_action->m_arg[3] = NV_DATA(yS[i]); + p_python_functions->jac_action->m_res[0] = tmp; + (*p_python_functions->jac_action)(); const int ns = p_python_functions->number_of_states; - casadi::casadi_axpy(ns, 1., tmp, NV_DATA(resvalS[i])); + axpy(ns, 1., tmp, NV_DATA(resvalS[i])); // put -(∂F/∂ ẏ) ṡ i (t) in tmp2 - p_python_functions->mass_action.m_arg[0] = NV_DATA(ypS[i]); - p_python_functions->mass_action.m_res[0] = tmp; - p_python_functions->mass_action(); + p_python_functions->mass_action->m_arg[0] = NV_DATA(ypS[i]); + p_python_functions->mass_action->m_res[0] = tmp; + (*p_python_functions->mass_action)(); // (∂F/∂y)s i (t)+(∂F/∂ ẏ) ṡ i (t)+(∂F/∂p i ) // AXPY: y <- a*x + y - casadi::casadi_axpy(ns, -1., tmp, NV_DATA(resvalS[i])); + axpy(ns, -1., tmp, NV_DATA(resvalS[i])); } return 0; diff --git a/pybamm/solvers/idaklu_solver.py b/pybamm/solvers/idaklu_solver.py index fef4cbce3c..f6e4d51644 100644 --- a/pybamm/solvers/idaklu_solver.py +++ b/pybamm/solvers/idaklu_solver.py @@ -2,13 +2,25 @@ # Solver class using sundials with the KLU sparse linear solver # # mypy: ignore-errors +import os import casadi import pybamm import numpy as np import numbers import scipy.sparse as sparse +from scipy.linalg import bandwidth import importlib +import warnings + +if pybamm.have_jax(): + import jax + from jax import numpy as jnp + + try: + import iree.compiler + except ImportError: # pragma: no cover + pass idaklu_spec = importlib.util.find_spec("pybamm.solvers.idaklu") if idaklu_spec is not None: @@ -24,6 +36,15 @@ def have_idaklu(): return idaklu_spec is not None +def have_iree(): + try: + import iree.compiler # noqa: F401 + + return True + except ImportError: # pragma: no cover + return False + + class IDAKLUSolver(pybamm.BaseSolver): """ Solve a discretised model, using sundials with the KLU sparse linear solver. @@ -75,6 +96,8 @@ class IDAKLUSolver(pybamm.BaseSolver): "precon_half_bandwidth_keep": 5, # Number of threads available for OpenMP "num_threads": 1, + # Evaluation engine to use for jax, can be 'jax'(native) or 'iree' + "jax_evaluator": "jax", } Note: These options only have an effect if model.convert_to_format == 'casadi' @@ -97,12 +120,36 @@ def __init__( default_options = { "print_stats": False, "jacobian": "sparse", - "linear_solver": "SUNLinSol_KLU", "preconditioner": "BBDP", - "linsol_max_iterations": 5, "precon_half_bandwidth": 5, "precon_half_bandwidth_keep": 5, "num_threads": 1, + "jax_evaluator": "jax", + # IDA main solver + "max_order_bdf": 5, + "max_num_steps": 500, + "dt_init": 0.0, # The solver default is used if this is left at zero + "dt_max": 0.0, # The solver default is used if this is left at zero + "max_error_test_failures": 10, + "max_nonlinear_iterations": 4, + "max_convergence_failures": 10, + "nonlinear_convergence_coefficient": 0.33, + "suppress_algebraic_error": False, + # IDA initial conditions calculation + "nonlinear_convergence_coefficient_ic": 0.0033, + "max_num_steps_ic": 5, + "max_num_jacobians_ic": 4, + "max_num_iterations_ic": 10, + "max_linesearch_backtracks_ic": 100, + "linesearch_off_ic": False, + "init_all_y_ic": False, + "calc_ic": True, + # IDALS linear solver interface + "linear_solver": "SUNLinSol_KLU", + "linsol_max_iterations": 5, + "epsilon_linear_tolerance": 0.05, + "increment_factor": 1.0, + "linear_solution_scaling": True, } if options is None: options = default_options @@ -110,6 +157,10 @@ def __init__( for key, value in default_options.items(): if key not in options: options[key] = value + if options["jax_evaluator"] not in ["jax", "iree"]: + raise pybamm.SolverError( + "Evaluation engine must be 'jax' or 'iree' for IDAKLU solver" + ) self._options = options self.output_variables = [] if output_variables is None else output_variables @@ -183,10 +234,14 @@ def inputs_to_dict(inputs): # only casadi solver needs sensitivity ics if model.convert_to_format != "casadi": y0S = None - if self.output_variables: + if self.output_variables and not ( + model.convert_to_format == "jax" + and self._options["jax_evaluator"] == "iree" + ): raise pybamm.SolverError( "output_variables can only be specified " - 'with convert_to_format="casadi"' + 'with convert_to_format="casadi", or convert_to_format="jax" ' + 'with jax_evaluator="iree"' ) # pragma: no cover if y0S is not None: if isinstance(y0S, casadi.DM): @@ -293,7 +348,7 @@ def resfn(t, y, inputs, ydot): ) ) - else: + elif self._options["jax_evaluator"] == "jax": t0 = 0 if t_eval is None else t_eval[0] jac_y0_t0 = model.jac_rhs_algebraic_eval(t0, y0, inputs_dict) if sparse.issparse(jac_y0_t0): @@ -355,7 +410,7 @@ def get_jac_col_ptrs(self): ) ], ) - else: + elif self._options["jax_evaluator"] == "jax": def rootfn(t, y, inputs): new_inputs = inputs_to_dict(inputs) @@ -437,40 +492,220 @@ def sensfn(resvalS, t, y, inputs, yp, yS, ypS): rtol = self.rtol atol = self._check_atol_type(atol, y0.size) - if model.convert_to_format == "casadi": - rhs_algebraic = idaklu.generate_function(rhs_algebraic.serialize()) - jac_times_cjmass = idaklu.generate_function(jac_times_cjmass.serialize()) - jac_rhs_algebraic_action = idaklu.generate_function( - jac_rhs_algebraic_action.serialize() - ) - rootfn = idaklu.generate_function(rootfn.serialize()) - mass_action = idaklu.generate_function(mass_action.serialize()) - sensfn = idaklu.generate_function(sensfn.serialize()) + if model.convert_to_format == "casadi" or ( + model.convert_to_format == "jax" + and self._options["jax_evaluator"] == "iree" + ): + if model.convert_to_format == "casadi": + # Serialize casadi functions + idaklu_solver_fcn = idaklu.create_casadi_solver + rhs_algebraic = idaklu.generate_function(rhs_algebraic.serialize()) + jac_times_cjmass = idaklu.generate_function( + jac_times_cjmass.serialize() + ) + jac_rhs_algebraic_action = idaklu.generate_function( + jac_rhs_algebraic_action.serialize() + ) + rootfn = idaklu.generate_function(rootfn.serialize()) + mass_action = idaklu.generate_function(mass_action.serialize()) + sensfn = idaklu.generate_function(sensfn.serialize()) + elif ( + model.convert_to_format == "jax" + and self._options["jax_evaluator"] == "iree" + ): + # Convert Jax functions to MLIR (also, demote to single precision) + idaklu_solver_fcn = idaklu.create_iree_solver + pybamm.demote_expressions_to_32bit = True + if pybamm.demote_expressions_to_32bit: + warnings.warn( + "Demoting expressions to 32-bit for MLIR conversion", + stacklevel=2, + ) + jnpfloat = jnp.float32 + else: # pragma: no cover + jnpfloat = jnp.float64 + raise pybamm.SolverError( + "Demoting expressions to 32-bit is required for MLIR conversion" + " at this time" + ) + + # input arguments (used for lowering) + t_eval = self._demote_64_to_32(jnp.array([0.0], dtype=jnpfloat)) + y0 = self._demote_64_to_32(model.y0) + inputs0 = self._demote_64_to_32(inputs_to_dict(inputs)) + cj = self._demote_64_to_32(jnp.array([1.0], dtype=jnpfloat)) # array + v0 = jnp.zeros(model.len_rhs_and_alg, jnpfloat) + mass_matrix = model.mass_matrix.entries.toarray() + mass_matrix_demoted = self._demote_64_to_32(mass_matrix) + + # rhs_algebraic + rhs_algebraic_demoted = model.rhs_algebraic_eval + rhs_algebraic_demoted._demote_constants() + + def fcn_rhs_algebraic(t, y, inputs): + # function wraps an expression tree (and names MLIR module) + return rhs_algebraic_demoted(t, y, inputs) + + rhs_algebraic = self._make_iree_function( + fcn_rhs_algebraic, t_eval, y0, inputs0 + ) + + # jac_times_cjmass + jac_rhs_algebraic_demoted = rhs_algebraic_demoted.get_jacobian() + + def fcn_jac_times_cjmass(t, y, p, cj): + return jac_rhs_algebraic_demoted(t, y, p) - cj * mass_matrix_demoted + + sparse_eval = sparse.csc_matrix( + fcn_jac_times_cjmass(t_eval, y0, inputs0, cj) + ) + jac_times_cjmass_nnz = sparse_eval.nnz + jac_times_cjmass_colptrs = sparse_eval.indptr + jac_times_cjmass_rowvals = sparse_eval.indices + jac_bw_lower, jac_bw_upper = bandwidth( + sparse_eval.todense() + ) # potentially slow + if jac_bw_upper <= 1: + jac_bw_upper = jac_bw_lower - 1 + if jac_bw_lower <= 1: + jac_bw_lower = jac_bw_upper + 1 + coo = sparse_eval.tocoo() # convert to COOrdinate format for indexing + + def fcn_jac_times_cjmass_sparse(t, y, p, cj): + return fcn_jac_times_cjmass(t, y, p, cj)[coo.row, coo.col] + + jac_times_cjmass = self._make_iree_function( + fcn_jac_times_cjmass_sparse, t_eval, y0, inputs0, cj + ) + + # Mass action + def fcn_mass_action(v): + return mass_matrix_demoted @ v + + mass_action_demoted = self._demote_64_to_32(fcn_mass_action) + mass_action = self._make_iree_function(mass_action_demoted, v0) + + # rootfn + for ix, _ in enumerate(model.terminate_events_eval): + model.terminate_events_eval[ix]._demote_constants() + + def fcn_rootfn(t, y, inputs): + return jnp.array( + [event(t, y, inputs) for event in model.terminate_events_eval], + dtype=jnpfloat, + ).reshape(-1) + + def fcn_rootfn_demoted(t, y, inputs): + return self._demote_64_to_32(fcn_rootfn)(t, y, inputs) + + rootfn = self._make_iree_function( + fcn_rootfn_demoted, t_eval, y0, inputs0 + ) + + # jac_rhs_algebraic_action + jac_rhs_algebraic_action_demoted = ( + rhs_algebraic_demoted.get_jacobian_action() + ) + + def fcn_jac_rhs_algebraic_action( + t, y, p, v + ): # sundials calls (t, y, inputs, v) + return jac_rhs_algebraic_action_demoted( + t, y, v, p + ) # jvp calls (t, y, v, inputs) + + jac_rhs_algebraic_action = self._make_iree_function( + fcn_jac_rhs_algebraic_action, t_eval, y0, inputs0, v0 + ) + + # sensfn + if model.jacp_rhs_algebraic_eval is None: + sensfn = idaklu.IREEBaseFunctionType() # empty equation + else: + sensfn_demoted = rhs_algebraic_demoted.get_sensitivities() + + def fcn_sensfn(t, y, p): + return sensfn_demoted(t, y, p) + + sensfn = self._make_iree_function( + fcn_sensfn, t_eval, jnp.zeros_like(y0), inputs0 + ) + + # output_variables + self.var_idaklu_fcns = [] + self.dvar_dy_idaklu_fcns = [] + self.dvar_dp_idaklu_fcns = [] + for key in self.output_variables: + fcn = self.computed_var_fcns[key] + fcn._demote_constants() + self.var_idaklu_fcns.append( + self._make_iree_function( + lambda t, y, p: fcn(t, y, p), # noqa: B023 + t_eval, + y0, + inputs0, + ) + ) + # Convert derivative functions for sensitivities + if (len(inputs) > 0) and (model.calculate_sensitivities): + dvar_dy = fcn.get_jacobian() + self.dvar_dy_idaklu_fcns.append( + self._make_iree_function( + lambda t, y, p: dvar_dy(t, y, p), # noqa: B023 + t_eval, + y0, + inputs0, + sparse_index=True, + ) + ) + dvar_dp = fcn.get_sensitivities() + self.dvar_dp_idaklu_fcns.append( + self._make_iree_function( + lambda t, y, p: dvar_dp(t, y, p), # noqa: B023 + t_eval, + y0, + inputs0, + ) + ) + + # Identify IREE library + iree_lib_path = os.path.join(iree.compiler.__path__[0], "_mlir_libs") + os.environ["IREE_COMPILER_LIB"] = os.path.join( + iree_lib_path, + next(f for f in os.listdir(iree_lib_path) if "IREECompiler" in f), + ) + + pybamm.demote_expressions_to_32bit = False + else: # pragma: no cover + raise pybamm.SolverError( + "Unsupported evaluation engine for convert_to_format='jax'" + ) self._setup = { - "jac_bandwidth_upper": jac_bw_upper, - "jac_bandwidth_lower": jac_bw_lower, - "rhs_algebraic": rhs_algebraic, - "jac_times_cjmass": jac_times_cjmass, - "jac_times_cjmass_colptrs": jac_times_cjmass_colptrs, - "jac_times_cjmass_rowvals": jac_times_cjmass_rowvals, - "jac_times_cjmass_nnz": jac_times_cjmass_nnz, - "jac_rhs_algebraic_action": jac_rhs_algebraic_action, - "mass_action": mass_action, - "sensfn": sensfn, - "rootfn": rootfn, - "num_of_events": num_of_events, - "ids": ids, + "solver_function": idaklu_solver_fcn, # callable + "jac_bandwidth_upper": jac_bw_upper, # int + "jac_bandwidth_lower": jac_bw_lower, # int + "rhs_algebraic": rhs_algebraic, # function + "jac_times_cjmass": jac_times_cjmass, # function + "jac_times_cjmass_colptrs": jac_times_cjmass_colptrs, # array + "jac_times_cjmass_rowvals": jac_times_cjmass_rowvals, # array + "jac_times_cjmass_nnz": jac_times_cjmass_nnz, # int + "jac_rhs_algebraic_action": jac_rhs_algebraic_action, # function + "mass_action": mass_action, # function + "sensfn": sensfn, # function + "rootfn": rootfn, # function + "num_of_events": num_of_events, # int + "ids": ids, # array "sensitivity_names": sensitivity_names, "number_of_sensitivity_parameters": number_of_sensitivity_parameters, "output_variables": self.output_variables, - "var_casadi_fcns": self.computed_var_fcns, + "var_fcns": self.computed_var_fcns, "var_idaklu_fcns": self.var_idaklu_fcns, "dvar_dy_idaklu_fcns": self.dvar_dy_idaklu_fcns, "dvar_dp_idaklu_fcns": self.dvar_dp_idaklu_fcns, } - solver = idaklu.create_casadi_solver( + solver = self._setup["solver_function"]( number_of_states=len(y0), number_of_parameters=self._setup["number_of_sensitivity_parameters"], rhs_alg=self._setup["rhs_algebraic"], @@ -489,7 +724,7 @@ def sensfn(resvalS, t, y, inputs, yp, yS, ypS): atol=atol, rtol=rtol, inputs=len(inputs), - var_casadi_fcns=self._setup["var_idaklu_fcns"], + var_fcns=self._setup["var_idaklu_fcns"], dvar_dy_fcns=self._setup["dvar_dy_idaklu_fcns"], dvar_dp_fcns=self._setup["dvar_dp_idaklu_fcns"], options=self._options, @@ -511,6 +746,56 @@ def sensfn(resvalS, t, y, inputs, yp, yS, ypS): return base_set_up_return + def _make_iree_function(self, fcn, *args, sparse_index=False): + # Initialise IREE function object + iree_fcn = idaklu.IREEBaseFunctionType() + # Get sparsity pattern index outputs as needed + try: + fcn_eval = fcn(*args) + if not isinstance(fcn_eval, np.ndarray): + fcn_eval = jax.flatten_util.ravel_pytree(fcn_eval)[0] + coo = sparse.coo_matrix(fcn_eval) + iree_fcn.nnz = coo.nnz + iree_fcn.numel = np.prod(coo.shape) + iree_fcn.col = coo.col + iree_fcn.row = coo.row + if sparse_index: + # Isolate NNZ elements while recording original sparsity structure + fcn_inner = fcn + + def fcn(*args): + return fcn_inner(*args)[coo.row, coo.col] + elif coo.nnz != iree_fcn.numel: + iree_fcn.nnz = iree_fcn.numel + iree_fcn.col = list(range(iree_fcn.numel)) + iree_fcn.row = [0] * iree_fcn.numel + except (TypeError, AttributeError) as error: # pragma: no cover + raise pybamm.SolverError( + "Could not get sparsity pattern for function {fcn.__name__}" + ) from error + # Lower to MLIR + lowered = jax.jit(fcn).lower(*args) + iree_fcn.mlir = lowered.as_text() + self._check_mlir_conversion(fcn.__name__, iree_fcn.mlir) + iree_fcn.kept_var_idx = list(lowered._lowering.compile_args["kept_var_idx"]) + # Record number of variables in each argument (these will flatten in the mlir) + iree_fcn.pytree_shape = [ + len(jax.tree_util.tree_flatten(arg)[0]) for arg in args + ] + # Record array length of each mlir variable + iree_fcn.pytree_sizes = [ + len(arg) for arg in jax.tree_util.tree_flatten(args)[0] + ] + iree_fcn.n_args = len(args) + return iree_fcn + + def _check_mlir_conversion(self, name, mlir: str): + if mlir.count("f64") > 0: # pragma: no cover + warnings.warn(f"f64 found in {name} (x{mlir.count('f64')})", stacklevel=2) + + def _demote_64_to_32(self, x: pybamm.EvaluatorJax): + return pybamm.EvaluatorJax._demote_64_to_32(x) + def _integrate(self, model, t_eval, inputs_dict=None): """ Solve a DAE model defined by residuals with initial conditions y0. @@ -527,10 +812,12 @@ def _integrate(self, model, t_eval, inputs_dict=None): inputs_dict = inputs_dict or {} # stack inputs if inputs_dict: + inputs_dict_keys = list(inputs_dict.keys()) # save order arrays_to_stack = [np.array(x).reshape(-1, 1) for x in inputs_dict.values()] inputs = np.vstack(arrays_to_stack) else: inputs = np.array([[]]) + inputs_dict_keys = [] # do this here cause y0 is set after set_up (calc consistent conditions) y0 = model.y0 @@ -539,25 +826,45 @@ def _integrate(self, model, t_eval, inputs_dict=None): y0 = y0.flatten() y0S = model.y0S - # only casadi solver needs sensitivity ics - if model.convert_to_format != "casadi": - y0S = None - if y0S is not None: - if isinstance(y0S, casadi.DM): - y0S = (y0S,) - - y0S = (x.full() for x in y0S) - y0S = [x.flatten() for x in y0S] - - # solver works with ydot0 set to zero - ydot0 = np.zeros_like(y0) - if y0S is not None: - ydot0S = [np.zeros_like(y0S_i) for y0S_i in y0S] - y0full = np.concatenate([y0, *y0S]) - ydot0full = np.concatenate([ydot0, *ydot0S]) + if ( + model.convert_to_format == "jax" + and self._options["jax_evaluator"] == "iree" + ): + if y0S is not None: + pybamm.demote_expressions_to_32bit = True + # preserve order of inputs + y0S = self._demote_64_to_32( + np.concatenate([y0S[k] for k in inputs_dict_keys]).flatten() + ) + y0full = self._demote_64_to_32(np.concatenate([y0, y0S]).flatten()) + ydot0S = self._demote_64_to_32(np.zeros_like(y0S)) + ydot0full = self._demote_64_to_32( + np.concatenate([np.zeros_like(y0), ydot0S]).flatten() + ) + pybamm.demote_expressions_to_32bit = False + else: + y0full = y0 + ydot0full = np.zeros_like(y0) else: - y0full = y0 - ydot0full = ydot0 + # only casadi solver needs sensitivity ics + if model.convert_to_format != "casadi": + y0S = None + if y0S is not None: + if isinstance(y0S, casadi.DM): + y0S = (y0S,) + + y0S = (x.full() for x in y0S) + y0S = [x.flatten() for x in y0S] + + # solver works with ydot0 set to zero + ydot0 = np.zeros_like(y0) + if y0S is not None: + ydot0S = [np.zeros_like(y0S_i) for y0S_i in y0S] + y0full = np.concatenate([y0, *y0S]) + ydot0full = np.concatenate([ydot0, *ydot0S]) + else: + y0full = y0 + ydot0full = ydot0 try: atol = model.atol @@ -568,7 +875,10 @@ def _integrate(self, model, t_eval, inputs_dict=None): atol = self._check_atol_type(atol, y0.size) timer = pybamm.Timer() - if model.convert_to_format == "casadi": + if model.convert_to_format == "casadi" or ( + model.convert_to_format == "jax" + and self._options["jax_evaluator"] == "iree" + ): sol = self._setup["solver"].solve( t_eval, y0full, @@ -625,58 +935,71 @@ def _integrate(self, model, t_eval, inputs_dict=None): else: yS_out = False - if sol.flag in [0, 2]: - # 0 = solved for all t_eval - if sol.flag == 0: - termination = "final time" - # 2 = found root(s) - elif sol.flag == 2: - termination = "event" - - newsol = pybamm.Solution( - sol.t, - np.transpose(y_out), - model, - inputs_dict, - np.array([t[-1]]), - np.transpose(y_out[-1])[:, np.newaxis], - termination, - sensitivities=yS_out, - ) - newsol.integration_time = integration_time - if self.output_variables: - # Populate variables and sensititivies dictionaries directly - number_of_samples = sol.y.shape[0] // number_of_timesteps - sol.y = sol.y.reshape((number_of_timesteps, number_of_samples)) - startk = 0 - for _, var in enumerate(self.output_variables): - # ExplicitTimeIntegral's are not computed as part of the solver and - # do not need to be converted - if isinstance( - model.variables_and_events[var], pybamm.ExplicitTimeIntegral - ): - continue + # 0 = solved for all t_eval + if sol.flag == 0: + termination = "final time" + # 2 = found root(s) + elif sol.flag == 2: + termination = "event" + else: + raise pybamm.SolverError("idaklu solver failed") + newsol = pybamm.Solution( + sol.t, + np.transpose(y_out), + model, + inputs_dict, + np.array([t[-1]]), + np.transpose(y_out[-1])[:, np.newaxis], + termination, + sensitivities=yS_out, + ) + newsol.integration_time = integration_time + if self.output_variables: + # Populate variables and sensititivies dictionaries directly + number_of_samples = sol.y.shape[0] // number_of_timesteps + sol.y = sol.y.reshape((number_of_timesteps, number_of_samples)) + startk = 0 + for var in self.output_variables: + # ExplicitTimeIntegral's are not computed as part of the solver and + # do not need to be converted + if isinstance( + model.variables_and_events[var], pybamm.ExplicitTimeIntegral + ): + continue + if model.convert_to_format == "casadi": len_of_var = ( - self._setup["var_casadi_fcns"][var](0, 0, 0).sparsity().nnz() + self._setup["var_fcns"][var](0.0, 0.0, 0.0).sparsity().nnz() ) - newsol._variables[var] = pybamm.ProcessedVariableComputed( - [model.variables_and_events[var]], - [self._setup["var_casadi_fcns"][var]], - [sol.y[:, startk : (startk + len_of_var)]], - newsol, + base_variables = [self._setup["var_fcns"][var]] + elif ( + model.convert_to_format == "jax" + and self._options["jax_evaluator"] == "iree" + ): + idx = self.output_variables.index(var) + len_of_var = self._setup["var_idaklu_fcns"][idx].nnz + base_variables = [self._setup["var_idaklu_fcns"][idx]] + else: # pragma: no cover + raise pybamm.SolverError( + "Unsupported evaluation engine for convert_to_format=" + + f"{model.convert_to_format} " + + f"(jax_evaluator={self._options['jax_evaluator']})" ) - # Add sensitivities - newsol[var]._sensitivities = {} - if model.calculate_sensitivities: - for paramk, param in enumerate(inputs_dict.keys()): - newsol[var].add_sensitivity( - param, - [sol.yS[:, startk : (startk + len_of_var), paramk]], - ) - startk += len_of_var - return newsol - else: - raise pybamm.SolverError("idaklu solver failed") + newsol._variables[var] = pybamm.ProcessedVariableComputed( + [model.variables_and_events[var]], + base_variables, + [sol.y[:, startk : (startk + len_of_var)]], + newsol, + ) + # Add sensitivities + newsol[var]._sensitivities = {} + if model.calculate_sensitivities: + for paramk, param in enumerate(inputs_dict.keys()): + newsol[var].add_sensitivity( + param, + [sol.yS[:, startk : (startk + len_of_var), paramk]], + ) + startk += len_of_var + return newsol def jaxify( self, diff --git a/pybamm/solvers/processed_variable_computed.py b/pybamm/solvers/processed_variable_computed.py index a069342254..a717c8b0cb 100644 --- a/pybamm/solvers/processed_variable_computed.py +++ b/pybamm/solvers/processed_variable_computed.py @@ -120,16 +120,25 @@ def _unroll_nnz(self, realdata=None): # unroll in nnz != numel, otherwise copy if realdata is None: realdata = self.base_variables_data - sp = self.base_variables_casadi[0](0, 0, 0).sparsity() - if sp.nnz() != sp.numel(): + if isinstance(self.base_variables_casadi[0], casadi.Function): # casadi fcn + sp = self.base_variables_casadi[0](0, 0, 0).sparsity() + nnz = sp.nnz() + numel = sp.numel() + row = sp.row() + elif "nnz" in dir(self.base_variables_casadi[0]): # IREE fcn + sp = self.base_variables_casadi[0] + nnz = sp.nnz + numel = sp.numel + row = sp.row + if nnz != numel: data = [None] * len(realdata) for datak in range(len(realdata)): data[datak] = np.zeros(self.base_eval_shape[0] * len(self.t_pts)) var_data = realdata[0].flatten() k = 0 for t_i in range(len(self.t_pts)): - base = t_i * sp.numel() - for r in sp.row(): + base = t_i * numel + for r in row: data[datak][base + r] = var_data[k] k = k + 1 else: diff --git a/pybamm/version.py b/pybamm/version.py index bc0f2f5d12..4c1e268285 100644 --- a/pybamm/version.py +++ b/pybamm/version.py @@ -1 +1 @@ -__version__ = "24.5rc0" +__version__ = "24.5rc2" diff --git a/pyproject.toml b/pyproject.toml index 890f884769..2bc6f4f3d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta" [project] name = "pybamm" -version = "24.5rc0" +version = "24.5rc2" license = { file = "LICENSE.txt" } description = "Python Battery Mathematical Modelling" authors = [{name = "The PyBaMM Team", email = "pybamm@pybamm.org"}] @@ -116,12 +116,19 @@ dev = [ # To access the metadata for python packages "importlib-metadata; python_version < '3.10'", ] -# For the Jax solver. Note: these must be kept in sync with the versions defined in pybamm/util.py. +# For the Jax solver. +# Note: These must be kept in sync with the versions defined in pybamm/util.py, and +# must remain compatible with IREE (see noxfile.py for IREE compatibility). jax = [ "jax==0.4.27", "jaxlib==0.4.27", ] -# Contains all optional dependencies, except for jax and dev dependencies +# For MLIR expression evaluation (IDAKLU Solver) +iree = [ + # must be pip installed with --find-links=https://iree.dev/pip-release-links.html + "iree-compiler==20240507.886", # see IREE compatibility notes in noxfile.py +] +# Contains all optional dependencies, except for jax, iree, and dev dependencies all = [ "scikit-fem>=8.1.0", "pybamm[examples,plot,cite,bpx,tqdm]", @@ -193,6 +200,7 @@ extend-select = [ "UP", # pyupgrade "YTT", # flake8-2020 "TID252", # relative-imports + "S101", # to identify use of assert statement ] ignore = [ "E741", # Ambiguous variable name @@ -213,7 +221,7 @@ ignore = [ ] [tool.ruff.lint.per-file-ignores] -"tests/*" = ["T20"] +"tests/*" = ["T20", "S101"] "docs/*" = ["T20"] "examples/*" = ["T20"] "**.ipynb" = ["E402", "E703"] diff --git a/setup.py b/setup.py index 6b97f73058..21dabcebb2 100644 --- a/setup.py +++ b/setup.py @@ -92,10 +92,14 @@ def run(self): use_python_casadi = True build_type = os.getenv("PYBAMM_CPP_BUILD_TYPE", "RELEASE") + idaklu_expr_casadi = os.getenv("PYBAMM_IDAKLU_EXPR_CASADI", "ON") + idaklu_expr_iree = os.getenv("PYBAMM_IDAKLU_EXPR_IREE", "OFF") cmake_args = [ f"-DCMAKE_BUILD_TYPE={build_type}", f"-DPYTHON_EXECUTABLE={sys.executable}", "-DUSE_PYTHON_CASADI={}".format("TRUE" if use_python_casadi else "FALSE"), + f"-DPYBAMM_IDAKLU_EXPR_CASADI={idaklu_expr_casadi}", + f"-DPYBAMM_IDAKLU_EXPR_IREE={idaklu_expr_iree}", ] if self.suitesparse_root: cmake_args.append( @@ -291,27 +295,39 @@ def compile_KLU(): name="pybamm.solvers.idaklu", # The sources list should mirror the list in CMakeLists.txt sources=[ - "pybamm/solvers/c_solvers/idaklu/casadi_functions.cpp", - "pybamm/solvers/c_solvers/idaklu/casadi_functions.hpp", - "pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp", - "pybamm/solvers/c_solvers/idaklu/casadi_solver.hpp", - "pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp", - "pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp", - "pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp", - "pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.hpp", - "pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp", - "pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.hpp", - "pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.cpp", - "pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.hpp", - "pybamm/solvers/c_solvers/idaklu/idaklu_jax.cpp", - "pybamm/solvers/c_solvers/idaklu/idaklu_jax.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/Expressions.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/Base/Expression.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionSet.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionTypes.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionSparsity.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/Casadi/CasadiFunctions.cpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/Casadi/CasadiFunctions.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEBaseFunction.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunction.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.cpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.cpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.cpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.hpp", + "pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp", + "pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.cpp", + "pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp", + "pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl", + "pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp", + "pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.cpp", + "pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.hpp", + "pybamm/solvers/c_solvers/idaklu/sundials_functions.inl", + "pybamm/solvers/c_solvers/idaklu/sundials_functions.hpp", + "pybamm/solvers/c_solvers/idaklu/IdakluJax.cpp", + "pybamm/solvers/c_solvers/idaklu/IdakluJax.hpp", "pybamm/solvers/c_solvers/idaklu/common.hpp", "pybamm/solvers/c_solvers/idaklu/python.hpp", "pybamm/solvers/c_solvers/idaklu/python.cpp", - "pybamm/solvers/c_solvers/idaklu/solution.cpp", - "pybamm/solvers/c_solvers/idaklu/solution.hpp", - "pybamm/solvers/c_solvers/idaklu/options.hpp", - "pybamm/solvers/c_solvers/idaklu/options.cpp", + "pybamm/solvers/c_solvers/idaklu/Solution.cpp", + "pybamm/solvers/c_solvers/idaklu/Solution.hpp", + "pybamm/solvers/c_solvers/idaklu/Options.hpp", + "pybamm/solvers/c_solvers/idaklu/Options.cpp", "pybamm/solvers/c_solvers/idaklu.cpp", ], ) diff --git a/tests/integration/test_models/standard_model_tests.py b/tests/integration/test_models/standard_model_tests.py index 7f0e9e6137..3f9cb56354 100644 --- a/tests/integration/test_models/standard_model_tests.py +++ b/tests/integration/test_models/standard_model_tests.py @@ -3,10 +3,9 @@ # import pybamm import tests -import uuid +import tempfile import numpy as np -import os class StandardModelTest: @@ -141,9 +140,8 @@ def test_sensitivities(self, param_name, param_value, output_name="Voltage [V]") ) def test_serialisation(self, solver=None, t_eval=None): - # Generating unique file names to avoid race conditions when run in parallel. - unique_id = uuid.uuid4() - file_name = f"test_model_{unique_id}" + temp = tempfile.NamedTemporaryFile(prefix="test_model") + file_name = temp.name self.model.save_model( file_name, variables=self.model.variables, mesh=self.disc.mesh ) @@ -178,8 +176,7 @@ def test_serialisation(self, solver=None, t_eval=None): np.testing.assert_array_almost_equal( new_solution.all_ys[x], self.solution.all_ys[x], decimal=accuracy ) - - os.remove(file_name + ".json") + temp.close() def test_all( self, param=None, disc=None, solver=None, t_eval=None, skip_output_tests=False diff --git a/tests/unit/test_doc_utils.py b/tests/unit/test_doc_utils.py index a7a4a1e5d5..8e8a626535 100644 --- a/tests/unit/test_doc_utils.py +++ b/tests/unit/test_doc_utils.py @@ -3,14 +3,11 @@ # is generated, but rather that the docstrings are correctly modified # -import pybamm -import unittest -from tests import TestCase from inspect import getmro from pybamm.doc_utils import copy_parameter_doc_from_parent, doc_extend_parent -class TestDocUtils(TestCase): +class TestDocUtils: def test_copy_parameter_doc_from_parent(self): """Test if parameters from the parent class are copied to child class docstring""" @@ -38,7 +35,7 @@ def __init__(self, foo, bar): base_parameters = "".join(Base.__doc__.partition("Parameters")[1:]) derived_parameters = "".join(Derived.__doc__.partition("Parameters")[1:]) # check that the parameters section is in the docstring - self.assertMultiLineEqual(base_parameters, derived_parameters) + assert base_parameters == derived_parameters def test_doc_extend_parent(self): """Test if the child class has the Extends directive in its docstring""" @@ -57,21 +54,11 @@ def __init__(self, param): super().__init__(param) # check that the Extends directive is in the docstring - self.assertIn("**Extends:**", Derived.__doc__) + assert "**Extends:**" in Derived.__doc__ # check that the Extends directive maps to the correct base class base_cls_name = f"{getmro(Derived)[1].__module__}.{getmro(Derived)[1].__name__}" - self.assertEqual( - Derived.__doc__.partition("**Extends:**")[2].strip(), - f":class:`{base_cls_name}`", + assert ( + Derived.__doc__.partition("**Extends:**")[2].strip() + == f":class:`{base_cls_name}`" ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_expression_tree/test_array.py b/tests/unit/test_expression_tree/test_array.py index b75c313f47..ffd29baa7e 100644 --- a/tests/unit/test_expression_tree/test_array.py +++ b/tests/unit/test_expression_tree/test_array.py @@ -1,20 +1,17 @@ # # Tests for the Array class # -from tests import TestCase -import unittest -import unittest.mock as mock + import numpy as np import sympy - import pybamm -class TestArray(TestCase): +class TestArray: def test_name(self): arr = pybamm.Array(np.array([1, 2, 3])) - self.assertEqual(arr.name, "Array of shape (3, 1)") + assert arr.name == "Array of shape (3, 1)" def test_list_entries(self): vect = pybamm.Array([1, 2, 3]) @@ -38,16 +35,14 @@ def test_meshgrid(self): np.testing.assert_array_equal(B, D.entries) def test_to_equation(self): - self.assertEqual( - pybamm.Array([1, 2]).to_equation(), sympy.Array([[1.0], [2.0]]) - ) + assert pybamm.Array([1, 2]).to_equation() == sympy.Array([[1.0], [2.0]]) - def test_to_from_json(self): + def test_to_from_json(self, mocker): arr = pybamm.Array(np.array([1, 2, 3])) json_dict = { "name": "Array of shape (3, 1)", - "id": mock.ANY, + "id": mocker.ANY, "domains": { "primary": [], "secondary": [], @@ -59,17 +54,7 @@ def test_to_from_json(self): # array to json conversion created_json = arr.to_json() - self.assertEqual(created_json, json_dict) + assert created_json == json_dict # json to array conversion - self.assertEqual(pybamm.Array._from_json(created_json), arr) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert pybamm.Array._from_json(created_json) == arr diff --git a/tests/unit/test_expression_tree/test_d_dt.py b/tests/unit/test_expression_tree/test_d_dt.py index b5632f9f64..38d7e20e13 100644 --- a/tests/unit/test_expression_tree/test_d_dt.py +++ b/tests/unit/test_expression_tree/test_d_dt.py @@ -1,43 +1,42 @@ # # Tests for the Scalar class # -from tests import TestCase +import pytest import pybamm -import unittest import numpy as np -class TestDDT(TestCase): +class TestDDT: def test_time_derivative(self): a = pybamm.Scalar(5).diff(pybamm.t) - self.assertIsInstance(a, pybamm.Scalar) - self.assertEqual(a.value, 0) + assert isinstance(a, pybamm.Scalar) + assert a.value == 0 a = pybamm.t.diff(pybamm.t) - self.assertIsInstance(a, pybamm.Scalar) - self.assertEqual(a.value, 1) + assert isinstance(a, pybamm.Scalar) + assert a.value == 1 a = (pybamm.t**2).diff(pybamm.t) - self.assertEqual(a, (2 * pybamm.t**1 * 1)) - self.assertEqual(a.evaluate(t=1), 2) + assert a == (2 * pybamm.t**1 * 1) + assert a.evaluate(t=1) == 2 a = (2 + pybamm.t**2).diff(pybamm.t) - self.assertEqual(a.evaluate(t=1), 2) + assert a.evaluate(t=1) == 2 def test_time_derivative_of_variable(self): a = (pybamm.Variable("a")).diff(pybamm.t) - self.assertIsInstance(a, pybamm.VariableDot) - self.assertEqual(a.name, "a'") + assert isinstance(a, pybamm.VariableDot) + assert a.name == "a'" p = pybamm.Parameter("p") a = 1 + p * pybamm.Variable("a") diff_a = a.diff(pybamm.t) - self.assertIsInstance(diff_a, pybamm.Multiplication) - self.assertEqual(diff_a.children[0].name, "p") - self.assertEqual(diff_a.children[1].name, "a'") + assert isinstance(diff_a, pybamm.Multiplication) + assert diff_a.children[0].name == "p" + assert diff_a.children[1].name == "a'" - with self.assertRaises(pybamm.ModelError): + with pytest.raises(pybamm.ModelError): a = (pybamm.Variable("a")).diff(pybamm.t).diff(pybamm.t) def test_time_derivative_of_state_vector(self): @@ -45,21 +44,11 @@ def test_time_derivative_of_state_vector(self): y_dot = np.linspace(0, 2, 19) a = sv.diff(pybamm.t) - self.assertIsInstance(a, pybamm.StateVectorDot) - self.assertEqual(a.name[-1], "'") + assert isinstance(a, pybamm.StateVectorDot) + assert a.name[-1] == "'" np.testing.assert_array_equal( a.evaluate(y_dot=y_dot), np.linspace(0, 1, 10)[:, np.newaxis] ) - with self.assertRaises(pybamm.ModelError): + with pytest.raises(pybamm.ModelError): a = (sv).diff(pybamm.t).diff(pybamm.t) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_expression_tree/test_independent_variable.py b/tests/unit/test_expression_tree/test_independent_variable.py index f747b60d40..79c5ab9ea2 100644 --- a/tests/unit/test_expression_tree/test_independent_variable.py +++ b/tests/unit/test_expression_tree/test_independent_variable.py @@ -1,87 +1,73 @@ # # Tests for the Parameter class # -from tests import TestCase -import unittest - +import pytest import pybamm import sympy -class TestIndependentVariable(TestCase): +class TestIndependentVariable: def test_variable_init(self): a = pybamm.IndependentVariable("a") - self.assertEqual(a.name, "a") - self.assertEqual(a.domain, []) + assert a.name == "a" + assert a.domain == [] a = pybamm.IndependentVariable("a", domain=["test"]) - self.assertEqual(a.domain[0], "test") + assert a.domain[0] == "test" a = pybamm.IndependentVariable("a", domain="test") - self.assertEqual(a.domain[0], "test") - with self.assertRaises(TypeError): + assert a.domain[0] == "test" + with pytest.raises(TypeError): pybamm.IndependentVariable("a", domain=1) def test_time(self): t = pybamm.Time() - self.assertEqual(t.name, "time") - self.assertEqual(t.evaluate(4), 4) - with self.assertRaises(ValueError): + assert t.name == "time" + assert t.evaluate(4) == 4 + with pytest.raises(ValueError): t.evaluate(None) t = pybamm.t - self.assertEqual(t.name, "time") - self.assertEqual(t.evaluate(4), 4) - with self.assertRaises(ValueError): + assert t.name == "time" + assert t.evaluate(4) == 4 + with pytest.raises(ValueError): t.evaluate(None) - self.assertEqual(t.evaluate_for_shape(), 0) + assert t.evaluate_for_shape() == 0 def test_spatial_variable(self): x = pybamm.SpatialVariable("x", "negative electrode") - self.assertEqual(x.name, "x") - self.assertFalse(x.evaluates_on_edges("primary")) + assert x.name == "x" + assert not x.evaluates_on_edges("primary") y = pybamm.SpatialVariable("y", "separator") - self.assertEqual(y.name, "y") + assert y.name == "y" z = pybamm.SpatialVariable("z", "positive electrode") - self.assertEqual(z.name, "z") + assert z.name == "z" r = pybamm.SpatialVariable("r", "negative particle") - self.assertEqual(r.name, "r") - with self.assertRaises(NotImplementedError): + assert r.name == "r" + with pytest.raises(NotImplementedError): x.evaluate() - with self.assertRaisesRegex(ValueError, "domain must be"): + with pytest.raises(ValueError, match="domain must be"): pybamm.SpatialVariable("x", []) - with self.assertRaises(pybamm.DomainError): + with pytest.raises(pybamm.DomainError): pybamm.SpatialVariable("r_n", ["positive particle"]) - with self.assertRaises(pybamm.DomainError): + with pytest.raises(pybamm.DomainError): pybamm.SpatialVariable("r_p", ["negative particle"]) - with self.assertRaises(pybamm.DomainError): + with pytest.raises(pybamm.DomainError): pybamm.SpatialVariable("x", ["negative particle"]) def test_spatial_variable_edge(self): x = pybamm.SpatialVariableEdge("x", "negative electrode") - self.assertEqual(x.name, "x") - self.assertTrue(x.evaluates_on_edges("primary")) + assert x.name == "x" + assert x.evaluates_on_edges("primary") def test_to_equation(self): # Test print_name func = pybamm.IndependentVariable("a") func.print_name = "test" - self.assertEqual(func.to_equation(), sympy.Symbol("test")) + assert func.to_equation() == sympy.Symbol("test") - self.assertEqual( - pybamm.IndependentVariable("a").to_equation(), sympy.Symbol("a") - ) + assert pybamm.IndependentVariable("a").to_equation() == sympy.Symbol("a") # Test time - self.assertEqual(pybamm.t.to_equation(), sympy.Symbol("t")) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert pybamm.t.to_equation() == sympy.Symbol("t") diff --git a/tests/unit/test_expression_tree/test_input_parameter.py b/tests/unit/test_expression_tree/test_input_parameter.py index a5fc79f2e2..87cbe79a31 100644 --- a/tests/unit/test_expression_tree/test_input_parameter.py +++ b/tests/unit/test_expression_tree/test_input_parameter.py @@ -1,54 +1,52 @@ # # Tests for the InputParameter class # -from tests import TestCase import numpy as np import pybamm -import unittest - +import pytest import unittest.mock as mock -class TestInputParameter(TestCase): +class TestInputParameter: def test_input_parameter_init(self): a = pybamm.InputParameter("a") - self.assertEqual(a.name, "a") - self.assertEqual(a.evaluate(inputs={"a": 1}), 1) - self.assertEqual(a.evaluate(inputs={"a": 5}), 5) + assert a.name == "a" + assert a.evaluate(inputs={"a": 1}) == 1 + assert a.evaluate(inputs={"a": 5}) == 5 a = pybamm.InputParameter("a", expected_size=10) - self.assertEqual(a._expected_size, 10) + assert a._expected_size == 10 np.testing.assert_array_equal( a.evaluate(inputs="shape test"), np.nan * np.ones((10, 1)) ) y = np.linspace(0, 1, 10) np.testing.assert_array_equal(a.evaluate(inputs={"a": y}), y[:, np.newaxis]) - with self.assertRaisesRegex( + with pytest.raises( ValueError, - "Input parameter 'a' was given an object of size '1' but was expecting an " + match="Input parameter 'a' was given an object of size '1' but was expecting an " "object of size '10'", ): a.evaluate(inputs={"a": 5}) def test_evaluate_for_shape(self): a = pybamm.InputParameter("a") - self.assertTrue(np.isnan(a.evaluate_for_shape())) - self.assertEqual(a.shape, ()) + assert np.isnan(a.evaluate_for_shape()) + assert a.shape == () a = pybamm.InputParameter("a", expected_size=10) - self.assertEqual(a.shape, (10, 1)) + assert a.shape == (10, 1) np.testing.assert_equal(a.evaluate_for_shape(), np.nan * np.ones((10, 1))) - self.assertEqual(a.evaluate_for_shape().shape, (10, 1)) + assert a.evaluate_for_shape().shape == (10, 1) def test_errors(self): a = pybamm.InputParameter("a") - with self.assertRaises(TypeError): + with pytest.raises(TypeError): a.evaluate(inputs="not a dictionary") - with self.assertRaises(KeyError): + with pytest.raises(KeyError): a.evaluate(inputs={"bad param": 5}) # if u is not provided it gets turned into a dictionary and then raises KeyError - with self.assertRaises(KeyError): + with pytest.raises(KeyError): a.evaluate() def test_to_from_json(self): @@ -62,17 +60,7 @@ def test_to_from_json(self): } # to_json - self.assertEqual(a.to_json(), json_dict) + assert a.to_json() == json_dict # from_json - self.assertEqual(pybamm.InputParameter._from_json(json_dict), a) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert pybamm.InputParameter._from_json(json_dict) == a diff --git a/tests/unit/test_expression_tree/test_matrix.py b/tests/unit/test_expression_tree/test_matrix.py index 055902b15e..d34af0d83f 100644 --- a/tests/unit/test_expression_tree/test_matrix.py +++ b/tests/unit/test_expression_tree/test_matrix.py @@ -1,26 +1,22 @@ # # Tests for the Matrix class # -from tests import TestCase import pybamm import numpy as np from scipy.sparse import csr_matrix -import unittest -import unittest.mock as mock - -class TestMatrix(TestCase): - def setUp(self): +class TestMatrix: + def setup_method(self): self.A = np.array([[1, 2, 0], [0, 1, 0], [0, 0, 1]]) self.x = np.array([1, 2, 3]) self.mat = pybamm.Matrix(self.A) self.vect = pybamm.Vector(self.x) def test_array_wrapper(self): - self.assertEqual(self.mat.ndim, 2) - self.assertEqual(self.mat.shape, (3, 3)) - self.assertEqual(self.mat.size, 9) + assert self.mat.ndim == 2 + assert self.mat.shape == (3, 3) + assert self.mat.size == 9 def test_list_entry(self): mat = pybamm.Matrix([[1, 2, 0], [0, 1, 0], [0, 0, 1]]) @@ -40,11 +36,11 @@ def test_matrix_operations(self): (self.mat @ self.vect).evaluate(), np.array([[5], [2], [3]]) ) - def test_to_from_json(self): + def test_to_from_json(self, mocker): arr = pybamm.Matrix(csr_matrix([[0, 1, 0, 0], [0, 0, 0, 1]])) json_dict = { "name": "Sparse Matrix (2, 4)", - "id": mock.ANY, + "id": mocker.ANY, "domains": { "primary": [], "secondary": [], @@ -59,16 +55,6 @@ def test_to_from_json(self): }, } - self.assertEqual(arr.to_json(), json_dict) - - self.assertEqual(pybamm.Matrix._from_json(json_dict), arr) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys + assert arr.to_json() == json_dict - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert pybamm.Matrix._from_json(json_dict) == arr diff --git a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py index b02c75f386..6e1b155eca 100644 --- a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py +++ b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py @@ -80,7 +80,7 @@ def test_find_symbols(self): # test values of variable_symbols self.assertEqual(next(iter(variable_symbols.values())), "y[0:1]") self.assertEqual(list(variable_symbols.values())[1], "y[1:2]") - self.assertEqual(list(variable_symbols.values())[2], f"-{var_b}") + self.assertEqual(list(variable_symbols.values())[2], f"-({var_b})") var_child = pybamm.id_to_python_variable(expr.children[1].id) self.assertEqual( list(variable_symbols.values())[3], f"np.maximum({var_a},{var_child})" @@ -674,6 +674,76 @@ def test_evaluator_jax_inputs(self): result = evaluator(inputs={"a": 2}) self.assertEqual(result, 4) + @unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed") + def test_evaluator_jax_demotion(self): + for demote in [True, False]: + pybamm.demote_expressions_to_32bit = demote # global flag + target_dtype = "32" if demote else "64" + if demote: + # Test only works after conversion to jax.numpy + for c in [ + 1.0, + 1, + ]: + self.assertEqual( + str(pybamm.EvaluatorJax._demote_64_to_32(c).dtype)[-2:], + target_dtype, + ) + for c in [ + np.float64(1.0), + np.int64(1), + np.array([1.0], dtype=np.float64), + np.array([1], dtype=np.int64), + jax.numpy.array([1.0], dtype=np.float64), + jax.numpy.array([1], dtype=np.int64), + ]: + self.assertEqual( + str(pybamm.EvaluatorJax._demote_64_to_32(c).dtype)[-2:], + target_dtype, + ) + for c in [ + {key: np.float64(1.0) for key in ["a", "b"]}, + ]: + expr_demoted = pybamm.EvaluatorJax._demote_64_to_32(c) + self.assertTrue( + all( + str(c_v.dtype)[-2:] == target_dtype + for c_k, c_v in expr_demoted.items() + ) + ) + for c in [ + (np.float64(1.0), np.float64(2.0)), + [np.float64(1.0), np.float64(2.0)], + ]: + expr_demoted = pybamm.EvaluatorJax._demote_64_to_32(c) + self.assertTrue( + all(str(c_i.dtype)[-2:] == target_dtype for c_i in expr_demoted) + ) + for dtype in [ + np.float64, + jax.numpy.float64, + ]: + c = pybamm.JaxCooMatrix([0, 1], [0, 1], dtype([1.0, 2.0]), (2, 2)) + c_demoted = pybamm.EvaluatorJax._demote_64_to_32(c) + self.assertTrue( + all(str(c_i.dtype)[-2:] == target_dtype for c_i in c_demoted.data) + ) + for dtype in [ + np.int64, + jax.numpy.int64, + ]: + c = pybamm.JaxCooMatrix( + dtype([0, 1]), dtype([0, 1]), [1.0, 2.0], (2, 2) + ) + c_demoted = pybamm.EvaluatorJax._demote_64_to_32(c) + self.assertTrue( + all(str(c_i.dtype)[-2:] == target_dtype for c_i in c_demoted.row) + ) + self.assertTrue( + all(str(c_i.dtype)[-2:] == target_dtype for c_i in c_demoted.col) + ) + pybamm.demote_expressions_to_32bit = False + @unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed") def test_jax_coo_matrix(self): A = pybamm.JaxCooMatrix([0, 1], [0, 1], [1.0, 2.0], (2, 2)) diff --git a/tests/unit/test_expression_tree/test_operations/test_unpack_symbols.py b/tests/unit/test_expression_tree/test_operations/test_unpack_symbols.py index ce669212ae..6736094288 100644 --- a/tests/unit/test_expression_tree/test_operations/test_unpack_symbols.py +++ b/tests/unit/test_expression_tree/test_operations/test_unpack_symbols.py @@ -1,27 +1,25 @@ # # Tests for the symbol unpacker # -from tests import TestCase import pybamm -import unittest -class TestSymbolUnpacker(TestCase): +class TestSymbolUnpacker: def test_basic_symbols(self): a = pybamm.Scalar(1) unpacker = pybamm.SymbolUnpacker(pybamm.Scalar) unpacked = unpacker.unpack_symbol(a) - self.assertEqual(unpacked, set([a])) + assert unpacked == set([a]) b = pybamm.Parameter("b") unpacker_param = pybamm.SymbolUnpacker(pybamm.Parameter) unpacked = unpacker_param.unpack_symbol(a) - self.assertEqual(unpacked, set()) + assert unpacked == set() unpacked = unpacker_param.unpack_symbol(b) - self.assertEqual(unpacked, set([b])) + assert unpacked == set([b]) def test_binary(self): a = pybamm.Scalar(1) @@ -29,11 +27,11 @@ def test_binary(self): unpacker = pybamm.SymbolUnpacker(pybamm.Scalar) unpacked = unpacker.unpack_symbol(a + b) - self.assertEqual(unpacked, set([a])) + assert unpacked == set([a]) unpacker_param = pybamm.SymbolUnpacker(pybamm.Parameter) unpacked = unpacker_param.unpack_symbol(a + b) - self.assertEqual(unpacked, set([b])) + assert unpacked == set([b]) def test_unpack_list_of_symbols(self): a = pybamm.Scalar(1) @@ -42,14 +40,4 @@ def test_unpack_list_of_symbols(self): unpacker = pybamm.SymbolUnpacker(pybamm.Parameter) unpacked = unpacker.unpack_list_of_symbols([a + b, a - c, b + c]) - self.assertEqual(unpacked, set([b, c])) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert unpacked == set([b, c]) diff --git a/tests/unit/test_expression_tree/test_printing/test_print_name.py b/tests/unit/test_expression_tree/test_printing/test_print_name.py index 9d74d6f1ab..c15ce18616 100644 --- a/tests/unit/test_expression_tree/test_printing/test_print_name.py +++ b/tests/unit/test_expression_tree/test_printing/test_print_name.py @@ -2,57 +2,42 @@ Tests for the print_name.py """ -from tests import TestCase -import unittest - import pybamm -class TestPrintName(TestCase): +class TestPrintName: def test_prettify_print_name(self): param = pybamm.LithiumIonParameters() param2 = pybamm.LeadAcidParameters() # Test PRINT_NAME_OVERRIDES - self.assertEqual(param.current_with_time.print_name, "I") + assert param.current_with_time.print_name == "I" # Test superscripts - self.assertEqual( - param.n.prim.c_init.print_name, r"c_{\mathrm{n}}^{\mathrm{init}}" - ) + assert param.n.prim.c_init.print_name == r"c_{\mathrm{n}}^{\mathrm{init}}" # Test subscripts - self.assertEqual(param.n.C_dl(0).print_name, r"C_{\mathrm{dl,n}}") + assert param.n.C_dl(0).print_name == r"C_{\mathrm{dl,n}}" # Test bar c_e_av = pybamm.Variable("c_e_av") c_e_av.print_name = "c_e_av" - self.assertEqual(c_e_av.print_name, r"\overline{c}_{\mathrm{e}}") + assert c_e_av.print_name == r"\overline{c}_{\mathrm{e}}" # Test greek letters - self.assertEqual(param2.delta.print_name, r"\delta") + assert param2.delta.print_name == r"\delta" # Test create_copy() a_n = param2.n.prim.a - self.assertEqual(a_n.create_copy().print_name, r"a_{\mathrm{n}}") + assert a_n.create_copy().print_name == r"a_{\mathrm{n}}" # Test eps eps_n = pybamm.Variable("eps_n") - self.assertEqual(eps_n.print_name, r"\epsilon_{\mathrm{n}}") + assert eps_n.print_name == r"\epsilon_{\mathrm{n}}" eps_n = pybamm.Variable("eps_c_e_n") - self.assertEqual(eps_n.print_name, r"(\epsilon c)_{\mathrm{e,n}}") + assert eps_n.print_name == r"(\epsilon c)_{\mathrm{e,n}}" # tplus t_plus = pybamm.Variable("t_plus") - self.assertEqual(t_plus.print_name, r"t_{\mathrm{+}}") - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert t_plus.print_name == r"t_{\mathrm{+}}" diff --git a/tests/unit/test_expression_tree/test_printing/test_sympy_overrides.py b/tests/unit/test_expression_tree/test_printing/test_sympy_overrides.py index 4b19c7d822..4ce073af4b 100644 --- a/tests/unit/test_expression_tree/test_printing/test_sympy_overrides.py +++ b/tests/unit/test_expression_tree/test_printing/test_sympy_overrides.py @@ -2,31 +2,17 @@ Tests for the sympy_overrides.py """ -from tests import TestCase -import unittest - -import pybamm from pybamm.expression_tree.printing.sympy_overrides import custom_print_func import sympy -class TestCustomPrint(TestCase): - def test_print_Derivative(self): +class TestCustomPrint: + def test_print_derivative(self): # Test force_partial der1 = sympy.Derivative("y", "x") der1.force_partial = True - self.assertEqual(custom_print_func(der1), "\\frac{\\partial}{\\partial x} y") + assert custom_print_func(der1) == "\\frac{\\partial}{\\partial x} y" # Test derivative der2 = sympy.Derivative("x") - self.assertEqual(custom_print_func(der2), "\\frac{d}{d x} x") - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert custom_print_func(der2) == "\\frac{d}{d x} x" diff --git a/tests/unit/test_expression_tree/test_scalar.py b/tests/unit/test_expression_tree/test_scalar.py index 34ea1aa514..986d3d3ccb 100644 --- a/tests/unit/test_expression_tree/test_scalar.py +++ b/tests/unit/test_expression_tree/test_scalar.py @@ -1,64 +1,51 @@ # # Tests for the Scalar class # -from tests import TestCase -import unittest -import unittest.mock as mock import pybamm -class TestScalar(TestCase): +class TestScalar: def test_scalar_eval(self): a = pybamm.Scalar(5) - self.assertEqual(a.value, 5) - self.assertEqual(a.evaluate(), 5) + assert a.value == 5 + assert a.evaluate() == 5 def test_scalar_operations(self): a = pybamm.Scalar(5) b = pybamm.Scalar(6) - self.assertEqual((a + b).evaluate(), 11) - self.assertEqual((a - b).evaluate(), -1) - self.assertEqual((a * b).evaluate(), 30) - self.assertEqual((a / b).evaluate(), 5 / 6) + assert (a + b).evaluate() == 11 + assert (a - b).evaluate() == -1 + assert (a * b).evaluate() == 30 + assert (a / b).evaluate() == 5 / 6 def test_scalar_eq(self): a1 = pybamm.Scalar(4) a2 = pybamm.Scalar(4) - self.assertEqual(a1, a2) + assert a1 == a2 a3 = pybamm.Scalar(5) - self.assertNotEqual(a1, a3) + assert a1 != a3 def test_to_equation(self): a = pybamm.Scalar(3) b = pybamm.Scalar(4) # Test value - self.assertEqual(str(a.to_equation()), "3.0") + assert str(a.to_equation()) == "3.0" # Test print_name b.print_name = "test" - self.assertEqual(str(b.to_equation()), "test") + assert str(b.to_equation()) == "test" def test_copy(self): a = pybamm.Scalar(5) b = a.create_copy() - self.assertEqual(a, b) + assert a == b - def test_to_from_json(self): + def test_to_from_json(self, mocker): a = pybamm.Scalar(5) - json_dict = {"name": "5.0", "id": mock.ANY, "value": 5.0} + json_dict = {"name": "5.0", "id": mocker.ANY, "value": 5.0} - self.assertEqual(a.to_json(), json_dict) + assert a.to_json() == json_dict - self.assertEqual(pybamm.Scalar._from_json(json_dict), a) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert pybamm.Scalar._from_json(json_dict) == a diff --git a/tests/unit/test_expression_tree/test_symbol.py b/tests/unit/test_expression_tree/test_symbol.py index 668c076907..e42f8dc8ef 100644 --- a/tests/unit/test_expression_tree/test_symbol.py +++ b/tests/unit/test_expression_tree/test_symbol.py @@ -18,6 +18,8 @@ class TestSymbol(TestCase): def test_symbol_init(self): sym = pybamm.Symbol("a symbol") + with self.assertRaises(TypeError): + sym.name = 1 self.assertEqual(sym.name, "a symbol") self.assertEqual(str(sym), "a symbol") diff --git a/tests/unit/test_expression_tree/test_vector.py b/tests/unit/test_expression_tree/test_vector.py index 34f817cf9c..e7b902fc73 100644 --- a/tests/unit/test_expression_tree/test_vector.py +++ b/tests/unit/test_expression_tree/test_vector.py @@ -1,22 +1,21 @@ # # Tests for the Vector class # -from tests import TestCase import pybamm import numpy as np -import unittest +import pytest -class TestVector(TestCase): - def setUp(self): +class TestVector: + def setup_method(self): self.x = np.array([[1], [2], [3]]) self.vect = pybamm.Vector(self.x) def test_array_wrapper(self): - self.assertEqual(self.vect.ndim, 2) - self.assertEqual(self.vect.shape, (3, 1)) - self.assertEqual(self.vect.size, 3) + assert self.vect.ndim == 2 + assert self.vect.shape == (3, 1) + assert self.vect.size == 3 def test_column_reshape(self): vect1d = pybamm.Vector(np.array([1, 2, 3])) @@ -39,17 +38,7 @@ def test_vector_operations(self): ) def test_wrong_size_entries(self): - with self.assertRaisesRegex( - ValueError, "Entries must have 1 dimension or be column vector" + with pytest.raises( + ValueError, match="Entries must have 1 dimension or be column vector" ): pybamm.Vector(np.ones((4, 5))) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_logger.py b/tests/unit/test_logger.py index 0897bc5835..06e2444c16 100644 --- a/tests/unit/test_logger.py +++ b/tests/unit/test_logger.py @@ -1,28 +1,27 @@ # # Tests the logger class. # -from tests import TestCase +import pytest import pybamm -import unittest -class TestLogger(TestCase): +class TestLogger: def test_logger(self): logger = pybamm.logger - self.assertEqual(logger.level, 30) + assert logger.level == 30 pybamm.set_logging_level("INFO") - self.assertEqual(logger.level, 20) + assert logger.level == 20 pybamm.set_logging_level("ERROR") - self.assertEqual(logger.level, 40) + assert logger.level == 40 pybamm.set_logging_level("VERBOSE") - self.assertEqual(logger.level, 15) + assert logger.level == 15 pybamm.set_logging_level("NOTICE") - self.assertEqual(logger.level, 25) + assert logger.level == 25 pybamm.set_logging_level("SUCCESS") - self.assertEqual(logger.level, 35) + assert logger.level == 35 pybamm.set_logging_level("SPAM") - self.assertEqual(logger.level, 5) + assert logger.level == 5 pybamm.logger.spam("Test spam level") pybamm.logger.verbose("Test verbose level") pybamm.logger.notice("Test notice level") @@ -32,15 +31,5 @@ def test_logger(self): pybamm.set_logging_level("WARNING") def test_exceptions(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): pybamm.get_new_logger("test", None) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_meshes/test_zero_dimensional_submesh.py b/tests/unit/test_meshes/test_zero_dimensional_submesh.py index 8bc1bc2e75..d9e3ebb5dd 100644 --- a/tests/unit/test_meshes/test_zero_dimensional_submesh.py +++ b/tests/unit/test_meshes/test_zero_dimensional_submesh.py @@ -1,12 +1,11 @@ import pybamm -import unittest -from tests import TestCase +import pytest -class TestSubMesh0D(TestCase): +class TestSubMesh0D: def test_exceptions(self): position = {"x": {"position": 0}, "y": {"position": 0}} - with self.assertRaises(pybamm.GeometryError): + with pytest.raises(pybamm.GeometryError): pybamm.SubMesh0D(position) def test_init(self): @@ -14,13 +13,3 @@ def test_init(self): generator = pybamm.SubMesh0D mesh = generator(position, None) mesh.add_ghost_meshes() - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_event.py b/tests/unit/test_models/test_event.py index 84b0dcde84..0636a0f5bd 100644 --- a/tests/unit/test_models/test_event.py +++ b/tests/unit/test_models/test_event.py @@ -1,27 +1,25 @@ # # Tests Event class # -from tests import TestCase import pybamm import numpy as np -import unittest -class TestEvent(TestCase): +class TestEvent: def test_event(self): expression = pybamm.Scalar(1) event = pybamm.Event("my event", expression) - self.assertEqual(event.name, "my event") - self.assertEqual(event.__str__(), "my event") - self.assertEqual(event.expression, expression) - self.assertEqual(event.event_type, pybamm.EventType.TERMINATION) + assert event.name == "my event" + assert event.__str__() == "my event" + assert event.expression == expression + assert event.event_type == pybamm.EventType.TERMINATION def test_expression_evaluate(self): # Test t expression = pybamm.t event = pybamm.Event("my event", expression) - self.assertEqual(event.evaluate(t=1), 1) + assert event.evaluate(t=1) == 1 # Test y sv = pybamm.StateVector(slice(0, 10)) @@ -46,7 +44,7 @@ def test_event_types(self): for event_type in event_types: event = pybamm.Event("my event", pybamm.Scalar(1), event_type) - self.assertEqual(event.event_type, event_type) + assert event.event_type == event_type def test_to_from_json(self): expression = pybamm.Scalar(1) @@ -58,24 +56,14 @@ def test_to_from_json(self): } event_ser_json = event.to_json() - self.assertEqual(event_ser_json, event_json) + assert event_ser_json == event_json event_json["expression"] = expression new_event = pybamm.Event._from_json(event_json) # check for equal expressions - self.assertEqual(new_event.expression, event.expression) + assert new_event.expression == event.expression # check for equal event types - self.assertEqual(new_event.event_type, event.event_type) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert new_event.event_type == event.event_type diff --git a/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_base_lead_acid_model.py b/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_base_lead_acid_model.py index ec280cdd1f..5d9ea27e2f 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_base_lead_acid_model.py +++ b/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_base_lead_acid_model.py @@ -1,46 +1,33 @@ # # Tests for the base lead acid model class # -from tests import TestCase import pybamm -import unittest +import pytest -class TestBaseLeadAcidModel(TestCase): +class TestBaseLeadAcidModel: def test_default_geometry(self): model = pybamm.lead_acid.BaseModel({"dimensionality": 0}) - self.assertEqual( - model.default_geometry["current collector"]["z"]["position"], 1 - ) + assert model.default_geometry["current collector"]["z"]["position"] == 1 model = pybamm.lead_acid.BaseModel({"dimensionality": 1}) - self.assertEqual(model.default_geometry["current collector"]["z"]["min"], 0) + assert model.default_geometry["current collector"]["z"]["min"] == 0 model = pybamm.lead_acid.BaseModel({"dimensionality": 2}) - self.assertEqual(model.default_geometry["current collector"]["y"]["min"], 0) + assert model.default_geometry["current collector"]["y"]["min"] == 0 def test_incompatible_options(self): - with self.assertRaisesRegex( + with pytest.raises( pybamm.OptionError, - "Lead-acid models can only have thermal effects if dimensionality is 0.", + match="Lead-acid models can only have thermal effects if dimensionality is 0.", ): pybamm.lead_acid.BaseModel({"dimensionality": 1, "thermal": "lumped"}) - with self.assertRaisesRegex(pybamm.OptionError, "SEI"): + with pytest.raises(pybamm.OptionError, match="SEI"): pybamm.lead_acid.BaseModel({"SEI": "constant"}) - with self.assertRaisesRegex(pybamm.OptionError, "lithium plating"): + with pytest.raises(pybamm.OptionError, match="lithium plating"): pybamm.lead_acid.BaseModel({"lithium plating": "reversible"}) - with self.assertRaisesRegex(pybamm.OptionError, "MSMR"): + with pytest.raises(pybamm.OptionError, match="MSMR"): pybamm.lead_acid.BaseModel( { "open-circuit potential": "MSMR", "particle": "MSMR", } ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_basic_models.py b/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_basic_models.py index 65b9f6bc9f..a7a708b394 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_basic_models.py +++ b/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_basic_models.py @@ -1,22 +1,10 @@ # # Tests for the basic lead acid models # -from tests import TestCase import pybamm -import unittest -class TestBasicModels(TestCase): +class TestBasicModels: def test_basic_full_lead_acid_well_posed(self): model = pybamm.lead_acid.BasicFull() model.check_well_posedness() - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_full.py b/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_full.py index c07c5c84c6..569851ec2a 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_full.py +++ b/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_full.py @@ -1,12 +1,10 @@ # # Tests for the lead-acid Full model # -from tests import TestCase import pybamm -import unittest -class TestLeadAcidFull(TestCase): +class TestLeadAcidFull: def test_well_posed(self): model = pybamm.lead_acid.Full() model.check_well_posedness() @@ -21,7 +19,7 @@ def test_well_posed_with_convection(self): model.check_well_posedness() -class TestLeadAcidFullSurfaceForm(TestCase): +class TestLeadAcidFullSurfaceForm: def test_well_posed_differential(self): options = {"surface form": "differential"} model = pybamm.lead_acid.Full(options) @@ -38,7 +36,7 @@ def test_well_posed_algebraic(self): model.check_well_posedness() -class TestLeadAcidFullSideReactions(TestCase): +class TestLeadAcidFullSideReactions: def test_well_posed(self): options = {"hydrolysis": "true"} model = pybamm.lead_acid.Full(options) @@ -48,20 +46,10 @@ def test_well_posed_surface_form_differential(self): options = {"hydrolysis": "true", "surface form": "differential"} model = pybamm.lead_acid.Full(options) model.check_well_posedness() - self.assertIsInstance(model.default_solver, pybamm.CasadiSolver) + assert isinstance(model.default_solver, pybamm.CasadiSolver) def test_well_posed_surface_form_algebraic(self): options = {"hydrolysis": "true", "surface form": "algebraic"} model = pybamm.lead_acid.Full(options) model.check_well_posedness() - self.assertIsInstance(model.default_solver, pybamm.CasadiSolver) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert isinstance(model.default_solver, pybamm.CasadiSolver) diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/base_lithium_ion_tests.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/base_lithium_ion_tests.py index 7e1f2d5cac..c8a3f6b509 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/base_lithium_ion_tests.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/base_lithium_ion_tests.py @@ -559,3 +559,28 @@ def test_well_posed_composite_diffusion_hysteresis(self): "open-circuit potential": (("current sigmoid", "single"), "single"), } self.check_well_posedness(options) + + def test_well_posed_composite_different_degradation(self): + # phases have same degradation + options = { + "particle phases": ("2", "1"), + "SEI": ("ec reaction limited", "none"), + "lithium plating": ("reversible", "none"), + "open-circuit potential": (("current sigmoid", "single"), "single"), + } + self.check_well_posedness(options) + # phases have different degradation + options = { + "particle phases": ("2", "1"), + "SEI": (("ec reaction limited", "solvent-diffusion limited"), "none"), + "lithium plating": (("reversible", "irreversible"), "none"), + "open-circuit potential": (("current sigmoid", "single"), "single"), + } + self.check_well_posedness(options) + # one of the phases has no degradation + options = { + "particle phases": ("2", "1"), + "SEI": (("none", "solvent-diffusion limited"), "none"), + "lithium plating": (("none", "irreversible"), "none"), + } + self.check_well_posedness(options) diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_Yang2017.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_Yang2017.py index 9631cf9f82..2fd18c17c6 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_Yang2017.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_Yang2017.py @@ -1,22 +1,10 @@ # # Tests for the lithium-ion DFN model # -from tests import TestCase import pybamm -import unittest -class TestYang2017(TestCase): +class TestYang2017: def test_well_posed(self): model = pybamm.lithium_ion.Yang2017() model.check_well_posedness() - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_base_lithium_ion_model.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_base_lithium_ion_model.py index fbc916d4a5..bfeb489661 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_base_lithium_ion_model.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_base_lithium_ion_model.py @@ -1,59 +1,44 @@ # # Tests for the base lead acid model class # -from tests import TestCase import pybamm -import unittest import os +import pytest -class TestBaseLithiumIonModel(TestCase): +class TestBaseLithiumIonModel: def test_incompatible_options(self): - with self.assertRaisesRegex(pybamm.OptionError, "convection not implemented"): + with pytest.raises(pybamm.OptionError, match="convection not implemented"): pybamm.lithium_ion.BaseModel({"convection": "uniform transverse"}) def test_default_parameters(self): # check parameters are read in ok model = pybamm.lithium_ion.BaseModel() - self.assertEqual( - model.default_parameter_values["Reference temperature [K]"], 298.15 - ) + assert model.default_parameter_values["Reference temperature [K]"] == 298.15 # change path and try again cwd = os.getcwd() os.chdir("..") model = pybamm.lithium_ion.BaseModel() - self.assertEqual( - model.default_parameter_values["Reference temperature [K]"], 298.15 - ) + assert model.default_parameter_values["Reference temperature [K]"] == 298.15 os.chdir(cwd) def test_insert_reference_electrode(self): model = pybamm.lithium_ion.SPM() model.insert_reference_electrode() - self.assertIn("Negative electrode 3E potential [V]", model.variables) - self.assertIn("Positive electrode 3E potential [V]", model.variables) - self.assertIn("Reference electrode potential [V]", model.variables) + assert "Negative electrode 3E potential [V]" in model.variables + assert "Positive electrode 3E potential [V]" in model.variables + assert "Reference electrode potential [V]" in model.variables model = pybamm.lithium_ion.SPM({"working electrode": "positive"}) model.insert_reference_electrode() - self.assertNotIn("Negative electrode potential [V]", model.variables) - self.assertIn("Positive electrode 3E potential [V]", model.variables) - self.assertIn("Reference electrode potential [V]", model.variables) + assert "Negative electrode potential [V]" not in model.variables + assert "Positive electrode 3E potential [V]" in model.variables + assert "Reference electrode potential [V]" in model.variables model = pybamm.lithium_ion.SPM({"dimensionality": 2}) - with self.assertRaisesRegex( - NotImplementedError, "Reference electrode can only be" + with pytest.raises( + NotImplementedError, match="Reference electrode can only be" ): model.insert_reference_electrode() - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_basic_models.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_basic_models.py index 2f00bb260c..8462e7c803 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_basic_models.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_basic_models.py @@ -1,12 +1,10 @@ # # Tests for the basic lithium-ion models # -from tests import TestCase import pybamm -import unittest -class TestBasicModels(TestCase): +class TestBasicModels: def test_dfn_well_posed(self): model = pybamm.lithium_ion.BasicDFN() model.check_well_posedness() @@ -23,13 +21,3 @@ def test_dfn_half_cell_well_posed(self): def test_dfn_composite_well_posed(self): model = pybamm.lithium_ion.BasicDFNComposite() model.check_well_posedness() - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_dfn.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_dfn.py index 20fc69e541..cddd59c352 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_dfn.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_dfn.py @@ -1,19 +1,19 @@ # # Tests for the lithium-ion DFN model # -from tests import TestCase import pybamm -import unittest +import pytest from tests import BaseUnitTestLithiumIon -class TestDFN(BaseUnitTestLithiumIon, TestCase): +class TestDFN(BaseUnitTestLithiumIon): + @pytest.fixture(autouse=True) def setUp(self): self.model = pybamm.lithium_ion.DFN def test_electrolyte_options(self): options = {"electrolyte conductivity": "integrated"} - with self.assertRaisesRegex(pybamm.OptionError, "electrolyte conductivity"): + with pytest.raises(pybamm.OptionError, match="electrolyte conductivity"): pybamm.lithium_ion.DFN(options) def test_well_posed_size_distribution(self): @@ -66,13 +66,3 @@ def test_well_posed_msmr_with_psd(self): "intercalation kinetics": "MSMR", } self.check_well_posedness(options) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_dfn_half_cell.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_dfn_half_cell.py index 78d9ebda94..389fcf9429 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_dfn_half_cell.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_dfn_half_cell.py @@ -1,22 +1,13 @@ # # Tests for the lithium-ion half-cell DFN model # -from tests import TestCase + import pybamm -import unittest from tests import BaseUnitTestLithiumIonHalfCell +import pytest -class TestDFNHalfCell(BaseUnitTestLithiumIonHalfCell, TestCase): +class TestDFNHalfCell(BaseUnitTestLithiumIonHalfCell): + @pytest.fixture(autouse=True) def setUp(self): self.model = pybamm.lithium_ion.DFN - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_mpm_half_cell.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_mpm_half_cell.py index 77d51f6cf7..e5637968c3 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_mpm_half_cell.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_mpm_half_cell.py @@ -1,12 +1,10 @@ # # Tests for the lithium-ion MPM model # -from tests import TestCase import pybamm -import unittest -class TestMPM(TestCase): +class TestMPM: def test_well_posed(self): options = {"thermal": "isothermal", "working electrode": "positive"} model = pybamm.lithium_ion.MPM(options) @@ -20,9 +18,9 @@ def test_well_posed(self): def test_default_parameter_values(self): # check default parameters are added correctly model = pybamm.lithium_ion.MPM({"working electrode": "positive"}) - self.assertEqual( - model.default_parameter_values["Positive minimum particle radius [m]"], - 0.0, + assert ( + model.default_parameter_values["Positive minimum particle radius [m]"] + == 0.0 ) def test_lumped_thermal_model_1D(self): @@ -44,7 +42,7 @@ def test_differential_surface_form(self): model.check_well_posedness() -class TestMPMExternalCircuits(TestCase): +class TestMPMExternalCircuits: def test_well_posed_voltage(self): options = {"operating mode": "voltage", "working electrode": "positive"} model = pybamm.lithium_ion.MPM(options) @@ -67,13 +65,3 @@ def external_circuit_function(variables): } model = pybamm.lithium_ion.MPM(options) model.check_well_posedness() - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_msmr.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_msmr.py index 96369fbac2..4f1958d095 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_msmr.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_msmr.py @@ -1,22 +1,10 @@ # # Tests for the lithium-ion MSMR model # -from tests import TestCase import pybamm -import unittest -class TestMSMR(TestCase): +class TestMSMR: def test_well_posed(self): model = pybamm.lithium_ion.MSMR({"number of MSMR reactions": ("6", "4")}) model.check_well_posedness() - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_newman_tobias.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_newman_tobias.py index 5369d94b29..c979474e13 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_newman_tobias.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_newman_tobias.py @@ -1,42 +1,41 @@ # # Tests for the lithium-ion Newman-Tobias model # -from tests import TestCase import pybamm -import unittest +import pytest from tests import BaseUnitTestLithiumIon -class TestNewmanTobias(BaseUnitTestLithiumIon, TestCase): +class TestNewmanTobias(BaseUnitTestLithiumIon): + @pytest.fixture(autouse=True) def setUp(self): self.model = pybamm.lithium_ion.NewmanTobias def test_electrolyte_options(self): options = {"electrolyte conductivity": "integrated"} - with self.assertRaisesRegex(pybamm.OptionError, "electrolyte conductivity"): + with pytest.raises(pybamm.OptionError, match="electrolyte conductivity"): pybamm.lithium_ion.NewmanTobias(options) + @pytest.mark.skip(reason="Test currently not implemented") def test_well_posed_particle_phases(self): pass # skip this test + @pytest.mark.skip(reason="Test currently not implemented") def test_well_posed_particle_phases_thermal(self): pass # Skip this test + @pytest.mark.skip(reason="Test currently not implemented") def test_well_posed_particle_phases_sei(self): pass # skip this test + @pytest.mark.skip(reason="Test currently not implemented") def test_well_posed_composite_kinetic_hysteresis(self): pass # skip this test + @pytest.mark.skip(reason="Test currently not implemented") def test_well_posed_composite_diffusion_hysteresis(self): pass # skip this test - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + @pytest.mark.skip(reason="Test currently not implemented") + def test_well_posed_composite_different_degradation(self): + pass # skip this test diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spm.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spm.py index 45cf00877b..99affc7ddd 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spm.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spm.py @@ -1,19 +1,19 @@ # # Tests for the lithium-ion SPM model # -from tests import TestCase import pybamm -import unittest from tests import BaseUnitTestLithiumIon +import pytest -class TestSPM(BaseUnitTestLithiumIon, TestCase): +class TestSPM(BaseUnitTestLithiumIon): + @pytest.fixture(autouse=True) def setUp(self): self.model = pybamm.lithium_ion.SPM def test_electrolyte_options(self): options = {"electrolyte conductivity": "full"} - with self.assertRaisesRegex(pybamm.OptionError, "electrolyte conductivity"): + with pytest.raises(pybamm.OptionError, match="electrolyte conductivity"): pybamm.lithium_ion.SPM(options) def test_kinetics_options(self): @@ -21,7 +21,7 @@ def test_kinetics_options(self): "surface form": "false", "intercalation kinetics": "Marcus-Hush-Chidsey", } - with self.assertRaisesRegex(pybamm.OptionError, "Inverse kinetics"): + with pytest.raises(pybamm.OptionError, match="Inverse kinetics"): pybamm.lithium_ion.SPM(options) def test_x_average_options(self): @@ -37,11 +37,11 @@ def test_x_average_options(self): # Check model with distributed side reactions throws an error options["x-average side reactions"] = "false" - with self.assertRaisesRegex(pybamm.OptionError, "cannot be 'false' for SPM"): + with pytest.raises(pybamm.OptionError, match="cannot be 'false' for SPM"): pybamm.lithium_ion.SPM(options) def test_distribution_options(self): - with self.assertRaisesRegex(pybamm.OptionError, "surface form"): + with pytest.raises(pybamm.OptionError, match="surface form"): pybamm.lithium_ion.SPM({"particle size": "distribution"}) def test_particle_size_distribution(self): @@ -53,10 +53,10 @@ def test_new_model(self): new_model = model.new_copy() model_T_eqn = model.rhs[model.variables["Cell temperature [K]"]] new_model_T_eqn = new_model.rhs[new_model.variables["Cell temperature [K]"]] - self.assertEqual(new_model_T_eqn, model_T_eqn) - self.assertEqual(new_model.name, model.name) - self.assertEqual(new_model.use_jacobian, model.use_jacobian) - self.assertEqual(new_model.convert_to_format, model.convert_to_format) + assert new_model_T_eqn == model_T_eqn + assert new_model.name == model.name + assert new_model.use_jacobian == model.use_jacobian + assert new_model.convert_to_format == model.convert_to_format # with custom submodels options = {"stress-induced diffusion": "false", "thermal": "x-full"} @@ -72,14 +72,4 @@ def test_new_model(self): new_model = model.new_copy() new_model_cs_eqn = list(new_model.rhs.values())[1] model_cs_eqn = list(model.rhs.values())[1] - self.assertEqual(new_model_cs_eqn, model_cs_eqn) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert new_model_cs_eqn == model_cs_eqn diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spm_half_cell.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spm_half_cell.py index 0d6ba93ce0..c1b6b34745 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spm_half_cell.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spm_half_cell.py @@ -1,22 +1,12 @@ # # Tests for the lithium-ion half-cell SPM model # -from tests import TestCase import pybamm -import unittest from tests import BaseUnitTestLithiumIonHalfCell +import pytest -class TestSPMHalfCell(BaseUnitTestLithiumIonHalfCell, TestCase): +class TestSPMHalfCell(BaseUnitTestLithiumIonHalfCell): + @pytest.fixture(autouse=True) def setUp(self): self.model = pybamm.lithium_ion.SPM - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spme.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spme.py index 72222ee060..b0d38fa9c7 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spme.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spme.py @@ -1,13 +1,13 @@ # # Tests for the lithium-ion SPMe model # -from tests import TestCase import pybamm -import unittest from tests import BaseUnitTestLithiumIon +import pytest -class TestSPMe(BaseUnitTestLithiumIon, TestCase): +class TestSPMe(BaseUnitTestLithiumIon): + @pytest.fixture(autouse=True) def setUp(self): self.model = pybamm.lithium_ion.SPMe @@ -31,19 +31,9 @@ def setUp(self): def test_electrolyte_options(self): options = {"electrolyte conductivity": "full"} - with self.assertRaisesRegex(pybamm.OptionError, "electrolyte conductivity"): + with pytest.raises(pybamm.OptionError, match="electrolyte conductivity"): pybamm.lithium_ion.SPMe(options) def test_integrated_conductivity(self): options = {"electrolyte conductivity": "integrated"} self.check_well_posedness(options) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spme_half_cell.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spme_half_cell.py index f1930df026..2a814c113e 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spme_half_cell.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spme_half_cell.py @@ -3,21 +3,11 @@ # This is achieved by using the {"working electrode": "positive"} option # import pybamm -import unittest -from tests import TestCase from tests import BaseUnitTestLithiumIonHalfCell +import pytest -class TestSPMeHalfCell(BaseUnitTestLithiumIonHalfCell, TestCase): +class TestSPMeHalfCell(BaseUnitTestLithiumIonHalfCell): + @pytest.fixture(autouse=True) def setUp(self): self.model = pybamm.lithium_ion.SPMe - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_model_info.py b/tests/unit/test_models/test_model_info.py index 144d763bf1..b754399872 100644 --- a/tests/unit/test_models/test_model_info.py +++ b/tests/unit/test_models/test_model_info.py @@ -1,12 +1,10 @@ # # Tests getting model info # -from tests import TestCase import pybamm -import unittest -class TestModelInfo(TestCase): +class TestModelInfo: def test_find_parameter_info(self): model = pybamm.lithium_ion.SPM() model.info("Negative particle diffusivity [m2.s-1]") @@ -16,13 +14,3 @@ def test_find_parameter_info(self): model.info("Negative particle diffusivity [m2.s-1]") model.info("Not a parameter") - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_submodels/test_base_submodel.py b/tests/unit/test_models/test_submodels/test_base_submodel.py index 1519a2fea2..9f2a9c3549 100644 --- a/tests/unit/test_models/test_submodels/test_base_submodel.py +++ b/tests/unit/test_models/test_submodels/test_base_submodel.py @@ -1,52 +1,50 @@ # # Test base submodel # -from tests import TestCase - +import pytest import pybamm -import unittest -class TestBaseSubModel(TestCase): +class TestBaseSubModel: def test_domain(self): # Accepted string submodel = pybamm.BaseSubModel(None, "negative", phase="primary") - self.assertEqual(submodel.domain, "negative") + assert submodel.domain == "negative" # None submodel = pybamm.BaseSubModel(None, None) - self.assertEqual(submodel.domain, None) + assert submodel.domain is None # bad string - with self.assertRaises(pybamm.DomainError): + with pytest.raises(pybamm.DomainError): pybamm.BaseSubModel(None, "bad string") def test_phase(self): # Without domain submodel = pybamm.BaseSubModel(None, None) - self.assertEqual(submodel.phase, None) + assert submodel.phase is None - with self.assertRaisesRegex(ValueError, "Phase must be None"): + with pytest.raises(ValueError, match="Phase must be None"): pybamm.BaseSubModel(None, None, phase="primary") # With domain submodel = pybamm.BaseSubModel(None, "negative", phase="primary") - self.assertEqual(submodel.phase, "primary") - self.assertEqual(submodel.phase_name, "") + assert submodel.phase == "primary" + assert submodel.phase_name == "" submodel = pybamm.BaseSubModel( None, "negative", options={"particle phases": "2"}, phase="secondary" ) - self.assertEqual(submodel.phase, "secondary") - self.assertEqual(submodel.phase_name, "secondary ") + assert submodel.phase == "secondary" + assert submodel.phase_name == "secondary " - with self.assertRaisesRegex(ValueError, "Phase must be 'primary'"): + with pytest.raises(ValueError, match="Phase must be 'primary'"): pybamm.BaseSubModel(None, "negative", phase="secondary") - with self.assertRaisesRegex(ValueError, "Phase must be either 'primary'"): + with pytest.raises(ValueError, match="Phase must be either 'primary'"): pybamm.BaseSubModel( None, "negative", options={"particle phases": "2"}, phase="tertiary" ) - with self.assertRaisesRegex(ValueError, "Phase must be 'primary'"): + with pytest.raises(ValueError, match="Phase must be 'primary'"): # 2 phases in the negative but only 1 in the positive pybamm.BaseSubModel( None, @@ -54,13 +52,3 @@ def test_phase(self): options={"particle phases": ("2", "1")}, phase="secondary", ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_submodels/test_effective_current_collector.py b/tests/unit/test_models/test_submodels/test_effective_current_collector.py index cbab3134d4..b2437ec1d9 100644 --- a/tests/unit/test_models/test_submodels/test_effective_current_collector.py +++ b/tests/unit/test_models/test_submodels/test_effective_current_collector.py @@ -1,13 +1,12 @@ # # Tests for the Effective Current Collector Resistance models # -from tests import TestCase +import pytest import pybamm -import unittest import numpy as np -class TestEffectiveResistance(TestCase): +class TestEffectiveResistance: def test_well_posed(self): model = pybamm.current_collector.EffectiveResistance({"dimensionality": 1}) model.check_well_posedness() @@ -17,36 +16,34 @@ def test_well_posed(self): def test_default_parameters(self): model = pybamm.current_collector.EffectiveResistance({"dimensionality": 1}) - self.assertEqual( - model.default_parameter_values, pybamm.ParameterValues("Marquis2019") - ) + assert model.default_parameter_values == pybamm.ParameterValues("Marquis2019") def test_default_geometry(self): model = pybamm.current_collector.EffectiveResistance({"dimensionality": 1}) - self.assertTrue("current collector" in model.default_geometry) - self.assertNotIn("negative electrode", model.default_geometry) + assert "current collector" in model.default_geometry + assert "negative electrode" not in model.default_geometry model = pybamm.current_collector.EffectiveResistance({"dimensionality": 2}) - self.assertTrue("current collector" in model.default_geometry) - self.assertNotIn("negative electrode", model.default_geometry) + assert "current collector" in model.default_geometry + assert "negative electrode" not in model.default_geometry def test_default_var_pts(self): model = pybamm.current_collector.EffectiveResistance({"dimensionality": 1}) - self.assertEqual(model.default_var_pts, {"y": 32, "z": 32}) + assert model.default_var_pts == {"y": 32, "z": 32} def test_default_solver(self): model = pybamm.current_collector.EffectiveResistance({"dimensionality": 1}) - self.assertIsInstance(model.default_solver, pybamm.CasadiAlgebraicSolver) + assert isinstance(model.default_solver, pybamm.CasadiAlgebraicSolver) model = pybamm.current_collector.EffectiveResistance({"dimensionality": 2}) - self.assertIsInstance(model.default_solver, pybamm.CasadiAlgebraicSolver) + assert isinstance(model.default_solver, pybamm.CasadiAlgebraicSolver) def test_bad_option(self): - with self.assertRaisesRegex(pybamm.OptionError, "Dimension of"): + with pytest.raises(pybamm.OptionError, match="Dimension of"): pybamm.current_collector.EffectiveResistance({"dimensionality": 10}) -class TestEffectiveResistancePostProcess(TestCase): +class TestEffectiveResistancePostProcess: def test_get_processed_variables(self): # solve cheap SPM to test post-processing (think of an alternative test?) models = [ @@ -87,13 +84,3 @@ def test_get_processed_variables(self): processed_var(t=solution_1D.t[5], z=pts) else: processed_var(t=solution_1D.t[5], y=pts, z=pts) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_submodels/test_particle_polynomial_profile.py b/tests/unit/test_models/test_submodels/test_particle_polynomial_profile.py index 787230d9f3..57f1436f2d 100644 --- a/tests/unit/test_models/test_submodels/test_particle_polynomial_profile.py +++ b/tests/unit/test_models/test_submodels/test_particle_polynomial_profile.py @@ -1,22 +1,11 @@ # # Tests for the polynomial profile submodel # -from tests import TestCase import pybamm -import unittest +import pytest -class TestParticlePolynomialProfile(TestCase): +class TestParticlePolynomialProfile: def test_errors(self): - with self.assertRaisesRegex(ValueError, "Particle type must be"): + with pytest.raises(ValueError, match="Particle type must be"): pybamm.particle.PolynomialProfile(None, "negative", {}) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_parameters/test_base_parameters.py b/tests/unit/test_parameters/test_base_parameters.py index 6c87cdcd88..2c48074a71 100644 --- a/tests/unit/test_parameters/test_base_parameters.py +++ b/tests/unit/test_parameters/test_base_parameters.py @@ -2,48 +2,37 @@ Tests for the base_parameters.py """ -from tests import TestCase import pybamm -import unittest +import pytest -class TestBaseParameters(TestCase): +class TestBaseParameters: def test_getattr__(self): param = pybamm.LithiumIonParameters() # ending in _n / _s / _p - with self.assertRaisesRegex(AttributeError, "param.n.L"): + with pytest.raises(AttributeError, match="param.n.L"): param.L_n - with self.assertRaisesRegex(AttributeError, "param.s.L"): + with pytest.raises(AttributeError, match="param.s.L"): param.L_s - with self.assertRaisesRegex(AttributeError, "param.p.L"): + with pytest.raises(AttributeError, match="param.p.L"): param.L_p # _n_ in the name - with self.assertRaisesRegex(AttributeError, "param.n.prim.c_max"): + with pytest.raises(AttributeError, match="param.n.prim.c_max"): param.c_n_max # _n_ or _p_ not in name - with self.assertRaisesRegex( - AttributeError, "has no attribute 'c_n_not_a_parameter" + with pytest.raises( + AttributeError, match="has no attribute 'c_n_not_a_parameter" ): param.c_n_not_a_parameter - with self.assertRaisesRegex(AttributeError, "has no attribute 'c_s_test"): + with pytest.raises(AttributeError, match="has no attribute 'c_s_test"): pybamm.electrical_parameters.c_s_test - self.assertEqual(param.n.cap_init, param.n.Q_init) - self.assertEqual(param.p.prim.cap_init, param.p.prim.Q_init) + assert param.n.cap_init == param.n.Q_init + assert param.p.prim.cap_init == param.p.prim.Q_init def test__setattr__(self): # domain gets added as a subscript param = pybamm.GeometricParameters() - self.assertEqual(param.n.L.print_name, r"L_{\mathrm{n}}") - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert param.n.L.print_name == r"L_{\mathrm{n}}" diff --git a/tests/unit/test_parameters/test_electrical_parameters.py b/tests/unit/test_parameters/test_electrical_parameters.py index 92bceaf632..7601c30721 100644 --- a/tests/unit/test_parameters/test_electrical_parameters.py +++ b/tests/unit/test_parameters/test_electrical_parameters.py @@ -1,13 +1,11 @@ # # Tests for the electrical parameters # -from tests import TestCase +import pytest import pybamm -import unittest - -class TestElectricalParameters(TestCase): +class TestElectricalParameters: def test_current_functions(self): # create current functions param = pybamm.electrical_parameters @@ -27,17 +25,7 @@ def test_current_functions(self): current_density_eval = parameter_values.process_symbol(current_density) # check current - self.assertEqual(current_eval.evaluate(t=3), 2) + assert current_eval.evaluate(t=3) == 2 # check current density - self.assertAlmostEqual(current_density_eval.evaluate(t=3), 2 / (8 * 0.1 * 0.1)) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert current_density_eval.evaluate(t=3) == pytest.approx(2 / (8 * 0.1 * 0.1)) diff --git a/tests/unit/test_parameters/test_geometric_parameters.py b/tests/unit/test_parameters/test_geometric_parameters.py index 6e59259a12..7e000bf645 100644 --- a/tests/unit/test_parameters/test_geometric_parameters.py +++ b/tests/unit/test_parameters/test_geometric_parameters.py @@ -1,12 +1,10 @@ # # Tests for the standard parameters # -from tests import TestCase import pybamm -import unittest -class TestGeometricParameters(TestCase): +class TestGeometricParameters: def test_macroscale_parameters(self): geo = pybamm.geometric_parameters L_n = geo.n.L @@ -26,16 +24,4 @@ def test_macroscale_parameters(self): L_p_eval = parameter_values.process_symbol(L_p) L_x_eval = parameter_values.process_symbol(L_x) - self.assertEqual( - (L_n_eval + L_s_eval + L_p_eval).evaluate(), L_x_eval.evaluate() - ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert (L_n_eval + L_s_eval + L_p_eval).evaluate() == L_x_eval.evaluate() diff --git a/tests/unit/test_parameters/test_parameter_sets/test_Ai2020.py b/tests/unit/test_parameters/test_parameter_sets/test_Ai2020.py index 8816551ab6..f7302330bf 100644 --- a/tests/unit/test_parameters/test_parameter_sets/test_Ai2020.py +++ b/tests/unit/test_parameters/test_parameter_sets/test_Ai2020.py @@ -1,12 +1,11 @@ # # Tests for Ai (2020) Enertech parameter set loads # -from tests import TestCase +import pytest import pybamm -import unittest -class TestAi2020(TestCase): +class TestAi2020: def test_functions(self): param = pybamm.ParameterValues("Ai2020") sto = pybamm.Scalar(0.5) @@ -42,16 +41,6 @@ def test_functions(self): } for name, value in fun_test.items(): - self.assertAlmostEqual( - param.evaluate(param[name](*value[0])), value[1], places=4 + assert param.evaluate(param[name](*value[0])) == pytest.approx( + value[1], abs=0.0001 ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_parameters/test_parameter_sets/test_Ecker2015_graphite_halfcell.py b/tests/unit/test_parameters/test_parameter_sets/test_Ecker2015_graphite_halfcell.py index f435ef6d36..6000b997b7 100644 --- a/tests/unit/test_parameters/test_parameter_sets/test_Ecker2015_graphite_halfcell.py +++ b/tests/unit/test_parameters/test_parameter_sets/test_Ecker2015_graphite_halfcell.py @@ -1,12 +1,11 @@ # # Tests for O'Kane (2022) parameter set # -from tests import TestCase +import pytest import pybamm -import unittest -class TestEcker2015_graphite_halfcell(TestCase): +class TestEcker2015_graphite_halfcell: def test_functions(self): param = pybamm.ParameterValues("Ecker2015_graphite_halfcell") sto = pybamm.Scalar(0.5) @@ -33,16 +32,6 @@ def test_functions(self): } for name, value in fun_test.items(): - self.assertAlmostEqual( - param.evaluate(param[name](*value[0])), value[1], places=4 + assert param.evaluate(param[name](*value[0])) == pytest.approx( + value[1], abs=0.0001 ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_parameters/test_parameter_sets/test_LCO_Ramadass2004.py b/tests/unit/test_parameters/test_parameter_sets/test_LCO_Ramadass2004.py index 2de67b9e62..e6c4b04fdf 100644 --- a/tests/unit/test_parameters/test_parameter_sets/test_LCO_Ramadass2004.py +++ b/tests/unit/test_parameters/test_parameter_sets/test_LCO_Ramadass2004.py @@ -1,12 +1,11 @@ # # Tests for Ai (2020) Enertech parameter set loads # -from tests import TestCase +import pytest import pybamm -import unittest -class TestRamadass2004(TestCase): +class TestRamadass2004: def test_functions(self): param = pybamm.ParameterValues("Ramadass2004") sto = pybamm.Scalar(0.5) @@ -40,16 +39,6 @@ def test_functions(self): } for name, value in fun_test.items(): - self.assertAlmostEqual( - param.evaluate(param[name](*value[0])), value[1], places=4 + assert param.evaluate(param[name](*value[0])) == pytest.approx( + value[1], abs=0.0001 ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_parameters/test_parameter_sets/test_LGM50_ORegan2022.py b/tests/unit/test_parameters/test_parameter_sets/test_LGM50_ORegan2022.py index f878b7d790..05a38b6245 100644 --- a/tests/unit/test_parameters/test_parameter_sets/test_LGM50_ORegan2022.py +++ b/tests/unit/test_parameters/test_parameter_sets/test_LGM50_ORegan2022.py @@ -1,12 +1,11 @@ # # Tests for LG M50 parameter set loads # -from tests import TestCase +import pytest import pybamm -import unittest -class TestORegan2022(TestCase): +class TestORegan2022: def test_functions(self): param = pybamm.ParameterValues("ORegan2022") T = pybamm.Scalar(298.15) @@ -68,16 +67,6 @@ def test_functions(self): } for name, value in fun_test.items(): - self.assertAlmostEqual( - param.evaluate(param[name](*value[0])), value[1], places=4 + assert param.evaluate(param[name](*value[0])) == pytest.approx( + value[1], abs=0.0001 ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_parameters/test_parameter_sets/test_OKane2022_negative_halfcell.py b/tests/unit/test_parameters/test_parameter_sets/test_OKane2022_negative_halfcell.py index 04a19e1002..bf39457dc4 100644 --- a/tests/unit/test_parameters/test_parameter_sets/test_OKane2022_negative_halfcell.py +++ b/tests/unit/test_parameters/test_parameter_sets/test_OKane2022_negative_halfcell.py @@ -1,12 +1,11 @@ # # Tests for O'Kane (2022) parameter set # -from tests import TestCase +import pytest import pybamm -import unittest -class TestOKane2022_graphite_SiOx_halfcell(TestCase): +class TestOKane2022_graphite_SiOx_halfcell: def test_functions(self): param = pybamm.ParameterValues("OKane2022_graphite_SiOx_halfcell") sto = pybamm.Scalar(0.9) @@ -31,16 +30,6 @@ def test_functions(self): } for name, value in fun_test.items(): - self.assertAlmostEqual( - param.evaluate(param[name](*value[0])), value[1], places=4 + assert param.evaluate(param[name](*value[0])) == pytest.approx( + value[1], abs=0.0001 ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_parameters/test_parameter_sets_class.py b/tests/unit/test_parameters/test_parameter_sets_class.py index b14000f987..342cf127aa 100644 --- a/tests/unit/test_parameters/test_parameter_sets_class.py +++ b/tests/unit/test_parameters/test_parameter_sets_class.py @@ -1,23 +1,22 @@ # # Tests for the ParameterSets class # -from tests import TestCase - +import pytest +import re import pybamm -import unittest -class TestParameterSets(TestCase): +class TestParameterSets: def test_name_interface(self): """Test that pybamm.parameters_sets. returns the name of the parameter set and a depreciation warning """ - with self.assertWarns(DeprecationWarning): + with pytest.warns(DeprecationWarning): out = pybamm.parameter_sets.Marquis2019 - self.assertEqual(out, "Marquis2019") + assert out == "Marquis2019" - # Expect error for parameter set's that aren't real - with self.assertRaises(AttributeError): + # Expect an error for parameter sets that aren't real + with pytest.raises(AttributeError): pybamm.parameter_sets.not_a_real_parameter_set def test_all_registered(self): @@ -26,26 +25,15 @@ def test_all_registered(self): known_entry_points = set( ep.name for ep in pybamm.parameter_sets.get_entries("pybamm_parameter_sets") ) - self.assertEqual(set(pybamm.parameter_sets.keys()), known_entry_points) - self.assertEqual(len(known_entry_points), len(pybamm.parameter_sets)) + assert set(pybamm.parameter_sets.keys()) == known_entry_points + assert len(known_entry_points) == len(pybamm.parameter_sets) def test_get_docstring(self): """Test that :meth:`pybamm.parameter_sets.get_doctstring` works""" docstring = pybamm.parameter_sets.get_docstring("Marquis2019") - self.assertRegex(docstring, "Parameters for a Kokam SLPB78205130H cell") + assert re.search("Parameters for a Kokam SLPB78205130H cell", docstring) def test_iter(self): """Test that iterating `pybamm.parameter_sets` iterates over keys""" for k in pybamm.parameter_sets: - self.assertIsInstance(k, str) - self.assertIn(k, pybamm.parameter_sets) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert isinstance(k, str) diff --git a/tests/unit/test_parameters/test_size_distribution_parameters.py b/tests/unit/test_parameters/test_size_distribution_parameters.py index 5deeaa62be..414b422055 100644 --- a/tests/unit/test_parameters/test_size_distribution_parameters.py +++ b/tests/unit/test_parameters/test_size_distribution_parameters.py @@ -2,13 +2,12 @@ # Tests particle size distribution parameters are loaded into a parameter set # and give expected values # +import pytest import pybamm -import unittest import numpy as np -from tests import TestCase -class TestSizeDistributionParameters(TestCase): +class TestSizeDistributionParameters: def test_parameter_values(self): values = pybamm.lithium_ion.BaseModel().default_parameter_values param = pybamm.LithiumIonParameters() @@ -20,7 +19,7 @@ def test_parameter_values(self): ) # check negative parameters aren't there yet - with self.assertRaises(KeyError): + with pytest.raises(KeyError): values["Negative maximum particle radius [m]"] # now add distribution parameter values for negative electrode @@ -41,13 +40,3 @@ def test_parameter_values(self): R_test = pybamm.Scalar(1.0) values.evaluate(param.n.prim.f_a_dist(R_test)) values.evaluate(param.p.prim.f_a_dist(R_test)) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_plotting/test_plot.py b/tests/unit/test_plotting/test_plot.py index f36e20cd6f..1c049269c3 100644 --- a/tests/unit/test_plotting/test_plot.py +++ b/tests/unit/test_plotting/test_plot.py @@ -1,14 +1,13 @@ import pybamm -import unittest +import pytest import numpy as np -from tests import TestCase import matplotlib.pyplot as plt from matplotlib import use use("Agg") -class TestPlot(TestCase): +class TestPlot: def test_plot(self): x = pybamm.Array(np.array([0, 3, 10])) y = pybamm.Array(np.array([6, 16, 78])) @@ -16,13 +15,13 @@ def test_plot(self): _, ax = plt.subplots() ax_out = pybamm.plot(x, y, ax=ax, show_plot=False) - self.assertEqual(ax_out, ax) + assert ax_out == ax def test_plot_fail(self): x = pybamm.Array(np.array([0])) - with self.assertRaisesRegex(TypeError, "x must be 'pybamm.Array'"): + with pytest.raises(TypeError, match="x must be 'pybamm.Array'"): pybamm.plot("bad", x) - with self.assertRaisesRegex(TypeError, "y must be 'pybamm.Array'"): + with pytest.raises(TypeError, match="y must be 'pybamm.Array'"): pybamm.plot(x, "bad") def test_plot2D(self): @@ -38,23 +37,13 @@ def test_plot2D(self): _, ax = plt.subplots() ax_out = pybamm.plot2D(X, Y, Y, ax=ax, show_plot=False) - self.assertEqual(ax_out, ax) + assert ax_out == ax def test_plot2D_fail(self): x = pybamm.Array(np.array([0])) - with self.assertRaisesRegex(TypeError, "x must be 'pybamm.Array'"): + with pytest.raises(TypeError, match="x must be 'pybamm.Array'"): pybamm.plot2D("bad", x, x) - with self.assertRaisesRegex(TypeError, "y must be 'pybamm.Array'"): + with pytest.raises(TypeError, match="y must be 'pybamm.Array'"): pybamm.plot2D(x, "bad", x) - with self.assertRaisesRegex(TypeError, "z must be 'pybamm.Array'"): + with pytest.raises(TypeError, match="z must be 'pybamm.Array'"): pybamm.plot2D(x, x, "bad") - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_plotting/test_plot_summary_variables.py b/tests/unit/test_plotting/test_plot_summary_variables.py index e896b1f468..5f1a650ced 100644 --- a/tests/unit/test_plotting/test_plot_summary_variables.py +++ b/tests/unit/test_plotting/test_plot_summary_variables.py @@ -1,10 +1,8 @@ import pybamm -import unittest import numpy as np -from tests import TestCase -class TestPlotSummaryVariables(TestCase): +class TestPlotSummaryVariables: def test_plot(self): model = pybamm.lithium_ion.SPM({"SEI": "ec reaction limited"}) parameter_values = pybamm.ParameterValues("Mohtat2020") @@ -39,11 +37,11 @@ def test_plot(self): axes = pybamm.plot_summary_variables(sol, show_plot=False) axes = axes.flatten() - self.assertEqual(len(axes), 9) + assert len(axes) == 9 for output_var, ax in zip(output_variables, axes): - self.assertEqual(ax.get_xlabel(), "Cycle number") - self.assertEqual(ax.get_ylabel(), output_var) + assert ax.get_xlabel() == "Cycle number" + assert ax.get_ylabel() == output_var cycle_number, var = ax.get_lines()[0].get_data() np.testing.assert_array_equal( @@ -56,11 +54,11 @@ def test_plot(self): ) axes = axes.flatten() - self.assertEqual(len(axes), 9) + assert len(axes) == 9 for output_var, ax in zip(output_variables, axes): - self.assertEqual(ax.get_xlabel(), "Cycle number") - self.assertEqual(ax.get_ylabel(), output_var) + assert ax.get_xlabel() == "Cycle number" + assert ax.get_ylabel() == output_var cycle_number, var = ax.get_lines()[0].get_data() np.testing.assert_array_equal( @@ -73,13 +71,3 @@ def test_plot(self): cycle_number, sol.summary_variables["Cycle number"] ) np.testing.assert_array_equal(var, sol.summary_variables[output_var]) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_plotting/test_plot_thermal_components.py b/tests/unit/test_plotting/test_plot_thermal_components.py index 99b3d40cac..2b4cdf1e1e 100644 --- a/tests/unit/test_plotting/test_plot_thermal_components.py +++ b/tests/unit/test_plotting/test_plot_thermal_components.py @@ -1,14 +1,13 @@ +import pytest import pybamm -import unittest import numpy as np -from tests import TestCase import matplotlib.pyplot as plt from matplotlib import use use("Agg") -class TestPlotThermalComponents(TestCase): +class TestPlotThermalComponents: def test_plot_with_solution(self): model = pybamm.lithium_ion.SPM({"thermal": "lumped"}) sim = pybamm.Simulation( @@ -30,22 +29,12 @@ def test_plot_with_solution(self): _, ax = plt.subplots(1, 2) _, ax_out = pybamm.plot_thermal_components(sol, ax=ax, show_legend=True) - self.assertEqual(ax_out[0], ax[0]) - self.assertEqual(ax_out[1], ax[1]) + assert ax_out[0] == ax[0] + assert ax_out[1] == ax[1] def test_not_implemented(self): model = pybamm.lithium_ion.SPM({"thermal": "x-full"}) sim = pybamm.Simulation(model) sol = sim.solve([0, 3600]) - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): pybamm.plot_thermal_components(sol) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_plotting/test_plot_voltage_components.py b/tests/unit/test_plotting/test_plot_voltage_components.py index 1773d576d9..2b9da43fc1 100644 --- a/tests/unit/test_plotting/test_plot_voltage_components.py +++ b/tests/unit/test_plotting/test_plot_voltage_components.py @@ -1,14 +1,13 @@ +import pytest import pybamm -import unittest import numpy as np -from tests import TestCase import matplotlib.pyplot as plt from matplotlib import use use("Agg") -class TestPlotVoltageComponents(TestCase): +class TestPlotVoltageComponents: def test_plot_with_solution(self): model = pybamm.lithium_ion.SPM() sim = pybamm.Simulation(model) @@ -23,7 +22,7 @@ def test_plot_with_solution(self): _, ax = plt.subplots() _, ax_out = pybamm.plot_voltage_components(sol, ax=ax, show_legend=True) - self.assertEqual(ax_out, ax) + assert ax_out == ax def test_plot_with_simulation(self): model = pybamm.lithium_ion.SPM() @@ -40,7 +39,7 @@ def test_plot_with_simulation(self): _, ax = plt.subplots() _, ax_out = pybamm.plot_voltage_components(sim, ax=ax, show_legend=True) - self.assertEqual(ax_out, ax) + assert ax_out == ax def test_plot_from_solution(self): model = pybamm.lithium_ion.SPM() @@ -56,7 +55,7 @@ def test_plot_from_solution(self): _, ax = plt.subplots() _, ax_out = sol.plot_voltage_components(ax=ax, show_legend=True) - self.assertEqual(ax_out, ax) + assert ax_out == ax def test_plot_from_simulation(self): model = pybamm.lithium_ion.SPM() @@ -73,25 +72,12 @@ def test_plot_from_simulation(self): _, ax = plt.subplots() _, ax_out = sim.plot_voltage_components(ax=ax, show_legend=True) - self.assertEqual(ax_out, ax) + assert ax_out == ax def test_plot_without_solution(self): model = pybamm.lithium_ion.SPM() sim = pybamm.Simulation(model) - with self.assertRaises(ValueError) as error: + with pytest.raises(ValueError) as error: sim.plot_voltage_components() - - self.assertEqual( - str(error.exception), "The simulation has not been solved yet." - ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert str(error.exception) == "The simulation has not been solved yet." diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py index a3b62f8ee4..6573929ad9 100644 --- a/tests/unit/test_settings.py +++ b/tests/unit/test_settings.py @@ -1,52 +1,49 @@ # # Tests the settings class. # -from tests import TestCase + import pybamm -import unittest +import pytest -class TestSettings(TestCase): +class TestSettings: def test_simplify(self): - self.assertTrue(pybamm.settings.simplify) + with pytest.raises(TypeError): + pybamm.settings.simplify = "Not Bool" + + assert pybamm.settings.simplify pybamm.settings.simplify = False - self.assertFalse(pybamm.settings.simplify) + assert not pybamm.settings.simplify pybamm.settings.simplify = True + def test_debug_mode(self): + with pytest.raises(TypeError): + pybamm.settings.debug_mode = "Not bool" + def test_smoothing_parameters(self): - self.assertEqual(pybamm.settings.min_max_mode, "exact") - self.assertEqual(pybamm.settings.heaviside_smoothing, "exact") - self.assertEqual(pybamm.settings.abs_smoothing, "exact") + assert pybamm.settings.min_max_mode == "exact" + assert pybamm.settings.heaviside_smoothing == "exact" + assert pybamm.settings.abs_smoothing == "exact" pybamm.settings.set_smoothing_parameters(10) - self.assertEqual(pybamm.settings.min_max_smoothing, 10) - self.assertEqual(pybamm.settings.heaviside_smoothing, 10) - self.assertEqual(pybamm.settings.abs_smoothing, 10) + assert pybamm.settings.min_max_smoothing == 10 + assert pybamm.settings.heaviside_smoothing == 10 + assert pybamm.settings.abs_smoothing == 10 pybamm.settings.set_smoothing_parameters("exact") # Test errors - with self.assertRaisesRegex(ValueError, "greater than 1"): + with pytest.raises(ValueError, match="greater than 1"): pybamm.settings.min_max_mode = "smooth" pybamm.settings.min_max_smoothing = 0.9 - with self.assertRaisesRegex(ValueError, "positive number"): + with pytest.raises(ValueError, match="positive number"): pybamm.settings.min_max_mode = "soft" pybamm.settings.min_max_smoothing = -10 - with self.assertRaisesRegex(ValueError, "positive number"): + with pytest.raises(ValueError, match="positive number"): pybamm.settings.heaviside_smoothing = -10 - with self.assertRaisesRegex(ValueError, "positive number"): + with pytest.raises(ValueError, match="positive number"): pybamm.settings.abs_smoothing = -10 - with self.assertRaisesRegex(ValueError, "'soft', or 'smooth'"): + with pytest.raises(ValueError, match="'soft', or 'smooth'"): pybamm.settings.min_max_mode = "unknown" pybamm.settings.set_smoothing_parameters("exact") - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_solvers/test_dummy_solver.py b/tests/unit/test_solvers/test_dummy_solver.py index 7c7b9a35f7..acce1d3543 100644 --- a/tests/unit/test_solvers/test_dummy_solver.py +++ b/tests/unit/test_solvers/test_dummy_solver.py @@ -1,14 +1,11 @@ # # Tests for the Dummy Solver class # -from tests import TestCase import pybamm import numpy as np -import unittest -import sys -class TestDummySolver(TestCase): +class TestDummySolver: def test_dummy_solver(self): model = pybamm.BaseModel() v = pybamm.Scalar(1) @@ -44,12 +41,3 @@ def test_dummy_solver_step(self): np.testing.assert_array_equal(len(sol.t), t_eval.size * 2 - 2) np.testing.assert_array_equal(sol.y, np.zeros((1, sol.t.size))) np.testing.assert_array_equal(sol["v"].data, np.ones(sol.t.size)) - - -if __name__ == "__main__": - print("Add -v for more debug output") - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_solvers/test_idaklu_solver.py b/tests/unit/test_solvers/test_idaklu_solver.py index 1697623486..567be324e3 100644 --- a/tests/unit/test_solvers/test_idaklu_solver.py +++ b/tests/unit/test_solvers/test_idaklu_solver.py @@ -18,15 +18,17 @@ def test_ida_roberts_klu(self): # this test implements a python version of the ida Roberts # example provided in sundials # see sundials ida examples pdf - for form in ["python", "casadi", "jax"]: - if form == "jax" and not pybamm.have_jax(): + for form in ["python", "casadi", "jax", "iree"]: + if (form == "jax" or form == "iree") and not pybamm.have_jax(): + continue + if (form == "iree") and not pybamm.have_iree(): continue if form == "casadi": root_method = "casadi" else: root_method = "lm" model = pybamm.BaseModel() - model.convert_to_format = form + model.convert_to_format = "jax" if form == "iree" else form u = pybamm.Variable("u") v = pybamm.Variable("v") model.rhs = {u: 0.1 * v} @@ -37,7 +39,10 @@ def test_ida_roberts_klu(self): disc = pybamm.Discretisation() disc.process_model(model) - solver = pybamm.IDAKLUSolver(root_method=root_method) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + options={"jax_evaluator": "iree"} if form == "iree" else {}, + ) t_eval = np.linspace(0, 3, 100) solution = solver.solve(model, t_eval) @@ -59,8 +64,10 @@ def test_ida_roberts_klu(self): np.testing.assert_array_almost_equal(solution.y[0, :], true_solution) def test_model_events(self): - for form in ["python", "casadi", "jax"]: - if form == "jax" and not pybamm.have_jax(): + for form in ["python", "casadi", "jax", "iree"]: + if (form == "jax" or form == "iree") and not pybamm.have_jax(): + continue + if (form == "iree") and not pybamm.have_iree(): continue if form == "casadi": root_method = "casadi" @@ -68,7 +75,7 @@ def test_model_events(self): root_method = "lm" # Create model model = pybamm.BaseModel() - model.convert_to_format = form + model.convert_to_format = "jax" if form == "iree" else form var = pybamm.Variable("var") model.rhs = {var: 0.1 * var} model.initial_conditions = {var: 1} @@ -77,7 +84,12 @@ def test_model_events(self): disc = pybamm.Discretisation() model_disc = disc.process_model(model, inplace=False) # Solve - solver = pybamm.IDAKLUSolver(rtol=1e-8, atol=1e-8, root_method=root_method) + solver = pybamm.IDAKLUSolver( + rtol=1e-8, + atol=1e-8, + root_method=root_method, + options={"jax_evaluator": "iree"} if form == "iree" else {}, + ) t_eval = np.linspace(0, 1, 100) solution = solver.solve(model_disc, t_eval) np.testing.assert_array_equal(solution.t, t_eval) @@ -92,7 +104,12 @@ def test_model_events(self): # enforce events that won't be triggered model.events = [pybamm.Event("an event", var + 1)] model_disc = disc.process_model(model, inplace=False) - solver = pybamm.IDAKLUSolver(rtol=1e-8, atol=1e-8, root_method=root_method) + solver = pybamm.IDAKLUSolver( + rtol=1e-8, + atol=1e-8, + root_method=root_method, + options={"jax_evaluator": "iree"} if form == "iree" else {}, + ) solution = solver.solve(model_disc, t_eval) np.testing.assert_array_equal(solution.t, t_eval) np.testing.assert_array_almost_equal( @@ -102,7 +119,12 @@ def test_model_events(self): # enforce events that will be triggered model.events = [pybamm.Event("an event", 1.01 - var)] model_disc = disc.process_model(model, inplace=False) - solver = pybamm.IDAKLUSolver(rtol=1e-8, atol=1e-8, root_method=root_method) + solver = pybamm.IDAKLUSolver( + rtol=1e-8, + atol=1e-8, + root_method=root_method, + options={"jax_evaluator": "iree"} if form == "iree" else {}, + ) solution = solver.solve(model_disc, t_eval) self.assertLess(len(solution.t), len(t_eval)) np.testing.assert_array_almost_equal( @@ -124,7 +146,12 @@ def test_model_events(self): disc = get_discretisation_for_testing() disc.process_model(model) - solver = pybamm.IDAKLUSolver(rtol=1e-8, atol=1e-8, root_method=root_method) + solver = pybamm.IDAKLUSolver( + rtol=1e-8, + atol=1e-8, + root_method=root_method, + options={"jax_evaluator": "iree"} if form == "iree" else {}, + ) t_eval = np.linspace(0, 5, 100) solution = solver.solve(model, t_eval) np.testing.assert_array_less(solution.y[0, :-1], 1.5) @@ -140,15 +167,17 @@ def test_model_events(self): def test_input_params(self): # test a mix of scalar and vector input params - for form in ["python", "casadi", "jax"]: - if form == "jax" and not pybamm.have_jax(): + for form in ["python", "casadi", "jax", "iree"]: + if (form == "jax" or form == "iree") and not pybamm.have_jax(): + continue + if (form == "iree") and not pybamm.have_iree(): continue if form == "casadi": root_method = "casadi" else: root_method = "lm" model = pybamm.BaseModel() - model.convert_to_format = form + model.convert_to_format = "jax" if form == "iree" else form u1 = pybamm.Variable("u1") u2 = pybamm.Variable("u2") u3 = pybamm.Variable("u3") @@ -162,7 +191,10 @@ def test_input_params(self): disc = pybamm.Discretisation() disc.process_model(model) - solver = pybamm.IDAKLUSolver(root_method=root_method) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + options={"jax_evaluator": "iree"} if form == "iree" else {}, + ) t_eval = np.linspace(0, 3, 100) a_value = 0.1 @@ -185,48 +217,63 @@ def test_input_params(self): true_solution = b_value * sol.t np.testing.assert_array_almost_equal(sol.y[1:3], true_solution) - def test_sensitivites_initial_condition(self): - for output_variables in [[], ["2v"]]: - model = pybamm.BaseModel() - model.convert_to_format = "casadi" - u = pybamm.Variable("u") - v = pybamm.Variable("v") - a = pybamm.InputParameter("a") - model.rhs = {u: -u} - model.algebraic = {v: a * u - v} - model.initial_conditions = {u: 1, v: 1} - model.variables = {"2v": 2 * v} - - disc = pybamm.Discretisation() - disc.process_model(model) - solver = pybamm.IDAKLUSolver(output_variables=output_variables) - - t_eval = np.linspace(0, 3, 100) - a_value = 0.1 - - sol = solver.solve( - model, t_eval, inputs={"a": a_value}, calculate_sensitivities=True - ) - - np.testing.assert_array_almost_equal( - sol["2v"].sensitivities["a"].full().flatten(), - np.exp(-sol.t) * 2, - decimal=4, - ) + def test_sensitivities_initial_condition(self): + for form in ["casadi", "iree"]: + for output_variables in [[], ["2v"]]: + if (form == "jax" or form == "iree") and not pybamm.have_jax(): + continue + if (form == "iree") and not pybamm.have_iree(): + continue + if form == "casadi": + root_method = "casadi" + else: + root_method = "lm" + model = pybamm.BaseModel() + model.convert_to_format = "jax" if form == "iree" else form + u = pybamm.Variable("u") + v = pybamm.Variable("v") + a = pybamm.InputParameter("a") + model.rhs = {u: -u} + model.algebraic = {v: a * u - v} + model.initial_conditions = {u: 1, v: 1} + model.variables = {"2v": 2 * v} + + disc = pybamm.Discretisation() + disc.process_model(model) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + output_variables=output_variables, + options={"jax_evaluator": "iree"} if form == "iree" else {}, + ) + + t_eval = np.linspace(0, 3, 100) + a_value = 0.1 + + sol = solver.solve( + model, t_eval, inputs={"a": a_value}, calculate_sensitivities=True + ) + + np.testing.assert_array_almost_equal( + sol["2v"].sensitivities["a"].full().flatten(), + np.exp(-sol.t) * 2, + decimal=4, + ) def test_ida_roberts_klu_sensitivities(self): # this test implements a python version of the ida Roberts # example provided in sundials # see sundials ida examples pdf - for form in ["python", "casadi", "jax"]: - if form == "jax" and not pybamm.have_jax(): + for form in ["python", "casadi", "jax", "iree"]: + if (form == "jax" or form == "iree") and not pybamm.have_jax(): + continue + if (form == "iree") and not pybamm.have_iree(): continue if form == "casadi": root_method = "casadi" else: root_method = "lm" model = pybamm.BaseModel() - model.convert_to_format = form + model.convert_to_format = "jax" if form == "iree" else form u = pybamm.Variable("u") v = pybamm.Variable("v") a = pybamm.InputParameter("a") @@ -238,7 +285,10 @@ def test_ida_roberts_klu_sensitivities(self): disc = pybamm.Discretisation() disc.process_model(model) - solver = pybamm.IDAKLUSolver(root_method=root_method) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + options={"jax_evaluator": "iree"} if form == "iree" else {}, + ) t_eval = np.linspace(0, 3, 100) a_value = 0.1 @@ -283,25 +333,32 @@ def test_ida_roberts_klu_sensitivities(self): dyda_fd = (sol_plus.y - sol_neg.y) / h dyda_fd = dyda_fd.transpose().reshape(-1, 1) - np.testing.assert_array_almost_equal(dyda_ida, dyda_fd) + decimal = ( + 2 if form == "iree" else 6 + ) # iree currently operates with single precision + np.testing.assert_array_almost_equal(dyda_ida, dyda_fd, decimal=decimal) # get the sensitivities for the variable d2uda = sol["2u"].sensitivities["a"] - np.testing.assert_array_almost_equal(2 * dyda_ida[0:200:2], d2uda) + np.testing.assert_array_almost_equal( + 2 * dyda_ida[0:200:2], d2uda, decimal=decimal + ) def test_sensitivities_with_events(self): # this test implements a python version of the ida Roberts # example provided in sundials # see sundials ida examples pdf - for form in ["casadi", "python", "jax"]: - if form == "jax" and not pybamm.have_jax(): + for form in ["casadi", "python", "jax", "iree"]: + if (form == "jax" or form == "iree") and not pybamm.have_jax(): + continue + if (form == "iree") and not pybamm.have_iree(): continue if form == "casadi": root_method = "casadi" else: root_method = "lm" model = pybamm.BaseModel() - model.convert_to_format = form + model.convert_to_format = "jax" if form == "iree" else form u = pybamm.Variable("u") v = pybamm.Variable("v") a = pybamm.InputParameter("a") @@ -314,7 +371,10 @@ def test_sensitivities_with_events(self): disc = pybamm.Discretisation() disc.process_model(model) - solver = pybamm.IDAKLUSolver(root_method=root_method) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + options={"jax_evaluator": "iree"} if form == "iree" else {}, + ) t_eval = np.linspace(0, 3, 100) a_value = 0.1 @@ -351,8 +411,11 @@ def test_sensitivities_with_events(self): dyda_fd = (sol_plus.y[:, :max_index] - sol_neg.y[:, :max_index]) / h dyda_fd = dyda_fd.transpose().reshape(-1, 1) + decimal = ( + 2 if form == "iree" else 6 + ) # iree currently operates with single precision np.testing.assert_array_almost_equal( - dyda_ida[: (2 * max_index), :], dyda_fd + dyda_ida[: (2 * max_index), :], dyda_fd, decimal=decimal ) sol_plus = solver.solve( @@ -366,7 +429,7 @@ def test_sensitivities_with_events(self): dydb_fd = dydb_fd.transpose().reshape(-1, 1) np.testing.assert_array_almost_equal( - dydb_ida[: (2 * max_index), :], dydb_fd + dydb_ida[: (2 * max_index), :], dydb_fd, decimal=decimal ) def test_failures(self): @@ -421,15 +484,17 @@ def test_failures(self): solver.solve(model, t_eval) def test_dae_solver_algebraic_model(self): - for form in ["python", "casadi", "jax"]: - if form == "jax" and not pybamm.have_jax(): + for form in ["python", "casadi", "jax", "iree"]: + if (form == "jax" or form == "iree") and not pybamm.have_jax(): + continue + if (form == "iree") and not pybamm.have_iree(): continue if form == "casadi": root_method = "casadi" else: root_method = "lm" model = pybamm.BaseModel() - model.convert_to_format = form + model.convert_to_format = "jax" if form == "iree" else form var = pybamm.Variable("var") model.algebraic = {var: var + 1} model.initial_conditions = {var: 0} @@ -437,7 +502,10 @@ def test_dae_solver_algebraic_model(self): disc = pybamm.Discretisation() disc.process_model(model) - solver = pybamm.IDAKLUSolver(root_method=root_method) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + options={"jax_evaluator": "iree"} if form == "iree" else {}, + ) t_eval = np.linspace(0, 1) solution = solver.solve(model, t_eval) np.testing.assert_array_equal(solution.y, -1) @@ -471,7 +539,7 @@ def test_banded(self): np.testing.assert_array_almost_equal(soln.y, soln_banded.y, 5) - def test_options(self): + def test_setup_options(self): model = pybamm.BaseModel() u = pybamm.Variable("u") v = pybamm.Variable("v") @@ -516,8 +584,13 @@ def test_options(self): "jacobian": jacobian, "linear_solver": linear_solver, "preconditioner": precon, + "max_num_steps": 10000, } - solver = pybamm.IDAKLUSolver(options=options) + solver = pybamm.IDAKLUSolver( + atol=1e-8, + rtol=1e-8, + options=options, + ) if ( jacobian == "none" and (linear_solver == "SUNLinSol_Dense") @@ -546,8 +619,72 @@ def test_options(self): with self.assertRaises(ValueError): soln = solver.solve(model, t_eval) + def test_solver_options(self): + model = pybamm.BaseModel() + u = pybamm.Variable("u") + v = pybamm.Variable("v") + model.rhs = {u: -0.1 * u} + model.algebraic = {v: v - u} + model.initial_conditions = {u: 1, v: 1} + disc = pybamm.Discretisation() + disc.process_model(model) + + t_eval = np.linspace(0, 1) + solver = pybamm.IDAKLUSolver() + soln_base = solver.solve(model, t_eval) + + options_success = { + "max_order_bdf": 4, + "max_num_steps": 490, + "dt_init": 0.01, + "dt_max": 1000.9, + "max_error_test_failures": 11, + "max_nonlinear_iterations": 5, + "max_convergence_failures": 11, + "nonlinear_convergence_coefficient": 1.0, + "suppress_algebraic_error": True, + "nonlinear_convergence_coefficient_ic": 0.01, + "max_num_steps_ic": 6, + "max_num_jacobians_ic": 5, + "max_num_iterations_ic": 11, + "max_linesearch_backtracks_ic": 101, + "linesearch_off_ic": True, + "init_all_y_ic": False, + "linear_solver": "SUNLinSol_KLU", + "linsol_max_iterations": 6, + "epsilon_linear_tolerance": 0.06, + "increment_factor": 0.99, + "linear_solution_scaling": False, + } + + # test everything works + for option in options_success: + options = {option: options_success[option]} + solver = pybamm.IDAKLUSolver(options=options) + soln = solver.solve(model, t_eval) + + np.testing.assert_array_almost_equal(soln.y, soln_base.y, 5) + + options_fail = { + "max_order_bdf": -1, + "max_num_steps_ic": -1, + "max_num_jacobians_ic": -1, + "max_num_iterations_ic": -1, + "max_linesearch_backtracks_ic": -1, + "epsilon_linear_tolerance": -1.0, + "increment_factor": -1.0, + } + + # test that the solver throws a warning + for option in options_fail: + options = {option: options_fail[option]} + solver = pybamm.IDAKLUSolver(options=options) + + with self.assertRaises(ValueError): + solver.solve(model, t_eval) + def test_with_output_variables(self): - # Construct a model and solve for all vairables, then test + # Construct a model and solve for all variables, then test # the 'output_variables' option for each variable in turn, confirming # equivalence input_parameters = {} # Sensitivities dictionary @@ -649,76 +786,110 @@ def construct_model(): sol["x_s [m]"].initialise_1D() def test_with_output_variables_and_sensitivities(self): - # Construct a model and solve for all vairables, then test + # Construct a model and solve for all variables, then test # the 'output_variables' option for each variable in turn, confirming # equivalence - # construct model - model = pybamm.lithium_ion.DFN() - geometry = model.default_geometry - param = model.default_parameter_values - input_parameters = { # Sensitivities dictionary - "Current function [A]": 0.680616, - "Separator porosity": 1.0, - } - param.update({key: "[input]" for key in input_parameters}) - param.process_model(model) - param.process_geometry(geometry) - var_pts = {"x_n": 50, "x_s": 50, "x_p": 50, "r_n": 5, "r_p": 5} - mesh = pybamm.Mesh(geometry, model.default_submesh_types, var_pts) - disc = pybamm.Discretisation(mesh, model.default_spatial_methods) - disc.process_model(model) - t_eval = np.linspace(0, 3600, 100) + for form in ["casadi", "iree"]: + if (form == "jax" or form == "iree") and not pybamm.have_jax(): + continue + if (form == "iree") and not pybamm.have_iree(): + continue + if form == "casadi": + root_method = "casadi" + else: + root_method = "lm" + input_parameters = { # Sensitivities dictionary + "Current function [A]": 0.222, + "Separator porosity": 0.3, + } - options = { - "linear_solver": "SUNLinSol_KLU", - "jacobian": "sparse", - "num_threads": 4, - } + # construct model + model = pybamm.lithium_ion.DFN() + model.convert_to_format = "jax" if form == "iree" else form + geometry = model.default_geometry + param = model.default_parameter_values + param.update({key: "[input]" for key in input_parameters}) + param.process_model(model) + param.process_geometry(geometry) + var_pts = {"x_n": 50, "x_s": 50, "x_p": 50, "r_n": 5, "r_p": 5} + mesh = pybamm.Mesh(geometry, model.default_submesh_types, var_pts) + disc = pybamm.Discretisation(mesh, model.default_spatial_methods) + disc.process_model(model) - # Use a selection of variables of different types - output_variables = [ - "Voltage [V]", - "Time [min]", - "x [m]", - "Negative particle flux [mol.m-2.s-1]", - "Throughput capacity [A.h]", # ExplicitTimeIntegral - ] + t_eval = np.linspace(0, 3600, 100) + + options = { + "linear_solver": "SUNLinSol_KLU", + "jacobian": "sparse", + "num_threads": 4, + } + if form == "iree": + options["jax_evaluator"] = "iree" + + # Use a selection of variables of different types + output_variables = [ + "Voltage [V]", + "Time [min]", + "x [m]", + "Negative particle flux [mol.m-2.s-1]", + "Throughput capacity [A.h]", # ExplicitTimeIntegral + ] - # Use the full model as comparison (tested separately) - solver_all = pybamm.IDAKLUSolver( - atol=1e-8, - rtol=1e-8, - options=options, - ) - sol_all = solver_all.solve( - model, - t_eval, - inputs=input_parameters, - calculate_sensitivities=True, - ) + # Use the full model as comparison (tested separately) + solver_all = pybamm.IDAKLUSolver( + root_method=root_method, + atol=1e-8 if form != "iree" else 1e-0, # iree has reduced precision + rtol=1e-8 if form != "iree" else 1e-0, # iree has reduced precision + options=options, + ) + sol_all = solver_all.solve( + model, + t_eval, + inputs=input_parameters, + calculate_sensitivities=True, + ) - # Solve for a subset of variables and compare results - solver = pybamm.IDAKLUSolver( - atol=1e-8, - rtol=1e-8, - options=options, - output_variables=output_variables, - ) - sol = solver.solve( - model, - t_eval, - inputs=input_parameters, - calculate_sensitivities=True, - ) + # Solve for a subset of variables and compare results + solver = pybamm.IDAKLUSolver( + root_method=root_method, + atol=1e-8 if form != "iree" else 1e-0, # iree has reduced precision + rtol=1e-8 if form != "iree" else 1e-0, # iree has reduced precision + options=options, + output_variables=output_variables, + ) + sol = solver.solve( + model, + t_eval, + inputs=input_parameters, + calculate_sensitivities=True, + ) - # Compare output to sol_all - for varname in output_variables: - self.assertTrue(np.allclose(sol[varname].data, sol_all[varname].data)) + # Compare output to sol_all + tol = 1e-5 if form != "iree" else 1e-2 # iree has reduced precision + for varname in output_variables: + np.testing.assert_array_almost_equal( + sol[varname].data, sol_all[varname].data, tol + ) - # Mock a 1D current collector and initialise (none in the model) - sol["x_s [m]"].domain = ["current collector"] - sol["x_s [m]"].initialise_1D() + # Mock a 1D current collector and initialise (none in the model) + sol["x_s [m]"].domain = ["current collector"] + sol["x_s [m]"].initialise_1D() + + def test_bad_jax_evaluator(self): + model = pybamm.lithium_ion.DFN() + model.convert_to_format = "jax" + with self.assertRaises(pybamm.SolverError): + pybamm.IDAKLUSolver(options={"jax_evaluator": "bad_evaluator"}) + + def test_bad_jax_evaluator_output_variables(self): + model = pybamm.lithium_ion.DFN() + model.convert_to_format = "jax" + with self.assertRaises(pybamm.SolverError): + pybamm.IDAKLUSolver( + options={"jax_evaluator": "bad_evaluator"}, + output_variables=["Terminal voltage [V]"], + ) if __name__ == "__main__": diff --git a/tests/unit/test_solvers/test_lrudict.py b/tests/unit/test_solvers/test_lrudict.py index a5378da786..ab38bddbc5 100644 --- a/tests/unit/test_solvers/test_lrudict.py +++ b/tests/unit/test_solvers/test_lrudict.py @@ -1,12 +1,12 @@ # # Tests for the LRUDict class # -import unittest +import pytest from pybamm.solvers.lrudict import LRUDict from collections import OrderedDict -class TestLRUDict(unittest.TestCase): +class TestLRUDict: def test_lrudict_defaultbehaviour(self): """Default behaviour [no LRU] mimics Dict""" d = LRUDict() @@ -20,27 +20,27 @@ def test_lrudict_defaultbehaviour(self): dd.get(count - 2) # assertCountEqual checks that the same elements are present in # both lists, not just that the lists are of equal count - self.assertCountEqual(set(d.keys()), set(dd.keys())) - self.assertCountEqual(set(d.values()), set(dd.values())) + assert set(d.keys()) == set(dd.keys()) + assert set(d.values()) == set(dd.values()) def test_lrudict_noitems(self): """Edge case: no items in LRU, raises KeyError on assignment""" d = LRUDict(maxsize=-1) - with self.assertRaises(KeyError): + with pytest.raises(KeyError): d["a"] = 1 def test_lrudict_singleitem(self): """Only the last added element should ever be present""" d = LRUDict(maxsize=1) item_list = range(1, 100) - self.assertEqual(len(d), 0) + assert len(d) == 0 for item in item_list: d[item] = item - self.assertEqual(len(d), 1) - self.assertIsNotNone(d[item]) + assert len(d) == 1 + assert d[item] is not None # Finally, pop the only item and check that the dictionary is empty d.popitem() - self.assertEqual(len(d), 0) + assert len(d) == 0 def test_lrudict_multiitem(self): """Check that the correctly ordered items are always present""" @@ -59,17 +59,17 @@ def test_lrudict_multiitem(self): expected = OrderedDict( (k, expected[k]) for k in list(expected.keys())[-maxsize:] ) - self.assertListEqual(list(d.keys()), list(expected.keys())) - self.assertListEqual(list(d.values()), list(expected.values())) + assert list(d.keys()) == list(expected.keys()) + assert list(d.values()) == list(expected.values()) def test_lrudict_invalidkey(self): d = LRUDict() value = 1 d["a"] = value # Access with valid key - self.assertEqual(d["a"], value) # checks getitem() - self.assertEqual(d.get("a"), value) # checks get() + assert d["a"] == value # checks getitem() + assert d.get("a") == value # checks get() # Access with invalid key - with self.assertRaises(KeyError): + with pytest.raises(KeyError): _ = d["b"] # checks getitem() - self.assertIsNone(d.get("b")) # checks get() + assert d.get("b") is None # checks get() diff --git a/tests/unit/test_spatial_methods/test_zero_dimensional_method.py b/tests/unit/test_spatial_methods/test_zero_dimensional_method.py index b3ec859412..1c620c7872 100644 --- a/tests/unit/test_spatial_methods/test_zero_dimensional_method.py +++ b/tests/unit/test_spatial_methods/test_zero_dimensional_method.py @@ -1,14 +1,12 @@ # # Test for the base Spatial Method class # -from tests import TestCase import numpy as np import pybamm -import unittest from tests import get_mesh_for_testing, get_discretisation_for_testing -class TestZeroDimensionalSpatialMethod(TestCase): +class TestZeroDimensionalSpatialMethod: def test_identity_ops(self): test_mesh = np.array([1, 2, 3]) spatial_method = pybamm.ZeroDimensionalSpatialMethod() @@ -16,14 +14,14 @@ def test_identity_ops(self): np.testing.assert_array_equal(spatial_method._mesh, test_mesh) a = pybamm.Symbol("a") - self.assertEqual(a, spatial_method.integral(None, a, "primary")) - self.assertEqual(a, spatial_method.indefinite_integral(None, a, "forward")) - self.assertEqual(a, spatial_method.boundary_value_or_flux(None, a)) - self.assertEqual((-a), spatial_method.indefinite_integral(None, a, "backward")) + assert a == spatial_method.integral(None, a, "primary") + assert a == spatial_method.indefinite_integral(None, a, "forward") + assert a == spatial_method.boundary_value_or_flux(None, a) + assert (-a) == spatial_method.indefinite_integral(None, a, "backward") mass_matrix = spatial_method.mass_matrix(None, None) - self.assertIsInstance(mass_matrix, pybamm.Matrix) - self.assertEqual(mass_matrix.shape, (1, 1)) + assert isinstance(mass_matrix, pybamm.Matrix) + assert mass_matrix.shape == (1, 1) np.testing.assert_array_equal(mass_matrix.entries, 1) def test_discretise_spatial_variable(self): @@ -38,7 +36,7 @@ def test_discretise_spatial_variable(self): r = pybamm.SpatialVariable("r", ["negative particle"]) for var in [x1, x2, r]: var_disc = spatial_method.spatial_variable(var) - self.assertIsInstance(var_disc, pybamm.Vector) + assert isinstance(var_disc, pybamm.Vector) np.testing.assert_array_equal( var_disc.evaluate()[:, 0], mesh[var.domain].nodes ) @@ -49,7 +47,7 @@ def test_discretise_spatial_variable(self): r_edge = pybamm.SpatialVariableEdge("r", ["negative particle"]) for var in [x1_edge, x2_edge, r_edge]: var_disc = spatial_method.spatial_variable(var) - self.assertIsInstance(var_disc, pybamm.Vector) + assert isinstance(var_disc, pybamm.Vector) np.testing.assert_array_equal( var_disc.evaluate()[:, 0], mesh[var.domain].edges ) @@ -70,13 +68,3 @@ def test_averages(self): np.testing.assert_array_equal( var_disc.evaluate(y=y), expr_disc.evaluate(y=y) ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_timer.py b/tests/unit/test_timer.py index 228cdd5dce..6ef62f791e 100644 --- a/tests/unit/test_timer.py +++ b/tests/unit/test_timer.py @@ -5,11 +5,9 @@ # (see https://github.com/pints-team/pints) # import pybamm -import unittest -from tests import TestCase -class TestTimer(TestCase): +class TestTimer: """ Tests the basic methods of the Timer class. """ @@ -20,64 +18,54 @@ def __init__(self, name): def test_timing(self): t = pybamm.Timer() a = t.time().value - self.assertGreaterEqual(a, 0) + assert a >= 0 for _ in range(100): - self.assertGreater(t.time().value, a) + assert t.time().value > a a = t.time().value t.reset() b = t.time().value - self.assertGreaterEqual(b, 0) - self.assertLess(b, a) + assert b >= 0 + assert b < a def test_timer_format(self): - self.assertEqual(str(pybamm.TimerTime(1e-9)), "1.000 ns") - self.assertEqual(str(pybamm.TimerTime(0.000000123456789)), "123.457 ns") - self.assertEqual(str(pybamm.TimerTime(1e-6)), "1.000 us") - self.assertEqual(str(pybamm.TimerTime(0.000123456789)), "123.457 us") - self.assertEqual(str(pybamm.TimerTime(0.999e-3)), "999.000 us") - self.assertEqual(str(pybamm.TimerTime(1e-3)), "1.000 ms") - self.assertEqual(str(pybamm.TimerTime(0.123456789)), "123.457 ms") - self.assertEqual(str(pybamm.TimerTime(2)), "2.000 s") - self.assertEqual(str(pybamm.TimerTime(2.5)), "2.500 s") - self.assertEqual(str(pybamm.TimerTime(12.5)), "12.500 s") - self.assertEqual(str(pybamm.TimerTime(59.41)), "59.410 s") - self.assertEqual(str(pybamm.TimerTime(59.4126347547)), "59.413 s") - self.assertEqual(str(pybamm.TimerTime(60.2)), "1 minute, 0 seconds") - self.assertEqual(str(pybamm.TimerTime(61)), "1 minute, 1 second") - self.assertEqual(str(pybamm.TimerTime(121)), "2 minutes, 1 second") - self.assertEqual( - str(pybamm.TimerTime(604800)), - "1 week, 0 days, 0 hours, 0 minutes, 0 seconds", + assert str(pybamm.TimerTime(1e-9)) == "1.000 ns" + assert str(pybamm.TimerTime(0.000000123456789)) == "123.457 ns" + assert str(pybamm.TimerTime(1e-6)) == "1.000 us" + assert str(pybamm.TimerTime(0.000123456789)) == "123.457 us" + assert str(pybamm.TimerTime(0.999e-3)) == "999.000 us" + assert str(pybamm.TimerTime(1e-3)) == "1.000 ms" + assert str(pybamm.TimerTime(0.123456789)) == "123.457 ms" + assert str(pybamm.TimerTime(2)) == "2.000 s" + assert str(pybamm.TimerTime(2.5)) == "2.500 s" + assert str(pybamm.TimerTime(12.5)) == "12.500 s" + assert str(pybamm.TimerTime(59.41)) == "59.410 s" + assert str(pybamm.TimerTime(59.4126347547)) == "59.413 s" + assert str(pybamm.TimerTime(60.2)) == "1 minute, 0 seconds" + assert str(pybamm.TimerTime(61)) == "1 minute, 1 second" + assert str(pybamm.TimerTime(121)) == "2 minutes, 1 second" + assert ( + str(pybamm.TimerTime(604800)) + == "1 week, 0 days, 0 hours, 0 minutes, 0 seconds" ) - self.assertEqual( - str(pybamm.TimerTime(2 * 604800 + 3 * 3600 + 60 + 4)), - "2 weeks, 0 days, 3 hours, 1 minute, 4 seconds", + assert ( + str(pybamm.TimerTime(2 * 604800 + 3 * 3600 + 60 + 4)) + == "2 weeks, 0 days, 3 hours, 1 minute, 4 seconds" ) - self.assertEqual(repr(pybamm.TimerTime(1.5)), "pybamm.TimerTime(1.5)") + assert repr(pybamm.TimerTime(1.5)) == "pybamm.TimerTime(1.5)" def test_timer_operations(self): - self.assertEqual((pybamm.TimerTime(1) + 2).value, 3) - self.assertEqual((1 + pybamm.TimerTime(1)).value, 2) - self.assertEqual((pybamm.TimerTime(1) - 2).value, -1) - self.assertEqual((pybamm.TimerTime(1) - pybamm.TimerTime(2)).value, -1) - self.assertEqual((1 - pybamm.TimerTime(1)).value, 0) - self.assertEqual((pybamm.TimerTime(4) * 2).value, 8) - self.assertEqual((pybamm.TimerTime(4) * pybamm.TimerTime(2)).value, 8) - self.assertEqual((2 * pybamm.TimerTime(5)).value, 10) - self.assertEqual((pybamm.TimerTime(4) / 2).value, 2) - self.assertEqual((pybamm.TimerTime(4) / pybamm.TimerTime(2)).value, 2) - self.assertEqual((2 / pybamm.TimerTime(5)).value, 2 / 5) + assert (pybamm.TimerTime(1) + 2).value == 3 + assert (1 + pybamm.TimerTime(1)).value == 2 + assert (pybamm.TimerTime(1) - 2).value == -1 + assert (pybamm.TimerTime(1) - pybamm.TimerTime(2)).value == -1 + assert (1 - pybamm.TimerTime(1)).value == 0 + assert (pybamm.TimerTime(4) * 2).value == 8 + assert (pybamm.TimerTime(4) * pybamm.TimerTime(2)).value == 8 + assert (2 * pybamm.TimerTime(5)).value == 10 + assert (pybamm.TimerTime(4) / 2).value == 2 + assert (pybamm.TimerTime(4) / pybamm.TimerTime(2)).value == 2 + assert (2 / pybamm.TimerTime(5)).value == 2 / 5 - self.assertTrue(pybamm.TimerTime(1) == pybamm.TimerTime(1)) - self.assertTrue(pybamm.TimerTime(1) != pybamm.TimerTime(2)) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert pybamm.TimerTime(1) == pybamm.TimerTime(1) + assert pybamm.TimerTime(1) != pybamm.TimerTime(2) diff --git a/vcpkg.json b/vcpkg.json index 4e2fb4fe7e..9134ac3fd9 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -1,6 +1,6 @@ { "name": "pybamm", - "version-string": "24.5rc0", + "version-string": "24.5rc2", "dependencies": [ "casadi", {