diff --git a/.github/workflows/publish_pypi.yml b/.github/workflows/publish_pypi.yml index 402bf9e859..8d328118a1 100644 --- a/.github/workflows/publish_pypi.yml +++ b/.github/workflows/publish_pypi.yml @@ -317,7 +317,7 @@ jobs: if: github.event.inputs.target == 'testpypi' uses: pypa/gh-action-pypi-publish@release/v1 with: - packages-dir: files/ + packages-dir: artifacts/ repository-url: https://test.pypi.org/legacy/ open_failure_issue: diff --git a/.github/workflows/run_periodic_tests.yml b/.github/workflows/run_periodic_tests.yml index 6f79df76b2..2dd9ef8a89 100644 --- a/.github/workflows/run_periodic_tests.yml +++ b/.github/workflows/run_periodic_tests.yml @@ -14,6 +14,8 @@ on: env: FORCE_COLOR: 3 + PYBAMM_IDAKLU_EXPR_CASADI: ON + PYBAMM_IDAKLU_EXPR_IREE: ON concurrency: # github.workflow: name of the workflow, so that we don't cancel other workflows diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index 0c81f71bde..224725e0f7 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -68,6 +68,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard (optional). # Commenting out will disable upload of results to your repo's Code Scanning dashboard - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@b611370bb5703a7efb587f9d136a52ea24c5c38c # v3.25.11 + uses: github/codeql-action/upload-sarif@2d790406f505036ef40ecba973cc774a50395aac # v3.25.13 with: sarif_file: results.sarif diff --git a/.github/workflows/test_on_push.yml b/.github/workflows/test_on_push.yml index 97c37e8c28..adfb698a69 100644 --- a/.github/workflows/test_on_push.yml +++ b/.github/workflows/test_on_push.yml @@ -6,6 +6,8 @@ on: env: FORCE_COLOR: 3 + PYBAMM_IDAKLU_EXPR_CASADI: ON + PYBAMM_IDAKLU_EXPR_IREE: ON concurrency: # github.workflow: name of the workflow, so that we don't cancel other workflows diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8effca2b07..6b2f300f38 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.5.1" + rev: "v0.5.4" hooks: - id: ruff args: [--fix, --show-fixes] diff --git a/CHANGELOG.md b/CHANGELOG.md index 9addd13346..decbacf529 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # [Unreleased](https://github.com/pybamm-team/PyBaMM/) -# [v24.5rc0](https://github.com/pybamm-team/PyBaMM/tree/v24.5rc0) - 2024-05-01 +# [v24.5rc2](https://github.com/pybamm-team/PyBaMM/tree/v24.5rc2) - 2024-07-12 ## Features diff --git a/CITATION.cff b/CITATION.cff index 43fa574cdd..7e28662bac 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -24,6 +24,6 @@ keywords: - "expression tree" - "python" - "symbolic differentiation" -version: "24.5rc0" +version: "24.5rc2" repository-code: "https://github.com/pybamm-team/PyBaMM" title: "Python Battery Mathematical Modelling (PyBaMM)" diff --git a/CMakeLists.txt b/CMakeLists.txt index b9fe37c331..661f63457e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,32 +35,76 @@ add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) if(NOT PYBIND11_DIR) set(PYBIND11_DIR pybind11) endif() - add_subdirectory(${PYBIND11_DIR}) -# The sources list should mirror the list in setup.py + +# Check Casadi build flag +if(NOT DEFINED PYBAMM_IDAKLU_EXPR_CASADI) + set(PYBAMM_IDAKLU_EXPR_CASADI ON) +endif() +message("PYBAMM_IDAKLU_EXPR_CASADI: ${PYBAMM_IDAKLU_EXPR_CASADI}") + +# Casadi PyBaMM source files +set(IDAKLU_EXPR_CASADI_SOURCE_FILES "") +if(${PYBAMM_IDAKLU_EXPR_CASADI} STREQUAL "ON" ) + add_compile_definitions(CASADI_ENABLE) + set(IDAKLU_EXPR_CASADI_SOURCE_FILES + pybamm/solvers/c_solvers/idaklu/Expressions/Casadi/CasadiFunctions.cpp + pybamm/solvers/c_solvers/idaklu/Expressions/Casadi/CasadiFunctions.hpp + ) +endif() + +# Check IREE build flag +if(NOT DEFINED PYBAMM_IDAKLU_EXPR_IREE) + set(PYBAMM_IDAKLU_EXPR_IREE OFF) +endif() +message("PYBAMM_IDAKLU_EXPR_IREE: ${PYBAMM_IDAKLU_EXPR_IREE}") + +# IREE (MLIR expression evaluation) PyBaMM source files +set(IDAKLU_EXPR_IREE_SOURCE_FILES "") +if(${PYBAMM_IDAKLU_EXPR_IREE} STREQUAL "ON" ) + add_compile_definitions(IREE_ENABLE) + # Source file list + set(IDAKLU_EXPR_IREE_SOURCE_FILES + pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.cpp + pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.hpp + pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.cpp + pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.hpp + pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.cpp + pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.hpp + ) +endif() + +# The complete (all dependencies) sources list should be mirrored in setup.py pybind11_add_module(idaklu - pybamm/solvers/c_solvers/idaklu/casadi_functions.cpp - pybamm/solvers/c_solvers/idaklu/casadi_functions.hpp - pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp - pybamm/solvers/c_solvers/idaklu/casadi_solver.hpp - pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp - pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp - pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp - pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.hpp - pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp - pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.hpp - pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.cpp - pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.hpp - pybamm/solvers/c_solvers/idaklu/idaklu_jax.cpp - pybamm/solvers/c_solvers/idaklu/idaklu_jax.hpp + # pybind11 interface + pybamm/solvers/c_solvers/idaklu.cpp + # IDAKLU solver (SUNDIALS) + pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp + pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.cpp + pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp + pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl + pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp + pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.cpp + pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.hpp + pybamm/solvers/c_solvers/idaklu/sundials_functions.inl + pybamm/solvers/c_solvers/idaklu/sundials_functions.hpp + pybamm/solvers/c_solvers/idaklu/IdakluJax.cpp + pybamm/solvers/c_solvers/idaklu/IdakluJax.hpp pybamm/solvers/c_solvers/idaklu/common.hpp pybamm/solvers/c_solvers/idaklu/python.hpp pybamm/solvers/c_solvers/idaklu/python.cpp - pybamm/solvers/c_solvers/idaklu/solution.cpp - pybamm/solvers/c_solvers/idaklu/solution.hpp - pybamm/solvers/c_solvers/idaklu/options.hpp - pybamm/solvers/c_solvers/idaklu/options.cpp - pybamm/solvers/c_solvers/idaklu.cpp + pybamm/solvers/c_solvers/idaklu/Solution.cpp + pybamm/solvers/c_solvers/idaklu/Solution.hpp + pybamm/solvers/c_solvers/idaklu/Options.hpp + pybamm/solvers/c_solvers/idaklu/Options.cpp + # IDAKLU expressions / function evaluation [abstract] + pybamm/solvers/c_solvers/idaklu/Expressions/Expressions.hpp + pybamm/solvers/c_solvers/idaklu/Expressions/Base/Expression.hpp + pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionSet.hpp + pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionTypes.hpp + # IDAKLU expressions - concrete implementations + ${IDAKLU_EXPR_CASADI_SOURCE_FILES} + ${IDAKLU_EXPR_IREE_SOURCE_FILES} ) if (NOT DEFINED USE_PYTHON_CASADI) @@ -113,3 +157,16 @@ else() endif() include_directories(${SuiteSparse_INCLUDE_DIRS}) target_link_libraries(idaklu PRIVATE ${SuiteSparse_LIBRARIES}) + +# IREE (MLIR compiler and runtime library) build settings +if(${PYBAMM_IDAKLU_EXPR_IREE} STREQUAL "ON" ) + set(IREE_BUILD_COMPILER ON) + set(IREE_BUILD_TESTS OFF) + set(IREE_BUILD_SAMPLES OFF) + add_subdirectory(iree EXCLUDE_FROM_ALL) + set(IREE_COMPILER_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/iree/compiler") + target_include_directories(idaklu SYSTEM PRIVATE "${IREE_COMPILER_ROOT}/bindings/c/iree/compiler") + target_compile_options(idaklu PRIVATE ${IREE_DEFAULT_COPTS}) + target_link_libraries(idaklu PRIVATE iree_compiler_bindings_c_loader) + target_link_libraries(idaklu PRIVATE iree_runtime_runtime) +endif() diff --git a/bandit.yml b/bandit.yml new file mode 100644 index 0000000000..87da61e530 --- /dev/null +++ b/bandit.yml @@ -0,0 +1,2 @@ +# To ignore false flagging of assert statements in tests by Codacy. +skips: ['B101'] diff --git a/docs/source/examples/notebooks/getting_started/tutorial-5-run-experiments.ipynb b/docs/source/examples/notebooks/getting_started/tutorial-5-run-experiments.ipynb index eb82f59719..85be34e421 100644 --- a/docs/source/examples/notebooks/getting_started/tutorial-5-run-experiments.ipynb +++ b/docs/source/examples/notebooks/getting_started/tutorial-5-run-experiments.ipynb @@ -25,18 +25,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", "Note: you may need to restart the kernel to use updated packages.\n" ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" - ] } ], "source": [ @@ -163,18 +153,19 @@ "name": "stderr", "output_type": "stream", "text": [ - "At t = 522.66 and h = 1.1556e-13, the corrector convergence failed repeatedly or with |h| = hmin.\n" + "At t = 339.952 and h = 1.4337e-18, the corrector convergence failed repeatedly or with |h| = hmin.\n", + "At t = 522.687 and h = 4.04917e-14, the corrector convergence failed repeatedly or with |h| = hmin.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "5ab1f22de6af4878b6ca43d27ffc01c5", + "model_id": "93feca98298f4111909ae487e2a1e273", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "interactive(children=(FloatSlider(value=0.0, description='t', max=40.132949019384355, step=0.40132949019384356…" + "interactive(children=(FloatSlider(value=0.0, description='t', max=40.13268704803602, step=0.4013268704803602),…" ] }, "metadata": {}, @@ -183,7 +174,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 5, @@ -211,12 +202,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "7cdac234d74241a2814918053454f6a6", + "model_id": "4d6e43032f4e4aa6be5843c4916b4b50", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "interactive(children=(FloatSlider(value=0.0, description='t', max=13.076977041121545, step=0.13076977041121546…" + "interactive(children=(FloatSlider(value=0.0, description='t', max=13.076887099589111, step=0.1307688709958911)…" ] }, "metadata": {}, @@ -225,7 +216,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 6, @@ -255,7 +246,7 @@ { "data": { "text/plain": [ - "_Step(C-rate, 1.0, duration=1 hour, period=1 minute, temperature=25oC, tags=['tag1'], description=Discharge at 1C for 1 hour)" + "Step(1.0, duration=1 hour, period=1 minute, temperature=25oC, tags=['tag1'], description=Discharge at 1C for 1 hour, direction=Discharge)" ] }, "execution_count": 7, @@ -293,7 +284,7 @@ { "data": { "text/plain": [ - "_Step(current, 1, duration=1 hour, termination=2.5 V)" + "Step(1, duration=1 hour, termination=2.5 V, direction=Discharge)" ] }, "execution_count": 8, @@ -321,7 +312,7 @@ { "data": { "text/plain": [ - "_Step(current, 1.0, duration=1 hour, termination=2.5V, description=Discharge at 1A for 1 hour or until 2.5V)" + "Step(1.0, duration=1 hour, termination=2.5V, description=Discharge at 1A for 1 hour or until 2.5V, direction=Discharge)" ] }, "execution_count": 9, @@ -348,10 +339,78 @@ "execution_count": 10, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-07-10 14:41:02.625 - [WARNING] callbacks.on_experiment_infeasible_time(240): \n", + "\n", + "\tExperiment is infeasible: default duration (1.0 seconds) was reached during 'Step([[ 0.00000000e+00 0.00000000e+00]\n", + " [ 1.69491525e-02 5.31467428e-02]\n", + " [ 3.38983051e-02 1.05691312e-01]\n", + " [ 5.08474576e-02 1.57038356e-01]\n", + " [ 6.77966102e-02 2.06606093e-01]\n", + " [ 8.47457627e-02 2.53832900e-01]\n", + " [ 1.01694915e-01 2.98183679e-01]\n", + " [ 1.18644068e-01 3.39155918e-01]\n", + " [ 1.35593220e-01 3.76285385e-01]\n", + " [ 1.52542373e-01 4.09151388e-01]\n", + " [ 1.69491525e-01 4.37381542e-01]\n", + " [ 1.86440678e-01 4.60655989e-01]\n", + " [ 2.03389831e-01 4.78711019e-01]\n", + " [ 2.20338983e-01 4.91342062e-01]\n", + " [ 2.37288136e-01 4.98406004e-01]\n", + " [ 2.54237288e-01 4.99822806e-01]\n", + " [ 2.71186441e-01 4.95576416e-01]\n", + " [ 2.88135593e-01 4.85714947e-01]\n", + " [ 3.05084746e-01 4.70350133e-01]\n", + " [ 3.22033898e-01 4.49656065e-01]\n", + " [ 3.38983051e-01 4.23867214e-01]\n", + " [ 3.55932203e-01 3.93275778e-01]\n", + " [ 3.72881356e-01 3.58228370e-01]\n", + " [ 3.89830508e-01 3.19122092e-01]\n", + " [ 4.06779661e-01 2.76400033e-01]\n", + " [ 4.23728814e-01 2.30546251e-01]\n", + " [ 4.40677966e-01 1.82080288e-01]\n", + " [ 4.57627119e-01 1.31551282e-01]\n", + " [ 4.74576271e-01 7.95317480e-02]\n", + " [ 4.91525424e-01 2.66110874e-02]\n", + " [ 5.08474576e-01 -2.66110874e-02]\n", + " [ 5.25423729e-01 -7.95317480e-02]\n", + " [ 5.42372881e-01 -1.31551282e-01]\n", + " [ 5.59322034e-01 -1.82080288e-01]\n", + " [ 5.76271186e-01 -2.30546251e-01]\n", + " [ 5.93220339e-01 -2.76400033e-01]\n", + " [ 6.10169492e-01 -3.19122092e-01]\n", + " [ 6.27118644e-01 -3.58228370e-01]\n", + " [ 6.44067797e-01 -3.93275778e-01]\n", + " [ 6.61016949e-01 -4.23867214e-01]\n", + " [ 6.77966102e-01 -4.49656065e-01]\n", + " [ 6.94915254e-01 -4.70350133e-01]\n", + " [ 7.11864407e-01 -4.85714947e-01]\n", + " [ 7.28813559e-01 -4.95576416e-01]\n", + " [ 7.45762712e-01 -4.99822806e-01]\n", + " [ 7.62711864e-01 -4.98406004e-01]\n", + " [ 7.79661017e-01 -4.91342062e-01]\n", + " [ 7.96610169e-01 -4.78711019e-01]\n", + " [ 8.13559322e-01 -4.60655989e-01]\n", + " [ 8.30508475e-01 -4.37381542e-01]\n", + " [ 8.47457627e-01 -4.09151388e-01]\n", + " [ 8.64406780e-01 -3.76285385e-01]\n", + " [ 8.81355932e-01 -3.39155918e-01]\n", + " [ 8.98305085e-01 -2.98183679e-01]\n", + " [ 9.15254237e-01 -2.53832900e-01]\n", + " [ 9.32203390e-01 -2.06606093e-01]\n", + " [ 9.49152542e-01 -1.57038356e-01]\n", + " [ 9.66101695e-01 -1.05691312e-01]\n", + " [ 9.83050847e-01 -5.31467428e-02]\n", + " [ 1.00000000e+00 -1.22464680e-16]], duration=1.0, period=0.016949152542372836, direction=Rest)'. The returned solution only contains up to step 1 of cycle 1. Please specify a duration in the step instructions.\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "730d5e19b17e447ebde5679de68c46ef", + "model_id": "6364b4579fc447e2a607f2f8414172ba", "version_major": 2, "version_minor": 0 }, @@ -365,7 +424,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 10, @@ -419,13 +478,14 @@ "output_type": "stream", "text": [ "[1] Joel A. E. Andersson, Joris Gillis, Greg Horn, James B. Rawlings, and Moritz Diehl. CasADi – A software framework for nonlinear optimization and optimal control. Mathematical Programming Computation, 11(1):1–36, 2019. doi:10.1007/s12532-018-0139-4.\n", - "[2] Marc Doyle, Thomas F. Fuller, and John Newman. Modeling of galvanostatic charge and discharge of the lithium/polymer/insertion cell. Journal of the Electrochemical society, 140(6):1526–1533, 1993. doi:10.1149/1.2221597.\n", - "[3] Charles R. Harris, K. Jarrod Millman, Stéfan J. van der Walt, Ralf Gommers, Pauli Virtanen, David Cournapeau, Eric Wieser, Julian Taylor, Sebastian Berg, Nathaniel J. Smith, and others. Array programming with NumPy. Nature, 585(7825):357–362, 2020. doi:10.1038/s41586-020-2649-2.\n", - "[4] Scott G. Marquis, Valentin Sulzer, Robert Timms, Colin P. Please, and S. Jon Chapman. An asymptotic derivation of a single particle model with electrolyte. Journal of The Electrochemical Society, 166(15):A3693–A3706, 2019. doi:10.1149/2.0341915jes.\n", - "[5] Peyman Mohtat, Suhak Lee, Jason B Siegel, and Anna G Stefanopoulou. Towards better estimability of electrode-specific state of health: decoding the cell expansion. Journal of Power Sources, 427:101–111, 2019.\n", - "[6] Valentin Sulzer, Scott G. Marquis, Robert Timms, Martin Robinson, and S. Jon Chapman. Python Battery Mathematical Modelling (PyBaMM). Journal of Open Research Software, 9(1):14, 2021. doi:10.5334/jors.309.\n", - "[7] Pauli Virtanen, Ralf Gommers, Travis E. Oliphant, Matt Haberland, Tyler Reddy, David Cournapeau, Evgeni Burovski, Pearu Peterson, Warren Weckesser, Jonathan Bright, and others. SciPy 1.0: fundamental algorithms for scientific computing in Python. Nature Methods, 17(3):261–272, 2020. doi:10.1038/s41592-019-0686-2.\n", - "[8] Andrew Weng, Jason B Siegel, and Anna Stefanopoulou. Differential voltage analysis for battery manufacturing process control. arXiv preprint arXiv:2303.07088, 2023.\n", + "[2] Von DAG Bruggeman. Berechnung verschiedener physikalischer konstanten von heterogenen substanzen. i. dielektrizitätskonstanten und leitfähigkeiten der mischkörper aus isotropen substanzen. Annalen der physik, 416(7):636–664, 1935.\n", + "[3] Marc Doyle, Thomas F. Fuller, and John Newman. Modeling of galvanostatic charge and discharge of the lithium/polymer/insertion cell. Journal of the Electrochemical society, 140(6):1526–1533, 1993. doi:10.1149/1.2221597.\n", + "[4] Charles R. Harris, K. Jarrod Millman, Stéfan J. van der Walt, Ralf Gommers, Pauli Virtanen, David Cournapeau, Eric Wieser, Julian Taylor, Sebastian Berg, Nathaniel J. Smith, and others. Array programming with NumPy. Nature, 585(7825):357–362, 2020. doi:10.1038/s41586-020-2649-2.\n", + "[5] Scott G. Marquis, Valentin Sulzer, Robert Timms, Colin P. Please, and S. Jon Chapman. An asymptotic derivation of a single particle model with electrolyte. Journal of The Electrochemical Society, 166(15):A3693–A3706, 2019. doi:10.1149/2.0341915jes.\n", + "[6] Peyman Mohtat, Suhak Lee, Jason B Siegel, and Anna G Stefanopoulou. Towards better estimability of electrode-specific state of health: decoding the cell expansion. Journal of Power Sources, 427:101–111, 2019.\n", + "[7] Valentin Sulzer, Scott G. Marquis, Robert Timms, Martin Robinson, and S. Jon Chapman. Python Battery Mathematical Modelling (PyBaMM). Journal of Open Research Software, 9(1):14, 2021. doi:10.5334/jors.309.\n", + "[8] Pauli Virtanen, Ralf Gommers, Travis E. Oliphant, Matt Haberland, Tyler Reddy, David Cournapeau, Evgeni Burovski, Pearu Peterson, Warren Weckesser, Jonathan Bright, and others. SciPy 1.0: fundamental algorithms for scientific computing in Python. Nature Methods, 17(3):261–272, 2020. doi:10.1038/s41592-019-0686-2.\n", + "[9] Andrew Weng, Jason B Siegel, and Anna Stefanopoulou. Differential voltage analysis for battery manufacturing process control. arXiv preprint arXiv:2303.07088, 2023.\n", "\n" ] } diff --git a/docs/source/examples/notebooks/models/spm1.png b/docs/source/examples/notebooks/models/spm1.png index 7e0e9ea9cc..a8509b442a 100644 Binary files a/docs/source/examples/notebooks/models/spm1.png and b/docs/source/examples/notebooks/models/spm1.png differ diff --git a/docs/source/user_guide/installation/gnu-linux-mac.rst b/docs/source/user_guide/installation/gnu-linux-mac.rst index 0be4b98e4c..121c6df437 100644 --- a/docs/source/user_guide/installation/gnu-linux-mac.rst +++ b/docs/source/user_guide/installation/gnu-linux-mac.rst @@ -101,6 +101,19 @@ Users can install ``jax`` and ``jaxlib`` to use the Jax solver. The ``pip install "pybamm[jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``jaxlib`` on your system. (``pybamm_install_jax`` is deprecated.) +.. _optional-iree-mlir-support: + +Optional - IREE / MLIR support +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Users can install ``iree`` (for MLIR just-in-time compilation) to use for main expression evaluation in the IDAKLU solver. Requires ``jax``. + +.. code:: bash + + pip install "pybamm[iree,jax]" + +The ``pip install "pybamm[iree,jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``iree`` onto your system. + Uninstall PyBaMM ---------------- diff --git a/docs/source/user_guide/installation/index.rst b/docs/source/user_guide/installation/index.rst index 778f64c3f9..d6411348c5 100644 --- a/docs/source/user_guide/installation/index.rst +++ b/docs/source/user_guide/installation/index.rst @@ -47,7 +47,8 @@ Optional solvers The following solvers are optionally available: -* `jax `_ -based solver, see `Optional - JaxSolver `_. +* `jax `_ -based solver, see `Optional - JaxSolver `_. +* `IREE `_ (`MLIR `_) support, see `Optional - IREE / MLIR Support `_. Dependencies ------------ @@ -205,6 +206,17 @@ Dependency Minimu `jaxlib `__ 0.4.20 jax Support library for JAX ========================================================================= ================== ================== ======================= +IREE dependencies +^^^^^^^^^^^^^^^^^^ + +Installable with ``pip install "pybamm[iree]"`` (requires ``jax`` dependencies to be installed). + +========================================================================= ================== ================== ======================= +Dependency Minimum Version pip extra Notes +========================================================================= ================== ================== ======================= +`iree-compiler `__ 20240507.886 iree IREE compiler +========================================================================= ================== ================== ======================= + Full installation guide ----------------------- diff --git a/noxfile.py b/noxfile.py index 7237786ef6..373b77f71f 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,6 +1,8 @@ import nox import os import sys +import warnings +import platform from pathlib import Path @@ -13,11 +15,54 @@ nox.options.sessions = ["pre-commit", "unit"] +def set_iree_state(): + """ + Check if IREE is enabled and set the environment variable accordingly. + + Returns + ------- + str + "ON" if IREE is enabled, "OFF" otherwise. + + """ + state = "ON" if os.getenv("PYBAMM_IDAKLU_EXPR_IREE", "OFF") == "ON" else "OFF" + if state == "ON": + if sys.platform == "win32": + warnings.warn( + ( + "IREE is not enabled on Windows yet. " + "Setting PYBAMM_IDAKLU_EXPR_IREE=OFF." + ), + stacklevel=2, + ) + return "OFF" + if sys.platform == "darwin": + # iree-compiler is currently only available as a wheel on macOS 13 (or + # higher) and Python version 3.11 + mac_ver = int(platform.mac_ver()[0].split(".")[0]) + if (not sys.version_info[:2] == (3, 11)) or mac_ver < 13: + warnings.warn( + ( + "IREE is only supported on MacOS 13 (or higher) and Python" + "version 3.11. Setting PYBAMM_IDAKLU_EXPR_IREE=OFF." + ), + stacklevel=2, + ) + return "OFF" + return state + + homedir = os.getenv("HOME") PYBAMM_ENV = { "SUNDIALS_INST": f"{homedir}/.local", "LD_LIBRARY_PATH": f"{homedir}/.local/lib", "PYTHONIOENCODING": "utf-8", + # Expression evaluators (...EXPR_CASADI cannot be fully disabled at this time) + "PYBAMM_IDAKLU_EXPR_CASADI": os.getenv("PYBAMM_IDAKLU_EXPR_CASADI", "ON"), + "PYBAMM_IDAKLU_EXPR_IREE": set_iree_state(), + "IREE_INDEX_URL": os.getenv( + "IREE_INDEX_URL", "https://iree.dev/pip-release-links.html" + ), } VENV_DIR = Path("./venv").resolve() @@ -59,6 +104,29 @@ def run_pybamm_requires(session): "advice.detachedHead=false", external=True, ) + if PYBAMM_ENV.get("PYBAMM_IDAKLU_EXPR_IREE") == "ON" and not os.path.exists( + "./iree" + ): + session.run( + "git", + "clone", + "--depth=1", + "--recurse-submodules", + "--shallow-submodules", + "--branch=candidate-20240507.886", + "https://github.com/openxla/iree", + "iree/", + external=True, + ) + with session.chdir("iree"): + session.run( + "git", + "submodule", + "update", + "--init", + "--recursive", + external=True, + ) else: session.error("nox -s pybamm-requires is only available on Linux & macOS.") @@ -70,6 +138,15 @@ def run_coverage(session): session.install("setuptools", silent=False) session.install("coverage", silent=False) session.install("-e", ".[all,dev,jax]", silent=False) + if PYBAMM_ENV.get("PYBAMM_IDAKLU_EXPR_IREE") == "ON": + # See comments in 'dev' session + session.install( + "-e", + ".[iree]", + "--find-links", + PYBAMM_ENV.get("IREE_INDEX_URL"), + silent=False, + ) session.run("pytest", "--cov=pybamm", "--cov-report=xml", "tests/unit") @@ -98,6 +175,15 @@ def run_unit(session): set_environment_variables(PYBAMM_ENV, session=session) session.install("setuptools", silent=False) session.install("-e", ".[all,dev,jax]", silent=False) + if PYBAMM_ENV.get("PYBAMM_IDAKLU_EXPR_IREE") == "ON": + # See comments in 'dev' session + session.install( + "-e", + ".[iree]", + "--find-links", + PYBAMM_ENV.get("IREE_INDEX_URL"), + silent=False, + ) session.run("python", "run-tests.py", "--unit") @@ -130,6 +216,17 @@ def set_dev(session): session.install("virtualenv", "cmake") session.run("virtualenv", os.fsdecode(VENV_DIR), silent=True) python = os.fsdecode(VENV_DIR.joinpath("bin/python")) + components = ["all", "dev", "jax"] + args = [] + if PYBAMM_ENV.get("PYBAMM_IDAKLU_EXPR_IREE") == "ON": + # Install IREE libraries for Jax-MLIR expression evaluation in the IDAKLU solver + # (optional). IREE is currently pre-release and relies on nightly jaxlib builds. + # When upgrading Jax/IREE ensure that the following are compatible with each other: + # - Jax and Jaxlib version [pyproject.toml] + # - IREE repository clone (use the matching nightly candidate) [noxfile.py] + # - IREE compiler matches Jaxlib (use the matching nightly build) [pyproject.toml] + components.append("iree") + args = ["--find-links", PYBAMM_ENV.get("IREE_INDEX_URL")] # Temporary fix for Python 3.12 CI. TODO: remove after # https://bitbucket.org/pybtex-devs/pybtex/issues/169/replace-pkg_resources-with # is fixed @@ -140,7 +237,8 @@ def set_dev(session): "pip", "install", "-e", - ".[all,dev,jax]", + ".[{}]".format(",".join(components)), + *args, external=True, ) diff --git a/pybamm/__init__.py b/pybamm/__init__.py index b3b7fafd3f..a371fdbc03 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -2,6 +2,9 @@ from pybamm.version import __version__ +# Demote expressions to 32-bit floats/ints - option used for IDAKLU-MLIR compilation +demote_expressions_to_32bit = False + # Utility classes and methods from .util import root_dir from .util import Timer, TimerTime, FuzzyDict @@ -168,7 +171,7 @@ from .solvers.jax_bdf_solver import jax_bdf_integrate from .solvers.idaklu_jax import IDAKLUJax -from .solvers.idaklu_solver import IDAKLUSolver, have_idaklu +from .solvers.idaklu_solver import IDAKLUSolver, have_idaklu, have_iree # Experiments from .experiment.experiment import Experiment diff --git a/pybamm/callbacks.py b/pybamm/callbacks.py index 57f2426d0f..4e8c67c8be 100644 --- a/pybamm/callbacks.py +++ b/pybamm/callbacks.py @@ -36,37 +36,37 @@ def on_experiment_start(self, logs): """ Called at the start of an experiment simulation. """ - pass + pass # pragma: no cover def on_cycle_start(self, logs): """ Called at the start of each cycle in an experiment simulation. """ - pass + pass # pragma: no cover def on_step_start(self, logs): """ Called at the start of each step in an experiment simulation. """ - pass + pass # pragma: no cover def on_step_end(self, logs): """ Called at the end of each step in an experiment simulation. """ - pass + pass # pragma: no cover def on_cycle_end(self, logs): """ Called at the end of each cycle in an experiment simulation. """ - pass + pass # pragma: no cover def on_experiment_end(self, logs): """ Called at the end of an experiment simulation. """ - pass + pass # pragma: no cover def on_experiment_error(self, logs): """ @@ -75,13 +75,19 @@ def on_experiment_error(self, logs): For example, this could be used to send an error alert with a bug report when running batch simulations in the cloud. """ - pass + pass # pragma: no cover - def on_experiment_infeasible(self, logs): + def on_experiment_infeasible_time(self, logs): """ - Called when an experiment simulation is infeasible. + Called when an experiment simulation is infeasible due to reaching maximum time. """ - pass + pass # pragma: no cover + + def on_experiment_infeasible_event(self, logs): + """ + Called when an experiment simulation is infeasible due to an event. + """ + pass # pragma: no cover ######################################################################################## @@ -226,7 +232,19 @@ def on_experiment_error(self, logs): error = logs["error"] pybamm.logger.error(f"Simulation error: {error}") - def on_experiment_infeasible(self, logs): + def on_experiment_infeasible_time(self, logs): + duration = logs["step duration"] + cycle_num = logs["cycle number"][0] + step_num = logs["step number"][0] + operating_conditions = logs["step operating conditions"] + self.logger.warning( + f"\n\n\tExperiment is infeasible: default duration ({duration} seconds) " + f"was reached during '{operating_conditions}'. The returned solution only " + f"contains up to step {step_num} of cycle {cycle_num}. " + "Please specify a duration in the step instructions." + ) + + def on_experiment_infeasible_event(self, logs): termination = logs["termination"] cycle_num = logs["cycle number"][0] step_num = logs["step number"][0] diff --git a/pybamm/experiment/step/base_step.py b/pybamm/experiment/step/base_step.py index 16224a6afa..6b77bed2cf 100644 --- a/pybamm/experiment/step/base_step.py +++ b/pybamm/experiment/step/base_step.py @@ -71,6 +71,8 @@ def __init__( description=None, direction=None, ): + self.input_duration = duration + self.input_value = value # Check if drive cycle is_drive_cycle = isinstance(value, np.ndarray) is_python_function = callable(value) @@ -100,8 +102,11 @@ def __init__( f"Input function must return a real number output at t = {t0}" ) + # Record whether the step uses the default duration + # This will be used by the experiment to check whether the step is feasible + self.uses_default_duration = duration is None # Set duration - if duration is None: + if self.uses_default_duration: duration = self.default_duration(value) self.duration = _convert_time_to_seconds(duration) @@ -195,8 +200,8 @@ def copy(self): A copy of the step. """ return self.__class__( - self.value, - duration=self.duration, + self.input_value, + duration=self.input_duration, termination=self.termination, period=self.period, temperature=self.temperature, @@ -259,7 +264,7 @@ def default_duration(self, value): t = value[:, 0] return t[-1] else: - return 24 * 3600 # 24 hours in seconds + return 24 * 3600 # one day in seconds def process_model(self, model, parameter_values): new_model = model.new_copy() @@ -411,10 +416,16 @@ def set_up(self, new_model, new_parameter_values): def _convert_time_to_seconds(time_and_units): """Convert a time in seconds, minutes or hours to a time in seconds""" - # If the time is a number, assume it is in seconds - if isinstance(time_and_units, numbers.Number) or time_and_units is None: + if time_and_units is None: return time_and_units + # If the time is a number, assume it is in seconds + if isinstance(time_and_units, numbers.Number): + if time_and_units <= 0: + raise ValueError("time must be positive") + else: + return time_and_units + # Split number and units units = time_and_units.lstrip("0123456789.- ") time = time_and_units[: -len(units)] diff --git a/pybamm/experiment/step/steps.py b/pybamm/experiment/step/steps.py index e3322104bf..e66178dc81 100644 --- a/pybamm/experiment/step/steps.py +++ b/pybamm/experiment/step/steps.py @@ -156,6 +156,11 @@ def __init__(self, value, **kwargs): def current_value(self, variables): return self.value * pybamm.Parameter("Nominal cell capacity [A.h]") + def default_duration(self, value): + # "value" is C-rate, so duration is "1 / value" hours in seconds + # with a 2x safety factor + return 1 / abs(value) * 3600 * 2 + def c_rate(value, **kwargs): """ diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index 6d13761756..20a6d4b4a2 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -281,7 +281,7 @@ def find_symbols( if isinstance(symbol, pybamm.Index): symbol_str = f"{children_vars[0]}[{symbol.slice.start}:{symbol.slice.stop}]" else: - symbol_str = symbol.name + children_vars[0] + symbol_str = symbol.name + "(" + children_vars[0] + ")" elif isinstance(symbol, pybamm.Function): children_str = "" @@ -596,6 +596,59 @@ def __init__(self, symbol: pybamm.Symbol): static_argnums=self._static_argnums, ) + def _demote_constants(self): + """Demote 64-bit constants (f64, i64) to 32-bit (f32, i32)""" + if not pybamm.demote_expressions_to_32bit: + return # pragma: no cover + self._constants = EvaluatorJax._demote_64_to_32(self._constants) + + @classmethod + def _demote_64_to_32(cls, c): + """Demote 64-bit operations (f64, i64) to 32-bit (f32, i32)""" + + if not pybamm.demote_expressions_to_32bit: + return c + if isinstance(c, float): + c = jax.numpy.float32(c) + if isinstance(c, int): + c = jax.numpy.int32(c) + if isinstance(c, np.int64): + c = c.astype(jax.numpy.int32) + if isinstance(c, np.ndarray): + if c.dtype == np.float64: + c = c.astype(jax.numpy.float32) + if c.dtype == np.int64: + c = c.astype(jax.numpy.int32) + if isinstance(c, jax.numpy.ndarray): + if c.dtype == jax.numpy.float64: + c = c.astype(jax.numpy.float32) + if c.dtype == jax.numpy.int64: + c = c.astype(jax.numpy.int32) + if isinstance( + c, pybamm.expression_tree.operations.evaluate_python.JaxCooMatrix + ): + if c.data.dtype == np.float64: + c.data = c.data.astype(jax.numpy.float32) + if c.row.dtype == np.int64: + c.row = c.row.astype(jax.numpy.int32) + if c.col.dtype == np.int64: + c.col = c.col.astype(jax.numpy.int32) + if isinstance(c, dict): + c = {key: EvaluatorJax._demote_64_to_32(value) for key, value in c.items()} + if isinstance(c, tuple): + c = tuple(EvaluatorJax._demote_64_to_32(value) for value in c) + if isinstance(c, list): + c = [EvaluatorJax._demote_64_to_32(value) for value in c] + return c + + @property + def _constants(self): + return tuple(map(EvaluatorJax._demote_64_to_32, self.__constants)) + + @_constants.setter + def _constants(self, value): + self.__constants = value + def get_jacobian(self): n = len(self._arg_list) diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index df549747c9..aa9ebe66db 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -282,7 +282,8 @@ def name(self): @name.setter def name(self, value: str): - assert isinstance(value, str) + if not isinstance(value, str): + raise TypeError(f"{value} must be of type str") self._name = value @property diff --git a/pybamm/models/full_battery_models/lithium_ion/base_lithium_ion_model.py b/pybamm/models/full_battery_models/lithium_ion/base_lithium_ion_model.py index 479e8203ed..6db56b74c4 100644 --- a/pybamm/models/full_battery_models/lithium_ion/base_lithium_ion_model.py +++ b/pybamm/models/full_battery_models/lithium_ion/base_lithium_ion_model.py @@ -265,9 +265,9 @@ def set_sei_submodel(self): reaction_loc = "x-average" else: reaction_loc = "full electrode" - sei_option = getattr(self.options, domain)["SEI"] phases = self.options.phases[domain] for phase in phases: + sei_option = getattr(getattr(self.options, domain), phase)["SEI"] if sei_option == "none": submodel = pybamm.sei.NoSEI(self.param, domain, self.options, phase) elif sei_option == "constant": @@ -333,9 +333,11 @@ def set_lithium_plating_submodel(self): for domain in self.options.whole_cell_domains: if domain != "separator": domain = domain.split()[0].lower() - lithium_plating_opt = getattr(self.options, domain)["lithium plating"] phases = self.options.phases[domain] for phase in phases: + lithium_plating_opt = getattr(getattr(self.options, domain), phase)[ + "lithium plating" + ] if lithium_plating_opt == "none": submodel = pybamm.lithium_plating.NoPlating( self.param, domain, self.options, phase diff --git a/pybamm/models/submodels/active_material/loss_active_material.py b/pybamm/models/submodels/active_material/loss_active_material.py index 7816122e07..6f027d89e6 100644 --- a/pybamm/models/submodels/active_material/loss_active_material.py +++ b/pybamm/models/submodels/active_material/loss_active_material.py @@ -60,7 +60,9 @@ def get_coupled_variables(self, variables): domain, Domain = self.domain_Domain deps_solid_dt = 0 - lam_option = getattr(self.options, self.domain)["loss of active material"] + lam_option = getattr(getattr(self.options, domain), self.phase)[ + "loss of active material" + ] if "stress" in lam_option: # obtain the rate of loss of active materials (LAM) by stress # This is loss of active material model by mechanical effects diff --git a/pybamm/settings.py b/pybamm/settings.py index 2ccd9bcd13..d190eaf47e 100644 --- a/pybamm/settings.py +++ b/pybamm/settings.py @@ -29,8 +29,9 @@ def debug_mode(self): return self._debug_mode @debug_mode.setter - def debug_mode(self, value): - assert isinstance(value, bool) + def debug_mode(self, value: bool): + if not isinstance(value, bool): + raise TypeError(f"{value} must be of type bool") self._debug_mode = value @property @@ -38,8 +39,9 @@ def simplify(self): return self._simplify @simplify.setter - def simplify(self, value): - assert isinstance(value, bool) + def simplify(self, value: bool): + if not isinstance(value, bool): + raise TypeError(f"{value} must be of type bool") self._simplify = value def set_smoothing_parameters(self, k): diff --git a/pybamm/simulation.py b/pybamm/simulation.py index 8ec85d67d4..a55310870e 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -679,6 +679,7 @@ def solve( logs["step number"] = (step_num, cycle_length) logs["step operating conditions"] = step_str + logs["step duration"] = step.duration callbacks.on_step_start(logs) inputs = { @@ -767,23 +768,33 @@ def solve( callbacks.on_step_end(logs) logs["termination"] = step_solution.termination - # Only allow events specified by experiment - if not ( + + # Check for some cases that would make the experiment end early + if step_termination == "final time" and step.uses_default_duration: + # reached the default duration of a step (typically we should + # reach an event before the default duration) + callbacks.on_experiment_infeasible_time(logs) + feasible = False + break + + elif not ( isinstance(step_solution, pybamm.EmptySolution) or step_termination == "final time" or "[experiment]" in step_termination ): - callbacks.on_experiment_infeasible(logs) + # Step has reached an event that is not specified in the + # experiment + callbacks.on_experiment_infeasible_event(logs) feasible = False break - if time_stop is not None: - max_time = cycle_solution.t[-1] - if max_time >= time_stop: - break + elif time_stop is not None and logs["experiment time"] >= time_stop: + # reached the time limit of the experiment + break - # Increment index for next iteration - idx += 1 + else: + # Increment index for next iteration, then continue + idx += 1 if save_this_cycle or feasible is False: self._solution = self._solution + cycle_solution diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 1425bf0845..0eb573e87a 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -256,32 +256,30 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False): model.casadi_sensitivities_rhs = jacp_rhs model.casadi_sensitivities_algebraic = jacp_algebraic - # if output_variables specified then convert functions to casadi - # expressions for evaluation within the respective solver - self.computed_var_fcns = {} - self.computed_dvar_dy_fcns = {} - self.computed_dvar_dp_fcns = {} - for key in self.output_variables: - # ExplicitTimeIntegral's are not computed as part of the solver and - # do not need to be converted - if isinstance( - model.variables_and_events[key], pybamm.ExplicitTimeIntegral - ): - continue - # Generate Casadi function to calculate variable and derivates - # to enable sensitivites to be computed within the solver - ( - self.computed_var_fcns[key], - self.computed_dvar_dy_fcns[key], - self.computed_dvar_dp_fcns[key], - _, - ) = process( - model.variables_and_events[key], - BaseSolver._wrangle_name(key), - vars_for_processing, - use_jacobian=True, - return_jacp_stacked=True, - ) + # if output_variables specified then convert functions to casadi + # expressions for evaluation within the respective solver + self.computed_var_fcns = {} + self.computed_dvar_dy_fcns = {} + self.computed_dvar_dp_fcns = {} + for key in self.output_variables: + # ExplicitTimeIntegral's are not computed as part of the solver and + # do not need to be converted + if isinstance(model.variables_and_events[key], pybamm.ExplicitTimeIntegral): + continue + # Generate Casadi function to calculate variable and derivates + # to enable sensitivites to be computed within the solver + ( + self.computed_var_fcns[key], + self.computed_dvar_dy_fcns[key], + self.computed_dvar_dp_fcns[key], + _, + ) = process( + model.variables_and_events[key], + BaseSolver._wrangle_name(key), + vars_for_processing, + use_jacobian=True, + return_jacp_stacked=True, + ) pybamm.logger.info("Finish solver set-up") diff --git a/pybamm/solvers/c_solvers/idaklu.cpp b/pybamm/solvers/c_solvers/idaklu.cpp index 9f99d4d3f4..3afed5faa8 100644 --- a/pybamm/solvers/c_solvers/idaklu.cpp +++ b/pybamm/solvers/c_solvers/idaklu.cpp @@ -8,14 +8,20 @@ #include #include -#include "idaklu/casadi_solver.hpp" -#include "idaklu/idaklu_jax.hpp" +#include "idaklu/idaklu_solver.hpp" +#include "idaklu/IdakluJax.hpp" #include "idaklu/common.hpp" #include "idaklu/python.hpp" +#include "idaklu/Expressions/Casadi/CasadiFunctions.hpp" -Function generate_function(const std::string &data) +#ifdef IREE_ENABLE +#include "idaklu/Expressions/IREE/IREEFunctions.hpp" +#endif + + +casadi::Function generate_casadi_function(const std::string &data) { - return Function::deserialize(data); + return casadi::Function::deserialize(data); } namespace py = pybind11; @@ -50,8 +56,8 @@ PYBIND11_MODULE(idaklu, m) py::arg("number_of_sensitivity_parameters"), py::return_value_policy::take_ownership); - py::class_(m, "CasadiSolver") - .def("solve", &CasadiSolver::solve, + py::class_(m, "IDAKLUSolver") + .def("solve", &IDAKLUSolver::solve, "perform a solve", py::arg("t"), py::arg("y0"), @@ -59,7 +65,7 @@ PYBIND11_MODULE(idaklu, m) py::arg("inputs"), py::return_value_policy::take_ownership); - m.def("create_casadi_solver", &create_casadi_solver, + m.def("create_casadi_solver", &create_idaklu_solver, "Create a casadi idaklu solver object", py::arg("number_of_states"), py::arg("number_of_parameters"), @@ -79,13 +85,41 @@ PYBIND11_MODULE(idaklu, m) py::arg("atol"), py::arg("rtol"), py::arg("inputs"), - py::arg("var_casadi_fcns"), + py::arg("var_fcns"), + py::arg("dvar_dy_fcns"), + py::arg("dvar_dp_fcns"), + py::arg("options"), + py::return_value_policy::take_ownership); + +#ifdef IREE_ENABLE + m.def("create_iree_solver", &create_idaklu_solver, + "Create a iree idaklu solver object", + py::arg("number_of_states"), + py::arg("number_of_parameters"), + py::arg("rhs_alg"), + py::arg("jac_times_cjmass"), + py::arg("jac_times_cjmass_colptrs"), + py::arg("jac_times_cjmass_rowvals"), + py::arg("jac_times_cjmass_nnz"), + py::arg("jac_bandwidth_lower"), + py::arg("jac_bandwidth_upper"), + py::arg("jac_action"), + py::arg("mass_action"), + py::arg("sens"), + py::arg("events"), + py::arg("number_of_events"), + py::arg("rhs_alg_id"), + py::arg("atol"), + py::arg("rtol"), + py::arg("inputs"), + py::arg("var_fcns"), py::arg("dvar_dy_fcns"), py::arg("dvar_dp_fcns"), py::arg("options"), py::return_value_policy::take_ownership); +#endif - m.def("generate_function", &generate_function, + m.def("generate_function", &generate_casadi_function, "Generate a casadi function", py::arg("string"), py::return_value_policy::take_ownership); @@ -133,11 +167,25 @@ PYBIND11_MODULE(idaklu, m) &Registrations ); - py::class_(m, "Function"); + py::class_(m, "Function"); + +#ifdef IREE_ENABLE + py::class_(m, "IREEBaseFunctionType") + .def(py::init<>()) + .def_readwrite("mlir", &IREEBaseFunctionType::mlir) + .def_readwrite("kept_var_idx", &IREEBaseFunctionType::kept_var_idx) + .def_readwrite("nnz", &IREEBaseFunctionType::nnz) + .def_readwrite("numel", &IREEBaseFunctionType::numel) + .def_readwrite("col", &IREEBaseFunctionType::col) + .def_readwrite("row", &IREEBaseFunctionType::row) + .def_readwrite("pytree_shape", &IREEBaseFunctionType::pytree_shape) + .def_readwrite("pytree_sizes", &IREEBaseFunctionType::pytree_sizes) + .def_readwrite("n_args", &IREEBaseFunctionType::n_args); +#endif py::class_(m, "solution") - .def_readwrite("t", &Solution::t) - .def_readwrite("y", &Solution::y) - .def_readwrite("yS", &Solution::yS) - .def_readwrite("flag", &Solution::flag); + .def_readwrite("t", &Solution::t) + .def_readwrite("y", &Solution::y) + .def_readwrite("yS", &Solution::yS) + .def_readwrite("flag", &Solution::flag); } diff --git a/pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp b/pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp deleted file mode 100644 index 16a04f8eb9..0000000000 --- a/pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp +++ /dev/null @@ -1 +0,0 @@ -#include "CasadiSolver.hpp" diff --git a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp b/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp deleted file mode 100644 index 868d2b2138..0000000000 --- a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp +++ /dev/null @@ -1 +0,0 @@ -#include "CasadiSolverOpenMP_solvers.hpp" diff --git a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.hpp b/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.hpp deleted file mode 100644 index 3e39e5a303..0000000000 --- a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.hpp +++ /dev/null @@ -1,125 +0,0 @@ -#ifndef PYBAMM_IDAKLU_CASADI_SOLVER_OPENMP_HPP -#define PYBAMM_IDAKLU_CASADI_SOLVER_OPENMP_HPP - -#include "CasadiSolverOpenMP.hpp" -#include "casadi_solver.hpp" - -/** - * @brief CasadiSolver Dense implementation with OpenMP class - */ -class CasadiSolverOpenMP_Dense : public CasadiSolverOpenMP { -public: - template - CasadiSolverOpenMP_Dense(Args&& ... args) - : CasadiSolverOpenMP(std::forward(args) ...) - { - LS = SUNLinSol_Dense(yy, J, sunctx); - Initialize(); - } -}; - -/** - * @brief CasadiSolver KLU implementation with OpenMP class - */ -class CasadiSolverOpenMP_KLU : public CasadiSolverOpenMP { -public: - template - CasadiSolverOpenMP_KLU(Args&& ... args) - : CasadiSolverOpenMP(std::forward(args) ...) - { - LS = SUNLinSol_KLU(yy, J, sunctx); - Initialize(); - } -}; - -/** - * @brief CasadiSolver Banded implementation with OpenMP class - */ -class CasadiSolverOpenMP_Band : public CasadiSolverOpenMP { -public: - template - CasadiSolverOpenMP_Band(Args&& ... args) - : CasadiSolverOpenMP(std::forward(args) ...) - { - LS = SUNLinSol_Band(yy, J, sunctx); - Initialize(); - } -}; - -/** - * @brief CasadiSolver SPBCGS implementation with OpenMP class - */ -class CasadiSolverOpenMP_SPBCGS : public CasadiSolverOpenMP { -public: - template - CasadiSolverOpenMP_SPBCGS(Args&& ... args) - : CasadiSolverOpenMP(std::forward(args) ...) - { - LS = SUNLinSol_SPBCGS( - yy, - precon_type, - options.linsol_max_iterations, - sunctx - ); - Initialize(); - } -}; - -/** - * @brief CasadiSolver SPFGMR implementation with OpenMP class - */ -class CasadiSolverOpenMP_SPFGMR : public CasadiSolverOpenMP { -public: - template - CasadiSolverOpenMP_SPFGMR(Args&& ... args) - : CasadiSolverOpenMP(std::forward(args) ...) - { - LS = SUNLinSol_SPFGMR( - yy, - precon_type, - options.linsol_max_iterations, - sunctx - ); - Initialize(); - } -}; - -/** - * @brief CasadiSolver SPGMR implementation with OpenMP class - */ -class CasadiSolverOpenMP_SPGMR : public CasadiSolverOpenMP { -public: - template - CasadiSolverOpenMP_SPGMR(Args&& ... args) - : CasadiSolverOpenMP(std::forward(args) ...) - { - LS = SUNLinSol_SPGMR( - yy, - precon_type, - options.linsol_max_iterations, - sunctx - ); - Initialize(); - } -}; - -/** - * @brief CasadiSolver SPTFQMR implementation with OpenMP class - */ -class CasadiSolverOpenMP_SPTFQMR : public CasadiSolverOpenMP { -public: - template - CasadiSolverOpenMP_SPTFQMR(Args&& ... args) - : CasadiSolverOpenMP(std::forward(args) ...) - { - LS = SUNLinSol_SPTFQMR( - yy, - precon_type, - options.linsol_max_iterations, - sunctx - ); - Initialize(); - } -}; - -#endif // PYBAMM_IDAKLU_CASADI_SOLVER_OPENMP_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/Base/Expression.hpp b/pybamm/solvers/c_solvers/idaklu/Expressions/Base/Expression.hpp new file mode 100644 index 0000000000..bbf60b4568 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/Base/Expression.hpp @@ -0,0 +1,69 @@ +#ifndef PYBAMM_EXPRESSION_HPP +#define PYBAMM_EXPRESSION_HPP + +#include "ExpressionTypes.hpp" +#include "../../common.hpp" +#include "../../Options.hpp" +#include +#include + +class Expression { +public: // method declarations + /** + * @brief Constructor + */ + Expression() = default; + + /** + * @brief Evaluation operator (for use after setting input and output data references) + */ + virtual void operator()() = 0; + + /** + * @brief Evaluation operator (supplying data references) + */ + virtual void operator()( + const std::vector& inputs, + const std::vector& results) = 0; + + /** + * @brief The maximum number of elements returned by the k'th output + * + * This is used to allocate memory for the output of the function and usually (but + * not always) corresponds to the number of non-zero elements (NNZ). + */ + virtual expr_int out_shape(int k) = 0; + + /** + * @brief Return the number of non-zero elements for the function output + */ + virtual expr_int nnz() = 0; + + /** + * @brief Return the number of non-zero elements for the function output + */ + virtual expr_int nnz_out() = 0; + + /** + * @brief Returns row indices in COO format (where the output data represents sparse matrix elements) + */ + virtual std::vector get_row() = 0; + + /** + * @brief Returns column indices in COO format (where the output data represents sparse matrix elements) + */ + virtual std::vector get_col() = 0; + +public: // data members + /** + * @brief Vector of pointers to the input data + */ + std::vector m_arg; // cppcheck-suppress unusedStructMember + + /** + * @brief Vector of pointers to the output data + */ + std::vector m_res; // cppcheck-suppress unusedStructMember +}; + +#endif // PYBAMM_EXPRESSION_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionSet.hpp b/pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionSet.hpp new file mode 100644 index 0000000000..a32f906a38 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionSet.hpp @@ -0,0 +1,86 @@ +#ifndef PYBAMM_IDAKLU_EXPRESSION_SET_HPP +#define PYBAMM_IDAKLU_EXPRESSION_SET_HPP + +#include "ExpressionTypes.hpp" +#include "Expression.hpp" +#include "../../common.hpp" +#include "../../Options.hpp" +#include + +template +class ExpressionSet +{ +public: + + /** + * @brief Constructor + */ + ExpressionSet( + Expression* rhs_alg, + Expression* jac_times_cjmass, + const int jac_times_cjmass_nnz, + const int jac_bandwidth_lower, + const int jac_bandwidth_upper, + const np_array_int &jac_times_cjmass_rowvals_arg, // cppcheck-suppress unusedStructMember + const np_array_int &jac_times_cjmass_colptrs_arg, // cppcheck-suppress unusedStructMember + const int inputs_length, + Expression* jac_action, + Expression* mass_action, + Expression* sens, + Expression* events, + const int n_s, + const int n_e, + const int n_p, + const Options& options) + : number_of_states(n_s), + number_of_events(n_e), + number_of_parameters(n_p), + number_of_nnz(jac_times_cjmass_nnz), + jac_bandwidth_lower(jac_bandwidth_lower), + jac_bandwidth_upper(jac_bandwidth_upper), + rhs_alg(rhs_alg), + jac_times_cjmass(jac_times_cjmass), + jac_action(jac_action), + mass_action(mass_action), + sens(sens), + events(events), + tmp_state_vector(number_of_states), + tmp_sparse_jacobian_data(jac_times_cjmass_nnz), + options(options) + {}; + + int number_of_states; + int number_of_parameters; + int number_of_events; + int number_of_nnz; + int jac_bandwidth_lower; + int jac_bandwidth_upper; + + Expression *rhs_alg = nullptr; + Expression *jac_times_cjmass = nullptr; + Expression *jac_action = nullptr; + Expression *mass_action = nullptr; + Expression *sens = nullptr; + Expression *events = nullptr; + + // `cppcheck-suppress unusedStructMember` is used because codacy reports + // these members as unused, but they are inherited through variadics + std::vector var_fcns; // cppcheck-suppress unusedStructMember + std::vector dvar_dy_fcns; // cppcheck-suppress unusedStructMember + std::vector dvar_dp_fcns; // cppcheck-suppress unusedStructMember + + std::vector jac_times_cjmass_rowvals; // cppcheck-suppress unusedStructMember + std::vector jac_times_cjmass_colptrs; // cppcheck-suppress unusedStructMember + std::vector inputs; // cppcheck-suppress unusedStructMember + + Options options; + + virtual realtype *get_tmp_state_vector() = 0; + virtual realtype *get_tmp_sparse_jacobian_data() = 0; + +protected: + std::vector tmp_state_vector; + std::vector tmp_sparse_jacobian_data; +}; + +#endif // PYBAMM_IDAKLU_EXPRESSION_SET_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionTypes.hpp b/pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionTypes.hpp new file mode 100644 index 0000000000..c8d690c125 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionTypes.hpp @@ -0,0 +1,6 @@ +#ifndef PYBAMM_EXPRESSION_TYPES_HPP +#define PYBAMM_EXPRESSION_TYPES_HPP + +using expr_int = long long int; + +#endif // PYBAMM_EXPRESSION_TYPES_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/Casadi/CasadiFunctions.cpp b/pybamm/solvers/c_solvers/idaklu/Expressions/Casadi/CasadiFunctions.cpp new file mode 100644 index 0000000000..b0c8ab1142 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/Casadi/CasadiFunctions.cpp @@ -0,0 +1,80 @@ +#include "CasadiFunctions.hpp" +#include + +CasadiFunction::CasadiFunction(const BaseFunctionType &f) : Expression(), m_func(f) +{ + DEBUG("CasadiFunction constructor: " << m_func.name()); + + size_t sz_arg; + size_t sz_res; + size_t sz_iw; + size_t sz_w; + m_func.sz_work(sz_arg, sz_res, sz_iw, sz_w); + + int nnz = (sz_res>0) ? m_func.nnz_out() : 0; // cppcheck-suppress unreadVariable + DEBUG("name = "<< m_func.name() << " arg = " << sz_arg << " res = " + << sz_res << " iw = " << sz_iw << " w = " << sz_w << " nnz = " << nnz); + + m_arg.resize(sz_arg, nullptr); + m_res.resize(sz_res, nullptr); + m_iw.resize(sz_iw, 0); + m_w.resize(sz_w, 0); +} + +// only call this once m_arg and m_res have been set appropriately +void CasadiFunction::operator()() +{ + DEBUG("CasadiFunction operator(): " << m_func.name()); + int mem = m_func.checkout(); + m_func(m_arg.data(), m_res.data(), m_iw.data(), m_w.data(), mem); + m_func.release(mem); +} + +expr_int CasadiFunction::out_shape(int k) { + DEBUG("CasadiFunctions out_shape(): " << m_func.name() << " " << m_func.nnz_out()); + return static_cast(m_func.nnz_out()); +} + +expr_int CasadiFunction::nnz() { + DEBUG("CasadiFunction nnz(): " << m_func.name() << " " << static_cast(m_func.nnz_out())); + return static_cast(m_func.nnz_out()); +} + +expr_int CasadiFunction::nnz_out() { + DEBUG("CasadiFunction nnz_out(): " << m_func.name() << " " << static_cast(m_func.nnz_out())); + return static_cast(m_func.nnz_out()); +} + +std::vector CasadiFunction::get_row() { + return get_row(0); +} + +std::vector CasadiFunction::get_row(expr_int ind) { + DEBUG("CasadiFunction get_row(): " << m_func.name()); + casadi::Sparsity casadi_sparsity = m_func.sparsity_out(ind); + return casadi_sparsity.get_row(); +} + +std::vector CasadiFunction::get_col() { + return get_col(0); +} + +std::vector CasadiFunction::get_col(expr_int ind) { + DEBUG("CasadiFunction get_col(): " << m_func.name()); + casadi::Sparsity casadi_sparsity = m_func.sparsity_out(ind); + return casadi_sparsity.get_col(); +} + +void CasadiFunction::operator()(const std::vector& inputs, + const std::vector& results) +{ + DEBUG("CasadiFunction operator() with inputs and results: " << m_func.name()); + + // Set-up input arguments, provide result vector, then execute function + // Example call: fcn({in1, in2, in3}, {out1}) + for(size_t k=0; k +#include +#include +#include + +/** + * @brief Class for handling individual casadi functions + */ +class CasadiFunction : public Expression +{ +public: + + typedef casadi::Function BaseFunctionType; + + /** + * @brief Constructor + */ + explicit CasadiFunction(const BaseFunctionType &f); + + // Method overrides + void operator()() override; + void operator()(const std::vector& inputs, + const std::vector& results) override; + expr_int out_shape(int k) override; + expr_int nnz() override; + expr_int nnz_out() override; + std::vector get_row() override; + std::vector get_row(expr_int ind); + std::vector get_col() override; + std::vector get_col(expr_int ind); + +public: + /* + * @brief Casadi function + */ + BaseFunctionType m_func; + +private: + std::vector m_iw; // cppcheck-suppress unusedStructMember + std::vector m_w; // cppcheck-suppress unusedStructMember +}; + +/** + * @brief Class for handling casadi functions + */ +class CasadiFunctions : public ExpressionSet +{ +public: + + typedef CasadiFunction::BaseFunctionType BaseFunctionType; // expose typedef in class + + /** + * @brief Create a new CasadiFunctions object + */ + CasadiFunctions( + const BaseFunctionType &rhs_alg, + const BaseFunctionType &jac_times_cjmass, + const int jac_times_cjmass_nnz, + const int jac_bandwidth_lower, + const int jac_bandwidth_upper, + const np_array_int &jac_times_cjmass_rowvals_arg, + const np_array_int &jac_times_cjmass_colptrs_arg, + const int inputs_length, + const BaseFunctionType &jac_action, + const BaseFunctionType &mass_action, + const BaseFunctionType &sens, + const BaseFunctionType &events, + const int n_s, + const int n_e, + const int n_p, + const std::vector& var_fcns, + const std::vector& dvar_dy_fcns, + const std::vector& dvar_dp_fcns, + const Options& options + ) : + rhs_alg_casadi(rhs_alg), + jac_times_cjmass_casadi(jac_times_cjmass), + jac_action_casadi(jac_action), + mass_action_casadi(mass_action), + sens_casadi(sens), + events_casadi(events), + ExpressionSet( + static_cast(&rhs_alg_casadi), + static_cast(&jac_times_cjmass_casadi), + jac_times_cjmass_nnz, + jac_bandwidth_lower, + jac_bandwidth_upper, + jac_times_cjmass_rowvals_arg, + jac_times_cjmass_colptrs_arg, + inputs_length, + static_cast(&jac_action_casadi), + static_cast(&mass_action_casadi), + static_cast(&sens_casadi), + static_cast(&events_casadi), + n_s, n_e, n_p, + options) + { + // convert BaseFunctionType list to CasadiFunction list + // NOTE: You must allocate ALL std::vector elements before taking references + for (auto& var : var_fcns) + var_fcns_casadi.push_back(CasadiFunction(*var)); + for (int k = 0; k < var_fcns_casadi.size(); k++) + ExpressionSet::var_fcns.push_back(&this->var_fcns_casadi[k]); + + for (auto& var : dvar_dy_fcns) + dvar_dy_fcns_casadi.push_back(CasadiFunction(*var)); + for (int k = 0; k < dvar_dy_fcns_casadi.size(); k++) + this->dvar_dy_fcns.push_back(&this->dvar_dy_fcns_casadi[k]); + + for (auto& var : dvar_dp_fcns) + dvar_dp_fcns_casadi.push_back(CasadiFunction(*var)); + for (int k = 0; k < dvar_dp_fcns_casadi.size(); k++) + this->dvar_dp_fcns.push_back(&this->dvar_dp_fcns_casadi[k]); + + // copy across numpy array values + const int n_row_vals = jac_times_cjmass_rowvals_arg.request().size; + auto p_jac_times_cjmass_rowvals = jac_times_cjmass_rowvals_arg.unchecked<1>(); + jac_times_cjmass_rowvals.resize(n_row_vals); + for (int i = 0; i < n_row_vals; i++) { + jac_times_cjmass_rowvals[i] = p_jac_times_cjmass_rowvals[i]; + } + + const int n_col_ptrs = jac_times_cjmass_colptrs_arg.request().size; + auto p_jac_times_cjmass_colptrs = jac_times_cjmass_colptrs_arg.unchecked<1>(); + jac_times_cjmass_colptrs.resize(n_col_ptrs); + for (int i = 0; i < n_col_ptrs; i++) { + jac_times_cjmass_colptrs[i] = p_jac_times_cjmass_colptrs[i]; + } + + inputs.resize(inputs_length); + } + + CasadiFunction rhs_alg_casadi; + CasadiFunction jac_times_cjmass_casadi; + CasadiFunction jac_action_casadi; + CasadiFunction mass_action_casadi; + CasadiFunction sens_casadi; + CasadiFunction events_casadi; + + std::vector var_fcns_casadi; + std::vector dvar_dy_fcns_casadi; + std::vector dvar_dp_fcns_casadi; + + realtype* get_tmp_state_vector() override { + return tmp_state_vector.data(); + } + realtype* get_tmp_sparse_jacobian_data() override { + return tmp_sparse_jacobian_data.data(); + } +}; + +#endif // PYBAMM_IDAKLU_CASADI_FUNCTIONS_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/Expressions.hpp b/pybamm/solvers/c_solvers/idaklu/Expressions/Expressions.hpp new file mode 100644 index 0000000000..70380eaba7 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/Expressions.hpp @@ -0,0 +1,6 @@ +#ifndef PYBAMM_IDAKLU_EXPRESSIONS_HPP +#define PYBAMM_IDAKLU_EXPRESSIONS_HPP + +#include "Base/ExpressionSet.hpp" + +#endif // PYBAMM_IDAKLU_EXPRESSIONS_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEBaseFunction.hpp b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEBaseFunction.hpp new file mode 100644 index 0000000000..d2ba7e4de0 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEBaseFunction.hpp @@ -0,0 +1,27 @@ +#ifndef PYBAMM_IDAKLU_IREE_BASE_FUNCTION_HPP +#define PYBAMM_IDAKLU_IREE_BASE_FUNCTION_HPP + +#include +#include + +/* + * @brief Function definition passed from PyBaMM + */ +class IREEBaseFunctionType +{ +public: // methods + const std::string& get_mlir() const { return mlir; } + +public: // data members + std::string mlir; // cppcheck-suppress unusedStructMember + std::vector kept_var_idx; // cppcheck-suppress unusedStructMember + expr_int nnz; // cppcheck-suppress unusedStructMember + expr_int numel; // cppcheck-suppress unusedStructMember + std::vector col; // cppcheck-suppress unusedStructMember + std::vector row; // cppcheck-suppress unusedStructMember + std::vector pytree_shape; // cppcheck-suppress unusedStructMember + std::vector pytree_sizes; // cppcheck-suppress unusedStructMember + expr_int n_args; // cppcheck-suppress unusedStructMember +}; + +#endif // PYBAMM_IDAKLU_IREE_BASE_FUNCTION_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunction.hpp b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunction.hpp new file mode 100644 index 0000000000..26f81c8f98 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunction.hpp @@ -0,0 +1,59 @@ +#ifndef PYBAMM_IDAKLU_IREE_FUNCTION_HPP +#define PYBAMM_IDAKLU_IREE_FUNCTION_HPP + +#include "../../Options.hpp" +#include "../Expressions.hpp" +#include +#include "iree_jit.hpp" +#include "IREEBaseFunction.hpp" + +/** + * @brief Class for handling individual iree functions + */ +class IREEFunction : public Expression +{ +public: + typedef IREEBaseFunctionType BaseFunctionType; + + /* + * @brief Constructor + */ + explicit IREEFunction(const BaseFunctionType &f); + + // Method overrides + void operator()() override; + void operator()(const std::vector& inputs, + const std::vector& results) override; + expr_int out_shape(int k) override; + expr_int nnz() override; + expr_int nnz_out() override; + std::vector get_col() override; + std::vector get_row() override; + + /* + * @brief Evaluate the MLIR function + */ + void evaluate(); + + /* + * @brief Evaluate the MLIR function + * @param n_outputs The number of outputs to return + */ + void evaluate(int n_outputs); + +public: + std::unique_ptr session; + std::vector> result; // cppcheck-suppress unusedStructMember + std::vector> input_shape; // cppcheck-suppress unusedStructMember + std::vector> output_shape; // cppcheck-suppress unusedStructMember + std::vector> input_data; // cppcheck-suppress unusedStructMember + + BaseFunctionType m_func; // cppcheck-suppress unusedStructMember + std::string module_name; // cppcheck-suppress unusedStructMember + std::string function_name; // cppcheck-suppress unusedStructMember + std::vector m_arg_argno; // cppcheck-suppress unusedStructMember + std::vector m_arg_argix; // cppcheck-suppress unusedStructMember + std::vector numel; // cppcheck-suppress unusedStructMember +}; + +#endif // PYBAMM_IDAKLU_IREE_FUNCTION_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.cpp b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.cpp new file mode 100644 index 0000000000..6837d21198 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.cpp @@ -0,0 +1,230 @@ +#include +#include +#include +#include +#include + +#include "IREEFunctions.hpp" +#include "iree_jit.hpp" +#include "ModuleParser.hpp" + +IREEFunction::IREEFunction(const BaseFunctionType &f) : Expression(), m_func(f) +{ + DEBUG("IreeFunction constructor"); + const std::string& mlir = f.get_mlir(); + + // Parse IREE (MLIR) function string + if (mlir.size() == 0) { + DEBUG("Empty function --- skipping..."); + return; + } + + // Parse MLIR for module name, input and output shapes + ModuleParser parser(mlir); + module_name = parser.getModuleName(); + function_name = parser.getFunctionName(); + input_shape = parser.getInputShape(); + output_shape = parser.getOutputShape(); + + DEBUG("Compiling module: '" << module_name << "'"); + const char* device_uri = "local-sync"; + session = std::make_unique(device_uri, mlir); + DEBUG("compile complete."); + // Create index vectors into m_arg + // This is required since Jax expands input arguments through PyTrees, which need to + // be remapped to the corresponding expression call. For example: + // fcn(t, y, inputs, cj) with inputs = [[in1], [in2], [in3]] + // will produce a function with six inputs; we therefore need to be able to map + // arguments to their 1) corresponding input argument, and 2) the correct position + // within that argument. + m_arg_argno.clear(); + m_arg_argix.clear(); + int current_element = 0; + for (int i=0; i 2) || + ((input_shape[j].size() == 2) && (input_shape[j][1] > 1)) + ) { + std::cerr << "Unsupported input shape: " << input_shape[j].size() << " ["; + for (int k=0; k {res0} signature (i.e. x and z are reduced out) + // with kept_var_idx = [1] + // + // *********************************************************************************** + + DEBUG("Copying inputs, shape " << input_shape.size() << " - " << m_func.kept_var_idx.size()); + for (int j=0; j 1) { + // Index into argument using appropriate shape + for(int k=0; k(m_arg[m_arg_from][m_arg_argix[mlir_arg]+k]); + } + } else { + // Copy the entire vector + for(int k=0; k(m_arg[m_arg_from][k]); + } + } + } + + // Call the 'main' function of the module + const std::string mlir = m_func.get_mlir(); + DEBUG("Calling function '" << function_name << "'"); + auto status = session->iree_runtime_exec(function_name, input_shape, input_data, result); + if (!iree_status_is_ok(status)) { + iree_status_fprint(stderr, status); + std::cerr << "MLIR: " << mlir.substr(0,1000) << std::endl; + throw std::runtime_error("Execution failed"); + } + + // Copy results to output array + for(size_t k=0; k(result[k][j]); + } + } + + DEBUG("IreeFunction operator() complete"); +} + +expr_int IREEFunction::out_shape(int k) { + DEBUG("IreeFunction nnz(" << k << "): " << m_func.nnz); + auto elements = 1; + for (auto i : output_shape[k]) { + elements *= i; + } + return elements; +} + +expr_int IREEFunction::nnz() { + DEBUG("IreeFunction nnz: " << m_func.nnz); + return nnz_out(); +} + +expr_int IREEFunction::nnz_out() { + DEBUG("IreeFunction nnz_out" << m_func.nnz); + return m_func.nnz; +} + +std::vector IREEFunction::get_row() { + DEBUG("IreeFunction get_row" << m_func.row.size()); + return m_func.row; +} + +std::vector IREEFunction::get_col() { + DEBUG("IreeFunction get_col" << m_func.col.size()); + return m_func.col; +} + +void IREEFunction::operator()(const std::vector& inputs, + const std::vector& results) +{ + DEBUG("IreeFunction operator() with inputs and results"); + // Set-up input arguments, provide result vector, then execute function + // Example call: fcn({in1, in2, in3}, {out1}) + ASSERT(inputs.size() == m_func.n_args); + for(size_t k=0; k +#include "iree_jit.hpp" +#include "IREEFunction.hpp" + +/** + * @brief Class for handling iree functions + */ +class IREEFunctions : public ExpressionSet +{ +public: + std::unique_ptr iree_compiler; + + typedef IREEFunction::BaseFunctionType BaseFunctionType; // expose typedef in class + + int iree_init_status; + + int iree_init(const std::string& device_uri, const std::string& target_backends) { + // Initialise IREE + DEBUG("IREEFunctions: Initialising IREECompiler"); + iree_compiler = std::make_unique(device_uri.c_str()); + + int iree_argc = 2; + std::string target_backends_str = "--iree-hal-target-backends=" + target_backends; + const char* iree_argv[2] = {"iree", target_backends_str.c_str()}; + iree_compiler->init(iree_argc, iree_argv); + DEBUG("IREEFunctions: Initialised IREECompiler"); + return 0; + } + + int iree_init() { + return iree_init("local-sync", "llvm-cpu"); + } + + + /** + * @brief Create a new IREEFunctions object + */ + IREEFunctions( + const BaseFunctionType &rhs_alg, + const BaseFunctionType &jac_times_cjmass, + const int jac_times_cjmass_nnz, + const int jac_bandwidth_lower, + const int jac_bandwidth_upper, + const np_array_int &jac_times_cjmass_rowvals_arg, + const np_array_int &jac_times_cjmass_colptrs_arg, + const int inputs_length, + const BaseFunctionType &jac_action, + const BaseFunctionType &mass_action, + const BaseFunctionType &sens, + const BaseFunctionType &events, + const int n_s, + const int n_e, + const int n_p, + const std::vector& var_fcns, + const std::vector& dvar_dy_fcns, + const std::vector& dvar_dp_fcns, + const Options& options + ) : + iree_init_status(iree_init()), + rhs_alg_iree(rhs_alg), + jac_times_cjmass_iree(jac_times_cjmass), + jac_action_iree(jac_action), + mass_action_iree(mass_action), + sens_iree(sens), + events_iree(events), + ExpressionSet( + static_cast(&rhs_alg_iree), + static_cast(&jac_times_cjmass_iree), + jac_times_cjmass_nnz, + jac_bandwidth_lower, + jac_bandwidth_upper, + jac_times_cjmass_rowvals_arg, + jac_times_cjmass_colptrs_arg, + inputs_length, + static_cast(&jac_action_iree), + static_cast(&mass_action_iree), + static_cast(&sens_iree), + static_cast(&events_iree), + n_s, n_e, n_p, + options) + { + // convert BaseFunctionType list to IREEFunction list + // NOTE: You must allocate ALL std::vector elements before taking references + for (auto& var : var_fcns) + var_fcns_iree.push_back(IREEFunction(*var)); + for (int k = 0; k < var_fcns_iree.size(); k++) + ExpressionSet::var_fcns.push_back(&this->var_fcns_iree[k]); + + for (auto& var : dvar_dy_fcns) + dvar_dy_fcns_iree.push_back(IREEFunction(*var)); + for (int k = 0; k < dvar_dy_fcns_iree.size(); k++) + this->dvar_dy_fcns.push_back(&this->dvar_dy_fcns_iree[k]); + + for (auto& var : dvar_dp_fcns) + dvar_dp_fcns_iree.push_back(IREEFunction(*var)); + for (int k = 0; k < dvar_dp_fcns_iree.size(); k++) + this->dvar_dp_fcns.push_back(&this->dvar_dp_fcns_iree[k]); + + // copy across numpy array values + const int n_row_vals = jac_times_cjmass_rowvals_arg.request().size; + auto p_jac_times_cjmass_rowvals = jac_times_cjmass_rowvals_arg.unchecked<1>(); + jac_times_cjmass_rowvals.resize(n_row_vals); + for (int i = 0; i < n_row_vals; i++) { + jac_times_cjmass_rowvals[i] = p_jac_times_cjmass_rowvals[i]; + } + + const int n_col_ptrs = jac_times_cjmass_colptrs_arg.request().size; + auto p_jac_times_cjmass_colptrs = jac_times_cjmass_colptrs_arg.unchecked<1>(); + jac_times_cjmass_colptrs.resize(n_col_ptrs); + for (int i = 0; i < n_col_ptrs; i++) { + jac_times_cjmass_colptrs[i] = p_jac_times_cjmass_colptrs[i]; + } + + inputs.resize(inputs_length); + } + + IREEFunction rhs_alg_iree; + IREEFunction jac_times_cjmass_iree; + IREEFunction jac_action_iree; + IREEFunction mass_action_iree; + IREEFunction sens_iree; + IREEFunction events_iree; + + std::vector var_fcns_iree; + std::vector dvar_dy_fcns_iree; + std::vector dvar_dp_fcns_iree; + + realtype* get_tmp_state_vector() override { + return tmp_state_vector.data(); + } + realtype* get_tmp_sparse_jacobian_data() override { + return tmp_sparse_jacobian_data.data(); + } + + ~IREEFunctions() { + // cleanup IREE + iree_compiler->cleanup(); + } +}; + +#endif // PYBAMM_IDAKLU_IREE_FUNCTIONS_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.cpp b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.cpp new file mode 100644 index 0000000000..d1c5575ee2 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.cpp @@ -0,0 +1,91 @@ +#include "ModuleParser.hpp" + +ModuleParser::ModuleParser(const std::string& mlir) : mlir(mlir) +{ + parse(); +} + +void ModuleParser::parse() +{ + // Parse module name + std::regex module_name_regex("module @([^\\s]+)"); // Match until first whitespace + std::smatch module_name_match; + std::regex_search(this->mlir, module_name_match, module_name_regex); + if (module_name_match.size() == 0) { + std::cerr << "Could not find module name in module" << std::endl; + std::cerr << "Module snippet: " << this->mlir.substr(0, 1000) << std::endl; + throw std::runtime_error("Could not find module name in module"); + } + module_name = module_name_match[1].str(); + DEBUG("Module name: " << module_name); + + // Assign function name + function_name = module_name + ".main"; + + // Isolate 'main' function call signature + std::regex main_func("public @main\\((.*?)\\) -> \\((.*?)\\)"); + std::smatch match; + std::regex_search(this->mlir, match, main_func); + if (match.size() == 0) { + std::cerr << "Could not find 'main' function in module" << std::endl; + std::cerr << "Module snippet: " << this->mlir.substr(0, 1000) << std::endl; + throw std::runtime_error("Could not find 'main' function in module"); + } + std::string main_sig_inputs = match[1].str(); + std::string main_sig_outputs = match[2].str(); + DEBUG( + "Main function signature: " << main_sig_inputs << " -> " << main_sig_outputs << '\n' + ); + + // Parse input sizes + input_shape.clear(); + std::regex input_size("tensor<(.*?)>"); + for(std::sregex_iterator i = std::sregex_iterator(main_sig_inputs.begin(), main_sig_inputs.end(), input_size); + i != std::sregex_iterator(); + ++i) + { + std::smatch matchi = *i; + std::string match_str = matchi.str(); + std::string shape_str = match_str.substr(7, match_str.size() - 8); // Remove 'tensor<>' from string + std::vector shape; + std::string dim_str; + for (char c : shape_str) { + if (c == 'x') { + shape.push_back(std::stoi(dim_str)); + dim_str = ""; + } else { + dim_str += c; + } + } + input_shape.push_back(shape); + } + + // Parse output sizes + output_shape.clear(); + std::regex output_size("tensor<(.*?)>"); + for( + std::sregex_iterator i = std::sregex_iterator(main_sig_outputs.begin(), main_sig_outputs.end(), output_size); + i != std::sregex_iterator(); + ++i + ) { + std::smatch matchi = *i; + std::string match_str = matchi.str(); + std::string shape_str = match_str.substr(7, match_str.size() - 8); // Remove 'tensor<>' from string + std::vector shape; + std::string dim_str; + for (char c : shape_str) { + if (c == 'x') { + shape.push_back(std::stoi(dim_str)); + dim_str = ""; + } else { + dim_str += c; + } + } + // If shape is empty, assume scalar (i.e. "tensor" or some singleton variant) + if (shape.size() == 0) { + shape.push_back(1); + } + // Add output to list + output_shape.push_back(shape); + } +} diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.hpp b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.hpp new file mode 100644 index 0000000000..2fbfdc086c --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.hpp @@ -0,0 +1,55 @@ +#ifndef PYBAMM_IDAKLU_IREE_MODULE_PARSER_HPP +#define PYBAMM_IDAKLU_IREE_MODULE_PARSER_HPP + +#include +#include +#include +#include +#include + +#include "../../common.hpp" + +class ModuleParser { +private: + std::string mlir; // cppcheck-suppress unusedStructMember + // codacy fix: member is referenced as this->mlir in parse() + std::string module_name; + std::string function_name; + std::vector> input_shape; + std::vector> output_shape; +public: + /** + * @brief Constructor + * @param mlir: string representation of MLIR code for the module + */ + explicit ModuleParser(const std::string& mlir); + + /** + * @brief Get the module name + * @return module name + */ + const std::string& getModuleName() const { return module_name; } + + /** + * @brief Get the function name + * @return function name + */ + const std::string& getFunctionName() const { return function_name; } + + /** + * @brief Get the input shape + * @return input shape + */ + const std::vector>& getInputShape() const { return input_shape; } + + /** + * @brief Get the output shape + * @return output shape + */ + const std::vector>& getOutputShape() const { return output_shape; } + +private: + void parse(); +}; + +#endif // PYBAMM_IDAKLU_IREE_MODULE_PARSER_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.cpp b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.cpp new file mode 100644 index 0000000000..c84c3928bd --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.cpp @@ -0,0 +1,408 @@ +#include "iree_jit.hpp" +#include "iree/hal/buffer_view.h" +#include "iree/hal/buffer_view_util.h" +#include "../../common.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +// Used to suppress stderr output (see initIREE below) +#ifdef _WIN32 +#include +#define close _close +#define dup _dup +#define fileno _fileno +#define open _open +#define dup2 _dup2 +#define NULL_DEVICE "NUL" +#else +#define NULL_DEVICE "/dev/null" +#endif + +void IREESession::handle_compiler_error(iree_compiler_error_t *error) { + const char *msg = ireeCompilerErrorGetMessage(error); + fprintf(stderr, "Error from compiler API:\n%s\n", msg); + ireeCompilerErrorDestroy(error); +} + +void IREESession::cleanup_compiler_state(compiler_state_t s) { + if (s.inv) + ireeCompilerInvocationDestroy(s.inv); + if (s.output) + ireeCompilerOutputDestroy(s.output); + if (s.source) + ireeCompilerSourceDestroy(s.source); + if (s.session) + ireeCompilerSessionDestroy(s.session); +} + +IREECompiler::IREECompiler() { + this->device_uri = "local-sync"; +}; + +IREECompiler::~IREECompiler() { + ireeCompilerGlobalShutdown(); +}; + +int IREECompiler::init(int argc, const char **argv) { + return initIREE(argc, argv); // Initialisation and version checking +}; + +int IREECompiler::cleanup() { + return 0; +}; + +IREESession::IREESession() { + s.session = NULL; + s.source = NULL; + s.output = NULL; + s.inv = NULL; +}; + +IREESession::IREESession(const char *device_uri, const std::string& mlir_code) : IREESession() { + this->device_uri=device_uri; + this->mlir_code=mlir_code; + init(); +} + +int IREESession::init() { + if (initCompiler() != 0) // Prepare compiler inputs and outputs + return 1; + if (initCompileToByteCode() != 0) // Compile to bytecode + return 1; + if (initRuntime() != 0) // Initialise runtime environment + return 1; + return 0; +}; + +int IREECompiler::initIREE(int argc, const char **argv) { + + if (device_uri == NULL) { + DEBUG("No device URI provided, using local-sync\n"); + this->device_uri = "local-sync"; + } + + int cl_argc = argc; + const char *iree_compiler_lib = std::getenv("IREE_COMPILER_LIB"); + + // Load the compiler library and initialize it + // NOTE: On second and subsequent calls, the function will return false and display + // a message on stderr, but it is safe to ignore this message. For an improved user + // experience we actively suppress stderr during the call to this function but since + // this also suppresses any other error message, we actively check for the presence + // of the library file prior to the call. + + // Check if the library file exists + if (iree_compiler_lib == NULL) { + fprintf(stderr, "Error: IREE_COMPILER_LIB environment variable not set\n"); + return 1; + } + if (access(iree_compiler_lib, F_OK) == -1) { + fprintf(stderr, "Error: IREE_COMPILER_LIB file not found\n"); + return 1; + } + // Suppress stderr + int saved_stderr = dup(fileno(stderr)); + if (!freopen(NULL_DEVICE, "w", stderr)) + DEBUG("Error: failed redirecting stderr"); + // Load library + bool result = ireeCompilerLoadLibrary(iree_compiler_lib); + // Restore stderr + fflush(stderr); + dup2(saved_stderr, fileno(stderr)); + close(saved_stderr); + // Process result + if (!result) { + // Library may have already been loaded (can be safely ignored), + // or may not be found (critical error), we cannot tell which from the return value. + return 1; + } + // Must be balanced with a call to ireeCompilerGlobalShutdown() + ireeCompilerGlobalInitialize(); + + // To set global options (see `iree-compile --help` for possibilities), use + // |ireeCompilerGetProcessCLArgs| and |ireeCompilerSetupGlobalCL| + ireeCompilerGetProcessCLArgs(&cl_argc, &argv); + ireeCompilerSetupGlobalCL(cl_argc, argv, "iree-jit", false); + + // Check the API version before proceeding any further + uint32_t api_version = (uint32_t)ireeCompilerGetAPIVersion(); + uint16_t api_version_major = (uint16_t)((api_version >> 16) & 0xFFFFUL); + uint16_t api_version_minor = (uint16_t)(api_version & 0xFFFFUL); + DEBUG("Compiler API version: " << api_version_major << "." << api_version_minor); + if (api_version_major > IREE_COMPILER_EXPECTED_API_MAJOR || + api_version_minor < IREE_COMPILER_EXPECTED_API_MINOR) { + fprintf(stderr, + "Error: incompatible API version; built for version %" PRIu16 + ".%" PRIu16 " but loaded version %" PRIu16 ".%" PRIu16 "\n", + IREE_COMPILER_EXPECTED_API_MAJOR, IREE_COMPILER_EXPECTED_API_MINOR, + api_version_major, api_version_minor); + ireeCompilerGlobalShutdown(); + return 1; + } + + // Check for a build tag with release version information + const char *revision = ireeCompilerGetRevision(); // cppcheck-suppress unreadVariable + DEBUG("Compiler revision: '" << revision << "'"); + return 0; +}; + +int IREESession::initCompiler() { + + // A session provides a scope where one or more invocations can be executed + s.session = ireeCompilerSessionCreate(); + + // Read the MLIR from memory + error = ireeCompilerSourceWrapBuffer( + s.session, + "expr_buffer", // name of the buffer (does not need to match MLIR) + mlir_code.c_str(), + mlir_code.length() + 1, + true, + &s.source + ); + if (error) { + fprintf(stderr, "Error wrapping source buffer\n"); + handle_compiler_error(error); + cleanup_compiler_state(s); + return 1; + } + DEBUG("Wrapped buffer as a compiler source"); + + return 0; +}; + +int IREESession::initCompileToByteCode() { + // Use an invocation to compile from the input source to the output stream + iree_compiler_invocation_t *inv = ireeCompilerInvocationCreate(s.session); + ireeCompilerInvocationEnableConsoleDiagnostics(inv); + + if (!ireeCompilerInvocationParseSource(inv, s.source)) { + fprintf(stderr, "Error parsing input source into invocation\n"); + cleanup_compiler_state(s); + return 1; + } + + // Compile, specifying the target dialect phase + ireeCompilerInvocationSetCompileToPhase(inv, "end"); + + // Run the compiler invocation pipeline + if (!ireeCompilerInvocationPipeline(inv, IREE_COMPILER_PIPELINE_STD)) { + fprintf(stderr, "Error running compiler invocation\n"); + cleanup_compiler_state(s); + return 1; + } + DEBUG("Compilation successful"); + + // Create compiler 'output' to a memory buffer + error = ireeCompilerOutputOpenMembuffer(&s.output); + if (error) { + fprintf(stderr, "Error opening output membuffer\n"); + handle_compiler_error(error); + cleanup_compiler_state(s); + return 1; + } + + // Create bytecode in memory + error = ireeCompilerInvocationOutputVMBytecode(inv, s.output); + if (error) { + fprintf(stderr, "Error creating VM bytecode\n"); + handle_compiler_error(error); + cleanup_compiler_state(s); + return 1; + } + + // Once the bytecode has been written, retrieve the memory map + ireeCompilerOutputMapMemory(s.output, &contents, &size); + + return 0; +}; + +int IREESession::initRuntime() { + // Setup the shared runtime instance + iree_runtime_instance_options_t instance_options; + iree_runtime_instance_options_initialize(&instance_options); + iree_runtime_instance_options_use_all_available_drivers(&instance_options); + status = iree_runtime_instance_create( + &instance_options, iree_allocator_system(), &instance); + + // Create the HAL device used to run the workloads + if (iree_status_is_ok(status)) { + status = iree_hal_create_device( + iree_runtime_instance_driver_registry(instance), + iree_make_cstring_view(device_uri), + iree_runtime_instance_host_allocator(instance), &device); + } + + // Set up the session to run the module + if (iree_status_is_ok(status)) { + iree_runtime_session_options_t session_options; + iree_runtime_session_options_initialize(&session_options); + status = iree_runtime_session_create_with_device( + instance, &session_options, device, + iree_runtime_instance_host_allocator(instance), &session); + } + + // Load the compiled user module from a file + if (iree_status_is_ok(status)) { + /*status = iree_runtime_session_append_bytecode_module_from_file(session, module_path);*/ + status = iree_runtime_session_append_bytecode_module_from_memory( + session, + iree_make_const_byte_span(contents, size), + iree_allocator_null()); + } + + if (!iree_status_is_ok(status)) + return 1; + + return 0; +}; + +// Release the session and free all cached resources. +int IREESession::cleanup() { + iree_runtime_session_release(session); + iree_hal_device_release(device); + iree_runtime_instance_release(instance); + + int ret = (int)iree_status_code(status); + if (!iree_status_is_ok(status)) { + iree_status_fprint(stderr, status); + iree_status_ignore(status); + } + cleanup_compiler_state(s); + return ret; +} + +iree_status_t IREESession::iree_runtime_exec( + const std::string& function_name, + const std::vector>& inputs, + const std::vector>& data, + std::vector>& result +) { + + // Initialize the call to the function. + status = iree_runtime_call_initialize_by_name( + session, iree_make_cstring_view(function_name.c_str()), &call); + if (!iree_status_is_ok(status)) { + std::cerr << "Error: iree_runtime_call_initialize_by_name failed" << std::endl; + iree_status_fprint(stderr, status); + return status; + } + + // Append the function inputs with the HAL device allocator in use by the + // session. The buffers will be usable within the session and _may_ be usable + // in other sessions depending on whether they share a compatible device. + iree_hal_allocator_t* device_allocator = + iree_runtime_session_device_allocator(session); + host_allocator = iree_runtime_session_host_allocator(session); + status = iree_ok_status(); + if (iree_status_is_ok(status)) { + + for(int k=0; k arg_shape(input_shape.size()); + for (int i = 0; i < input_shape.size(); i++) { + arg_shape[i] = input_shape[i]; + } + int numel = 1; + for(int i = 0; i < input_shape.size(); i++) { + numel *= input_shape[i]; + } + std::vector arg_data(numel); + for(int i = 0; i < numel; i++) { + arg_data[i] = input_data[i]; + } + + status = iree_hal_buffer_view_allocate_buffer_copy( + device, device_allocator, + // Shape rank and dimensions: + arg_shape.size(), arg_shape.data(), + // Element type: + IREE_HAL_ELEMENT_TYPE_FLOAT_32, + // Encoding type: + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + (iree_hal_buffer_params_t){ + // Intended usage of the buffer (transfers, dispatches, etc): + .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, + // Access to allow to this memory: + .access = IREE_HAL_MEMORY_ACCESS_ALL, + // Where to allocate (host or device): + .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, + }, + // The actual heap buffer to wrap or clone and its allocator: + iree_make_const_byte_span(&arg_data[0], sizeof(float) * arg_data.size()), + // Buffer view + storage are returned and owned by the caller: + &arg); + } + if (iree_status_is_ok(status)) { + // Add to the call inputs list (which retains the buffer view). + status = iree_runtime_call_inputs_push_back_buffer_view(&call, arg); + if (!iree_status_is_ok(status)) { + std::cerr << "Error: iree_runtime_call_inputs_push_back_buffer_view failed" << std::endl; + iree_status_fprint(stderr, status); + } + } + // Since the call retains the buffer view we can release it here. + iree_hal_buffer_view_release(arg); + } + } + + // Synchronously perform the call. + if (iree_status_is_ok(status)) { + status = iree_runtime_call_invoke(&call, /*flags=*/0); + } + if (!iree_status_is_ok(status)) { + std::cerr << "Error: iree_runtime_call_invoke failed" << std::endl; + iree_status_fprint(stderr, status); + } + + for(int k=0; k +#include +#include +#include + +#include +#include +#include + +#define IREE_COMPILER_EXPECTED_API_MAJOR 1 // At most this major version +#define IREE_COMPILER_EXPECTED_API_MINOR 2 // At least this minor version + +// Forward declaration +class IREESession; + +/* + * @brief IREECompiler class + * @details This class is used to compile MLIR code to IREE bytecode and + * create IREE sessions. + */ +class IREECompiler { +private: + /* + * @brief Device Uniform Resource Identifier (URI) + * @details The device URI is used to specify the device to be used by the + * IREE runtime. E.g. "local-sync" for CPU, "vulkan" for GPU, etc. + */ + const char *device_uri = NULL; + +private: + /* + * @brief Initialize the IREE runtime + */ + int initIREE(int argc, const char **argv); + +public: + /* + * @brief Default constructor + */ + IREECompiler(); + + /* + * @brief Destructor + */ + ~IREECompiler(); + + /* + * @brief Constructor with device URI + * @param device_uri Device URI + */ + explicit IREECompiler(const char *device_uri) + : IREECompiler() { this->device_uri=device_uri; } + + /* + * @brief Initialize the compiler + */ + int init(int argc, const char **argv); + + /* + * @brief Cleanup the compiler + * @details This method cleans up the compiler and all the IREE sessions + * created by the compiler. Returns 0 on success. + */ + int cleanup(); +}; + +/* + * @brief Compiler state + */ +typedef struct compiler_state_t { + iree_compiler_session_t *session; // cppcheck-suppress unusedStructMember + iree_compiler_source_t *source; // cppcheck-suppress unusedStructMember + iree_compiler_output_t *output; // cppcheck-suppress unusedStructMember + iree_compiler_invocation_t *inv; // cppcheck-suppress unusedStructMember +} compiler_state_t; + +/* + * @brief IREE session class + */ +class IREESession { +private: // data members + const char *device_uri = NULL; + compiler_state_t s; + iree_compiler_error_t *error = NULL; + void *contents = NULL; + uint64_t size = 0; + iree_runtime_session_t* session = NULL; + iree_status_t status; + iree_hal_device_t* device = NULL; + iree_runtime_instance_t* instance = NULL; + std::string mlir_code; // cppcheck-suppress unusedStructMember + iree_runtime_call_t call; + iree_allocator_t host_allocator; + +private: // private methods + void handle_compiler_error(iree_compiler_error_t *error); + void cleanup_compiler_state(compiler_state_t s); + int init(); + int initCompiler(); + int initCompileToByteCode(); + int initRuntime(); + +public: // public methods + + /* + * @brief Default constructor + */ + IREESession(); + + /* + * @brief Constructor with device URI and MLIR code + * @param device_uri Device URI + * @param mlir_code MLIR code + */ + explicit IREESession(const char *device_uri, const std::string& mlir_code); + + /* + * @brief Cleanup the IREE session + */ + int cleanup(); + + /* + * @brief Execute the pre-compiled byte-code with the given inputs + * @param function_name Function name to execute + * @param inputs List of input shapes + * @param data List of input data + * @param result List of output data + */ + iree_status_t iree_runtime_exec( + const std::string& function_name, + const std::vector>& inputs, + const std::vector>& data, + std::vector>& result + ); +}; + +#endif // IREE_JIT_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.cpp b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.cpp new file mode 100644 index 0000000000..b769d4d1d4 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.cpp @@ -0,0 +1 @@ +#include "IDAKLUSolver.hpp" diff --git a/pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp similarity index 75% rename from pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp rename to pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp index dac94579f3..26e587e424 100644 --- a/pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp +++ b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp @@ -1,33 +1,27 @@ #ifndef PYBAMM_IDAKLU_CASADI_SOLVER_HPP #define PYBAMM_IDAKLU_CASADI_SOLVER_HPP -#include -using Function = casadi::Function; - -#include "casadi_functions.hpp" #include "common.hpp" -#include "options.hpp" -#include "solution.hpp" -#include "sundials_legacy_wrapper.hpp" +#include "Solution.hpp" /** * Abstract base class for solutions that can use different solvers and vector * implementations. * @brief An abstract base class for the Idaklu solver */ -class CasadiSolver +class IDAKLUSolver { public: /** * @brief Default constructor */ - CasadiSolver() = default; + IDAKLUSolver() = default; /** * @brief Default destructor */ - ~CasadiSolver() = default; + ~IDAKLUSolver() = default; /** * @brief Abstract solver method that returns a Solution class diff --git a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.hpp b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp similarity index 82% rename from pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.hpp rename to pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp index 2312f9cf8f..8c49069b30 100644 --- a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.hpp +++ b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp @@ -1,14 +1,10 @@ -#ifndef PYBAMM_IDAKLU_CASADISOLVEROPENMP_HPP -#define PYBAMM_IDAKLU_CASADISOLVEROPENMP_HPP +#ifndef PYBAMM_IDAKLU_SOLVEROPENMP_HPP +#define PYBAMM_IDAKLU_SOLVEROPENMP_HPP -#include "CasadiSolver.hpp" -#include -using Function = casadi::Function; - -#include "casadi_functions.hpp" +#include "IDAKLUSolver.hpp" #include "common.hpp" -#include "options.hpp" -#include "solution.hpp" +#include "Options.hpp" +#include "Solution.hpp" #include "sundials_legacy_wrapper.hpp" /** @@ -40,7 +36,8 @@ using Function = casadi::Function; * 19. Destroy objects * 20. (N/A) Finalize MPI */ -class CasadiSolverOpenMP : public CasadiSolver +template +class IDAKLUSolverOpenMP : public IDAKLUSolver { // NB: cppcheck-suppress unusedStructMember is used because codacy reports // these members as unused even though they are important in child @@ -63,10 +60,10 @@ class CasadiSolverOpenMP : public CasadiSolver int jac_bandwidth_upper; // cppcheck-suppress unusedStructMember SUNMatrix J; SUNLinearSolver LS = nullptr; - std::unique_ptr functions; - realtype *res = nullptr; - realtype *res_dvar_dy = nullptr; - realtype *res_dvar_dp = nullptr; + std::unique_ptr functions; + std::vector res; + std::vector res_dvar_dy; + std::vector res_dvar_dp; Options options; #if SUNDIALS_VERSION_MAJOR >= 6 @@ -77,7 +74,7 @@ class CasadiSolverOpenMP : public CasadiSolver /** * @brief Constructor */ - CasadiSolverOpenMP( + IDAKLUSolverOpenMP( np_array atol_np, double rel_tol, np_array rhs_alg_id, @@ -86,18 +83,18 @@ class CasadiSolverOpenMP : public CasadiSolver int jac_times_cjmass_nnz, int jac_bandwidth_lower, int jac_bandwidth_upper, - std::unique_ptr functions, + std::unique_ptr functions, const Options& options); /** * @brief Destructor */ - ~CasadiSolverOpenMP(); + ~IDAKLUSolverOpenMP(); /** - * Evaluate casadi functions (including sensitivies) for each requested + * Evaluate functions (including sensitivies) for each requested * variable and store - * @brief Evaluate casadi functions + * @brief Evaluate functions */ void CalcVars( realtype *y_return, @@ -110,7 +107,7 @@ class CasadiSolverOpenMP : public CasadiSolver size_t *ySk); /** - * @brief Evaluate casadi functions for sensitivities + * @brief Evaluate functions for sensitivities */ void CalcVarsSensitivities( realtype *tret, @@ -144,4 +141,6 @@ class CasadiSolverOpenMP : public CasadiSolver void SetMatrix(); }; -#endif // PYBAMM_IDAKLU_CASADISOLVEROPENMP_HPP +#include "IDAKLUSolverOpenMP.inl" + +#endif // PYBAMM_IDAKLU_SOLVEROPENMP_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl similarity index 82% rename from pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp rename to pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl index ad51eda4e1..383037e2ca 100644 --- a/pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp +++ b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl @@ -1,10 +1,8 @@ -#include "CasadiSolverOpenMP.hpp" -#include "casadi_sundials_functions.hpp" -#include -#include -#include +#include "Expressions/Expressions.hpp" +#include "sundials_functions.hpp" -CasadiSolverOpenMP::CasadiSolverOpenMP( +template +IDAKLUSolverOpenMP::IDAKLUSolverOpenMP( np_array atol_np, double rel_tol, np_array rhs_alg_id, @@ -13,7 +11,7 @@ CasadiSolverOpenMP::CasadiSolverOpenMP( int jac_times_cjmass_nnz, int jac_bandwidth_lower, int jac_bandwidth_upper, - std::unique_ptr functions_arg, + std::unique_ptr functions_arg, const Options &options ) : atol_np(atol_np), @@ -28,8 +26,8 @@ CasadiSolverOpenMP::CasadiSolverOpenMP( options(options) { // Construction code moved to Initialize() which is called from the - // (child) CasadiSolver_XXX class constructors. - DEBUG("CasadiSolverOpenMP::CasadiSolverOpenMP"); + // (child) IDAKLUSolver_* class constructors. + DEBUG("IDAKLUSolverOpenMP:IDAKLUSolverOpenMP"); auto atol = atol_np.unchecked<1>(); // create SUNDIALS context object @@ -59,14 +57,14 @@ CasadiSolverOpenMP::CasadiSolverOpenMP( SetMatrix(); // initialise solver - IDAInit(ida_mem, residual_casadi, 0, yy, yp); + IDAInit(ida_mem, residual_eval, 0, yy, yp); // set tolerances rtol = RCONST(rel_tol); IDASVtolerances(ida_mem, rtol, avtol); // set events - IDARootInit(ida_mem, number_of_events, events_casadi); + IDARootInit(ida_mem, number_of_events, events_eval); void *user_data = functions.get(); IDASetUserData(ida_mem, user_data); @@ -77,7 +75,8 @@ CasadiSolverOpenMP::CasadiSolverOpenMP( } } -void CasadiSolverOpenMP::AllocateVectors() { +template +void IDAKLUSolverOpenMP::AllocateVectors() { // Create vectors yy = N_VNew_OpenMP(number_of_states, options.num_threads, sunctx); yp = N_VNew_OpenMP(number_of_states, options.num_threads, sunctx); @@ -85,7 +84,8 @@ void CasadiSolverOpenMP::AllocateVectors() { id = N_VNew_OpenMP(number_of_states, options.num_threads, sunctx); } -void CasadiSolverOpenMP::SetMatrix() { +template +void IDAKLUSolverOpenMP::SetMatrix() { // Create Matrix object if (options.jacobian == "sparse") { @@ -94,7 +94,7 @@ void CasadiSolverOpenMP::SetMatrix() { number_of_states, number_of_states, jac_times_cjmass_nnz, - CSC_MAT, // CSC is used by casadi; CSR requires a conversion step + CSC_MAT, sunctx ); } @@ -124,7 +124,8 @@ void CasadiSolverOpenMP::SetMatrix() { throw std::invalid_argument("Unsupported matrix requested"); } -void CasadiSolverOpenMP::Initialize() { +template +void IDAKLUSolverOpenMP::Initialize() { // Call after setting the solver // attach the linear solver @@ -139,18 +140,18 @@ void CasadiSolverOpenMP::Initialize() { IDABBDPrecInit( ida_mem, number_of_states, options.precon_half_bandwidth, options.precon_half_bandwidth, options.precon_half_bandwidth_keep, - options.precon_half_bandwidth_keep, 0.0, residual_casadi_approx, NULL); + options.precon_half_bandwidth_keep, 0.0, residual_eval_approx, NULL); } if (options.jacobian == "matrix-free") - IDASetJacTimes(ida_mem, NULL, jtimes_casadi); + IDASetJacTimes(ida_mem, NULL, jtimes_eval); else if (options.jacobian != "none") - IDASetJacFn(ida_mem, jacobian_casadi); + IDASetJacFn(ida_mem, jacobian_eval); if (number_of_parameters > 0) { IDASensInit(ida_mem, number_of_parameters, IDA_SIMULTANEOUS, - sensitivities_casadi, yyS, ypS); + sensitivities_eval, yyS, ypS); IDASensEEtolerances(ida_mem); } @@ -167,7 +168,8 @@ void CasadiSolverOpenMP::Initialize() { IDASetId(ida_mem, id); } -CasadiSolverOpenMP::~CasadiSolverOpenMP() +template +IDAKLUSolverOpenMP::~IDAKLUSolverOpenMP() { // Free memory if (number_of_parameters > 0) @@ -190,7 +192,8 @@ CasadiSolverOpenMP::~CasadiSolverOpenMP() SUNContext_Free(&sunctx); } -void CasadiSolverOpenMP::CalcVars( +template +void IDAKLUSolverOpenMP::CalcVars( realtype *y_return, size_t length_of_return_vector, size_t t_i, @@ -200,61 +203,61 @@ void CasadiSolverOpenMP::CalcVars( realtype *yS_return, size_t *ySk ) { - // Evaluate casadi functions for each requested variable and store + DEBUG("IDAKLUSolver::CalcVars"); + // Evaluate functions for each requested variable and store size_t j = 0; - for (auto& var_fcn : functions->var_casadi_fcns) { - var_fcn({tret, yval, functions->inputs.data()}, {res}); + for (auto& var_fcn : functions->var_fcns) { + (*var_fcn)({tret, yval, functions->inputs.data()}, {&res[0]}); // store in return vector - for (size_t jj=0; jjnnz_out(); jj++) y_return[t_i*length_of_return_vector + j++] = res[jj]; } // calculate sensitivities CalcVarsSensitivities(tret, yval, ySval, yS_return, ySk); } -void CasadiSolverOpenMP::CalcVarsSensitivities( +template +void IDAKLUSolverOpenMP::CalcVarsSensitivities( realtype *tret, realtype *yval, const std::vector& ySval, realtype *yS_return, size_t *ySk ) { + DEBUG("IDAKLUSolver::CalcVarsSensitivities"); // Calculate sensitivities - - // Loop over variables - realtype* dens_dvar_dp = new realtype[number_of_parameters]; + std::vector dens_dvar_dp = std::vector(number_of_parameters, 0); for (size_t dvar_k=0; dvar_kdvar_dy_fcns.size(); dvar_k++) { // Isolate functions - CasadiFunction dvar_dy = functions->dvar_dy_fcns[dvar_k]; - CasadiFunction dvar_dp = functions->dvar_dp_fcns[dvar_k]; + Expression* dvar_dy = functions->dvar_dy_fcns[dvar_k]; + Expression* dvar_dp = functions->dvar_dp_fcns[dvar_k]; // Calculate dvar/dy - dvar_dy({tret, yval, functions->inputs.data()}, {res_dvar_dy}); - casadi::Sparsity spdy = dvar_dy.sparsity_out(0); + (*dvar_dy)({tret, yval, functions->inputs.data()}, {&res_dvar_dy[0]}); // Calculate dvar/dp and convert to dense array for indexing - dvar_dp({tret, yval, functions->inputs.data()}, {res_dvar_dp}); - casadi::Sparsity spdp = dvar_dp.sparsity_out(0); + (*dvar_dp)({tret, yval, functions->inputs.data()}, {&res_dvar_dp[0]}); for(int k=0; knnz_out(); k++) + dens_dvar_dp[dvar_dp->get_row()[k]] = res_dvar_dp[k]; // Calculate sensitivities for(int paramk=0; paramknnz_out(); spk++) + yS_return[*ySk] += res_dvar_dy[spk] * ySval[paramk][dvar_dy->get_col()[spk]]; (*ySk)++; } } } -Solution CasadiSolverOpenMP::solve( +template +Solution IDAKLUSolverOpenMP::solve( np_array t_np, np_array y0_np, np_array yp0_np, np_array_dense inputs ) { - DEBUG("CasadiSolver::solve"); + DEBUG("IDAKLUSolver::solve"); int number_of_timesteps = t_np.request().size; auto t = t_np.unchecked<1>(); @@ -315,15 +318,15 @@ Solution CasadiSolverOpenMP::solve( int length_of_return_vector = 0; size_t max_res_size = 0; // maximum result size (for common result buffer) size_t max_res_dvar_dy = 0, max_res_dvar_dp = 0; - if (functions->var_casadi_fcns.size() > 0) { + if (functions->var_fcns.size() > 0) { // return only the requested variables list after computation - for (auto& var_fcn : functions->var_casadi_fcns) { - max_res_size = std::max(max_res_size, size_t(var_fcn.nnz_out())); - length_of_return_vector += var_fcn.nnz_out(); + for (auto& var_fcn : functions->var_fcns) { + max_res_size = std::max(max_res_size, size_t(var_fcn->out_shape(0))); + length_of_return_vector += var_fcn->nnz_out(); for (auto& dvar_fcn : functions->dvar_dy_fcns) - max_res_dvar_dy = std::max(max_res_dvar_dy, size_t(dvar_fcn.nnz_out())); + max_res_dvar_dy = std::max(max_res_dvar_dy, size_t(dvar_fcn->out_shape(0))); for (auto& dvar_fcn : functions->dvar_dp_fcns) - max_res_dvar_dp = std::max(max_res_dvar_dp, size_t(dvar_fcn.nnz_out())); + max_res_dvar_dp = std::max(max_res_dvar_dp, size_t(dvar_fcn->out_shape(0))); } } else { // Return full y state-vector @@ -336,9 +339,9 @@ Solution CasadiSolverOpenMP::solve( number_of_timesteps * length_of_return_vector]; - res = new realtype[max_res_size]; - res_dvar_dy = new realtype[max_res_dvar_dy]; - res_dvar_dp = new realtype[max_res_dvar_dp]; + res.resize(max_res_size); + res_dvar_dy.resize(max_res_dvar_dy); + res_dvar_dp.resize(max_res_dvar_dp); py::capsule free_t_when_done( t_return, @@ -366,8 +369,8 @@ Solution CasadiSolverOpenMP::solve( int t_i = 0; size_t ySk = 0; t_return[t_i] = t(t_i); - if (functions->var_casadi_fcns.size() > 0) { - // Evaluate casadi functions for each requested variable and store + if (functions->var_fcns.size() > 0) { + // Evaluate functions for each requested variable and store CalcVars(y_return, length_of_return_vector, t_i, &tret, yval, ySval, yS_return, &ySk); } else { @@ -401,8 +404,8 @@ Solution CasadiSolverOpenMP::solve( // Evaluate and store results for the time step t_return[t_i] = tret; - if (functions->var_casadi_fcns.size() > 0) { - // Evaluate casadi functions for each requested variable and store + if (functions->var_fcns.size() > 0) { + // Evaluate functions for each requested variable and store // NOTE: Indexing of yS_return is (time:var:param) CalcVars(y_return, length_of_return_vector, t_i, &tret, yval, ySval, yS_return, &ySk); @@ -446,7 +449,7 @@ Solution CasadiSolverOpenMP::solve( // Note: Ordering of vector is differnet if computing variables vs returning // the complete state vector np_array yS_ret; - if (functions->var_casadi_fcns.size() > 0) { + if (functions->var_fcns.size() > 0) { yS_ret = np_array( std::vector { number_of_timesteps, diff --git a/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.cpp b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.cpp new file mode 100644 index 0000000000..45ceed0ada --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.cpp @@ -0,0 +1 @@ +#include "IDAKLUSolverOpenMP_solvers.hpp" diff --git a/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.hpp b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.hpp new file mode 100644 index 0000000000..ebeb543232 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.hpp @@ -0,0 +1,131 @@ +#ifndef PYBAMM_IDAKLU_CASADI_SOLVER_OPENMP_HPP +#define PYBAMM_IDAKLU_CASADI_SOLVER_OPENMP_HPP + +#include "IDAKLUSolverOpenMP.hpp" + +/** + * @brief IDAKLUSolver Dense implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_Dense : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_Dense(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_Dense(Base::yy, Base::J, Base::sunctx); + Base::Initialize(); + } +}; + +/** + * @brief IDAKLUSolver KLU implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_KLU : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_KLU(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_KLU(Base::yy, Base::J, Base::sunctx); + Base::Initialize(); + } +}; + +/** + * @brief IDAKLUSolver Banded implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_Band : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_Band(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_Band(Base::yy, Base::J, Base::sunctx); + Base::Initialize(); + } +}; + +/** + * @brief IDAKLUSolver SPBCGS implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_SPBCGS : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_SPBCGS(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_SPBCGS( + Base::yy, + Base::precon_type, + Base::options.linsol_max_iterations, + Base::sunctx + ); + Base::Initialize(); + } +}; + +/** + * @brief IDAKLUSolver SPFGMR implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_SPFGMR : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_SPFGMR(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_SPFGMR( + Base::yy, + Base::precon_type, + Base::options.linsol_max_iterations, + Base::sunctx + ); + Base::Initialize(); + } +}; + +/** + * @brief IDAKLUSolver SPGMR implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_SPGMR : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_SPGMR(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_SPGMR( + Base::yy, + Base::precon_type, + Base::options.linsol_max_iterations, + Base::sunctx + ); + Base::Initialize(); + } +}; + +/** + * @brief IDAKLUSolver SPTFQMR implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_SPTFQMR : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_SPTFQMR(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_SPTFQMR( + Base::yy, + Base::precon_type, + Base::options.linsol_max_iterations, + Base::sunctx + ); + Base::Initialize(); + } +}; + +#endif // PYBAMM_IDAKLU_CASADI_SOLVER_OPENMP_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/idaklu_jax.cpp b/pybamm/solvers/c_solvers/idaklu/IdakluJax.cpp similarity index 99% rename from pybamm/solvers/c_solvers/idaklu/idaklu_jax.cpp rename to pybamm/solvers/c_solvers/idaklu/IdakluJax.cpp index b338560259..15c2b2d811 100644 --- a/pybamm/solvers/c_solvers/idaklu/idaklu_jax.cpp +++ b/pybamm/solvers/c_solvers/idaklu/IdakluJax.cpp @@ -1,4 +1,4 @@ -#include "idaklu_jax.hpp" +#include "IdakluJax.hpp" #include #include diff --git a/pybamm/solvers/c_solvers/idaklu/idaklu_jax.hpp b/pybamm/solvers/c_solvers/idaklu/IdakluJax.hpp similarity index 100% rename from pybamm/solvers/c_solvers/idaklu/idaklu_jax.hpp rename to pybamm/solvers/c_solvers/idaklu/IdakluJax.hpp diff --git a/pybamm/solvers/c_solvers/idaklu/options.cpp b/pybamm/solvers/c_solvers/idaklu/Options.cpp similarity index 99% rename from pybamm/solvers/c_solvers/idaklu/options.cpp rename to pybamm/solvers/c_solvers/idaklu/Options.cpp index efad4d5de0..684ab47f33 100644 --- a/pybamm/solvers/c_solvers/idaklu/options.cpp +++ b/pybamm/solvers/c_solvers/idaklu/Options.cpp @@ -1,4 +1,4 @@ -#include "options.hpp" +#include "Options.hpp" #include #include diff --git a/pybamm/solvers/c_solvers/idaklu/options.hpp b/pybamm/solvers/c_solvers/idaklu/Options.hpp similarity index 100% rename from pybamm/solvers/c_solvers/idaklu/options.hpp rename to pybamm/solvers/c_solvers/idaklu/Options.hpp diff --git a/pybamm/solvers/c_solvers/idaklu/Solution.cpp b/pybamm/solvers/c_solvers/idaklu/Solution.cpp new file mode 100644 index 0000000000..7b50364379 --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/Solution.cpp @@ -0,0 +1 @@ +#include "Solution.hpp" diff --git a/pybamm/solvers/c_solvers/idaklu/solution.hpp b/pybamm/solvers/c_solvers/idaklu/Solution.hpp similarity index 100% rename from pybamm/solvers/c_solvers/idaklu/solution.hpp rename to pybamm/solvers/c_solvers/idaklu/Solution.hpp diff --git a/pybamm/solvers/c_solvers/idaklu/casadi_functions.cpp b/pybamm/solvers/c_solvers/idaklu/casadi_functions.cpp deleted file mode 100644 index ddad4612c9..0000000000 --- a/pybamm/solvers/c_solvers/idaklu/casadi_functions.cpp +++ /dev/null @@ -1,105 +0,0 @@ -#include "casadi_functions.hpp" - -CasadiFunction::CasadiFunction(const Function &f) : m_func(f) -{ - size_t sz_arg; - size_t sz_res; - size_t sz_iw; - size_t sz_w; - m_func.sz_work(sz_arg, sz_res, sz_iw, sz_w); - //int nnz = (sz_res>0) ? m_func.nnz_out() : 0; - //std::cout << "name = "<< m_func.name() << " arg = " << sz_arg << " res = " - // << sz_res << " iw = " << sz_iw << " w = " << sz_w << " nnz = " << nnz << - // std::endl; - m_arg.resize(sz_arg, nullptr); - m_res.resize(sz_res, nullptr); - m_iw.resize(sz_iw, 0); - m_w.resize(sz_w, 0); -} - -// only call this once m_arg and m_res have been set appropriately -void CasadiFunction::operator()() -{ - int mem = m_func.checkout(); - m_func(m_arg.data(), m_res.data(), m_iw.data(), m_w.data(), mem); - m_func.release(mem); -} - -casadi_int CasadiFunction::nnz_out() { - return m_func.nnz_out(); -} - -casadi::Sparsity CasadiFunction::sparsity_out(casadi_int ind) { - return m_func.sparsity_out(ind); -} - -void CasadiFunction::operator()(const std::vector& inputs, - const std::vector& results) -{ - // Set-up input arguments, provide result vector, then execute function - // Example call: fcn({in1, in2, in3}, {out1}) - for(size_t k=0; k& var_casadi_fcns, - const std::vector& dvar_dy_fcns, - const std::vector& dvar_dp_fcns, - const Options& options) - : number_of_states(n_s), number_of_events(n_e), number_of_parameters(n_p), - number_of_nnz(jac_times_cjmass_nnz), - jac_bandwidth_lower(jac_bandwidth_lower), jac_bandwidth_upper(jac_bandwidth_upper), - rhs_alg(rhs_alg), - jac_times_cjmass(jac_times_cjmass), jac_action(jac_action), - mass_action(mass_action), sens(sens), events(events), - tmp_state_vector(number_of_states), - tmp_sparse_jacobian_data(jac_times_cjmass_nnz), - options(options) -{ - // convert casadi::Function list to CasadiFunction list - for (auto& var : var_casadi_fcns) { - this->var_casadi_fcns.push_back(CasadiFunction(*var)); - } - for (auto& var : dvar_dy_fcns) { - this->dvar_dy_fcns.push_back(CasadiFunction(*var)); - } - for (auto& var : dvar_dp_fcns) { - this->dvar_dp_fcns.push_back(CasadiFunction(*var)); - } - - // copy across numpy array values - const int n_row_vals = jac_times_cjmass_rowvals_arg.request().size; - auto p_jac_times_cjmass_rowvals = jac_times_cjmass_rowvals_arg.unchecked<1>(); - jac_times_cjmass_rowvals.resize(n_row_vals); - for (int i = 0; i < n_row_vals; i++) { - jac_times_cjmass_rowvals[i] = p_jac_times_cjmass_rowvals[i]; - } - - const int n_col_ptrs = jac_times_cjmass_colptrs_arg.request().size; - auto p_jac_times_cjmass_colptrs = jac_times_cjmass_colptrs_arg.unchecked<1>(); - jac_times_cjmass_colptrs.resize(n_col_ptrs); - for (int i = 0; i < n_col_ptrs; i++) { - jac_times_cjmass_colptrs[i] = p_jac_times_cjmass_colptrs[i]; - } - - inputs.resize(inputs_length); -} - -realtype *CasadiFunctions::get_tmp_state_vector() { - return tmp_state_vector.data(); -} -realtype *CasadiFunctions::get_tmp_sparse_jacobian_data() { - return tmp_sparse_jacobian_data.data(); -} diff --git a/pybamm/solvers/c_solvers/idaklu/casadi_functions.hpp b/pybamm/solvers/c_solvers/idaklu/casadi_functions.hpp deleted file mode 100644 index 1a33b957f8..0000000000 --- a/pybamm/solvers/c_solvers/idaklu/casadi_functions.hpp +++ /dev/null @@ -1,160 +0,0 @@ -#ifndef PYBAMM_IDAKLU_CASADI_FUNCTIONS_HPP -#define PYBAMM_IDAKLU_CASADI_FUNCTIONS_HPP - -#include "common.hpp" -#include "options.hpp" -#include -#include -#include - -/** - * Utility function to convert compressed-sparse-column (CSC) to/from - * compressed-sparse-row (CSR) matrix representation. Conversion is symmetric / - * invertible using this function. - * @brief Utility function to convert to/from CSC/CSR matrix representations. - * @param f Data vector containing the sparse matrix elements - * @param c Index pointer to column starts - * @param r Array of row indices - * @param nf New data vector that will contain the transformed sparse matrix - * @param nc New array of column indices - * @param nr New index pointer to row starts - */ -template -void csc_csr(const realtype f[], const T1 c[], const T1 r[], realtype nf[], T2 nc[], T2 nr[], int N, int cols) { - std::vector nn(cols+1); - std::vector rr(N); - for (int i=0; i& inputs, - const std::vector& results); - - /** - * @brief Return the number of non-zero elements for the function output - */ - casadi_int nnz_out(); - - /** - * @brief Return the number of non-zero elements for the function output - */ - casadi::Sparsity sparsity_out(casadi_int ind); - -public: - std::vector m_arg; - std::vector m_res; - -private: - const Function &m_func; - std::vector m_iw; - std::vector m_w; -}; - -/** - * @brief Class for handling casadi functions - */ -class CasadiFunctions -{ -public: - /** - * @brief Create a new CasadiFunctions object - */ - CasadiFunctions( - const Function &rhs_alg, - const Function &jac_times_cjmass, - const int jac_times_cjmass_nnz, - const int jac_bandwidth_lower, - const int jac_bandwidth_upper, - const np_array_int &jac_times_cjmass_rowvals, - const np_array_int &jac_times_cjmass_colptrs, - const int inputs_length, - const Function &jac_action, - const Function &mass_action, - const Function &sens, - const Function &events, - const int n_s, - const int n_e, - const int n_p, - const std::vector& var_casadi_fcns, - const std::vector& dvar_dy_fcns, - const std::vector& dvar_dp_fcns, - const Options& options - ); - -public: - int number_of_states; - int number_of_parameters; - int number_of_events; - int number_of_nnz; - int jac_bandwidth_lower; - int jac_bandwidth_upper; - - CasadiFunction rhs_alg; - CasadiFunction sens; - CasadiFunction jac_times_cjmass; - CasadiFunction jac_action; - CasadiFunction mass_action; - CasadiFunction events; - - // NB: cppcheck-suppress unusedStructMember is used because codacy reports - // these members as unused even though they are important - std::vector var_casadi_fcns; // cppcheck-suppress unusedStructMember - std::vector dvar_dy_fcns; // cppcheck-suppress unusedStructMember - std::vector dvar_dp_fcns; // cppcheck-suppress unusedStructMember - - std::vector jac_times_cjmass_rowvals; - std::vector jac_times_cjmass_colptrs; - std::vector inputs; - - Options options; - - realtype *get_tmp_state_vector(); - realtype *get_tmp_sparse_jacobian_data(); - -private: - std::vector tmp_state_vector; - std::vector tmp_sparse_jacobian_data; -}; - -#endif // PYBAMM_IDAKLU_CASADI_FUNCTIONS_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/casadi_solver.hpp b/pybamm/solvers/c_solvers/idaklu/casadi_solver.hpp deleted file mode 100644 index 335907a93a..0000000000 --- a/pybamm/solvers/c_solvers/idaklu/casadi_solver.hpp +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef PYBAMM_IDAKLU_CREATE_CASADI_SOLVER_HPP -#define PYBAMM_IDAKLU_CREATE_CASADI_SOLVER_HPP - -#include "CasadiSolver.hpp" - -/** - * Creates a concrete casadi solver given a linear solver, as specified in - * options_cpp.linear_solver. - * @brief Create a concrete casadi solver given a linear solver - */ -CasadiSolver *create_casadi_solver( - int number_of_states, - int number_of_parameters, - const Function &rhs_alg, - const Function &jac_times_cjmass, - const np_array_int &jac_times_cjmass_colptrs, - const np_array_int &jac_times_cjmass_rowvals, - const int jac_times_cjmass_nnz, - const int jac_bandwidth_lower, - const int jac_bandwidth_upper, - const Function &jac_action, - const Function &mass_action, - const Function &sens, - const Function &event, - const int number_of_events, - np_array rhs_alg_id, - np_array atol_np, - double rel_tol, - int inputs_length, - const std::vector& var_casadi_fcns, - const std::vector& dvar_dy_fcns, - const std::vector& dvar_dp_fcns, - py::dict options -); - -#endif // PYBAMM_IDAKLU_CREATE_CASADI_SOLVER_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.hpp b/pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.hpp deleted file mode 100644 index a2192030b4..0000000000 --- a/pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.hpp +++ /dev/null @@ -1,27 +0,0 @@ -#ifndef PYBAMM_IDAKLU_CASADI_SUNDIALS_FUNCTIONS_HPP -#define PYBAMM_IDAKLU_CASADI_SUNDIALS_FUNCTIONS_HPP - -#include "common.hpp" - -int residual_casadi(realtype tres, N_Vector yy, N_Vector yp, N_Vector rr, - void *user_data); - -int jtimes_casadi(realtype tt, N_Vector yy, N_Vector yp, N_Vector rr, - N_Vector v, N_Vector Jv, realtype cj, void *user_data, - N_Vector tmp1, N_Vector tmp2); - -int events_casadi(realtype t, N_Vector yy, N_Vector yp, realtype *events_ptr, - void *user_data); - -int sensitivities_casadi(int Ns, realtype t, N_Vector yy, N_Vector yp, - N_Vector resval, N_Vector *yS, N_Vector *ypS, - N_Vector *resvalS, void *user_data, N_Vector tmp1, - N_Vector tmp2, N_Vector tmp3); - -int jacobian_casadi(realtype tt, realtype cj, N_Vector yy, N_Vector yp, - N_Vector resvec, SUNMatrix JJ, void *user_data, - N_Vector tempv1, N_Vector tempv2, N_Vector tempv3); - -int residual_casadi_approx(sunindextype Nlocal, realtype tt, N_Vector yy, - N_Vector yp, N_Vector gval, void *user_data); -#endif // PYBAMM_IDAKLU_CASADI_SUNDIALS_FUNCTIONS_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/common.hpp b/pybamm/solvers/c_solvers/idaklu/common.hpp index e0abbb5a1d..0ef7ee60a0 100644 --- a/pybamm/solvers/c_solvers/idaklu/common.hpp +++ b/pybamm/solvers/c_solvers/idaklu/common.hpp @@ -1,6 +1,8 @@ #ifndef PYBAMM_IDAKLU_COMMON_HPP #define PYBAMM_IDAKLU_COMMON_HPP +#include + #include /* prototypes for IDAS fcts., consts. */ #include /* access to IDABBDPRE preconditioner */ @@ -33,16 +35,58 @@ using np_array = py::array_t; using np_array_dense = py::array_t; using np_array_int = py::array_t; +/** + * Utility function to convert compressed-sparse-column (CSC) to/from + * compressed-sparse-row (CSR) matrix representation. Conversion is symmetric / + * invertible using this function. + * @brief Utility function to convert to/from CSC/CSR matrix representations. + * @param f Data vector containing the sparse matrix elements + * @param c Index pointer to column starts + * @param r Array of row indices + * @param nf New data vector that will contain the transformed sparse matrix + * @param nc New array of column indices + * @param nr New index pointer to row starts + */ +template +void csc_csr(const realtype f[], const T1 c[], const T1 r[], realtype nf[], T2 nc[], T2 nr[], int N, int cols) { + std::vector nn(cols+1); + std::vector rr(N); + for (int i=0; i; } \ std::cout << "]" << std::endl; } -#define DEBUG_v(v, N) {\ +#define DEBUG_v(v, M) {\ + int N = 2; \ std::cout << #v << "[n=" << N << "] = ["; \ for (int i = 0; i < N; i++) { \ std::cout << v[i]; \ @@ -82,6 +127,13 @@ using np_array_int = py::array_t; std::cerr << __FILE__ << ":" << __LINE__ << "," << #x << " = " << x << std::endl; \ } +#define ASSERT(x) { \ + if (!(x)) { \ + std::cerr << __FILE__ << ":" << __LINE__ << " Assertion failed: " << #x << std::endl; \ + throw std::runtime_error("Assertion failed: " #x); \ + } \ + } + #endif #endif // PYBAMM_IDAKLU_COMMON_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp b/pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp similarity index 69% rename from pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp rename to pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp index 9fcfa06510..a53b167ac4 100644 --- a/pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp +++ b/pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp @@ -1,37 +1,42 @@ -#include "casadi_solver.hpp" -#include "CasadiSolver.hpp" -#include "CasadiSolverOpenMP_solvers.hpp" -#include "casadi_sundials_functions.hpp" -#include "common.hpp" +#ifndef PYBAMM_CREATE_IDAKLU_SOLVER_HPP +#define PYBAMM_CREATE_IDAKLU_SOLVER_HPP + +#include "IDAKLUSolverOpenMP_solvers.hpp" #include #include -CasadiSolver *create_casadi_solver( +/** + * Creates a concrete solver given a linear solver, as specified in + * options_cpp.linear_solver. + * @brief Create a concrete solver given a linear solver + */ +template +IDAKLUSolver *create_idaklu_solver( int number_of_states, int number_of_parameters, - const Function &rhs_alg, - const Function &jac_times_cjmass, + const typename ExprSet::BaseFunctionType &rhs_alg, + const typename ExprSet::BaseFunctionType &jac_times_cjmass, const np_array_int &jac_times_cjmass_colptrs, const np_array_int &jac_times_cjmass_rowvals, const int jac_times_cjmass_nnz, const int jac_bandwidth_lower, const int jac_bandwidth_upper, - const Function &jac_action, - const Function &mass_action, - const Function &sens, - const Function &events, + const typename ExprSet::BaseFunctionType &jac_action, + const typename ExprSet::BaseFunctionType &mass_action, + const typename ExprSet::BaseFunctionType &sens, + const typename ExprSet::BaseFunctionType &events, const int number_of_events, np_array rhs_alg_id, np_array atol_np, double rel_tol, int inputs_length, - const std::vector& var_casadi_fcns, - const std::vector& dvar_dy_fcns, - const std::vector& dvar_dp_fcns, + const std::vector& var_fcns, + const std::vector& dvar_dy_fcns, + const std::vector& dvar_dp_fcns, py::dict options ) { auto options_cpp = Options(options); - auto functions = std::make_unique( + auto functions = std::make_unique( rhs_alg, jac_times_cjmass, jac_times_cjmass_nnz, @@ -47,19 +52,19 @@ CasadiSolver *create_casadi_solver( number_of_states, number_of_events, number_of_parameters, - var_casadi_fcns, + var_fcns, dvar_dy_fcns, dvar_dp_fcns, options_cpp ); - CasadiSolver *casadiSolver = nullptr; + IDAKLUSolver *idakluSolver = nullptr; // Instantiate solver class if (options_cpp.linear_solver == "SUNLinSol_Dense") { DEBUG("\tsetting SUNLinSol_Dense linear solver"); - casadiSolver = new CasadiSolverOpenMP_Dense( + idakluSolver = new IDAKLUSolverOpenMP_Dense( atol_np, rel_tol, rhs_alg_id, @@ -75,7 +80,7 @@ CasadiSolver *create_casadi_solver( else if (options_cpp.linear_solver == "SUNLinSol_KLU") { DEBUG("\tsetting SUNLinSol_KLU linear solver"); - casadiSolver = new CasadiSolverOpenMP_KLU( + idakluSolver = new IDAKLUSolverOpenMP_KLU( atol_np, rel_tol, rhs_alg_id, @@ -91,7 +96,7 @@ CasadiSolver *create_casadi_solver( else if (options_cpp.linear_solver == "SUNLinSol_Band") { DEBUG("\tsetting SUNLinSol_Band linear solver"); - casadiSolver = new CasadiSolverOpenMP_Band( + idakluSolver = new IDAKLUSolverOpenMP_Band( atol_np, rel_tol, rhs_alg_id, @@ -107,7 +112,7 @@ CasadiSolver *create_casadi_solver( else if (options_cpp.linear_solver == "SUNLinSol_SPBCGS") { DEBUG("\tsetting SUNLinSol_SPBCGS_linear solver"); - casadiSolver = new CasadiSolverOpenMP_SPBCGS( + idakluSolver = new IDAKLUSolverOpenMP_SPBCGS( atol_np, rel_tol, rhs_alg_id, @@ -123,7 +128,7 @@ CasadiSolver *create_casadi_solver( else if (options_cpp.linear_solver == "SUNLinSol_SPFGMR") { DEBUG("\tsetting SUNLinSol_SPFGMR_linear solver"); - casadiSolver = new CasadiSolverOpenMP_SPFGMR( + idakluSolver = new IDAKLUSolverOpenMP_SPFGMR( atol_np, rel_tol, rhs_alg_id, @@ -139,7 +144,7 @@ CasadiSolver *create_casadi_solver( else if (options_cpp.linear_solver == "SUNLinSol_SPGMR") { DEBUG("\tsetting SUNLinSol_SPGMR solver"); - casadiSolver = new CasadiSolverOpenMP_SPGMR( + idakluSolver = new IDAKLUSolverOpenMP_SPGMR( atol_np, rel_tol, rhs_alg_id, @@ -155,7 +160,7 @@ CasadiSolver *create_casadi_solver( else if (options_cpp.linear_solver == "SUNLinSol_SPTFQMR") { DEBUG("\tsetting SUNLinSol_SPGMR solver"); - casadiSolver = new CasadiSolverOpenMP_SPTFQMR( + idakluSolver = new IDAKLUSolverOpenMP_SPTFQMR( atol_np, rel_tol, rhs_alg_id, @@ -169,9 +174,11 @@ CasadiSolver *create_casadi_solver( ); } - if (casadiSolver == nullptr) { + if (idakluSolver == nullptr) { throw std::invalid_argument("Unsupported solver requested"); } - return casadiSolver; + return idakluSolver; } + +#endif // PYBAMM_CREATE_IDAKLU_SOLVER_HPP diff --git a/pybamm/solvers/c_solvers/idaklu/python.hpp b/pybamm/solvers/c_solvers/idaklu/python.hpp index 0478d0946f..6231d13eb6 100644 --- a/pybamm/solvers/c_solvers/idaklu/python.hpp +++ b/pybamm/solvers/c_solvers/idaklu/python.hpp @@ -2,7 +2,7 @@ #define PYBAMM_IDAKLU_HPP #include "common.hpp" -#include "solution.hpp" +#include "Solution.hpp" #include using residual_type = std::function< diff --git a/pybamm/solvers/c_solvers/idaklu/solution.cpp b/pybamm/solvers/c_solvers/idaklu/solution.cpp deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/pybamm/solvers/c_solvers/idaklu/sundials_functions.hpp b/pybamm/solvers/c_solvers/idaklu/sundials_functions.hpp new file mode 100644 index 0000000000..c4024bc20a --- /dev/null +++ b/pybamm/solvers/c_solvers/idaklu/sundials_functions.hpp @@ -0,0 +1,36 @@ +#ifndef PYBAMM_SUNDIALS_FUNCTIONS_HPP +#define PYBAMM_SUNDIALS_FUNCTIONS_HPP + +#include "common.hpp" + +template +void axpy(int n, T alpha, const T* x, T* y) { + if (!x || !y) return; + for (int i=0; i #define NV_DATA NV_DATA_OMP // Serial: NV_DATA_S -int residual_casadi(realtype tres, N_Vector yy, N_Vector yp, N_Vector rr, - void *user_data) +template +int residual_eval(realtype tres, N_Vector yy, N_Vector yp, N_Vector rr, void *user_data) { - DEBUG("residual_casadi"); - CasadiFunctions *p_python_functions = - static_cast(user_data); + DEBUG("residual_eval"); + ExpressionSet *p_python_functions = + static_cast *>(user_data); - p_python_functions->rhs_alg.m_arg[0] = &tres; - p_python_functions->rhs_alg.m_arg[1] = NV_DATA(yy); - p_python_functions->rhs_alg.m_arg[2] = p_python_functions->inputs.data(); - p_python_functions->rhs_alg.m_res[0] = NV_DATA(rr); - p_python_functions->rhs_alg(); + DEBUG_VECTORn(yy, 100); + DEBUG_VECTORn(yp, 100); + + p_python_functions->rhs_alg->m_arg[0] = &tres; + p_python_functions->rhs_alg->m_arg[1] = NV_DATA(yy); + p_python_functions->rhs_alg->m_arg[2] = p_python_functions->inputs.data(); + p_python_functions->rhs_alg->m_res[0] = NV_DATA(rr); + (*p_python_functions->rhs_alg)(); + + DEBUG_VECTORn(rr, 100); realtype *tmp = p_python_functions->get_tmp_state_vector(); - p_python_functions->mass_action.m_arg[0] = NV_DATA(yp); - p_python_functions->mass_action.m_res[0] = tmp; - p_python_functions->mass_action(); + p_python_functions->mass_action->m_arg[0] = NV_DATA(yp); + p_python_functions->mass_action->m_res[0] = tmp; + (*p_python_functions->mass_action)(); // AXPY: y <- a*x + y const int ns = p_python_functions->number_of_states; - casadi::casadi_axpy(ns, -1., tmp, NV_DATA(rr)); + axpy(ns, -1., tmp, NV_DATA(rr)); - //DEBUG_VECTOR(yy); - //DEBUG_VECTOR(yp); - //DEBUG_VECTOR(rr); + DEBUG("mass - rhs"); + DEBUG_VECTORn(rr, 100); // now rr has rhs_alg(t, y) - mass_matrix * yp return 0; @@ -64,13 +68,14 @@ int residual_casadi(realtype tres, N_Vector yy, N_Vector yp, N_Vector rr, // within user_data. // // The case where G is mathematically identical to F is allowed. -int residual_casadi_approx(sunindextype Nlocal, realtype tt, N_Vector yy, +template +int residual_eval_approx(sunindextype Nlocal, realtype tt, N_Vector yy, N_Vector yp, N_Vector gval, void *user_data) { - DEBUG("residual_casadi_approx"); + DEBUG("residual_eval_approx"); // Just use true residual for now - int result = residual_casadi(tt, yy, yp, gval, user_data); + int result = residual_eval(tt, yy, yp, gval, user_data); return result; } @@ -94,32 +99,35 @@ int residual_casadi_approx(sunindextype Nlocal, realtype tt, N_Vector yy, // tmp2 are pointers to memory allocated for variables of type N Vector // which can // be used by IDALsJacTimesVecFn as temporary storage or work space. -int jtimes_casadi(realtype tt, N_Vector yy, N_Vector yp, N_Vector rr, +template +int jtimes_eval(realtype tt, N_Vector yy, N_Vector yp, N_Vector rr, N_Vector v, N_Vector Jv, realtype cj, void *user_data, N_Vector tmp1, N_Vector tmp2) { - DEBUG("jtimes_casadi"); - CasadiFunctions *p_python_functions = - static_cast(user_data); + DEBUG("jtimes_eval"); + T *p_python_functions = + static_cast(user_data); // Jv has ∂F/∂y v - p_python_functions->jac_action.m_arg[0] = &tt; - p_python_functions->jac_action.m_arg[1] = NV_DATA(yy); - p_python_functions->jac_action.m_arg[2] = p_python_functions->inputs.data(); - p_python_functions->jac_action.m_arg[3] = NV_DATA(v); - p_python_functions->jac_action.m_res[0] = NV_DATA(Jv); - p_python_functions->jac_action(); + p_python_functions->jac_action->m_arg[0] = &tt; + p_python_functions->jac_action->m_arg[1] = NV_DATA(yy); + p_python_functions->jac_action->m_arg[2] = p_python_functions->inputs.data(); + p_python_functions->jac_action->m_arg[3] = NV_DATA(v); + p_python_functions->jac_action->m_res[0] = NV_DATA(Jv); + (*p_python_functions->jac_action)(); // tmp has -∂F/∂y˙ v realtype *tmp = p_python_functions->get_tmp_state_vector(); - p_python_functions->mass_action.m_arg[0] = NV_DATA(v); - p_python_functions->mass_action.m_res[0] = tmp; - p_python_functions->mass_action(); + p_python_functions->mass_action->m_arg[0] = NV_DATA(v); + p_python_functions->mass_action->m_res[0] = tmp; + (*p_python_functions->mass_action)(); // AXPY: y <- a*x + y // Jv has ∂F/∂y v + cj ∂F/∂y˙ v const int ns = p_python_functions->number_of_states; - casadi::casadi_axpy(ns, -cj, tmp, NV_DATA(Jv)); + axpy(ns, -cj, tmp, NV_DATA(Jv)); + + DEBUG_VECTORn(Jv, 10); return 0; } @@ -141,14 +149,15 @@ int jtimes_casadi(realtype tt, N_Vector yy, N_Vector yp, N_Vector rr, // tmp3 are pointers to memory allocated for variables of type N Vector which // can // be used by IDALsJacFn function as temporary storage or work space. -int jacobian_casadi(realtype tt, realtype cj, N_Vector yy, N_Vector yp, +template +int jacobian_eval(realtype tt, realtype cj, N_Vector yy, N_Vector yp, N_Vector resvec, SUNMatrix JJ, void *user_data, N_Vector tempv1, N_Vector tempv2, N_Vector tempv3) { - DEBUG("jacobian_casadi"); + DEBUG("jacobian_eval"); - CasadiFunctions *p_python_functions = - static_cast(user_data); + T *p_python_functions = + static_cast(user_data); // create pointer to jac data, column pointers, and row values realtype *jac_data; @@ -164,16 +173,23 @@ int jacobian_casadi(realtype tt, realtype cj, N_Vector yy, N_Vector yp, jac_data = SUNDenseMatrix_Data(JJ); } + DEBUG_VECTORn(yy, 100); + // args are t, y, cj, put result in jacobian data matrix - p_python_functions->jac_times_cjmass.m_arg[0] = &tt; - p_python_functions->jac_times_cjmass.m_arg[1] = NV_DATA(yy); - p_python_functions->jac_times_cjmass.m_arg[2] = + p_python_functions->jac_times_cjmass->m_arg[0] = &tt; + p_python_functions->jac_times_cjmass->m_arg[1] = NV_DATA(yy); + p_python_functions->jac_times_cjmass->m_arg[2] = p_python_functions->inputs.data(); - p_python_functions->jac_times_cjmass.m_arg[3] = &cj; - p_python_functions->jac_times_cjmass.m_res[0] = jac_data; - - p_python_functions->jac_times_cjmass(); + p_python_functions->jac_times_cjmass->m_arg[3] = &cj; + p_python_functions->jac_times_cjmass->m_res[0] = jac_data; + (*p_python_functions->jac_times_cjmass)(); + DEBUG("jac_times_cjmass [" << sizeof(jac_data) << "]"); + DEBUG("t = " << tt); + DEBUG_VECTORn(yy, 100); + DEBUG("inputs = " << p_python_functions->inputs); + DEBUG("cj = " << cj); + DEBUG_v(jac_data, 100); if (p_python_functions->options.using_banded_matrix) { @@ -219,20 +235,12 @@ int jacobian_casadi(realtype tt, realtype cj, N_Vector yy, N_Vector yp, jac_colptrs[i] = p_jac_times_cjmass_colptrs[i]; } } else if (SUNSparseMatrix_SparseType(JJ) == CSR_MAT) { - std::vector newjac(SUNSparseMatrix_NNZ(JJ)); + // make a copy so that we can overwrite jac_data as CSR + std::vector newjac(&jac_data[0], &jac_data[SUNSparseMatrix_NNZ(JJ)]); sunindextype *jac_ptrs = SUNSparseMatrix_IndexPointers(JJ); sunindextype *jac_vals = SUNSparseMatrix_IndexValues(JJ); - // args are t, y, cj, put result in jacobian data matrix - p_python_functions->jac_times_cjmass.m_arg[0] = &tt; - p_python_functions->jac_times_cjmass.m_arg[1] = NV_DATA(yy); - p_python_functions->jac_times_cjmass.m_arg[2] = - p_python_functions->inputs.data(); - p_python_functions->jac_times_cjmass.m_arg[3] = &cj; - p_python_functions->jac_times_cjmass.m_res[0] = newjac.data(); - p_python_functions->jac_times_cjmass(); - - // convert (casadi's) CSC format to CSR + // convert CSC format to CSR csc_csr< std::remove_pointer_tjac_times_cjmass_rowvals.data())>, std::remove_pointer_t @@ -253,18 +261,20 @@ int jacobian_casadi(realtype tt, realtype cj, N_Vector yy, N_Vector yp, return (0); } -int events_casadi(realtype t, N_Vector yy, N_Vector yp, realtype *events_ptr, +template +int events_eval(realtype t, N_Vector yy, N_Vector yp, realtype *events_ptr, void *user_data) { - CasadiFunctions *p_python_functions = - static_cast(user_data); + DEBUG("events_eval"); + T *p_python_functions = + static_cast(user_data); // args are t, y, put result in events_ptr - p_python_functions->events.m_arg[0] = &t; - p_python_functions->events.m_arg[1] = NV_DATA(yy); - p_python_functions->events.m_arg[2] = p_python_functions->inputs.data(); - p_python_functions->events.m_res[0] = events_ptr; - p_python_functions->events(); + p_python_functions->events->m_arg[0] = &t; + p_python_functions->events->m_arg[1] = NV_DATA(yy); + p_python_functions->events->m_arg[2] = p_python_functions->inputs.data(); + p_python_functions->events->m_res[0] = events_ptr; + (*p_python_functions->events)(); return (0); } @@ -290,52 +300,52 @@ int events_casadi(realtype t, N_Vector yy, N_Vector yp, realtype *events_ptr, // occurred (in which case idas will attempt to correct), // or a negative value if it failed unrecoverably (in which case the integration // is halted and IDA SRES FAIL is returned) -// -int sensitivities_casadi(int Ns, realtype t, N_Vector yy, N_Vector yp, +template +int sensitivities_eval(int Ns, realtype t, N_Vector yy, N_Vector yp, N_Vector resval, N_Vector *yS, N_Vector *ypS, N_Vector *resvalS, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) { - DEBUG("sensitivities_casadi"); - CasadiFunctions *p_python_functions = - static_cast(user_data); + DEBUG("sensitivities_eval"); + T *p_python_functions = + static_cast(user_data); const int np = p_python_functions->number_of_parameters; // args are t, y put result in rr - p_python_functions->sens.m_arg[0] = &t; - p_python_functions->sens.m_arg[1] = NV_DATA(yy); - p_python_functions->sens.m_arg[2] = p_python_functions->inputs.data(); + p_python_functions->sens->m_arg[0] = &t; + p_python_functions->sens->m_arg[1] = NV_DATA(yy); + p_python_functions->sens->m_arg[2] = p_python_functions->inputs.data(); for (int i = 0; i < np; i++) { - p_python_functions->sens.m_res[i] = NV_DATA(resvalS[i]); + p_python_functions->sens->m_res[i] = NV_DATA(resvalS[i]); } // resvalsS now has (∂F/∂p i ) - p_python_functions->sens(); + (*p_python_functions->sens)(); for (int i = 0; i < np; i++) { // put (∂F/∂y)s i (t) in tmp realtype *tmp = p_python_functions->get_tmp_state_vector(); - p_python_functions->jac_action.m_arg[0] = &t; - p_python_functions->jac_action.m_arg[1] = NV_DATA(yy); - p_python_functions->jac_action.m_arg[2] = p_python_functions->inputs.data(); - p_python_functions->jac_action.m_arg[3] = NV_DATA(yS[i]); - p_python_functions->jac_action.m_res[0] = tmp; - p_python_functions->jac_action(); + p_python_functions->jac_action->m_arg[0] = &t; + p_python_functions->jac_action->m_arg[1] = NV_DATA(yy); + p_python_functions->jac_action->m_arg[2] = p_python_functions->inputs.data(); + p_python_functions->jac_action->m_arg[3] = NV_DATA(yS[i]); + p_python_functions->jac_action->m_res[0] = tmp; + (*p_python_functions->jac_action)(); const int ns = p_python_functions->number_of_states; - casadi::casadi_axpy(ns, 1., tmp, NV_DATA(resvalS[i])); + axpy(ns, 1., tmp, NV_DATA(resvalS[i])); // put -(∂F/∂ ẏ) ṡ i (t) in tmp2 - p_python_functions->mass_action.m_arg[0] = NV_DATA(ypS[i]); - p_python_functions->mass_action.m_res[0] = tmp; - p_python_functions->mass_action(); + p_python_functions->mass_action->m_arg[0] = NV_DATA(ypS[i]); + p_python_functions->mass_action->m_res[0] = tmp; + (*p_python_functions->mass_action)(); // (∂F/∂y)s i (t)+(∂F/∂ ẏ) ṡ i (t)+(∂F/∂p i ) // AXPY: y <- a*x + y - casadi::casadi_axpy(ns, -1., tmp, NV_DATA(resvalS[i])); + axpy(ns, -1., tmp, NV_DATA(resvalS[i])); } return 0; diff --git a/pybamm/solvers/idaklu_solver.py b/pybamm/solvers/idaklu_solver.py index fef4cbce3c..f1f32b1e63 100644 --- a/pybamm/solvers/idaklu_solver.py +++ b/pybamm/solvers/idaklu_solver.py @@ -2,13 +2,25 @@ # Solver class using sundials with the KLU sparse linear solver # # mypy: ignore-errors +import os import casadi import pybamm import numpy as np import numbers import scipy.sparse as sparse +from scipy.linalg import bandwidth import importlib +import warnings + +if pybamm.have_jax(): + import jax + from jax import numpy as jnp + + try: + import iree.compiler + except ImportError: # pragma: no cover + pass idaklu_spec = importlib.util.find_spec("pybamm.solvers.idaklu") if idaklu_spec is not None: @@ -24,6 +36,15 @@ def have_idaklu(): return idaklu_spec is not None +def have_iree(): + try: + import iree.compiler # noqa: F401 + + return True + except ImportError: # pragma: no cover + return False + + class IDAKLUSolver(pybamm.BaseSolver): """ Solve a discretised model, using sundials with the KLU sparse linear solver. @@ -75,6 +96,8 @@ class IDAKLUSolver(pybamm.BaseSolver): "precon_half_bandwidth_keep": 5, # Number of threads available for OpenMP "num_threads": 1, + # Evaluation engine to use for jax, can be 'jax'(native) or 'iree' + "jax_evaluator": "jax", } Note: These options only have an effect if model.convert_to_format == 'casadi' @@ -103,6 +126,7 @@ def __init__( "precon_half_bandwidth": 5, "precon_half_bandwidth_keep": 5, "num_threads": 1, + "jax_evaluator": "jax", } if options is None: options = default_options @@ -110,6 +134,10 @@ def __init__( for key, value in default_options.items(): if key not in options: options[key] = value + if options["jax_evaluator"] not in ["jax", "iree"]: + raise pybamm.SolverError( + "Evaluation engine must be 'jax' or 'iree' for IDAKLU solver" + ) self._options = options self.output_variables = [] if output_variables is None else output_variables @@ -183,10 +211,14 @@ def inputs_to_dict(inputs): # only casadi solver needs sensitivity ics if model.convert_to_format != "casadi": y0S = None - if self.output_variables: + if self.output_variables and not ( + model.convert_to_format == "jax" + and self._options["jax_evaluator"] == "iree" + ): raise pybamm.SolverError( "output_variables can only be specified " - 'with convert_to_format="casadi"' + 'with convert_to_format="casadi", or convert_to_format="jax" ' + 'with jax_evaluator="iree"' ) # pragma: no cover if y0S is not None: if isinstance(y0S, casadi.DM): @@ -293,7 +325,7 @@ def resfn(t, y, inputs, ydot): ) ) - else: + elif self._options["jax_evaluator"] == "jax": t0 = 0 if t_eval is None else t_eval[0] jac_y0_t0 = model.jac_rhs_algebraic_eval(t0, y0, inputs_dict) if sparse.issparse(jac_y0_t0): @@ -355,7 +387,7 @@ def get_jac_col_ptrs(self): ) ], ) - else: + elif self._options["jax_evaluator"] == "jax": def rootfn(t, y, inputs): new_inputs = inputs_to_dict(inputs) @@ -437,40 +469,220 @@ def sensfn(resvalS, t, y, inputs, yp, yS, ypS): rtol = self.rtol atol = self._check_atol_type(atol, y0.size) - if model.convert_to_format == "casadi": - rhs_algebraic = idaklu.generate_function(rhs_algebraic.serialize()) - jac_times_cjmass = idaklu.generate_function(jac_times_cjmass.serialize()) - jac_rhs_algebraic_action = idaklu.generate_function( - jac_rhs_algebraic_action.serialize() - ) - rootfn = idaklu.generate_function(rootfn.serialize()) - mass_action = idaklu.generate_function(mass_action.serialize()) - sensfn = idaklu.generate_function(sensfn.serialize()) + if model.convert_to_format == "casadi" or ( + model.convert_to_format == "jax" + and self._options["jax_evaluator"] == "iree" + ): + if model.convert_to_format == "casadi": + # Serialize casadi functions + idaklu_solver_fcn = idaklu.create_casadi_solver + rhs_algebraic = idaklu.generate_function(rhs_algebraic.serialize()) + jac_times_cjmass = idaklu.generate_function( + jac_times_cjmass.serialize() + ) + jac_rhs_algebraic_action = idaklu.generate_function( + jac_rhs_algebraic_action.serialize() + ) + rootfn = idaklu.generate_function(rootfn.serialize()) + mass_action = idaklu.generate_function(mass_action.serialize()) + sensfn = idaklu.generate_function(sensfn.serialize()) + elif ( + model.convert_to_format == "jax" + and self._options["jax_evaluator"] == "iree" + ): + # Convert Jax functions to MLIR (also, demote to single precision) + idaklu_solver_fcn = idaklu.create_iree_solver + pybamm.demote_expressions_to_32bit = True + if pybamm.demote_expressions_to_32bit: + warnings.warn( + "Demoting expressions to 32-bit for MLIR conversion", + stacklevel=2, + ) + jnpfloat = jnp.float32 + else: # pragma: no cover + jnpfloat = jnp.float64 + raise pybamm.SolverError( + "Demoting expressions to 32-bit is required for MLIR conversion" + " at this time" + ) + + # input arguments (used for lowering) + t_eval = self._demote_64_to_32(jnp.array([0.0], dtype=jnpfloat)) + y0 = self._demote_64_to_32(model.y0) + inputs0 = self._demote_64_to_32(inputs_to_dict(inputs)) + cj = self._demote_64_to_32(jnp.array([1.0], dtype=jnpfloat)) # array + v0 = jnp.zeros(model.len_rhs_and_alg, jnpfloat) + mass_matrix = model.mass_matrix.entries.toarray() + mass_matrix_demoted = self._demote_64_to_32(mass_matrix) + + # rhs_algebraic + rhs_algebraic_demoted = model.rhs_algebraic_eval + rhs_algebraic_demoted._demote_constants() + + def fcn_rhs_algebraic(t, y, inputs): + # function wraps an expression tree (and names MLIR module) + return rhs_algebraic_demoted(t, y, inputs) + + rhs_algebraic = self._make_iree_function( + fcn_rhs_algebraic, t_eval, y0, inputs0 + ) + + # jac_times_cjmass + jac_rhs_algebraic_demoted = rhs_algebraic_demoted.get_jacobian() + + def fcn_jac_times_cjmass(t, y, p, cj): + return jac_rhs_algebraic_demoted(t, y, p) - cj * mass_matrix_demoted + + sparse_eval = sparse.csc_matrix( + fcn_jac_times_cjmass(t_eval, y0, inputs0, cj) + ) + jac_times_cjmass_nnz = sparse_eval.nnz + jac_times_cjmass_colptrs = sparse_eval.indptr + jac_times_cjmass_rowvals = sparse_eval.indices + jac_bw_lower, jac_bw_upper = bandwidth( + sparse_eval.todense() + ) # potentially slow + if jac_bw_upper <= 1: + jac_bw_upper = jac_bw_lower - 1 + if jac_bw_lower <= 1: + jac_bw_lower = jac_bw_upper + 1 + coo = sparse_eval.tocoo() # convert to COOrdinate format for indexing + + def fcn_jac_times_cjmass_sparse(t, y, p, cj): + return fcn_jac_times_cjmass(t, y, p, cj)[coo.row, coo.col] + + jac_times_cjmass = self._make_iree_function( + fcn_jac_times_cjmass_sparse, t_eval, y0, inputs0, cj + ) + + # Mass action + def fcn_mass_action(v): + return mass_matrix_demoted @ v + + mass_action_demoted = self._demote_64_to_32(fcn_mass_action) + mass_action = self._make_iree_function(mass_action_demoted, v0) + + # rootfn + for ix, _ in enumerate(model.terminate_events_eval): + model.terminate_events_eval[ix]._demote_constants() + + def fcn_rootfn(t, y, inputs): + return jnp.array( + [event(t, y, inputs) for event in model.terminate_events_eval], + dtype=jnpfloat, + ).reshape(-1) + + def fcn_rootfn_demoted(t, y, inputs): + return self._demote_64_to_32(fcn_rootfn)(t, y, inputs) + + rootfn = self._make_iree_function( + fcn_rootfn_demoted, t_eval, y0, inputs0 + ) + + # jac_rhs_algebraic_action + jac_rhs_algebraic_action_demoted = ( + rhs_algebraic_demoted.get_jacobian_action() + ) + + def fcn_jac_rhs_algebraic_action( + t, y, p, v + ): # sundials calls (t, y, inputs, v) + return jac_rhs_algebraic_action_demoted( + t, y, v, p + ) # jvp calls (t, y, v, inputs) + + jac_rhs_algebraic_action = self._make_iree_function( + fcn_jac_rhs_algebraic_action, t_eval, y0, inputs0, v0 + ) + + # sensfn + if model.jacp_rhs_algebraic_eval is None: + sensfn = idaklu.IREEBaseFunctionType() # empty equation + else: + sensfn_demoted = rhs_algebraic_demoted.get_sensitivities() + + def fcn_sensfn(t, y, p): + return sensfn_demoted(t, y, p) + + sensfn = self._make_iree_function( + fcn_sensfn, t_eval, jnp.zeros_like(y0), inputs0 + ) + + # output_variables + self.var_idaklu_fcns = [] + self.dvar_dy_idaklu_fcns = [] + self.dvar_dp_idaklu_fcns = [] + for key in self.output_variables: + fcn = self.computed_var_fcns[key] + fcn._demote_constants() + self.var_idaklu_fcns.append( + self._make_iree_function( + lambda t, y, p: fcn(t, y, p), # noqa: B023 + t_eval, + y0, + inputs0, + ) + ) + # Convert derivative functions for sensitivities + if (len(inputs) > 0) and (model.calculate_sensitivities): + dvar_dy = fcn.get_jacobian() + self.dvar_dy_idaklu_fcns.append( + self._make_iree_function( + lambda t, y, p: dvar_dy(t, y, p), # noqa: B023 + t_eval, + y0, + inputs0, + sparse_index=True, + ) + ) + dvar_dp = fcn.get_sensitivities() + self.dvar_dp_idaklu_fcns.append( + self._make_iree_function( + lambda t, y, p: dvar_dp(t, y, p), # noqa: B023 + t_eval, + y0, + inputs0, + ) + ) + + # Identify IREE library + iree_lib_path = os.path.join(iree.compiler.__path__[0], "_mlir_libs") + os.environ["IREE_COMPILER_LIB"] = os.path.join( + iree_lib_path, + next(f for f in os.listdir(iree_lib_path) if "IREECompiler" in f), + ) + + pybamm.demote_expressions_to_32bit = False + else: # pragma: no cover + raise pybamm.SolverError( + "Unsupported evaluation engine for convert_to_format='jax'" + ) self._setup = { - "jac_bandwidth_upper": jac_bw_upper, - "jac_bandwidth_lower": jac_bw_lower, - "rhs_algebraic": rhs_algebraic, - "jac_times_cjmass": jac_times_cjmass, - "jac_times_cjmass_colptrs": jac_times_cjmass_colptrs, - "jac_times_cjmass_rowvals": jac_times_cjmass_rowvals, - "jac_times_cjmass_nnz": jac_times_cjmass_nnz, - "jac_rhs_algebraic_action": jac_rhs_algebraic_action, - "mass_action": mass_action, - "sensfn": sensfn, - "rootfn": rootfn, - "num_of_events": num_of_events, - "ids": ids, + "solver_function": idaklu_solver_fcn, # callable + "jac_bandwidth_upper": jac_bw_upper, # int + "jac_bandwidth_lower": jac_bw_lower, # int + "rhs_algebraic": rhs_algebraic, # function + "jac_times_cjmass": jac_times_cjmass, # function + "jac_times_cjmass_colptrs": jac_times_cjmass_colptrs, # array + "jac_times_cjmass_rowvals": jac_times_cjmass_rowvals, # array + "jac_times_cjmass_nnz": jac_times_cjmass_nnz, # int + "jac_rhs_algebraic_action": jac_rhs_algebraic_action, # function + "mass_action": mass_action, # function + "sensfn": sensfn, # function + "rootfn": rootfn, # function + "num_of_events": num_of_events, # int + "ids": ids, # array "sensitivity_names": sensitivity_names, "number_of_sensitivity_parameters": number_of_sensitivity_parameters, "output_variables": self.output_variables, - "var_casadi_fcns": self.computed_var_fcns, + "var_fcns": self.computed_var_fcns, "var_idaklu_fcns": self.var_idaklu_fcns, "dvar_dy_idaklu_fcns": self.dvar_dy_idaklu_fcns, "dvar_dp_idaklu_fcns": self.dvar_dp_idaklu_fcns, } - solver = idaklu.create_casadi_solver( + solver = self._setup["solver_function"]( number_of_states=len(y0), number_of_parameters=self._setup["number_of_sensitivity_parameters"], rhs_alg=self._setup["rhs_algebraic"], @@ -489,7 +701,7 @@ def sensfn(resvalS, t, y, inputs, yp, yS, ypS): atol=atol, rtol=rtol, inputs=len(inputs), - var_casadi_fcns=self._setup["var_idaklu_fcns"], + var_fcns=self._setup["var_idaklu_fcns"], dvar_dy_fcns=self._setup["dvar_dy_idaklu_fcns"], dvar_dp_fcns=self._setup["dvar_dp_idaklu_fcns"], options=self._options, @@ -511,6 +723,56 @@ def sensfn(resvalS, t, y, inputs, yp, yS, ypS): return base_set_up_return + def _make_iree_function(self, fcn, *args, sparse_index=False): + # Initialise IREE function object + iree_fcn = idaklu.IREEBaseFunctionType() + # Get sparsity pattern index outputs as needed + try: + fcn_eval = fcn(*args) + if not isinstance(fcn_eval, np.ndarray): + fcn_eval = jax.flatten_util.ravel_pytree(fcn_eval)[0] + coo = sparse.coo_matrix(fcn_eval) + iree_fcn.nnz = coo.nnz + iree_fcn.numel = np.prod(coo.shape) + iree_fcn.col = coo.col + iree_fcn.row = coo.row + if sparse_index: + # Isolate NNZ elements while recording original sparsity structure + fcn_inner = fcn + + def fcn(*args): + return fcn_inner(*args)[coo.row, coo.col] + elif coo.nnz != iree_fcn.numel: + iree_fcn.nnz = iree_fcn.numel + iree_fcn.col = list(range(iree_fcn.numel)) + iree_fcn.row = [0] * iree_fcn.numel + except (TypeError, AttributeError) as error: # pragma: no cover + raise pybamm.SolverError( + "Could not get sparsity pattern for function {fcn.__name__}" + ) from error + # Lower to MLIR + lowered = jax.jit(fcn).lower(*args) + iree_fcn.mlir = lowered.as_text() + self._check_mlir_conversion(fcn.__name__, iree_fcn.mlir) + iree_fcn.kept_var_idx = list(lowered._lowering.compile_args["kept_var_idx"]) + # Record number of variables in each argument (these will flatten in the mlir) + iree_fcn.pytree_shape = [ + len(jax.tree_util.tree_flatten(arg)[0]) for arg in args + ] + # Record array length of each mlir variable + iree_fcn.pytree_sizes = [ + len(arg) for arg in jax.tree_util.tree_flatten(args)[0] + ] + iree_fcn.n_args = len(args) + return iree_fcn + + def _check_mlir_conversion(self, name, mlir: str): + if mlir.count("f64") > 0: # pragma: no cover + warnings.warn(f"f64 found in {name} (x{mlir.count('f64')})", stacklevel=2) + + def _demote_64_to_32(self, x: pybamm.EvaluatorJax): + return pybamm.EvaluatorJax._demote_64_to_32(x) + def _integrate(self, model, t_eval, inputs_dict=None): """ Solve a DAE model defined by residuals with initial conditions y0. @@ -527,10 +789,12 @@ def _integrate(self, model, t_eval, inputs_dict=None): inputs_dict = inputs_dict or {} # stack inputs if inputs_dict: + inputs_dict_keys = list(inputs_dict.keys()) # save order arrays_to_stack = [np.array(x).reshape(-1, 1) for x in inputs_dict.values()] inputs = np.vstack(arrays_to_stack) else: inputs = np.array([[]]) + inputs_dict_keys = [] # do this here cause y0 is set after set_up (calc consistent conditions) y0 = model.y0 @@ -539,25 +803,45 @@ def _integrate(self, model, t_eval, inputs_dict=None): y0 = y0.flatten() y0S = model.y0S - # only casadi solver needs sensitivity ics - if model.convert_to_format != "casadi": - y0S = None - if y0S is not None: - if isinstance(y0S, casadi.DM): - y0S = (y0S,) - - y0S = (x.full() for x in y0S) - y0S = [x.flatten() for x in y0S] - - # solver works with ydot0 set to zero - ydot0 = np.zeros_like(y0) - if y0S is not None: - ydot0S = [np.zeros_like(y0S_i) for y0S_i in y0S] - y0full = np.concatenate([y0, *y0S]) - ydot0full = np.concatenate([ydot0, *ydot0S]) + if ( + model.convert_to_format == "jax" + and self._options["jax_evaluator"] == "iree" + ): + if y0S is not None: + pybamm.demote_expressions_to_32bit = True + # preserve order of inputs + y0S = self._demote_64_to_32( + np.concatenate([y0S[k] for k in inputs_dict_keys]).flatten() + ) + y0full = self._demote_64_to_32(np.concatenate([y0, y0S]).flatten()) + ydot0S = self._demote_64_to_32(np.zeros_like(y0S)) + ydot0full = self._demote_64_to_32( + np.concatenate([np.zeros_like(y0), ydot0S]).flatten() + ) + pybamm.demote_expressions_to_32bit = False + else: + y0full = y0 + ydot0full = np.zeros_like(y0) else: - y0full = y0 - ydot0full = ydot0 + # only casadi solver needs sensitivity ics + if model.convert_to_format != "casadi": + y0S = None + if y0S is not None: + if isinstance(y0S, casadi.DM): + y0S = (y0S,) + + y0S = (x.full() for x in y0S) + y0S = [x.flatten() for x in y0S] + + # solver works with ydot0 set to zero + ydot0 = np.zeros_like(y0) + if y0S is not None: + ydot0S = [np.zeros_like(y0S_i) for y0S_i in y0S] + y0full = np.concatenate([y0, *y0S]) + ydot0full = np.concatenate([ydot0, *ydot0S]) + else: + y0full = y0 + ydot0full = ydot0 try: atol = model.atol @@ -568,7 +852,10 @@ def _integrate(self, model, t_eval, inputs_dict=None): atol = self._check_atol_type(atol, y0.size) timer = pybamm.Timer() - if model.convert_to_format == "casadi": + if model.convert_to_format == "casadi" or ( + model.convert_to_format == "jax" + and self._options["jax_evaluator"] == "iree" + ): sol = self._setup["solver"].solve( t_eval, y0full, @@ -656,12 +943,27 @@ def _integrate(self, model, t_eval, inputs_dict=None): model.variables_and_events[var], pybamm.ExplicitTimeIntegral ): continue - len_of_var = ( - self._setup["var_casadi_fcns"][var](0, 0, 0).sparsity().nnz() - ) + if model.convert_to_format == "casadi": + len_of_var = ( + self._setup["var_fcns"][var](0.0, 0.0, 0.0).sparsity().nnz() + ) + base_variables = [self._setup["var_fcns"][var]] + elif ( + model.convert_to_format == "jax" + and self._options["jax_evaluator"] == "iree" + ): + idx = self.output_variables.index(var) + len_of_var = self._setup["var_idaklu_fcns"][idx].nnz + base_variables = [self._setup["var_idaklu_fcns"][idx]] + else: # pragma: no cover + raise pybamm.SolverError( + "Unsupported evaluation engine for convert_to_format=" + + f"{model.convert_to_format} " + + f"(jax_evaluator={self._options['jax_evaluator']})" + ) newsol._variables[var] = pybamm.ProcessedVariableComputed( [model.variables_and_events[var]], - [self._setup["var_casadi_fcns"][var]], + base_variables, [sol.y[:, startk : (startk + len_of_var)]], newsol, ) diff --git a/pybamm/solvers/processed_variable_computed.py b/pybamm/solvers/processed_variable_computed.py index a069342254..a717c8b0cb 100644 --- a/pybamm/solvers/processed_variable_computed.py +++ b/pybamm/solvers/processed_variable_computed.py @@ -120,16 +120,25 @@ def _unroll_nnz(self, realdata=None): # unroll in nnz != numel, otherwise copy if realdata is None: realdata = self.base_variables_data - sp = self.base_variables_casadi[0](0, 0, 0).sparsity() - if sp.nnz() != sp.numel(): + if isinstance(self.base_variables_casadi[0], casadi.Function): # casadi fcn + sp = self.base_variables_casadi[0](0, 0, 0).sparsity() + nnz = sp.nnz() + numel = sp.numel() + row = sp.row() + elif "nnz" in dir(self.base_variables_casadi[0]): # IREE fcn + sp = self.base_variables_casadi[0] + nnz = sp.nnz + numel = sp.numel + row = sp.row + if nnz != numel: data = [None] * len(realdata) for datak in range(len(realdata)): data[datak] = np.zeros(self.base_eval_shape[0] * len(self.t_pts)) var_data = realdata[0].flatten() k = 0 for t_i in range(len(self.t_pts)): - base = t_i * sp.numel() - for r in sp.row(): + base = t_i * numel + for r in row: data[datak][base + r] = var_data[k] k = k + 1 else: diff --git a/pybamm/version.py b/pybamm/version.py index bc0f2f5d12..4c1e268285 100644 --- a/pybamm/version.py +++ b/pybamm/version.py @@ -1 +1 @@ -__version__ = "24.5rc0" +__version__ = "24.5rc2" diff --git a/pyproject.toml b/pyproject.toml index 890f884769..2bc6f4f3d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta" [project] name = "pybamm" -version = "24.5rc0" +version = "24.5rc2" license = { file = "LICENSE.txt" } description = "Python Battery Mathematical Modelling" authors = [{name = "The PyBaMM Team", email = "pybamm@pybamm.org"}] @@ -116,12 +116,19 @@ dev = [ # To access the metadata for python packages "importlib-metadata; python_version < '3.10'", ] -# For the Jax solver. Note: these must be kept in sync with the versions defined in pybamm/util.py. +# For the Jax solver. +# Note: These must be kept in sync with the versions defined in pybamm/util.py, and +# must remain compatible with IREE (see noxfile.py for IREE compatibility). jax = [ "jax==0.4.27", "jaxlib==0.4.27", ] -# Contains all optional dependencies, except for jax and dev dependencies +# For MLIR expression evaluation (IDAKLU Solver) +iree = [ + # must be pip installed with --find-links=https://iree.dev/pip-release-links.html + "iree-compiler==20240507.886", # see IREE compatibility notes in noxfile.py +] +# Contains all optional dependencies, except for jax, iree, and dev dependencies all = [ "scikit-fem>=8.1.0", "pybamm[examples,plot,cite,bpx,tqdm]", @@ -193,6 +200,7 @@ extend-select = [ "UP", # pyupgrade "YTT", # flake8-2020 "TID252", # relative-imports + "S101", # to identify use of assert statement ] ignore = [ "E741", # Ambiguous variable name @@ -213,7 +221,7 @@ ignore = [ ] [tool.ruff.lint.per-file-ignores] -"tests/*" = ["T20"] +"tests/*" = ["T20", "S101"] "docs/*" = ["T20"] "examples/*" = ["T20"] "**.ipynb" = ["E402", "E703"] diff --git a/setup.py b/setup.py index 6b97f73058..21dabcebb2 100644 --- a/setup.py +++ b/setup.py @@ -92,10 +92,14 @@ def run(self): use_python_casadi = True build_type = os.getenv("PYBAMM_CPP_BUILD_TYPE", "RELEASE") + idaklu_expr_casadi = os.getenv("PYBAMM_IDAKLU_EXPR_CASADI", "ON") + idaklu_expr_iree = os.getenv("PYBAMM_IDAKLU_EXPR_IREE", "OFF") cmake_args = [ f"-DCMAKE_BUILD_TYPE={build_type}", f"-DPYTHON_EXECUTABLE={sys.executable}", "-DUSE_PYTHON_CASADI={}".format("TRUE" if use_python_casadi else "FALSE"), + f"-DPYBAMM_IDAKLU_EXPR_CASADI={idaklu_expr_casadi}", + f"-DPYBAMM_IDAKLU_EXPR_IREE={idaklu_expr_iree}", ] if self.suitesparse_root: cmake_args.append( @@ -291,27 +295,39 @@ def compile_KLU(): name="pybamm.solvers.idaklu", # The sources list should mirror the list in CMakeLists.txt sources=[ - "pybamm/solvers/c_solvers/idaklu/casadi_functions.cpp", - "pybamm/solvers/c_solvers/idaklu/casadi_functions.hpp", - "pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp", - "pybamm/solvers/c_solvers/idaklu/casadi_solver.hpp", - "pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp", - "pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp", - "pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp", - "pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.hpp", - "pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp", - "pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.hpp", - "pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.cpp", - "pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.hpp", - "pybamm/solvers/c_solvers/idaklu/idaklu_jax.cpp", - "pybamm/solvers/c_solvers/idaklu/idaklu_jax.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/Expressions.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/Base/Expression.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionSet.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionTypes.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionSparsity.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/Casadi/CasadiFunctions.cpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/Casadi/CasadiFunctions.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEBaseFunction.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunction.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.cpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.cpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.hpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.cpp", + "pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.hpp", + "pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp", + "pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.cpp", + "pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp", + "pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl", + "pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp", + "pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.cpp", + "pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.hpp", + "pybamm/solvers/c_solvers/idaklu/sundials_functions.inl", + "pybamm/solvers/c_solvers/idaklu/sundials_functions.hpp", + "pybamm/solvers/c_solvers/idaklu/IdakluJax.cpp", + "pybamm/solvers/c_solvers/idaklu/IdakluJax.hpp", "pybamm/solvers/c_solvers/idaklu/common.hpp", "pybamm/solvers/c_solvers/idaklu/python.hpp", "pybamm/solvers/c_solvers/idaklu/python.cpp", - "pybamm/solvers/c_solvers/idaklu/solution.cpp", - "pybamm/solvers/c_solvers/idaklu/solution.hpp", - "pybamm/solvers/c_solvers/idaklu/options.hpp", - "pybamm/solvers/c_solvers/idaklu/options.cpp", + "pybamm/solvers/c_solvers/idaklu/Solution.cpp", + "pybamm/solvers/c_solvers/idaklu/Solution.hpp", + "pybamm/solvers/c_solvers/idaklu/Options.hpp", + "pybamm/solvers/c_solvers/idaklu/Options.cpp", "pybamm/solvers/c_solvers/idaklu.cpp", ], ) diff --git a/tests/integration/test_models/standard_model_tests.py b/tests/integration/test_models/standard_model_tests.py index 7f0e9e6137..3f9cb56354 100644 --- a/tests/integration/test_models/standard_model_tests.py +++ b/tests/integration/test_models/standard_model_tests.py @@ -3,10 +3,9 @@ # import pybamm import tests -import uuid +import tempfile import numpy as np -import os class StandardModelTest: @@ -141,9 +140,8 @@ def test_sensitivities(self, param_name, param_value, output_name="Voltage [V]") ) def test_serialisation(self, solver=None, t_eval=None): - # Generating unique file names to avoid race conditions when run in parallel. - unique_id = uuid.uuid4() - file_name = f"test_model_{unique_id}" + temp = tempfile.NamedTemporaryFile(prefix="test_model") + file_name = temp.name self.model.save_model( file_name, variables=self.model.variables, mesh=self.disc.mesh ) @@ -178,8 +176,7 @@ def test_serialisation(self, solver=None, t_eval=None): np.testing.assert_array_almost_equal( new_solution.all_ys[x], self.solution.all_ys[x], decimal=accuracy ) - - os.remove(file_name + ".json") + temp.close() def test_all( self, param=None, disc=None, solver=None, t_eval=None, skip_output_tests=False diff --git a/tests/unit/test_callbacks.py b/tests/unit/test_callbacks.py index b36fef9ec6..649c7d9ec8 100644 --- a/tests/unit/test_callbacks.py +++ b/tests/unit/test_callbacks.py @@ -81,6 +81,7 @@ def test_logging_callback(self): "cycle number": (5, 12), "step number": (1, 4), "elapsed time": 0.45, + "step duration": 1, "step operating conditions": "Charge", "termination": "event", } @@ -96,10 +97,14 @@ def test_logging_callback(self): with open("test_callback.log") as f: self.assertIn("Cycle 5/12, step 1/4", f.read()) - callback.on_experiment_infeasible(logs) + callback.on_experiment_infeasible_event(logs) with open("test_callback.log") as f: self.assertIn("Experiment is infeasible: 'event'", f.read()) + callback.on_experiment_infeasible_time(logs) + with open("test_callback.log") as f: + self.assertIn("Experiment is infeasible: default duration", f.read()) + callback.on_experiment_end(logs) with open("test_callback.log") as f: self.assertIn("took 0.45", f.read()) diff --git a/tests/unit/test_doc_utils.py b/tests/unit/test_doc_utils.py index a7a4a1e5d5..8e8a626535 100644 --- a/tests/unit/test_doc_utils.py +++ b/tests/unit/test_doc_utils.py @@ -3,14 +3,11 @@ # is generated, but rather that the docstrings are correctly modified # -import pybamm -import unittest -from tests import TestCase from inspect import getmro from pybamm.doc_utils import copy_parameter_doc_from_parent, doc_extend_parent -class TestDocUtils(TestCase): +class TestDocUtils: def test_copy_parameter_doc_from_parent(self): """Test if parameters from the parent class are copied to child class docstring""" @@ -38,7 +35,7 @@ def __init__(self, foo, bar): base_parameters = "".join(Base.__doc__.partition("Parameters")[1:]) derived_parameters = "".join(Derived.__doc__.partition("Parameters")[1:]) # check that the parameters section is in the docstring - self.assertMultiLineEqual(base_parameters, derived_parameters) + assert base_parameters == derived_parameters def test_doc_extend_parent(self): """Test if the child class has the Extends directive in its docstring""" @@ -57,21 +54,11 @@ def __init__(self, param): super().__init__(param) # check that the Extends directive is in the docstring - self.assertIn("**Extends:**", Derived.__doc__) + assert "**Extends:**" in Derived.__doc__ # check that the Extends directive maps to the correct base class base_cls_name = f"{getmro(Derived)[1].__module__}.{getmro(Derived)[1].__name__}" - self.assertEqual( - Derived.__doc__.partition("**Extends:**")[2].strip(), - f":class:`{base_cls_name}`", + assert ( + Derived.__doc__.partition("**Extends:**")[2].strip() + == f":class:`{base_cls_name}`" ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_experiments/test_experiment_steps.py b/tests/unit/test_experiments/test_experiment_steps.py index 9d1abcc133..4bb686986f 100644 --- a/tests/unit/test_experiments/test_experiment_steps.py +++ b/tests/unit/test_experiments/test_experiment_steps.py @@ -43,15 +43,20 @@ def test_step(self): with self.assertRaisesRegex(ValueError, "temperature units"): step = pybamm.step.current(1, temperature="298T") + with self.assertRaisesRegex(ValueError, "time must be positive"): + pybamm.step.current(1, duration=0) + def test_specific_steps(self): current = pybamm.step.current(1) self.assertIsInstance(current, pybamm.step.Current) self.assertEqual(current.value, 1) self.assertEqual(str(current), repr(current)) + self.assertEqual(current.duration, 24 * 3600) c_rate = pybamm.step.c_rate(1) self.assertIsInstance(c_rate, pybamm.step.CRate) self.assertEqual(c_rate.value, 1) + self.assertEqual(c_rate.duration, 3600 * 2) voltage = pybamm.step.voltage(1) self.assertIsInstance(voltage, pybamm.step.Voltage) @@ -145,19 +150,19 @@ def test_step_string(self): { "type": "CRate", "value": -1, - "duration": 86400, + "duration": 7200, "termination": [pybamm.step.VoltageTermination(4.1)], }, { "value": 4.1, "type": "Voltage", - "duration": 86400, + "duration": 3600 * 24, "termination": [pybamm.step.CurrentTermination(0.05)], }, { "value": 3, "type": "Voltage", - "duration": 86400, + "duration": 3600 * 24, "termination": [pybamm.step.CrateTermination(0.02)], }, { diff --git a/tests/unit/test_experiments/test_simulation_with_experiment.py b/tests/unit/test_experiments/test_simulation_with_experiment.py index bfc5ad6dee..4b3fa3366a 100644 --- a/tests/unit/test_experiments/test_simulation_with_experiment.py +++ b/tests/unit/test_experiments/test_simulation_with_experiment.py @@ -10,6 +10,12 @@ from datetime import datetime +class ShortDurationCRate(pybamm.step.CRate): + def default_duration(self, value): + # Set a short default duration for testing early stopping due to infeasible time + return 1 + + class TestSimulationExperiment(TestCase): def test_set_up(self): experiment = pybamm.Experiment( @@ -272,6 +278,19 @@ def test_run_experiment_breaks_early_error(self): # Different callback - this is for coverage on the `Callback` class sol = sim.solve(callbacks=pybamm.callbacks.Callback()) + def test_run_experiment_infeasible_time(self): + experiment = pybamm.Experiment( + [ShortDurationCRate(1, termination="2.5V"), "Rest for 1 hour"] + ) + model = pybamm.lithium_ion.SPM() + parameter_values = pybamm.ParameterValues("Chen2020") + sim = pybamm.Simulation( + model, parameter_values=parameter_values, experiment=experiment + ) + sol = sim.solve() + self.assertEqual(len(sol.cycles), 1) + self.assertEqual(len(sol.cycles[0].steps), 1) + def test_run_experiment_termination_capacity(self): # with percent experiment = pybamm.Experiment( diff --git a/tests/unit/test_expression_tree/test_array.py b/tests/unit/test_expression_tree/test_array.py index b75c313f47..ffd29baa7e 100644 --- a/tests/unit/test_expression_tree/test_array.py +++ b/tests/unit/test_expression_tree/test_array.py @@ -1,20 +1,17 @@ # # Tests for the Array class # -from tests import TestCase -import unittest -import unittest.mock as mock + import numpy as np import sympy - import pybamm -class TestArray(TestCase): +class TestArray: def test_name(self): arr = pybamm.Array(np.array([1, 2, 3])) - self.assertEqual(arr.name, "Array of shape (3, 1)") + assert arr.name == "Array of shape (3, 1)" def test_list_entries(self): vect = pybamm.Array([1, 2, 3]) @@ -38,16 +35,14 @@ def test_meshgrid(self): np.testing.assert_array_equal(B, D.entries) def test_to_equation(self): - self.assertEqual( - pybamm.Array([1, 2]).to_equation(), sympy.Array([[1.0], [2.0]]) - ) + assert pybamm.Array([1, 2]).to_equation() == sympy.Array([[1.0], [2.0]]) - def test_to_from_json(self): + def test_to_from_json(self, mocker): arr = pybamm.Array(np.array([1, 2, 3])) json_dict = { "name": "Array of shape (3, 1)", - "id": mock.ANY, + "id": mocker.ANY, "domains": { "primary": [], "secondary": [], @@ -59,17 +54,7 @@ def test_to_from_json(self): # array to json conversion created_json = arr.to_json() - self.assertEqual(created_json, json_dict) + assert created_json == json_dict # json to array conversion - self.assertEqual(pybamm.Array._from_json(created_json), arr) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert pybamm.Array._from_json(created_json) == arr diff --git a/tests/unit/test_expression_tree/test_d_dt.py b/tests/unit/test_expression_tree/test_d_dt.py index b5632f9f64..38d7e20e13 100644 --- a/tests/unit/test_expression_tree/test_d_dt.py +++ b/tests/unit/test_expression_tree/test_d_dt.py @@ -1,43 +1,42 @@ # # Tests for the Scalar class # -from tests import TestCase +import pytest import pybamm -import unittest import numpy as np -class TestDDT(TestCase): +class TestDDT: def test_time_derivative(self): a = pybamm.Scalar(5).diff(pybamm.t) - self.assertIsInstance(a, pybamm.Scalar) - self.assertEqual(a.value, 0) + assert isinstance(a, pybamm.Scalar) + assert a.value == 0 a = pybamm.t.diff(pybamm.t) - self.assertIsInstance(a, pybamm.Scalar) - self.assertEqual(a.value, 1) + assert isinstance(a, pybamm.Scalar) + assert a.value == 1 a = (pybamm.t**2).diff(pybamm.t) - self.assertEqual(a, (2 * pybamm.t**1 * 1)) - self.assertEqual(a.evaluate(t=1), 2) + assert a == (2 * pybamm.t**1 * 1) + assert a.evaluate(t=1) == 2 a = (2 + pybamm.t**2).diff(pybamm.t) - self.assertEqual(a.evaluate(t=1), 2) + assert a.evaluate(t=1) == 2 def test_time_derivative_of_variable(self): a = (pybamm.Variable("a")).diff(pybamm.t) - self.assertIsInstance(a, pybamm.VariableDot) - self.assertEqual(a.name, "a'") + assert isinstance(a, pybamm.VariableDot) + assert a.name == "a'" p = pybamm.Parameter("p") a = 1 + p * pybamm.Variable("a") diff_a = a.diff(pybamm.t) - self.assertIsInstance(diff_a, pybamm.Multiplication) - self.assertEqual(diff_a.children[0].name, "p") - self.assertEqual(diff_a.children[1].name, "a'") + assert isinstance(diff_a, pybamm.Multiplication) + assert diff_a.children[0].name == "p" + assert diff_a.children[1].name == "a'" - with self.assertRaises(pybamm.ModelError): + with pytest.raises(pybamm.ModelError): a = (pybamm.Variable("a")).diff(pybamm.t).diff(pybamm.t) def test_time_derivative_of_state_vector(self): @@ -45,21 +44,11 @@ def test_time_derivative_of_state_vector(self): y_dot = np.linspace(0, 2, 19) a = sv.diff(pybamm.t) - self.assertIsInstance(a, pybamm.StateVectorDot) - self.assertEqual(a.name[-1], "'") + assert isinstance(a, pybamm.StateVectorDot) + assert a.name[-1] == "'" np.testing.assert_array_equal( a.evaluate(y_dot=y_dot), np.linspace(0, 1, 10)[:, np.newaxis] ) - with self.assertRaises(pybamm.ModelError): + with pytest.raises(pybamm.ModelError): a = (sv).diff(pybamm.t).diff(pybamm.t) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_expression_tree/test_independent_variable.py b/tests/unit/test_expression_tree/test_independent_variable.py index f747b60d40..79c5ab9ea2 100644 --- a/tests/unit/test_expression_tree/test_independent_variable.py +++ b/tests/unit/test_expression_tree/test_independent_variable.py @@ -1,87 +1,73 @@ # # Tests for the Parameter class # -from tests import TestCase -import unittest - +import pytest import pybamm import sympy -class TestIndependentVariable(TestCase): +class TestIndependentVariable: def test_variable_init(self): a = pybamm.IndependentVariable("a") - self.assertEqual(a.name, "a") - self.assertEqual(a.domain, []) + assert a.name == "a" + assert a.domain == [] a = pybamm.IndependentVariable("a", domain=["test"]) - self.assertEqual(a.domain[0], "test") + assert a.domain[0] == "test" a = pybamm.IndependentVariable("a", domain="test") - self.assertEqual(a.domain[0], "test") - with self.assertRaises(TypeError): + assert a.domain[0] == "test" + with pytest.raises(TypeError): pybamm.IndependentVariable("a", domain=1) def test_time(self): t = pybamm.Time() - self.assertEqual(t.name, "time") - self.assertEqual(t.evaluate(4), 4) - with self.assertRaises(ValueError): + assert t.name == "time" + assert t.evaluate(4) == 4 + with pytest.raises(ValueError): t.evaluate(None) t = pybamm.t - self.assertEqual(t.name, "time") - self.assertEqual(t.evaluate(4), 4) - with self.assertRaises(ValueError): + assert t.name == "time" + assert t.evaluate(4) == 4 + with pytest.raises(ValueError): t.evaluate(None) - self.assertEqual(t.evaluate_for_shape(), 0) + assert t.evaluate_for_shape() == 0 def test_spatial_variable(self): x = pybamm.SpatialVariable("x", "negative electrode") - self.assertEqual(x.name, "x") - self.assertFalse(x.evaluates_on_edges("primary")) + assert x.name == "x" + assert not x.evaluates_on_edges("primary") y = pybamm.SpatialVariable("y", "separator") - self.assertEqual(y.name, "y") + assert y.name == "y" z = pybamm.SpatialVariable("z", "positive electrode") - self.assertEqual(z.name, "z") + assert z.name == "z" r = pybamm.SpatialVariable("r", "negative particle") - self.assertEqual(r.name, "r") - with self.assertRaises(NotImplementedError): + assert r.name == "r" + with pytest.raises(NotImplementedError): x.evaluate() - with self.assertRaisesRegex(ValueError, "domain must be"): + with pytest.raises(ValueError, match="domain must be"): pybamm.SpatialVariable("x", []) - with self.assertRaises(pybamm.DomainError): + with pytest.raises(pybamm.DomainError): pybamm.SpatialVariable("r_n", ["positive particle"]) - with self.assertRaises(pybamm.DomainError): + with pytest.raises(pybamm.DomainError): pybamm.SpatialVariable("r_p", ["negative particle"]) - with self.assertRaises(pybamm.DomainError): + with pytest.raises(pybamm.DomainError): pybamm.SpatialVariable("x", ["negative particle"]) def test_spatial_variable_edge(self): x = pybamm.SpatialVariableEdge("x", "negative electrode") - self.assertEqual(x.name, "x") - self.assertTrue(x.evaluates_on_edges("primary")) + assert x.name == "x" + assert x.evaluates_on_edges("primary") def test_to_equation(self): # Test print_name func = pybamm.IndependentVariable("a") func.print_name = "test" - self.assertEqual(func.to_equation(), sympy.Symbol("test")) + assert func.to_equation() == sympy.Symbol("test") - self.assertEqual( - pybamm.IndependentVariable("a").to_equation(), sympy.Symbol("a") - ) + assert pybamm.IndependentVariable("a").to_equation() == sympy.Symbol("a") # Test time - self.assertEqual(pybamm.t.to_equation(), sympy.Symbol("t")) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert pybamm.t.to_equation() == sympy.Symbol("t") diff --git a/tests/unit/test_expression_tree/test_input_parameter.py b/tests/unit/test_expression_tree/test_input_parameter.py index a5fc79f2e2..87cbe79a31 100644 --- a/tests/unit/test_expression_tree/test_input_parameter.py +++ b/tests/unit/test_expression_tree/test_input_parameter.py @@ -1,54 +1,52 @@ # # Tests for the InputParameter class # -from tests import TestCase import numpy as np import pybamm -import unittest - +import pytest import unittest.mock as mock -class TestInputParameter(TestCase): +class TestInputParameter: def test_input_parameter_init(self): a = pybamm.InputParameter("a") - self.assertEqual(a.name, "a") - self.assertEqual(a.evaluate(inputs={"a": 1}), 1) - self.assertEqual(a.evaluate(inputs={"a": 5}), 5) + assert a.name == "a" + assert a.evaluate(inputs={"a": 1}) == 1 + assert a.evaluate(inputs={"a": 5}) == 5 a = pybamm.InputParameter("a", expected_size=10) - self.assertEqual(a._expected_size, 10) + assert a._expected_size == 10 np.testing.assert_array_equal( a.evaluate(inputs="shape test"), np.nan * np.ones((10, 1)) ) y = np.linspace(0, 1, 10) np.testing.assert_array_equal(a.evaluate(inputs={"a": y}), y[:, np.newaxis]) - with self.assertRaisesRegex( + with pytest.raises( ValueError, - "Input parameter 'a' was given an object of size '1' but was expecting an " + match="Input parameter 'a' was given an object of size '1' but was expecting an " "object of size '10'", ): a.evaluate(inputs={"a": 5}) def test_evaluate_for_shape(self): a = pybamm.InputParameter("a") - self.assertTrue(np.isnan(a.evaluate_for_shape())) - self.assertEqual(a.shape, ()) + assert np.isnan(a.evaluate_for_shape()) + assert a.shape == () a = pybamm.InputParameter("a", expected_size=10) - self.assertEqual(a.shape, (10, 1)) + assert a.shape == (10, 1) np.testing.assert_equal(a.evaluate_for_shape(), np.nan * np.ones((10, 1))) - self.assertEqual(a.evaluate_for_shape().shape, (10, 1)) + assert a.evaluate_for_shape().shape == (10, 1) def test_errors(self): a = pybamm.InputParameter("a") - with self.assertRaises(TypeError): + with pytest.raises(TypeError): a.evaluate(inputs="not a dictionary") - with self.assertRaises(KeyError): + with pytest.raises(KeyError): a.evaluate(inputs={"bad param": 5}) # if u is not provided it gets turned into a dictionary and then raises KeyError - with self.assertRaises(KeyError): + with pytest.raises(KeyError): a.evaluate() def test_to_from_json(self): @@ -62,17 +60,7 @@ def test_to_from_json(self): } # to_json - self.assertEqual(a.to_json(), json_dict) + assert a.to_json() == json_dict # from_json - self.assertEqual(pybamm.InputParameter._from_json(json_dict), a) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert pybamm.InputParameter._from_json(json_dict) == a diff --git a/tests/unit/test_expression_tree/test_matrix.py b/tests/unit/test_expression_tree/test_matrix.py index 055902b15e..d34af0d83f 100644 --- a/tests/unit/test_expression_tree/test_matrix.py +++ b/tests/unit/test_expression_tree/test_matrix.py @@ -1,26 +1,22 @@ # # Tests for the Matrix class # -from tests import TestCase import pybamm import numpy as np from scipy.sparse import csr_matrix -import unittest -import unittest.mock as mock - -class TestMatrix(TestCase): - def setUp(self): +class TestMatrix: + def setup_method(self): self.A = np.array([[1, 2, 0], [0, 1, 0], [0, 0, 1]]) self.x = np.array([1, 2, 3]) self.mat = pybamm.Matrix(self.A) self.vect = pybamm.Vector(self.x) def test_array_wrapper(self): - self.assertEqual(self.mat.ndim, 2) - self.assertEqual(self.mat.shape, (3, 3)) - self.assertEqual(self.mat.size, 9) + assert self.mat.ndim == 2 + assert self.mat.shape == (3, 3) + assert self.mat.size == 9 def test_list_entry(self): mat = pybamm.Matrix([[1, 2, 0], [0, 1, 0], [0, 0, 1]]) @@ -40,11 +36,11 @@ def test_matrix_operations(self): (self.mat @ self.vect).evaluate(), np.array([[5], [2], [3]]) ) - def test_to_from_json(self): + def test_to_from_json(self, mocker): arr = pybamm.Matrix(csr_matrix([[0, 1, 0, 0], [0, 0, 0, 1]])) json_dict = { "name": "Sparse Matrix (2, 4)", - "id": mock.ANY, + "id": mocker.ANY, "domains": { "primary": [], "secondary": [], @@ -59,16 +55,6 @@ def test_to_from_json(self): }, } - self.assertEqual(arr.to_json(), json_dict) - - self.assertEqual(pybamm.Matrix._from_json(json_dict), arr) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys + assert arr.to_json() == json_dict - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert pybamm.Matrix._from_json(json_dict) == arr diff --git a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py index b02c75f386..6e1b155eca 100644 --- a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py +++ b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py @@ -80,7 +80,7 @@ def test_find_symbols(self): # test values of variable_symbols self.assertEqual(next(iter(variable_symbols.values())), "y[0:1]") self.assertEqual(list(variable_symbols.values())[1], "y[1:2]") - self.assertEqual(list(variable_symbols.values())[2], f"-{var_b}") + self.assertEqual(list(variable_symbols.values())[2], f"-({var_b})") var_child = pybamm.id_to_python_variable(expr.children[1].id) self.assertEqual( list(variable_symbols.values())[3], f"np.maximum({var_a},{var_child})" @@ -674,6 +674,76 @@ def test_evaluator_jax_inputs(self): result = evaluator(inputs={"a": 2}) self.assertEqual(result, 4) + @unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed") + def test_evaluator_jax_demotion(self): + for demote in [True, False]: + pybamm.demote_expressions_to_32bit = demote # global flag + target_dtype = "32" if demote else "64" + if demote: + # Test only works after conversion to jax.numpy + for c in [ + 1.0, + 1, + ]: + self.assertEqual( + str(pybamm.EvaluatorJax._demote_64_to_32(c).dtype)[-2:], + target_dtype, + ) + for c in [ + np.float64(1.0), + np.int64(1), + np.array([1.0], dtype=np.float64), + np.array([1], dtype=np.int64), + jax.numpy.array([1.0], dtype=np.float64), + jax.numpy.array([1], dtype=np.int64), + ]: + self.assertEqual( + str(pybamm.EvaluatorJax._demote_64_to_32(c).dtype)[-2:], + target_dtype, + ) + for c in [ + {key: np.float64(1.0) for key in ["a", "b"]}, + ]: + expr_demoted = pybamm.EvaluatorJax._demote_64_to_32(c) + self.assertTrue( + all( + str(c_v.dtype)[-2:] == target_dtype + for c_k, c_v in expr_demoted.items() + ) + ) + for c in [ + (np.float64(1.0), np.float64(2.0)), + [np.float64(1.0), np.float64(2.0)], + ]: + expr_demoted = pybamm.EvaluatorJax._demote_64_to_32(c) + self.assertTrue( + all(str(c_i.dtype)[-2:] == target_dtype for c_i in expr_demoted) + ) + for dtype in [ + np.float64, + jax.numpy.float64, + ]: + c = pybamm.JaxCooMatrix([0, 1], [0, 1], dtype([1.0, 2.0]), (2, 2)) + c_demoted = pybamm.EvaluatorJax._demote_64_to_32(c) + self.assertTrue( + all(str(c_i.dtype)[-2:] == target_dtype for c_i in c_demoted.data) + ) + for dtype in [ + np.int64, + jax.numpy.int64, + ]: + c = pybamm.JaxCooMatrix( + dtype([0, 1]), dtype([0, 1]), [1.0, 2.0], (2, 2) + ) + c_demoted = pybamm.EvaluatorJax._demote_64_to_32(c) + self.assertTrue( + all(str(c_i.dtype)[-2:] == target_dtype for c_i in c_demoted.row) + ) + self.assertTrue( + all(str(c_i.dtype)[-2:] == target_dtype for c_i in c_demoted.col) + ) + pybamm.demote_expressions_to_32bit = False + @unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed") def test_jax_coo_matrix(self): A = pybamm.JaxCooMatrix([0, 1], [0, 1], [1.0, 2.0], (2, 2)) diff --git a/tests/unit/test_expression_tree/test_operations/test_unpack_symbols.py b/tests/unit/test_expression_tree/test_operations/test_unpack_symbols.py index ce669212ae..6736094288 100644 --- a/tests/unit/test_expression_tree/test_operations/test_unpack_symbols.py +++ b/tests/unit/test_expression_tree/test_operations/test_unpack_symbols.py @@ -1,27 +1,25 @@ # # Tests for the symbol unpacker # -from tests import TestCase import pybamm -import unittest -class TestSymbolUnpacker(TestCase): +class TestSymbolUnpacker: def test_basic_symbols(self): a = pybamm.Scalar(1) unpacker = pybamm.SymbolUnpacker(pybamm.Scalar) unpacked = unpacker.unpack_symbol(a) - self.assertEqual(unpacked, set([a])) + assert unpacked == set([a]) b = pybamm.Parameter("b") unpacker_param = pybamm.SymbolUnpacker(pybamm.Parameter) unpacked = unpacker_param.unpack_symbol(a) - self.assertEqual(unpacked, set()) + assert unpacked == set() unpacked = unpacker_param.unpack_symbol(b) - self.assertEqual(unpacked, set([b])) + assert unpacked == set([b]) def test_binary(self): a = pybamm.Scalar(1) @@ -29,11 +27,11 @@ def test_binary(self): unpacker = pybamm.SymbolUnpacker(pybamm.Scalar) unpacked = unpacker.unpack_symbol(a + b) - self.assertEqual(unpacked, set([a])) + assert unpacked == set([a]) unpacker_param = pybamm.SymbolUnpacker(pybamm.Parameter) unpacked = unpacker_param.unpack_symbol(a + b) - self.assertEqual(unpacked, set([b])) + assert unpacked == set([b]) def test_unpack_list_of_symbols(self): a = pybamm.Scalar(1) @@ -42,14 +40,4 @@ def test_unpack_list_of_symbols(self): unpacker = pybamm.SymbolUnpacker(pybamm.Parameter) unpacked = unpacker.unpack_list_of_symbols([a + b, a - c, b + c]) - self.assertEqual(unpacked, set([b, c])) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert unpacked == set([b, c]) diff --git a/tests/unit/test_expression_tree/test_printing/test_print_name.py b/tests/unit/test_expression_tree/test_printing/test_print_name.py index 9d74d6f1ab..c15ce18616 100644 --- a/tests/unit/test_expression_tree/test_printing/test_print_name.py +++ b/tests/unit/test_expression_tree/test_printing/test_print_name.py @@ -2,57 +2,42 @@ Tests for the print_name.py """ -from tests import TestCase -import unittest - import pybamm -class TestPrintName(TestCase): +class TestPrintName: def test_prettify_print_name(self): param = pybamm.LithiumIonParameters() param2 = pybamm.LeadAcidParameters() # Test PRINT_NAME_OVERRIDES - self.assertEqual(param.current_with_time.print_name, "I") + assert param.current_with_time.print_name == "I" # Test superscripts - self.assertEqual( - param.n.prim.c_init.print_name, r"c_{\mathrm{n}}^{\mathrm{init}}" - ) + assert param.n.prim.c_init.print_name == r"c_{\mathrm{n}}^{\mathrm{init}}" # Test subscripts - self.assertEqual(param.n.C_dl(0).print_name, r"C_{\mathrm{dl,n}}") + assert param.n.C_dl(0).print_name == r"C_{\mathrm{dl,n}}" # Test bar c_e_av = pybamm.Variable("c_e_av") c_e_av.print_name = "c_e_av" - self.assertEqual(c_e_av.print_name, r"\overline{c}_{\mathrm{e}}") + assert c_e_av.print_name == r"\overline{c}_{\mathrm{e}}" # Test greek letters - self.assertEqual(param2.delta.print_name, r"\delta") + assert param2.delta.print_name == r"\delta" # Test create_copy() a_n = param2.n.prim.a - self.assertEqual(a_n.create_copy().print_name, r"a_{\mathrm{n}}") + assert a_n.create_copy().print_name == r"a_{\mathrm{n}}" # Test eps eps_n = pybamm.Variable("eps_n") - self.assertEqual(eps_n.print_name, r"\epsilon_{\mathrm{n}}") + assert eps_n.print_name == r"\epsilon_{\mathrm{n}}" eps_n = pybamm.Variable("eps_c_e_n") - self.assertEqual(eps_n.print_name, r"(\epsilon c)_{\mathrm{e,n}}") + assert eps_n.print_name == r"(\epsilon c)_{\mathrm{e,n}}" # tplus t_plus = pybamm.Variable("t_plus") - self.assertEqual(t_plus.print_name, r"t_{\mathrm{+}}") - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert t_plus.print_name == r"t_{\mathrm{+}}" diff --git a/tests/unit/test_expression_tree/test_printing/test_sympy_overrides.py b/tests/unit/test_expression_tree/test_printing/test_sympy_overrides.py index 4b19c7d822..4ce073af4b 100644 --- a/tests/unit/test_expression_tree/test_printing/test_sympy_overrides.py +++ b/tests/unit/test_expression_tree/test_printing/test_sympy_overrides.py @@ -2,31 +2,17 @@ Tests for the sympy_overrides.py """ -from tests import TestCase -import unittest - -import pybamm from pybamm.expression_tree.printing.sympy_overrides import custom_print_func import sympy -class TestCustomPrint(TestCase): - def test_print_Derivative(self): +class TestCustomPrint: + def test_print_derivative(self): # Test force_partial der1 = sympy.Derivative("y", "x") der1.force_partial = True - self.assertEqual(custom_print_func(der1), "\\frac{\\partial}{\\partial x} y") + assert custom_print_func(der1) == "\\frac{\\partial}{\\partial x} y" # Test derivative der2 = sympy.Derivative("x") - self.assertEqual(custom_print_func(der2), "\\frac{d}{d x} x") - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert custom_print_func(der2) == "\\frac{d}{d x} x" diff --git a/tests/unit/test_expression_tree/test_scalar.py b/tests/unit/test_expression_tree/test_scalar.py index 34ea1aa514..986d3d3ccb 100644 --- a/tests/unit/test_expression_tree/test_scalar.py +++ b/tests/unit/test_expression_tree/test_scalar.py @@ -1,64 +1,51 @@ # # Tests for the Scalar class # -from tests import TestCase -import unittest -import unittest.mock as mock import pybamm -class TestScalar(TestCase): +class TestScalar: def test_scalar_eval(self): a = pybamm.Scalar(5) - self.assertEqual(a.value, 5) - self.assertEqual(a.evaluate(), 5) + assert a.value == 5 + assert a.evaluate() == 5 def test_scalar_operations(self): a = pybamm.Scalar(5) b = pybamm.Scalar(6) - self.assertEqual((a + b).evaluate(), 11) - self.assertEqual((a - b).evaluate(), -1) - self.assertEqual((a * b).evaluate(), 30) - self.assertEqual((a / b).evaluate(), 5 / 6) + assert (a + b).evaluate() == 11 + assert (a - b).evaluate() == -1 + assert (a * b).evaluate() == 30 + assert (a / b).evaluate() == 5 / 6 def test_scalar_eq(self): a1 = pybamm.Scalar(4) a2 = pybamm.Scalar(4) - self.assertEqual(a1, a2) + assert a1 == a2 a3 = pybamm.Scalar(5) - self.assertNotEqual(a1, a3) + assert a1 != a3 def test_to_equation(self): a = pybamm.Scalar(3) b = pybamm.Scalar(4) # Test value - self.assertEqual(str(a.to_equation()), "3.0") + assert str(a.to_equation()) == "3.0" # Test print_name b.print_name = "test" - self.assertEqual(str(b.to_equation()), "test") + assert str(b.to_equation()) == "test" def test_copy(self): a = pybamm.Scalar(5) b = a.create_copy() - self.assertEqual(a, b) + assert a == b - def test_to_from_json(self): + def test_to_from_json(self, mocker): a = pybamm.Scalar(5) - json_dict = {"name": "5.0", "id": mock.ANY, "value": 5.0} + json_dict = {"name": "5.0", "id": mocker.ANY, "value": 5.0} - self.assertEqual(a.to_json(), json_dict) + assert a.to_json() == json_dict - self.assertEqual(pybamm.Scalar._from_json(json_dict), a) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert pybamm.Scalar._from_json(json_dict) == a diff --git a/tests/unit/test_expression_tree/test_symbol.py b/tests/unit/test_expression_tree/test_symbol.py index 668c076907..e42f8dc8ef 100644 --- a/tests/unit/test_expression_tree/test_symbol.py +++ b/tests/unit/test_expression_tree/test_symbol.py @@ -18,6 +18,8 @@ class TestSymbol(TestCase): def test_symbol_init(self): sym = pybamm.Symbol("a symbol") + with self.assertRaises(TypeError): + sym.name = 1 self.assertEqual(sym.name, "a symbol") self.assertEqual(str(sym), "a symbol") diff --git a/tests/unit/test_expression_tree/test_vector.py b/tests/unit/test_expression_tree/test_vector.py index 34f817cf9c..e7b902fc73 100644 --- a/tests/unit/test_expression_tree/test_vector.py +++ b/tests/unit/test_expression_tree/test_vector.py @@ -1,22 +1,21 @@ # # Tests for the Vector class # -from tests import TestCase import pybamm import numpy as np -import unittest +import pytest -class TestVector(TestCase): - def setUp(self): +class TestVector: + def setup_method(self): self.x = np.array([[1], [2], [3]]) self.vect = pybamm.Vector(self.x) def test_array_wrapper(self): - self.assertEqual(self.vect.ndim, 2) - self.assertEqual(self.vect.shape, (3, 1)) - self.assertEqual(self.vect.size, 3) + assert self.vect.ndim == 2 + assert self.vect.shape == (3, 1) + assert self.vect.size == 3 def test_column_reshape(self): vect1d = pybamm.Vector(np.array([1, 2, 3])) @@ -39,17 +38,7 @@ def test_vector_operations(self): ) def test_wrong_size_entries(self): - with self.assertRaisesRegex( - ValueError, "Entries must have 1 dimension or be column vector" + with pytest.raises( + ValueError, match="Entries must have 1 dimension or be column vector" ): pybamm.Vector(np.ones((4, 5))) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_logger.py b/tests/unit/test_logger.py index 0897bc5835..06e2444c16 100644 --- a/tests/unit/test_logger.py +++ b/tests/unit/test_logger.py @@ -1,28 +1,27 @@ # # Tests the logger class. # -from tests import TestCase +import pytest import pybamm -import unittest -class TestLogger(TestCase): +class TestLogger: def test_logger(self): logger = pybamm.logger - self.assertEqual(logger.level, 30) + assert logger.level == 30 pybamm.set_logging_level("INFO") - self.assertEqual(logger.level, 20) + assert logger.level == 20 pybamm.set_logging_level("ERROR") - self.assertEqual(logger.level, 40) + assert logger.level == 40 pybamm.set_logging_level("VERBOSE") - self.assertEqual(logger.level, 15) + assert logger.level == 15 pybamm.set_logging_level("NOTICE") - self.assertEqual(logger.level, 25) + assert logger.level == 25 pybamm.set_logging_level("SUCCESS") - self.assertEqual(logger.level, 35) + assert logger.level == 35 pybamm.set_logging_level("SPAM") - self.assertEqual(logger.level, 5) + assert logger.level == 5 pybamm.logger.spam("Test spam level") pybamm.logger.verbose("Test verbose level") pybamm.logger.notice("Test notice level") @@ -32,15 +31,5 @@ def test_logger(self): pybamm.set_logging_level("WARNING") def test_exceptions(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): pybamm.get_new_logger("test", None) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_meshes/test_zero_dimensional_submesh.py b/tests/unit/test_meshes/test_zero_dimensional_submesh.py index 8bc1bc2e75..d9e3ebb5dd 100644 --- a/tests/unit/test_meshes/test_zero_dimensional_submesh.py +++ b/tests/unit/test_meshes/test_zero_dimensional_submesh.py @@ -1,12 +1,11 @@ import pybamm -import unittest -from tests import TestCase +import pytest -class TestSubMesh0D(TestCase): +class TestSubMesh0D: def test_exceptions(self): position = {"x": {"position": 0}, "y": {"position": 0}} - with self.assertRaises(pybamm.GeometryError): + with pytest.raises(pybamm.GeometryError): pybamm.SubMesh0D(position) def test_init(self): @@ -14,13 +13,3 @@ def test_init(self): generator = pybamm.SubMesh0D mesh = generator(position, None) mesh.add_ghost_meshes() - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_event.py b/tests/unit/test_models/test_event.py index 84b0dcde84..0636a0f5bd 100644 --- a/tests/unit/test_models/test_event.py +++ b/tests/unit/test_models/test_event.py @@ -1,27 +1,25 @@ # # Tests Event class # -from tests import TestCase import pybamm import numpy as np -import unittest -class TestEvent(TestCase): +class TestEvent: def test_event(self): expression = pybamm.Scalar(1) event = pybamm.Event("my event", expression) - self.assertEqual(event.name, "my event") - self.assertEqual(event.__str__(), "my event") - self.assertEqual(event.expression, expression) - self.assertEqual(event.event_type, pybamm.EventType.TERMINATION) + assert event.name == "my event" + assert event.__str__() == "my event" + assert event.expression == expression + assert event.event_type == pybamm.EventType.TERMINATION def test_expression_evaluate(self): # Test t expression = pybamm.t event = pybamm.Event("my event", expression) - self.assertEqual(event.evaluate(t=1), 1) + assert event.evaluate(t=1) == 1 # Test y sv = pybamm.StateVector(slice(0, 10)) @@ -46,7 +44,7 @@ def test_event_types(self): for event_type in event_types: event = pybamm.Event("my event", pybamm.Scalar(1), event_type) - self.assertEqual(event.event_type, event_type) + assert event.event_type == event_type def test_to_from_json(self): expression = pybamm.Scalar(1) @@ -58,24 +56,14 @@ def test_to_from_json(self): } event_ser_json = event.to_json() - self.assertEqual(event_ser_json, event_json) + assert event_ser_json == event_json event_json["expression"] = expression new_event = pybamm.Event._from_json(event_json) # check for equal expressions - self.assertEqual(new_event.expression, event.expression) + assert new_event.expression == event.expression # check for equal event types - self.assertEqual(new_event.event_type, event.event_type) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert new_event.event_type == event.event_type diff --git a/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_base_lead_acid_model.py b/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_base_lead_acid_model.py index ec280cdd1f..5d9ea27e2f 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_base_lead_acid_model.py +++ b/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_base_lead_acid_model.py @@ -1,46 +1,33 @@ # # Tests for the base lead acid model class # -from tests import TestCase import pybamm -import unittest +import pytest -class TestBaseLeadAcidModel(TestCase): +class TestBaseLeadAcidModel: def test_default_geometry(self): model = pybamm.lead_acid.BaseModel({"dimensionality": 0}) - self.assertEqual( - model.default_geometry["current collector"]["z"]["position"], 1 - ) + assert model.default_geometry["current collector"]["z"]["position"] == 1 model = pybamm.lead_acid.BaseModel({"dimensionality": 1}) - self.assertEqual(model.default_geometry["current collector"]["z"]["min"], 0) + assert model.default_geometry["current collector"]["z"]["min"] == 0 model = pybamm.lead_acid.BaseModel({"dimensionality": 2}) - self.assertEqual(model.default_geometry["current collector"]["y"]["min"], 0) + assert model.default_geometry["current collector"]["y"]["min"] == 0 def test_incompatible_options(self): - with self.assertRaisesRegex( + with pytest.raises( pybamm.OptionError, - "Lead-acid models can only have thermal effects if dimensionality is 0.", + match="Lead-acid models can only have thermal effects if dimensionality is 0.", ): pybamm.lead_acid.BaseModel({"dimensionality": 1, "thermal": "lumped"}) - with self.assertRaisesRegex(pybamm.OptionError, "SEI"): + with pytest.raises(pybamm.OptionError, match="SEI"): pybamm.lead_acid.BaseModel({"SEI": "constant"}) - with self.assertRaisesRegex(pybamm.OptionError, "lithium plating"): + with pytest.raises(pybamm.OptionError, match="lithium plating"): pybamm.lead_acid.BaseModel({"lithium plating": "reversible"}) - with self.assertRaisesRegex(pybamm.OptionError, "MSMR"): + with pytest.raises(pybamm.OptionError, match="MSMR"): pybamm.lead_acid.BaseModel( { "open-circuit potential": "MSMR", "particle": "MSMR", } ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_basic_models.py b/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_basic_models.py index 65b9f6bc9f..a7a708b394 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_basic_models.py +++ b/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_basic_models.py @@ -1,22 +1,10 @@ # # Tests for the basic lead acid models # -from tests import TestCase import pybamm -import unittest -class TestBasicModels(TestCase): +class TestBasicModels: def test_basic_full_lead_acid_well_posed(self): model = pybamm.lead_acid.BasicFull() model.check_well_posedness() - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_full.py b/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_full.py index c07c5c84c6..569851ec2a 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_full.py +++ b/tests/unit/test_models/test_full_battery_models/test_lead_acid/test_full.py @@ -1,12 +1,10 @@ # # Tests for the lead-acid Full model # -from tests import TestCase import pybamm -import unittest -class TestLeadAcidFull(TestCase): +class TestLeadAcidFull: def test_well_posed(self): model = pybamm.lead_acid.Full() model.check_well_posedness() @@ -21,7 +19,7 @@ def test_well_posed_with_convection(self): model.check_well_posedness() -class TestLeadAcidFullSurfaceForm(TestCase): +class TestLeadAcidFullSurfaceForm: def test_well_posed_differential(self): options = {"surface form": "differential"} model = pybamm.lead_acid.Full(options) @@ -38,7 +36,7 @@ def test_well_posed_algebraic(self): model.check_well_posedness() -class TestLeadAcidFullSideReactions(TestCase): +class TestLeadAcidFullSideReactions: def test_well_posed(self): options = {"hydrolysis": "true"} model = pybamm.lead_acid.Full(options) @@ -48,20 +46,10 @@ def test_well_posed_surface_form_differential(self): options = {"hydrolysis": "true", "surface form": "differential"} model = pybamm.lead_acid.Full(options) model.check_well_posedness() - self.assertIsInstance(model.default_solver, pybamm.CasadiSolver) + assert isinstance(model.default_solver, pybamm.CasadiSolver) def test_well_posed_surface_form_algebraic(self): options = {"hydrolysis": "true", "surface form": "algebraic"} model = pybamm.lead_acid.Full(options) model.check_well_posedness() - self.assertIsInstance(model.default_solver, pybamm.CasadiSolver) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert isinstance(model.default_solver, pybamm.CasadiSolver) diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/base_lithium_ion_tests.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/base_lithium_ion_tests.py index 7e1f2d5cac..c8a3f6b509 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/base_lithium_ion_tests.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/base_lithium_ion_tests.py @@ -559,3 +559,28 @@ def test_well_posed_composite_diffusion_hysteresis(self): "open-circuit potential": (("current sigmoid", "single"), "single"), } self.check_well_posedness(options) + + def test_well_posed_composite_different_degradation(self): + # phases have same degradation + options = { + "particle phases": ("2", "1"), + "SEI": ("ec reaction limited", "none"), + "lithium plating": ("reversible", "none"), + "open-circuit potential": (("current sigmoid", "single"), "single"), + } + self.check_well_posedness(options) + # phases have different degradation + options = { + "particle phases": ("2", "1"), + "SEI": (("ec reaction limited", "solvent-diffusion limited"), "none"), + "lithium plating": (("reversible", "irreversible"), "none"), + "open-circuit potential": (("current sigmoid", "single"), "single"), + } + self.check_well_posedness(options) + # one of the phases has no degradation + options = { + "particle phases": ("2", "1"), + "SEI": (("none", "solvent-diffusion limited"), "none"), + "lithium plating": (("none", "irreversible"), "none"), + } + self.check_well_posedness(options) diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_Yang2017.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_Yang2017.py index 9631cf9f82..2fd18c17c6 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_Yang2017.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_Yang2017.py @@ -1,22 +1,10 @@ # # Tests for the lithium-ion DFN model # -from tests import TestCase import pybamm -import unittest -class TestYang2017(TestCase): +class TestYang2017: def test_well_posed(self): model = pybamm.lithium_ion.Yang2017() model.check_well_posedness() - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_base_lithium_ion_model.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_base_lithium_ion_model.py index fbc916d4a5..bfeb489661 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_base_lithium_ion_model.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_base_lithium_ion_model.py @@ -1,59 +1,44 @@ # # Tests for the base lead acid model class # -from tests import TestCase import pybamm -import unittest import os +import pytest -class TestBaseLithiumIonModel(TestCase): +class TestBaseLithiumIonModel: def test_incompatible_options(self): - with self.assertRaisesRegex(pybamm.OptionError, "convection not implemented"): + with pytest.raises(pybamm.OptionError, match="convection not implemented"): pybamm.lithium_ion.BaseModel({"convection": "uniform transverse"}) def test_default_parameters(self): # check parameters are read in ok model = pybamm.lithium_ion.BaseModel() - self.assertEqual( - model.default_parameter_values["Reference temperature [K]"], 298.15 - ) + assert model.default_parameter_values["Reference temperature [K]"] == 298.15 # change path and try again cwd = os.getcwd() os.chdir("..") model = pybamm.lithium_ion.BaseModel() - self.assertEqual( - model.default_parameter_values["Reference temperature [K]"], 298.15 - ) + assert model.default_parameter_values["Reference temperature [K]"] == 298.15 os.chdir(cwd) def test_insert_reference_electrode(self): model = pybamm.lithium_ion.SPM() model.insert_reference_electrode() - self.assertIn("Negative electrode 3E potential [V]", model.variables) - self.assertIn("Positive electrode 3E potential [V]", model.variables) - self.assertIn("Reference electrode potential [V]", model.variables) + assert "Negative electrode 3E potential [V]" in model.variables + assert "Positive electrode 3E potential [V]" in model.variables + assert "Reference electrode potential [V]" in model.variables model = pybamm.lithium_ion.SPM({"working electrode": "positive"}) model.insert_reference_electrode() - self.assertNotIn("Negative electrode potential [V]", model.variables) - self.assertIn("Positive electrode 3E potential [V]", model.variables) - self.assertIn("Reference electrode potential [V]", model.variables) + assert "Negative electrode potential [V]" not in model.variables + assert "Positive electrode 3E potential [V]" in model.variables + assert "Reference electrode potential [V]" in model.variables model = pybamm.lithium_ion.SPM({"dimensionality": 2}) - with self.assertRaisesRegex( - NotImplementedError, "Reference electrode can only be" + with pytest.raises( + NotImplementedError, match="Reference electrode can only be" ): model.insert_reference_electrode() - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_basic_models.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_basic_models.py index 2f00bb260c..8462e7c803 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_basic_models.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_basic_models.py @@ -1,12 +1,10 @@ # # Tests for the basic lithium-ion models # -from tests import TestCase import pybamm -import unittest -class TestBasicModels(TestCase): +class TestBasicModels: def test_dfn_well_posed(self): model = pybamm.lithium_ion.BasicDFN() model.check_well_posedness() @@ -23,13 +21,3 @@ def test_dfn_half_cell_well_posed(self): def test_dfn_composite_well_posed(self): model = pybamm.lithium_ion.BasicDFNComposite() model.check_well_posedness() - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_dfn.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_dfn.py index 20fc69e541..cddd59c352 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_dfn.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_dfn.py @@ -1,19 +1,19 @@ # # Tests for the lithium-ion DFN model # -from tests import TestCase import pybamm -import unittest +import pytest from tests import BaseUnitTestLithiumIon -class TestDFN(BaseUnitTestLithiumIon, TestCase): +class TestDFN(BaseUnitTestLithiumIon): + @pytest.fixture(autouse=True) def setUp(self): self.model = pybamm.lithium_ion.DFN def test_electrolyte_options(self): options = {"electrolyte conductivity": "integrated"} - with self.assertRaisesRegex(pybamm.OptionError, "electrolyte conductivity"): + with pytest.raises(pybamm.OptionError, match="electrolyte conductivity"): pybamm.lithium_ion.DFN(options) def test_well_posed_size_distribution(self): @@ -66,13 +66,3 @@ def test_well_posed_msmr_with_psd(self): "intercalation kinetics": "MSMR", } self.check_well_posedness(options) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_dfn_half_cell.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_dfn_half_cell.py index 78d9ebda94..389fcf9429 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_dfn_half_cell.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_dfn_half_cell.py @@ -1,22 +1,13 @@ # # Tests for the lithium-ion half-cell DFN model # -from tests import TestCase + import pybamm -import unittest from tests import BaseUnitTestLithiumIonHalfCell +import pytest -class TestDFNHalfCell(BaseUnitTestLithiumIonHalfCell, TestCase): +class TestDFNHalfCell(BaseUnitTestLithiumIonHalfCell): + @pytest.fixture(autouse=True) def setUp(self): self.model = pybamm.lithium_ion.DFN - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_mpm_half_cell.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_mpm_half_cell.py index 77d51f6cf7..e5637968c3 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_mpm_half_cell.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_mpm_half_cell.py @@ -1,12 +1,10 @@ # # Tests for the lithium-ion MPM model # -from tests import TestCase import pybamm -import unittest -class TestMPM(TestCase): +class TestMPM: def test_well_posed(self): options = {"thermal": "isothermal", "working electrode": "positive"} model = pybamm.lithium_ion.MPM(options) @@ -20,9 +18,9 @@ def test_well_posed(self): def test_default_parameter_values(self): # check default parameters are added correctly model = pybamm.lithium_ion.MPM({"working electrode": "positive"}) - self.assertEqual( - model.default_parameter_values["Positive minimum particle radius [m]"], - 0.0, + assert ( + model.default_parameter_values["Positive minimum particle radius [m]"] + == 0.0 ) def test_lumped_thermal_model_1D(self): @@ -44,7 +42,7 @@ def test_differential_surface_form(self): model.check_well_posedness() -class TestMPMExternalCircuits(TestCase): +class TestMPMExternalCircuits: def test_well_posed_voltage(self): options = {"operating mode": "voltage", "working electrode": "positive"} model = pybamm.lithium_ion.MPM(options) @@ -67,13 +65,3 @@ def external_circuit_function(variables): } model = pybamm.lithium_ion.MPM(options) model.check_well_posedness() - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_msmr.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_msmr.py index 96369fbac2..4f1958d095 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_msmr.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_msmr.py @@ -1,22 +1,10 @@ # # Tests for the lithium-ion MSMR model # -from tests import TestCase import pybamm -import unittest -class TestMSMR(TestCase): +class TestMSMR: def test_well_posed(self): model = pybamm.lithium_ion.MSMR({"number of MSMR reactions": ("6", "4")}) model.check_well_posedness() - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_newman_tobias.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_newman_tobias.py index 5369d94b29..c979474e13 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_newman_tobias.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_newman_tobias.py @@ -1,42 +1,41 @@ # # Tests for the lithium-ion Newman-Tobias model # -from tests import TestCase import pybamm -import unittest +import pytest from tests import BaseUnitTestLithiumIon -class TestNewmanTobias(BaseUnitTestLithiumIon, TestCase): +class TestNewmanTobias(BaseUnitTestLithiumIon): + @pytest.fixture(autouse=True) def setUp(self): self.model = pybamm.lithium_ion.NewmanTobias def test_electrolyte_options(self): options = {"electrolyte conductivity": "integrated"} - with self.assertRaisesRegex(pybamm.OptionError, "electrolyte conductivity"): + with pytest.raises(pybamm.OptionError, match="electrolyte conductivity"): pybamm.lithium_ion.NewmanTobias(options) + @pytest.mark.skip(reason="Test currently not implemented") def test_well_posed_particle_phases(self): pass # skip this test + @pytest.mark.skip(reason="Test currently not implemented") def test_well_posed_particle_phases_thermal(self): pass # Skip this test + @pytest.mark.skip(reason="Test currently not implemented") def test_well_posed_particle_phases_sei(self): pass # skip this test + @pytest.mark.skip(reason="Test currently not implemented") def test_well_posed_composite_kinetic_hysteresis(self): pass # skip this test + @pytest.mark.skip(reason="Test currently not implemented") def test_well_posed_composite_diffusion_hysteresis(self): pass # skip this test - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + @pytest.mark.skip(reason="Test currently not implemented") + def test_well_posed_composite_different_degradation(self): + pass # skip this test diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spm.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spm.py index 45cf00877b..99affc7ddd 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spm.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spm.py @@ -1,19 +1,19 @@ # # Tests for the lithium-ion SPM model # -from tests import TestCase import pybamm -import unittest from tests import BaseUnitTestLithiumIon +import pytest -class TestSPM(BaseUnitTestLithiumIon, TestCase): +class TestSPM(BaseUnitTestLithiumIon): + @pytest.fixture(autouse=True) def setUp(self): self.model = pybamm.lithium_ion.SPM def test_electrolyte_options(self): options = {"electrolyte conductivity": "full"} - with self.assertRaisesRegex(pybamm.OptionError, "electrolyte conductivity"): + with pytest.raises(pybamm.OptionError, match="electrolyte conductivity"): pybamm.lithium_ion.SPM(options) def test_kinetics_options(self): @@ -21,7 +21,7 @@ def test_kinetics_options(self): "surface form": "false", "intercalation kinetics": "Marcus-Hush-Chidsey", } - with self.assertRaisesRegex(pybamm.OptionError, "Inverse kinetics"): + with pytest.raises(pybamm.OptionError, match="Inverse kinetics"): pybamm.lithium_ion.SPM(options) def test_x_average_options(self): @@ -37,11 +37,11 @@ def test_x_average_options(self): # Check model with distributed side reactions throws an error options["x-average side reactions"] = "false" - with self.assertRaisesRegex(pybamm.OptionError, "cannot be 'false' for SPM"): + with pytest.raises(pybamm.OptionError, match="cannot be 'false' for SPM"): pybamm.lithium_ion.SPM(options) def test_distribution_options(self): - with self.assertRaisesRegex(pybamm.OptionError, "surface form"): + with pytest.raises(pybamm.OptionError, match="surface form"): pybamm.lithium_ion.SPM({"particle size": "distribution"}) def test_particle_size_distribution(self): @@ -53,10 +53,10 @@ def test_new_model(self): new_model = model.new_copy() model_T_eqn = model.rhs[model.variables["Cell temperature [K]"]] new_model_T_eqn = new_model.rhs[new_model.variables["Cell temperature [K]"]] - self.assertEqual(new_model_T_eqn, model_T_eqn) - self.assertEqual(new_model.name, model.name) - self.assertEqual(new_model.use_jacobian, model.use_jacobian) - self.assertEqual(new_model.convert_to_format, model.convert_to_format) + assert new_model_T_eqn == model_T_eqn + assert new_model.name == model.name + assert new_model.use_jacobian == model.use_jacobian + assert new_model.convert_to_format == model.convert_to_format # with custom submodels options = {"stress-induced diffusion": "false", "thermal": "x-full"} @@ -72,14 +72,4 @@ def test_new_model(self): new_model = model.new_copy() new_model_cs_eqn = list(new_model.rhs.values())[1] model_cs_eqn = list(model.rhs.values())[1] - self.assertEqual(new_model_cs_eqn, model_cs_eqn) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert new_model_cs_eqn == model_cs_eqn diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spm_half_cell.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spm_half_cell.py index 0d6ba93ce0..c1b6b34745 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spm_half_cell.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spm_half_cell.py @@ -1,22 +1,12 @@ # # Tests for the lithium-ion half-cell SPM model # -from tests import TestCase import pybamm -import unittest from tests import BaseUnitTestLithiumIonHalfCell +import pytest -class TestSPMHalfCell(BaseUnitTestLithiumIonHalfCell, TestCase): +class TestSPMHalfCell(BaseUnitTestLithiumIonHalfCell): + @pytest.fixture(autouse=True) def setUp(self): self.model = pybamm.lithium_ion.SPM - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spme.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spme.py index 72222ee060..b0d38fa9c7 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spme.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spme.py @@ -1,13 +1,13 @@ # # Tests for the lithium-ion SPMe model # -from tests import TestCase import pybamm -import unittest from tests import BaseUnitTestLithiumIon +import pytest -class TestSPMe(BaseUnitTestLithiumIon, TestCase): +class TestSPMe(BaseUnitTestLithiumIon): + @pytest.fixture(autouse=True) def setUp(self): self.model = pybamm.lithium_ion.SPMe @@ -31,19 +31,9 @@ def setUp(self): def test_electrolyte_options(self): options = {"electrolyte conductivity": "full"} - with self.assertRaisesRegex(pybamm.OptionError, "electrolyte conductivity"): + with pytest.raises(pybamm.OptionError, match="electrolyte conductivity"): pybamm.lithium_ion.SPMe(options) def test_integrated_conductivity(self): options = {"electrolyte conductivity": "integrated"} self.check_well_posedness(options) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spme_half_cell.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spme_half_cell.py index f1930df026..2a814c113e 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spme_half_cell.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spme_half_cell.py @@ -3,21 +3,11 @@ # This is achieved by using the {"working electrode": "positive"} option # import pybamm -import unittest -from tests import TestCase from tests import BaseUnitTestLithiumIonHalfCell +import pytest -class TestSPMeHalfCell(BaseUnitTestLithiumIonHalfCell, TestCase): +class TestSPMeHalfCell(BaseUnitTestLithiumIonHalfCell): + @pytest.fixture(autouse=True) def setUp(self): self.model = pybamm.lithium_ion.SPMe - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_model_info.py b/tests/unit/test_models/test_model_info.py index 144d763bf1..b754399872 100644 --- a/tests/unit/test_models/test_model_info.py +++ b/tests/unit/test_models/test_model_info.py @@ -1,12 +1,10 @@ # # Tests getting model info # -from tests import TestCase import pybamm -import unittest -class TestModelInfo(TestCase): +class TestModelInfo: def test_find_parameter_info(self): model = pybamm.lithium_ion.SPM() model.info("Negative particle diffusivity [m2.s-1]") @@ -16,13 +14,3 @@ def test_find_parameter_info(self): model.info("Negative particle diffusivity [m2.s-1]") model.info("Not a parameter") - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_submodels/test_base_submodel.py b/tests/unit/test_models/test_submodels/test_base_submodel.py index 1519a2fea2..9f2a9c3549 100644 --- a/tests/unit/test_models/test_submodels/test_base_submodel.py +++ b/tests/unit/test_models/test_submodels/test_base_submodel.py @@ -1,52 +1,50 @@ # # Test base submodel # -from tests import TestCase - +import pytest import pybamm -import unittest -class TestBaseSubModel(TestCase): +class TestBaseSubModel: def test_domain(self): # Accepted string submodel = pybamm.BaseSubModel(None, "negative", phase="primary") - self.assertEqual(submodel.domain, "negative") + assert submodel.domain == "negative" # None submodel = pybamm.BaseSubModel(None, None) - self.assertEqual(submodel.domain, None) + assert submodel.domain is None # bad string - with self.assertRaises(pybamm.DomainError): + with pytest.raises(pybamm.DomainError): pybamm.BaseSubModel(None, "bad string") def test_phase(self): # Without domain submodel = pybamm.BaseSubModel(None, None) - self.assertEqual(submodel.phase, None) + assert submodel.phase is None - with self.assertRaisesRegex(ValueError, "Phase must be None"): + with pytest.raises(ValueError, match="Phase must be None"): pybamm.BaseSubModel(None, None, phase="primary") # With domain submodel = pybamm.BaseSubModel(None, "negative", phase="primary") - self.assertEqual(submodel.phase, "primary") - self.assertEqual(submodel.phase_name, "") + assert submodel.phase == "primary" + assert submodel.phase_name == "" submodel = pybamm.BaseSubModel( None, "negative", options={"particle phases": "2"}, phase="secondary" ) - self.assertEqual(submodel.phase, "secondary") - self.assertEqual(submodel.phase_name, "secondary ") + assert submodel.phase == "secondary" + assert submodel.phase_name == "secondary " - with self.assertRaisesRegex(ValueError, "Phase must be 'primary'"): + with pytest.raises(ValueError, match="Phase must be 'primary'"): pybamm.BaseSubModel(None, "negative", phase="secondary") - with self.assertRaisesRegex(ValueError, "Phase must be either 'primary'"): + with pytest.raises(ValueError, match="Phase must be either 'primary'"): pybamm.BaseSubModel( None, "negative", options={"particle phases": "2"}, phase="tertiary" ) - with self.assertRaisesRegex(ValueError, "Phase must be 'primary'"): + with pytest.raises(ValueError, match="Phase must be 'primary'"): # 2 phases in the negative but only 1 in the positive pybamm.BaseSubModel( None, @@ -54,13 +52,3 @@ def test_phase(self): options={"particle phases": ("2", "1")}, phase="secondary", ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_submodels/test_effective_current_collector.py b/tests/unit/test_models/test_submodels/test_effective_current_collector.py index cbab3134d4..b2437ec1d9 100644 --- a/tests/unit/test_models/test_submodels/test_effective_current_collector.py +++ b/tests/unit/test_models/test_submodels/test_effective_current_collector.py @@ -1,13 +1,12 @@ # # Tests for the Effective Current Collector Resistance models # -from tests import TestCase +import pytest import pybamm -import unittest import numpy as np -class TestEffectiveResistance(TestCase): +class TestEffectiveResistance: def test_well_posed(self): model = pybamm.current_collector.EffectiveResistance({"dimensionality": 1}) model.check_well_posedness() @@ -17,36 +16,34 @@ def test_well_posed(self): def test_default_parameters(self): model = pybamm.current_collector.EffectiveResistance({"dimensionality": 1}) - self.assertEqual( - model.default_parameter_values, pybamm.ParameterValues("Marquis2019") - ) + assert model.default_parameter_values == pybamm.ParameterValues("Marquis2019") def test_default_geometry(self): model = pybamm.current_collector.EffectiveResistance({"dimensionality": 1}) - self.assertTrue("current collector" in model.default_geometry) - self.assertNotIn("negative electrode", model.default_geometry) + assert "current collector" in model.default_geometry + assert "negative electrode" not in model.default_geometry model = pybamm.current_collector.EffectiveResistance({"dimensionality": 2}) - self.assertTrue("current collector" in model.default_geometry) - self.assertNotIn("negative electrode", model.default_geometry) + assert "current collector" in model.default_geometry + assert "negative electrode" not in model.default_geometry def test_default_var_pts(self): model = pybamm.current_collector.EffectiveResistance({"dimensionality": 1}) - self.assertEqual(model.default_var_pts, {"y": 32, "z": 32}) + assert model.default_var_pts == {"y": 32, "z": 32} def test_default_solver(self): model = pybamm.current_collector.EffectiveResistance({"dimensionality": 1}) - self.assertIsInstance(model.default_solver, pybamm.CasadiAlgebraicSolver) + assert isinstance(model.default_solver, pybamm.CasadiAlgebraicSolver) model = pybamm.current_collector.EffectiveResistance({"dimensionality": 2}) - self.assertIsInstance(model.default_solver, pybamm.CasadiAlgebraicSolver) + assert isinstance(model.default_solver, pybamm.CasadiAlgebraicSolver) def test_bad_option(self): - with self.assertRaisesRegex(pybamm.OptionError, "Dimension of"): + with pytest.raises(pybamm.OptionError, match="Dimension of"): pybamm.current_collector.EffectiveResistance({"dimensionality": 10}) -class TestEffectiveResistancePostProcess(TestCase): +class TestEffectiveResistancePostProcess: def test_get_processed_variables(self): # solve cheap SPM to test post-processing (think of an alternative test?) models = [ @@ -87,13 +84,3 @@ def test_get_processed_variables(self): processed_var(t=solution_1D.t[5], z=pts) else: processed_var(t=solution_1D.t[5], y=pts, z=pts) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_models/test_submodels/test_particle_polynomial_profile.py b/tests/unit/test_models/test_submodels/test_particle_polynomial_profile.py index 787230d9f3..57f1436f2d 100644 --- a/tests/unit/test_models/test_submodels/test_particle_polynomial_profile.py +++ b/tests/unit/test_models/test_submodels/test_particle_polynomial_profile.py @@ -1,22 +1,11 @@ # # Tests for the polynomial profile submodel # -from tests import TestCase import pybamm -import unittest +import pytest -class TestParticlePolynomialProfile(TestCase): +class TestParticlePolynomialProfile: def test_errors(self): - with self.assertRaisesRegex(ValueError, "Particle type must be"): + with pytest.raises(ValueError, match="Particle type must be"): pybamm.particle.PolynomialProfile(None, "negative", {}) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_parameters/test_base_parameters.py b/tests/unit/test_parameters/test_base_parameters.py index 6c87cdcd88..2c48074a71 100644 --- a/tests/unit/test_parameters/test_base_parameters.py +++ b/tests/unit/test_parameters/test_base_parameters.py @@ -2,48 +2,37 @@ Tests for the base_parameters.py """ -from tests import TestCase import pybamm -import unittest +import pytest -class TestBaseParameters(TestCase): +class TestBaseParameters: def test_getattr__(self): param = pybamm.LithiumIonParameters() # ending in _n / _s / _p - with self.assertRaisesRegex(AttributeError, "param.n.L"): + with pytest.raises(AttributeError, match="param.n.L"): param.L_n - with self.assertRaisesRegex(AttributeError, "param.s.L"): + with pytest.raises(AttributeError, match="param.s.L"): param.L_s - with self.assertRaisesRegex(AttributeError, "param.p.L"): + with pytest.raises(AttributeError, match="param.p.L"): param.L_p # _n_ in the name - with self.assertRaisesRegex(AttributeError, "param.n.prim.c_max"): + with pytest.raises(AttributeError, match="param.n.prim.c_max"): param.c_n_max # _n_ or _p_ not in name - with self.assertRaisesRegex( - AttributeError, "has no attribute 'c_n_not_a_parameter" + with pytest.raises( + AttributeError, match="has no attribute 'c_n_not_a_parameter" ): param.c_n_not_a_parameter - with self.assertRaisesRegex(AttributeError, "has no attribute 'c_s_test"): + with pytest.raises(AttributeError, match="has no attribute 'c_s_test"): pybamm.electrical_parameters.c_s_test - self.assertEqual(param.n.cap_init, param.n.Q_init) - self.assertEqual(param.p.prim.cap_init, param.p.prim.Q_init) + assert param.n.cap_init == param.n.Q_init + assert param.p.prim.cap_init == param.p.prim.Q_init def test__setattr__(self): # domain gets added as a subscript param = pybamm.GeometricParameters() - self.assertEqual(param.n.L.print_name, r"L_{\mathrm{n}}") - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert param.n.L.print_name == r"L_{\mathrm{n}}" diff --git a/tests/unit/test_parameters/test_electrical_parameters.py b/tests/unit/test_parameters/test_electrical_parameters.py index 92bceaf632..7601c30721 100644 --- a/tests/unit/test_parameters/test_electrical_parameters.py +++ b/tests/unit/test_parameters/test_electrical_parameters.py @@ -1,13 +1,11 @@ # # Tests for the electrical parameters # -from tests import TestCase +import pytest import pybamm -import unittest - -class TestElectricalParameters(TestCase): +class TestElectricalParameters: def test_current_functions(self): # create current functions param = pybamm.electrical_parameters @@ -27,17 +25,7 @@ def test_current_functions(self): current_density_eval = parameter_values.process_symbol(current_density) # check current - self.assertEqual(current_eval.evaluate(t=3), 2) + assert current_eval.evaluate(t=3) == 2 # check current density - self.assertAlmostEqual(current_density_eval.evaluate(t=3), 2 / (8 * 0.1 * 0.1)) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert current_density_eval.evaluate(t=3) == pytest.approx(2 / (8 * 0.1 * 0.1)) diff --git a/tests/unit/test_parameters/test_geometric_parameters.py b/tests/unit/test_parameters/test_geometric_parameters.py index 6e59259a12..7e000bf645 100644 --- a/tests/unit/test_parameters/test_geometric_parameters.py +++ b/tests/unit/test_parameters/test_geometric_parameters.py @@ -1,12 +1,10 @@ # # Tests for the standard parameters # -from tests import TestCase import pybamm -import unittest -class TestGeometricParameters(TestCase): +class TestGeometricParameters: def test_macroscale_parameters(self): geo = pybamm.geometric_parameters L_n = geo.n.L @@ -26,16 +24,4 @@ def test_macroscale_parameters(self): L_p_eval = parameter_values.process_symbol(L_p) L_x_eval = parameter_values.process_symbol(L_x) - self.assertEqual( - (L_n_eval + L_s_eval + L_p_eval).evaluate(), L_x_eval.evaluate() - ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert (L_n_eval + L_s_eval + L_p_eval).evaluate() == L_x_eval.evaluate() diff --git a/tests/unit/test_parameters/test_parameter_sets/test_Ai2020.py b/tests/unit/test_parameters/test_parameter_sets/test_Ai2020.py index 8816551ab6..f7302330bf 100644 --- a/tests/unit/test_parameters/test_parameter_sets/test_Ai2020.py +++ b/tests/unit/test_parameters/test_parameter_sets/test_Ai2020.py @@ -1,12 +1,11 @@ # # Tests for Ai (2020) Enertech parameter set loads # -from tests import TestCase +import pytest import pybamm -import unittest -class TestAi2020(TestCase): +class TestAi2020: def test_functions(self): param = pybamm.ParameterValues("Ai2020") sto = pybamm.Scalar(0.5) @@ -42,16 +41,6 @@ def test_functions(self): } for name, value in fun_test.items(): - self.assertAlmostEqual( - param.evaluate(param[name](*value[0])), value[1], places=4 + assert param.evaluate(param[name](*value[0])) == pytest.approx( + value[1], abs=0.0001 ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_parameters/test_parameter_sets/test_Ecker2015_graphite_halfcell.py b/tests/unit/test_parameters/test_parameter_sets/test_Ecker2015_graphite_halfcell.py index f435ef6d36..6000b997b7 100644 --- a/tests/unit/test_parameters/test_parameter_sets/test_Ecker2015_graphite_halfcell.py +++ b/tests/unit/test_parameters/test_parameter_sets/test_Ecker2015_graphite_halfcell.py @@ -1,12 +1,11 @@ # # Tests for O'Kane (2022) parameter set # -from tests import TestCase +import pytest import pybamm -import unittest -class TestEcker2015_graphite_halfcell(TestCase): +class TestEcker2015_graphite_halfcell: def test_functions(self): param = pybamm.ParameterValues("Ecker2015_graphite_halfcell") sto = pybamm.Scalar(0.5) @@ -33,16 +32,6 @@ def test_functions(self): } for name, value in fun_test.items(): - self.assertAlmostEqual( - param.evaluate(param[name](*value[0])), value[1], places=4 + assert param.evaluate(param[name](*value[0])) == pytest.approx( + value[1], abs=0.0001 ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_parameters/test_parameter_sets/test_LCO_Ramadass2004.py b/tests/unit/test_parameters/test_parameter_sets/test_LCO_Ramadass2004.py index 2de67b9e62..e6c4b04fdf 100644 --- a/tests/unit/test_parameters/test_parameter_sets/test_LCO_Ramadass2004.py +++ b/tests/unit/test_parameters/test_parameter_sets/test_LCO_Ramadass2004.py @@ -1,12 +1,11 @@ # # Tests for Ai (2020) Enertech parameter set loads # -from tests import TestCase +import pytest import pybamm -import unittest -class TestRamadass2004(TestCase): +class TestRamadass2004: def test_functions(self): param = pybamm.ParameterValues("Ramadass2004") sto = pybamm.Scalar(0.5) @@ -40,16 +39,6 @@ def test_functions(self): } for name, value in fun_test.items(): - self.assertAlmostEqual( - param.evaluate(param[name](*value[0])), value[1], places=4 + assert param.evaluate(param[name](*value[0])) == pytest.approx( + value[1], abs=0.0001 ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_parameters/test_parameter_sets/test_LGM50_ORegan2022.py b/tests/unit/test_parameters/test_parameter_sets/test_LGM50_ORegan2022.py index f878b7d790..05a38b6245 100644 --- a/tests/unit/test_parameters/test_parameter_sets/test_LGM50_ORegan2022.py +++ b/tests/unit/test_parameters/test_parameter_sets/test_LGM50_ORegan2022.py @@ -1,12 +1,11 @@ # # Tests for LG M50 parameter set loads # -from tests import TestCase +import pytest import pybamm -import unittest -class TestORegan2022(TestCase): +class TestORegan2022: def test_functions(self): param = pybamm.ParameterValues("ORegan2022") T = pybamm.Scalar(298.15) @@ -68,16 +67,6 @@ def test_functions(self): } for name, value in fun_test.items(): - self.assertAlmostEqual( - param.evaluate(param[name](*value[0])), value[1], places=4 + assert param.evaluate(param[name](*value[0])) == pytest.approx( + value[1], abs=0.0001 ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_parameters/test_parameter_sets/test_OKane2022_negative_halfcell.py b/tests/unit/test_parameters/test_parameter_sets/test_OKane2022_negative_halfcell.py index 04a19e1002..bf39457dc4 100644 --- a/tests/unit/test_parameters/test_parameter_sets/test_OKane2022_negative_halfcell.py +++ b/tests/unit/test_parameters/test_parameter_sets/test_OKane2022_negative_halfcell.py @@ -1,12 +1,11 @@ # # Tests for O'Kane (2022) parameter set # -from tests import TestCase +import pytest import pybamm -import unittest -class TestOKane2022_graphite_SiOx_halfcell(TestCase): +class TestOKane2022_graphite_SiOx_halfcell: def test_functions(self): param = pybamm.ParameterValues("OKane2022_graphite_SiOx_halfcell") sto = pybamm.Scalar(0.9) @@ -31,16 +30,6 @@ def test_functions(self): } for name, value in fun_test.items(): - self.assertAlmostEqual( - param.evaluate(param[name](*value[0])), value[1], places=4 + assert param.evaluate(param[name](*value[0])) == pytest.approx( + value[1], abs=0.0001 ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_parameters/test_parameter_sets_class.py b/tests/unit/test_parameters/test_parameter_sets_class.py index b14000f987..342cf127aa 100644 --- a/tests/unit/test_parameters/test_parameter_sets_class.py +++ b/tests/unit/test_parameters/test_parameter_sets_class.py @@ -1,23 +1,22 @@ # # Tests for the ParameterSets class # -from tests import TestCase - +import pytest +import re import pybamm -import unittest -class TestParameterSets(TestCase): +class TestParameterSets: def test_name_interface(self): """Test that pybamm.parameters_sets. returns the name of the parameter set and a depreciation warning """ - with self.assertWarns(DeprecationWarning): + with pytest.warns(DeprecationWarning): out = pybamm.parameter_sets.Marquis2019 - self.assertEqual(out, "Marquis2019") + assert out == "Marquis2019" - # Expect error for parameter set's that aren't real - with self.assertRaises(AttributeError): + # Expect an error for parameter sets that aren't real + with pytest.raises(AttributeError): pybamm.parameter_sets.not_a_real_parameter_set def test_all_registered(self): @@ -26,26 +25,15 @@ def test_all_registered(self): known_entry_points = set( ep.name for ep in pybamm.parameter_sets.get_entries("pybamm_parameter_sets") ) - self.assertEqual(set(pybamm.parameter_sets.keys()), known_entry_points) - self.assertEqual(len(known_entry_points), len(pybamm.parameter_sets)) + assert set(pybamm.parameter_sets.keys()) == known_entry_points + assert len(known_entry_points) == len(pybamm.parameter_sets) def test_get_docstring(self): """Test that :meth:`pybamm.parameter_sets.get_doctstring` works""" docstring = pybamm.parameter_sets.get_docstring("Marquis2019") - self.assertRegex(docstring, "Parameters for a Kokam SLPB78205130H cell") + assert re.search("Parameters for a Kokam SLPB78205130H cell", docstring) def test_iter(self): """Test that iterating `pybamm.parameter_sets` iterates over keys""" for k in pybamm.parameter_sets: - self.assertIsInstance(k, str) - self.assertIn(k, pybamm.parameter_sets) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert isinstance(k, str) diff --git a/tests/unit/test_parameters/test_size_distribution_parameters.py b/tests/unit/test_parameters/test_size_distribution_parameters.py index 5deeaa62be..414b422055 100644 --- a/tests/unit/test_parameters/test_size_distribution_parameters.py +++ b/tests/unit/test_parameters/test_size_distribution_parameters.py @@ -2,13 +2,12 @@ # Tests particle size distribution parameters are loaded into a parameter set # and give expected values # +import pytest import pybamm -import unittest import numpy as np -from tests import TestCase -class TestSizeDistributionParameters(TestCase): +class TestSizeDistributionParameters: def test_parameter_values(self): values = pybamm.lithium_ion.BaseModel().default_parameter_values param = pybamm.LithiumIonParameters() @@ -20,7 +19,7 @@ def test_parameter_values(self): ) # check negative parameters aren't there yet - with self.assertRaises(KeyError): + with pytest.raises(KeyError): values["Negative maximum particle radius [m]"] # now add distribution parameter values for negative electrode @@ -41,13 +40,3 @@ def test_parameter_values(self): R_test = pybamm.Scalar(1.0) values.evaluate(param.n.prim.f_a_dist(R_test)) values.evaluate(param.p.prim.f_a_dist(R_test)) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_plotting/test_plot.py b/tests/unit/test_plotting/test_plot.py index f36e20cd6f..1c049269c3 100644 --- a/tests/unit/test_plotting/test_plot.py +++ b/tests/unit/test_plotting/test_plot.py @@ -1,14 +1,13 @@ import pybamm -import unittest +import pytest import numpy as np -from tests import TestCase import matplotlib.pyplot as plt from matplotlib import use use("Agg") -class TestPlot(TestCase): +class TestPlot: def test_plot(self): x = pybamm.Array(np.array([0, 3, 10])) y = pybamm.Array(np.array([6, 16, 78])) @@ -16,13 +15,13 @@ def test_plot(self): _, ax = plt.subplots() ax_out = pybamm.plot(x, y, ax=ax, show_plot=False) - self.assertEqual(ax_out, ax) + assert ax_out == ax def test_plot_fail(self): x = pybamm.Array(np.array([0])) - with self.assertRaisesRegex(TypeError, "x must be 'pybamm.Array'"): + with pytest.raises(TypeError, match="x must be 'pybamm.Array'"): pybamm.plot("bad", x) - with self.assertRaisesRegex(TypeError, "y must be 'pybamm.Array'"): + with pytest.raises(TypeError, match="y must be 'pybamm.Array'"): pybamm.plot(x, "bad") def test_plot2D(self): @@ -38,23 +37,13 @@ def test_plot2D(self): _, ax = plt.subplots() ax_out = pybamm.plot2D(X, Y, Y, ax=ax, show_plot=False) - self.assertEqual(ax_out, ax) + assert ax_out == ax def test_plot2D_fail(self): x = pybamm.Array(np.array([0])) - with self.assertRaisesRegex(TypeError, "x must be 'pybamm.Array'"): + with pytest.raises(TypeError, match="x must be 'pybamm.Array'"): pybamm.plot2D("bad", x, x) - with self.assertRaisesRegex(TypeError, "y must be 'pybamm.Array'"): + with pytest.raises(TypeError, match="y must be 'pybamm.Array'"): pybamm.plot2D(x, "bad", x) - with self.assertRaisesRegex(TypeError, "z must be 'pybamm.Array'"): + with pytest.raises(TypeError, match="z must be 'pybamm.Array'"): pybamm.plot2D(x, x, "bad") - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_plotting/test_plot_summary_variables.py b/tests/unit/test_plotting/test_plot_summary_variables.py index e896b1f468..5f1a650ced 100644 --- a/tests/unit/test_plotting/test_plot_summary_variables.py +++ b/tests/unit/test_plotting/test_plot_summary_variables.py @@ -1,10 +1,8 @@ import pybamm -import unittest import numpy as np -from tests import TestCase -class TestPlotSummaryVariables(TestCase): +class TestPlotSummaryVariables: def test_plot(self): model = pybamm.lithium_ion.SPM({"SEI": "ec reaction limited"}) parameter_values = pybamm.ParameterValues("Mohtat2020") @@ -39,11 +37,11 @@ def test_plot(self): axes = pybamm.plot_summary_variables(sol, show_plot=False) axes = axes.flatten() - self.assertEqual(len(axes), 9) + assert len(axes) == 9 for output_var, ax in zip(output_variables, axes): - self.assertEqual(ax.get_xlabel(), "Cycle number") - self.assertEqual(ax.get_ylabel(), output_var) + assert ax.get_xlabel() == "Cycle number" + assert ax.get_ylabel() == output_var cycle_number, var = ax.get_lines()[0].get_data() np.testing.assert_array_equal( @@ -56,11 +54,11 @@ def test_plot(self): ) axes = axes.flatten() - self.assertEqual(len(axes), 9) + assert len(axes) == 9 for output_var, ax in zip(output_variables, axes): - self.assertEqual(ax.get_xlabel(), "Cycle number") - self.assertEqual(ax.get_ylabel(), output_var) + assert ax.get_xlabel() == "Cycle number" + assert ax.get_ylabel() == output_var cycle_number, var = ax.get_lines()[0].get_data() np.testing.assert_array_equal( @@ -73,13 +71,3 @@ def test_plot(self): cycle_number, sol.summary_variables["Cycle number"] ) np.testing.assert_array_equal(var, sol.summary_variables[output_var]) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_plotting/test_plot_thermal_components.py b/tests/unit/test_plotting/test_plot_thermal_components.py index 99b3d40cac..2b4cdf1e1e 100644 --- a/tests/unit/test_plotting/test_plot_thermal_components.py +++ b/tests/unit/test_plotting/test_plot_thermal_components.py @@ -1,14 +1,13 @@ +import pytest import pybamm -import unittest import numpy as np -from tests import TestCase import matplotlib.pyplot as plt from matplotlib import use use("Agg") -class TestPlotThermalComponents(TestCase): +class TestPlotThermalComponents: def test_plot_with_solution(self): model = pybamm.lithium_ion.SPM({"thermal": "lumped"}) sim = pybamm.Simulation( @@ -30,22 +29,12 @@ def test_plot_with_solution(self): _, ax = plt.subplots(1, 2) _, ax_out = pybamm.plot_thermal_components(sol, ax=ax, show_legend=True) - self.assertEqual(ax_out[0], ax[0]) - self.assertEqual(ax_out[1], ax[1]) + assert ax_out[0] == ax[0] + assert ax_out[1] == ax[1] def test_not_implemented(self): model = pybamm.lithium_ion.SPM({"thermal": "x-full"}) sim = pybamm.Simulation(model) sol = sim.solve([0, 3600]) - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): pybamm.plot_thermal_components(sol) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_plotting/test_plot_voltage_components.py b/tests/unit/test_plotting/test_plot_voltage_components.py index 1773d576d9..2b9da43fc1 100644 --- a/tests/unit/test_plotting/test_plot_voltage_components.py +++ b/tests/unit/test_plotting/test_plot_voltage_components.py @@ -1,14 +1,13 @@ +import pytest import pybamm -import unittest import numpy as np -from tests import TestCase import matplotlib.pyplot as plt from matplotlib import use use("Agg") -class TestPlotVoltageComponents(TestCase): +class TestPlotVoltageComponents: def test_plot_with_solution(self): model = pybamm.lithium_ion.SPM() sim = pybamm.Simulation(model) @@ -23,7 +22,7 @@ def test_plot_with_solution(self): _, ax = plt.subplots() _, ax_out = pybamm.plot_voltage_components(sol, ax=ax, show_legend=True) - self.assertEqual(ax_out, ax) + assert ax_out == ax def test_plot_with_simulation(self): model = pybamm.lithium_ion.SPM() @@ -40,7 +39,7 @@ def test_plot_with_simulation(self): _, ax = plt.subplots() _, ax_out = pybamm.plot_voltage_components(sim, ax=ax, show_legend=True) - self.assertEqual(ax_out, ax) + assert ax_out == ax def test_plot_from_solution(self): model = pybamm.lithium_ion.SPM() @@ -56,7 +55,7 @@ def test_plot_from_solution(self): _, ax = plt.subplots() _, ax_out = sol.plot_voltage_components(ax=ax, show_legend=True) - self.assertEqual(ax_out, ax) + assert ax_out == ax def test_plot_from_simulation(self): model = pybamm.lithium_ion.SPM() @@ -73,25 +72,12 @@ def test_plot_from_simulation(self): _, ax = plt.subplots() _, ax_out = sim.plot_voltage_components(ax=ax, show_legend=True) - self.assertEqual(ax_out, ax) + assert ax_out == ax def test_plot_without_solution(self): model = pybamm.lithium_ion.SPM() sim = pybamm.Simulation(model) - with self.assertRaises(ValueError) as error: + with pytest.raises(ValueError) as error: sim.plot_voltage_components() - - self.assertEqual( - str(error.exception), "The simulation has not been solved yet." - ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert str(error.exception) == "The simulation has not been solved yet." diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py index a3b62f8ee4..6573929ad9 100644 --- a/tests/unit/test_settings.py +++ b/tests/unit/test_settings.py @@ -1,52 +1,49 @@ # # Tests the settings class. # -from tests import TestCase + import pybamm -import unittest +import pytest -class TestSettings(TestCase): +class TestSettings: def test_simplify(self): - self.assertTrue(pybamm.settings.simplify) + with pytest.raises(TypeError): + pybamm.settings.simplify = "Not Bool" + + assert pybamm.settings.simplify pybamm.settings.simplify = False - self.assertFalse(pybamm.settings.simplify) + assert not pybamm.settings.simplify pybamm.settings.simplify = True + def test_debug_mode(self): + with pytest.raises(TypeError): + pybamm.settings.debug_mode = "Not bool" + def test_smoothing_parameters(self): - self.assertEqual(pybamm.settings.min_max_mode, "exact") - self.assertEqual(pybamm.settings.heaviside_smoothing, "exact") - self.assertEqual(pybamm.settings.abs_smoothing, "exact") + assert pybamm.settings.min_max_mode == "exact" + assert pybamm.settings.heaviside_smoothing == "exact" + assert pybamm.settings.abs_smoothing == "exact" pybamm.settings.set_smoothing_parameters(10) - self.assertEqual(pybamm.settings.min_max_smoothing, 10) - self.assertEqual(pybamm.settings.heaviside_smoothing, 10) - self.assertEqual(pybamm.settings.abs_smoothing, 10) + assert pybamm.settings.min_max_smoothing == 10 + assert pybamm.settings.heaviside_smoothing == 10 + assert pybamm.settings.abs_smoothing == 10 pybamm.settings.set_smoothing_parameters("exact") # Test errors - with self.assertRaisesRegex(ValueError, "greater than 1"): + with pytest.raises(ValueError, match="greater than 1"): pybamm.settings.min_max_mode = "smooth" pybamm.settings.min_max_smoothing = 0.9 - with self.assertRaisesRegex(ValueError, "positive number"): + with pytest.raises(ValueError, match="positive number"): pybamm.settings.min_max_mode = "soft" pybamm.settings.min_max_smoothing = -10 - with self.assertRaisesRegex(ValueError, "positive number"): + with pytest.raises(ValueError, match="positive number"): pybamm.settings.heaviside_smoothing = -10 - with self.assertRaisesRegex(ValueError, "positive number"): + with pytest.raises(ValueError, match="positive number"): pybamm.settings.abs_smoothing = -10 - with self.assertRaisesRegex(ValueError, "'soft', or 'smooth'"): + with pytest.raises(ValueError, match="'soft', or 'smooth'"): pybamm.settings.min_max_mode = "unknown" pybamm.settings.set_smoothing_parameters("exact") - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_solvers/test_dummy_solver.py b/tests/unit/test_solvers/test_dummy_solver.py index 7c7b9a35f7..acce1d3543 100644 --- a/tests/unit/test_solvers/test_dummy_solver.py +++ b/tests/unit/test_solvers/test_dummy_solver.py @@ -1,14 +1,11 @@ # # Tests for the Dummy Solver class # -from tests import TestCase import pybamm import numpy as np -import unittest -import sys -class TestDummySolver(TestCase): +class TestDummySolver: def test_dummy_solver(self): model = pybamm.BaseModel() v = pybamm.Scalar(1) @@ -44,12 +41,3 @@ def test_dummy_solver_step(self): np.testing.assert_array_equal(len(sol.t), t_eval.size * 2 - 2) np.testing.assert_array_equal(sol.y, np.zeros((1, sol.t.size))) np.testing.assert_array_equal(sol["v"].data, np.ones(sol.t.size)) - - -if __name__ == "__main__": - print("Add -v for more debug output") - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_solvers/test_idaklu_solver.py b/tests/unit/test_solvers/test_idaklu_solver.py index 1697623486..a80ab74b9e 100644 --- a/tests/unit/test_solvers/test_idaklu_solver.py +++ b/tests/unit/test_solvers/test_idaklu_solver.py @@ -18,15 +18,17 @@ def test_ida_roberts_klu(self): # this test implements a python version of the ida Roberts # example provided in sundials # see sundials ida examples pdf - for form in ["python", "casadi", "jax"]: - if form == "jax" and not pybamm.have_jax(): + for form in ["python", "casadi", "jax", "iree"]: + if (form == "jax" or form == "iree") and not pybamm.have_jax(): + continue + if (form == "iree") and not pybamm.have_iree(): continue if form == "casadi": root_method = "casadi" else: root_method = "lm" model = pybamm.BaseModel() - model.convert_to_format = form + model.convert_to_format = "jax" if form == "iree" else form u = pybamm.Variable("u") v = pybamm.Variable("v") model.rhs = {u: 0.1 * v} @@ -37,7 +39,10 @@ def test_ida_roberts_klu(self): disc = pybamm.Discretisation() disc.process_model(model) - solver = pybamm.IDAKLUSolver(root_method=root_method) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + options={"jax_evaluator": "iree"} if form == "iree" else {}, + ) t_eval = np.linspace(0, 3, 100) solution = solver.solve(model, t_eval) @@ -59,8 +64,10 @@ def test_ida_roberts_klu(self): np.testing.assert_array_almost_equal(solution.y[0, :], true_solution) def test_model_events(self): - for form in ["python", "casadi", "jax"]: - if form == "jax" and not pybamm.have_jax(): + for form in ["python", "casadi", "jax", "iree"]: + if (form == "jax" or form == "iree") and not pybamm.have_jax(): + continue + if (form == "iree") and not pybamm.have_iree(): continue if form == "casadi": root_method = "casadi" @@ -68,7 +75,7 @@ def test_model_events(self): root_method = "lm" # Create model model = pybamm.BaseModel() - model.convert_to_format = form + model.convert_to_format = "jax" if form == "iree" else form var = pybamm.Variable("var") model.rhs = {var: 0.1 * var} model.initial_conditions = {var: 1} @@ -77,7 +84,12 @@ def test_model_events(self): disc = pybamm.Discretisation() model_disc = disc.process_model(model, inplace=False) # Solve - solver = pybamm.IDAKLUSolver(rtol=1e-8, atol=1e-8, root_method=root_method) + solver = pybamm.IDAKLUSolver( + rtol=1e-8, + atol=1e-8, + root_method=root_method, + options={"jax_evaluator": "iree"} if form == "iree" else {}, + ) t_eval = np.linspace(0, 1, 100) solution = solver.solve(model_disc, t_eval) np.testing.assert_array_equal(solution.t, t_eval) @@ -92,7 +104,12 @@ def test_model_events(self): # enforce events that won't be triggered model.events = [pybamm.Event("an event", var + 1)] model_disc = disc.process_model(model, inplace=False) - solver = pybamm.IDAKLUSolver(rtol=1e-8, atol=1e-8, root_method=root_method) + solver = pybamm.IDAKLUSolver( + rtol=1e-8, + atol=1e-8, + root_method=root_method, + options={"jax_evaluator": "iree"} if form == "iree" else {}, + ) solution = solver.solve(model_disc, t_eval) np.testing.assert_array_equal(solution.t, t_eval) np.testing.assert_array_almost_equal( @@ -102,7 +119,12 @@ def test_model_events(self): # enforce events that will be triggered model.events = [pybamm.Event("an event", 1.01 - var)] model_disc = disc.process_model(model, inplace=False) - solver = pybamm.IDAKLUSolver(rtol=1e-8, atol=1e-8, root_method=root_method) + solver = pybamm.IDAKLUSolver( + rtol=1e-8, + atol=1e-8, + root_method=root_method, + options={"jax_evaluator": "iree"} if form == "iree" else {}, + ) solution = solver.solve(model_disc, t_eval) self.assertLess(len(solution.t), len(t_eval)) np.testing.assert_array_almost_equal( @@ -124,7 +146,12 @@ def test_model_events(self): disc = get_discretisation_for_testing() disc.process_model(model) - solver = pybamm.IDAKLUSolver(rtol=1e-8, atol=1e-8, root_method=root_method) + solver = pybamm.IDAKLUSolver( + rtol=1e-8, + atol=1e-8, + root_method=root_method, + options={"jax_evaluator": "iree"} if form == "iree" else {}, + ) t_eval = np.linspace(0, 5, 100) solution = solver.solve(model, t_eval) np.testing.assert_array_less(solution.y[0, :-1], 1.5) @@ -140,15 +167,17 @@ def test_model_events(self): def test_input_params(self): # test a mix of scalar and vector input params - for form in ["python", "casadi", "jax"]: - if form == "jax" and not pybamm.have_jax(): + for form in ["python", "casadi", "jax", "iree"]: + if (form == "jax" or form == "iree") and not pybamm.have_jax(): + continue + if (form == "iree") and not pybamm.have_iree(): continue if form == "casadi": root_method = "casadi" else: root_method = "lm" model = pybamm.BaseModel() - model.convert_to_format = form + model.convert_to_format = "jax" if form == "iree" else form u1 = pybamm.Variable("u1") u2 = pybamm.Variable("u2") u3 = pybamm.Variable("u3") @@ -162,7 +191,10 @@ def test_input_params(self): disc = pybamm.Discretisation() disc.process_model(model) - solver = pybamm.IDAKLUSolver(root_method=root_method) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + options={"jax_evaluator": "iree"} if form == "iree" else {}, + ) t_eval = np.linspace(0, 3, 100) a_value = 0.1 @@ -185,48 +217,63 @@ def test_input_params(self): true_solution = b_value * sol.t np.testing.assert_array_almost_equal(sol.y[1:3], true_solution) - def test_sensitivites_initial_condition(self): - for output_variables in [[], ["2v"]]: - model = pybamm.BaseModel() - model.convert_to_format = "casadi" - u = pybamm.Variable("u") - v = pybamm.Variable("v") - a = pybamm.InputParameter("a") - model.rhs = {u: -u} - model.algebraic = {v: a * u - v} - model.initial_conditions = {u: 1, v: 1} - model.variables = {"2v": 2 * v} - - disc = pybamm.Discretisation() - disc.process_model(model) - solver = pybamm.IDAKLUSolver(output_variables=output_variables) - - t_eval = np.linspace(0, 3, 100) - a_value = 0.1 - - sol = solver.solve( - model, t_eval, inputs={"a": a_value}, calculate_sensitivities=True - ) - - np.testing.assert_array_almost_equal( - sol["2v"].sensitivities["a"].full().flatten(), - np.exp(-sol.t) * 2, - decimal=4, - ) + def test_sensitivities_initial_condition(self): + for form in ["casadi", "iree"]: + for output_variables in [[], ["2v"]]: + if (form == "jax" or form == "iree") and not pybamm.have_jax(): + continue + if (form == "iree") and not pybamm.have_iree(): + continue + if form == "casadi": + root_method = "casadi" + else: + root_method = "lm" + model = pybamm.BaseModel() + model.convert_to_format = "jax" if form == "iree" else form + u = pybamm.Variable("u") + v = pybamm.Variable("v") + a = pybamm.InputParameter("a") + model.rhs = {u: -u} + model.algebraic = {v: a * u - v} + model.initial_conditions = {u: 1, v: 1} + model.variables = {"2v": 2 * v} + + disc = pybamm.Discretisation() + disc.process_model(model) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + output_variables=output_variables, + options={"jax_evaluator": "iree"} if form == "iree" else {}, + ) + + t_eval = np.linspace(0, 3, 100) + a_value = 0.1 + + sol = solver.solve( + model, t_eval, inputs={"a": a_value}, calculate_sensitivities=True + ) + + np.testing.assert_array_almost_equal( + sol["2v"].sensitivities["a"].full().flatten(), + np.exp(-sol.t) * 2, + decimal=4, + ) def test_ida_roberts_klu_sensitivities(self): # this test implements a python version of the ida Roberts # example provided in sundials # see sundials ida examples pdf - for form in ["python", "casadi", "jax"]: - if form == "jax" and not pybamm.have_jax(): + for form in ["python", "casadi", "jax", "iree"]: + if (form == "jax" or form == "iree") and not pybamm.have_jax(): + continue + if (form == "iree") and not pybamm.have_iree(): continue if form == "casadi": root_method = "casadi" else: root_method = "lm" model = pybamm.BaseModel() - model.convert_to_format = form + model.convert_to_format = "jax" if form == "iree" else form u = pybamm.Variable("u") v = pybamm.Variable("v") a = pybamm.InputParameter("a") @@ -238,7 +285,10 @@ def test_ida_roberts_klu_sensitivities(self): disc = pybamm.Discretisation() disc.process_model(model) - solver = pybamm.IDAKLUSolver(root_method=root_method) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + options={"jax_evaluator": "iree"} if form == "iree" else {}, + ) t_eval = np.linspace(0, 3, 100) a_value = 0.1 @@ -283,25 +333,32 @@ def test_ida_roberts_klu_sensitivities(self): dyda_fd = (sol_plus.y - sol_neg.y) / h dyda_fd = dyda_fd.transpose().reshape(-1, 1) - np.testing.assert_array_almost_equal(dyda_ida, dyda_fd) + decimal = ( + 2 if form == "iree" else 6 + ) # iree currently operates with single precision + np.testing.assert_array_almost_equal(dyda_ida, dyda_fd, decimal=decimal) # get the sensitivities for the variable d2uda = sol["2u"].sensitivities["a"] - np.testing.assert_array_almost_equal(2 * dyda_ida[0:200:2], d2uda) + np.testing.assert_array_almost_equal( + 2 * dyda_ida[0:200:2], d2uda, decimal=decimal + ) def test_sensitivities_with_events(self): # this test implements a python version of the ida Roberts # example provided in sundials # see sundials ida examples pdf - for form in ["casadi", "python", "jax"]: - if form == "jax" and not pybamm.have_jax(): + for form in ["casadi", "python", "jax", "iree"]: + if (form == "jax" or form == "iree") and not pybamm.have_jax(): + continue + if (form == "iree") and not pybamm.have_iree(): continue if form == "casadi": root_method = "casadi" else: root_method = "lm" model = pybamm.BaseModel() - model.convert_to_format = form + model.convert_to_format = "jax" if form == "iree" else form u = pybamm.Variable("u") v = pybamm.Variable("v") a = pybamm.InputParameter("a") @@ -314,7 +371,10 @@ def test_sensitivities_with_events(self): disc = pybamm.Discretisation() disc.process_model(model) - solver = pybamm.IDAKLUSolver(root_method=root_method) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + options={"jax_evaluator": "iree"} if form == "iree" else {}, + ) t_eval = np.linspace(0, 3, 100) a_value = 0.1 @@ -351,8 +411,11 @@ def test_sensitivities_with_events(self): dyda_fd = (sol_plus.y[:, :max_index] - sol_neg.y[:, :max_index]) / h dyda_fd = dyda_fd.transpose().reshape(-1, 1) + decimal = ( + 2 if form == "iree" else 6 + ) # iree currently operates with single precision np.testing.assert_array_almost_equal( - dyda_ida[: (2 * max_index), :], dyda_fd + dyda_ida[: (2 * max_index), :], dyda_fd, decimal=decimal ) sol_plus = solver.solve( @@ -366,7 +429,7 @@ def test_sensitivities_with_events(self): dydb_fd = dydb_fd.transpose().reshape(-1, 1) np.testing.assert_array_almost_equal( - dydb_ida[: (2 * max_index), :], dydb_fd + dydb_ida[: (2 * max_index), :], dydb_fd, decimal=decimal ) def test_failures(self): @@ -421,15 +484,17 @@ def test_failures(self): solver.solve(model, t_eval) def test_dae_solver_algebraic_model(self): - for form in ["python", "casadi", "jax"]: - if form == "jax" and not pybamm.have_jax(): + for form in ["python", "casadi", "jax", "iree"]: + if (form == "jax" or form == "iree") and not pybamm.have_jax(): + continue + if (form == "iree") and not pybamm.have_iree(): continue if form == "casadi": root_method = "casadi" else: root_method = "lm" model = pybamm.BaseModel() - model.convert_to_format = form + model.convert_to_format = "jax" if form == "iree" else form var = pybamm.Variable("var") model.algebraic = {var: var + 1} model.initial_conditions = {var: 0} @@ -437,7 +502,10 @@ def test_dae_solver_algebraic_model(self): disc = pybamm.Discretisation() disc.process_model(model) - solver = pybamm.IDAKLUSolver(root_method=root_method) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + options={"jax_evaluator": "iree"} if form == "iree" else {}, + ) t_eval = np.linspace(0, 1) solution = solver.solve(model, t_eval) np.testing.assert_array_equal(solution.y, -1) @@ -547,7 +615,7 @@ def test_options(self): soln = solver.solve(model, t_eval) def test_with_output_variables(self): - # Construct a model and solve for all vairables, then test + # Construct a model and solve for all variables, then test # the 'output_variables' option for each variable in turn, confirming # equivalence input_parameters = {} # Sensitivities dictionary @@ -649,76 +717,110 @@ def construct_model(): sol["x_s [m]"].initialise_1D() def test_with_output_variables_and_sensitivities(self): - # Construct a model and solve for all vairables, then test + # Construct a model and solve for all variables, then test # the 'output_variables' option for each variable in turn, confirming # equivalence - # construct model - model = pybamm.lithium_ion.DFN() - geometry = model.default_geometry - param = model.default_parameter_values - input_parameters = { # Sensitivities dictionary - "Current function [A]": 0.680616, - "Separator porosity": 1.0, - } - param.update({key: "[input]" for key in input_parameters}) - param.process_model(model) - param.process_geometry(geometry) - var_pts = {"x_n": 50, "x_s": 50, "x_p": 50, "r_n": 5, "r_p": 5} - mesh = pybamm.Mesh(geometry, model.default_submesh_types, var_pts) - disc = pybamm.Discretisation(mesh, model.default_spatial_methods) - disc.process_model(model) - t_eval = np.linspace(0, 3600, 100) + for form in ["casadi", "iree"]: + if (form == "jax" or form == "iree") and not pybamm.have_jax(): + continue + if (form == "iree") and not pybamm.have_iree(): + continue + if form == "casadi": + root_method = "casadi" + else: + root_method = "lm" + input_parameters = { # Sensitivities dictionary + "Current function [A]": 0.222, + "Separator porosity": 0.3, + } - options = { - "linear_solver": "SUNLinSol_KLU", - "jacobian": "sparse", - "num_threads": 4, - } + # construct model + model = pybamm.lithium_ion.DFN() + model.convert_to_format = "jax" if form == "iree" else form + geometry = model.default_geometry + param = model.default_parameter_values + param.update({key: "[input]" for key in input_parameters}) + param.process_model(model) + param.process_geometry(geometry) + var_pts = {"x_n": 50, "x_s": 50, "x_p": 50, "r_n": 5, "r_p": 5} + mesh = pybamm.Mesh(geometry, model.default_submesh_types, var_pts) + disc = pybamm.Discretisation(mesh, model.default_spatial_methods) + disc.process_model(model) - # Use a selection of variables of different types - output_variables = [ - "Voltage [V]", - "Time [min]", - "x [m]", - "Negative particle flux [mol.m-2.s-1]", - "Throughput capacity [A.h]", # ExplicitTimeIntegral - ] + t_eval = np.linspace(0, 3600, 100) + + options = { + "linear_solver": "SUNLinSol_KLU", + "jacobian": "sparse", + "num_threads": 4, + } + if form == "iree": + options["jax_evaluator"] = "iree" + + # Use a selection of variables of different types + output_variables = [ + "Voltage [V]", + "Time [min]", + "x [m]", + "Negative particle flux [mol.m-2.s-1]", + "Throughput capacity [A.h]", # ExplicitTimeIntegral + ] - # Use the full model as comparison (tested separately) - solver_all = pybamm.IDAKLUSolver( - atol=1e-8, - rtol=1e-8, - options=options, - ) - sol_all = solver_all.solve( - model, - t_eval, - inputs=input_parameters, - calculate_sensitivities=True, - ) + # Use the full model as comparison (tested separately) + solver_all = pybamm.IDAKLUSolver( + root_method=root_method, + atol=1e-8 if form != "iree" else 1e-0, # iree has reduced precision + rtol=1e-8 if form != "iree" else 1e-0, # iree has reduced precision + options=options, + ) + sol_all = solver_all.solve( + model, + t_eval, + inputs=input_parameters, + calculate_sensitivities=True, + ) - # Solve for a subset of variables and compare results - solver = pybamm.IDAKLUSolver( - atol=1e-8, - rtol=1e-8, - options=options, - output_variables=output_variables, - ) - sol = solver.solve( - model, - t_eval, - inputs=input_parameters, - calculate_sensitivities=True, - ) + # Solve for a subset of variables and compare results + solver = pybamm.IDAKLUSolver( + root_method=root_method, + atol=1e-8 if form != "iree" else 1e-0, # iree has reduced precision + rtol=1e-8 if form != "iree" else 1e-0, # iree has reduced precision + options=options, + output_variables=output_variables, + ) + sol = solver.solve( + model, + t_eval, + inputs=input_parameters, + calculate_sensitivities=True, + ) - # Compare output to sol_all - for varname in output_variables: - self.assertTrue(np.allclose(sol[varname].data, sol_all[varname].data)) + # Compare output to sol_all + tol = 1e-5 if form != "iree" else 1e-2 # iree has reduced precision + for varname in output_variables: + np.testing.assert_array_almost_equal( + sol[varname].data, sol_all[varname].data, tol + ) - # Mock a 1D current collector and initialise (none in the model) - sol["x_s [m]"].domain = ["current collector"] - sol["x_s [m]"].initialise_1D() + # Mock a 1D current collector and initialise (none in the model) + sol["x_s [m]"].domain = ["current collector"] + sol["x_s [m]"].initialise_1D() + + def test_bad_jax_evaluator(self): + model = pybamm.lithium_ion.DFN() + model.convert_to_format = "jax" + with self.assertRaises(pybamm.SolverError): + pybamm.IDAKLUSolver(options={"jax_evaluator": "bad_evaluator"}) + + def test_bad_jax_evaluator_output_variables(self): + model = pybamm.lithium_ion.DFN() + model.convert_to_format = "jax" + with self.assertRaises(pybamm.SolverError): + pybamm.IDAKLUSolver( + options={"jax_evaluator": "bad_evaluator"}, + output_variables=["Terminal voltage [V]"], + ) if __name__ == "__main__": diff --git a/tests/unit/test_solvers/test_lrudict.py b/tests/unit/test_solvers/test_lrudict.py index a5378da786..ab38bddbc5 100644 --- a/tests/unit/test_solvers/test_lrudict.py +++ b/tests/unit/test_solvers/test_lrudict.py @@ -1,12 +1,12 @@ # # Tests for the LRUDict class # -import unittest +import pytest from pybamm.solvers.lrudict import LRUDict from collections import OrderedDict -class TestLRUDict(unittest.TestCase): +class TestLRUDict: def test_lrudict_defaultbehaviour(self): """Default behaviour [no LRU] mimics Dict""" d = LRUDict() @@ -20,27 +20,27 @@ def test_lrudict_defaultbehaviour(self): dd.get(count - 2) # assertCountEqual checks that the same elements are present in # both lists, not just that the lists are of equal count - self.assertCountEqual(set(d.keys()), set(dd.keys())) - self.assertCountEqual(set(d.values()), set(dd.values())) + assert set(d.keys()) == set(dd.keys()) + assert set(d.values()) == set(dd.values()) def test_lrudict_noitems(self): """Edge case: no items in LRU, raises KeyError on assignment""" d = LRUDict(maxsize=-1) - with self.assertRaises(KeyError): + with pytest.raises(KeyError): d["a"] = 1 def test_lrudict_singleitem(self): """Only the last added element should ever be present""" d = LRUDict(maxsize=1) item_list = range(1, 100) - self.assertEqual(len(d), 0) + assert len(d) == 0 for item in item_list: d[item] = item - self.assertEqual(len(d), 1) - self.assertIsNotNone(d[item]) + assert len(d) == 1 + assert d[item] is not None # Finally, pop the only item and check that the dictionary is empty d.popitem() - self.assertEqual(len(d), 0) + assert len(d) == 0 def test_lrudict_multiitem(self): """Check that the correctly ordered items are always present""" @@ -59,17 +59,17 @@ def test_lrudict_multiitem(self): expected = OrderedDict( (k, expected[k]) for k in list(expected.keys())[-maxsize:] ) - self.assertListEqual(list(d.keys()), list(expected.keys())) - self.assertListEqual(list(d.values()), list(expected.values())) + assert list(d.keys()) == list(expected.keys()) + assert list(d.values()) == list(expected.values()) def test_lrudict_invalidkey(self): d = LRUDict() value = 1 d["a"] = value # Access with valid key - self.assertEqual(d["a"], value) # checks getitem() - self.assertEqual(d.get("a"), value) # checks get() + assert d["a"] == value # checks getitem() + assert d.get("a") == value # checks get() # Access with invalid key - with self.assertRaises(KeyError): + with pytest.raises(KeyError): _ = d["b"] # checks getitem() - self.assertIsNone(d.get("b")) # checks get() + assert d.get("b") is None # checks get() diff --git a/tests/unit/test_spatial_methods/test_zero_dimensional_method.py b/tests/unit/test_spatial_methods/test_zero_dimensional_method.py index b3ec859412..1c620c7872 100644 --- a/tests/unit/test_spatial_methods/test_zero_dimensional_method.py +++ b/tests/unit/test_spatial_methods/test_zero_dimensional_method.py @@ -1,14 +1,12 @@ # # Test for the base Spatial Method class # -from tests import TestCase import numpy as np import pybamm -import unittest from tests import get_mesh_for_testing, get_discretisation_for_testing -class TestZeroDimensionalSpatialMethod(TestCase): +class TestZeroDimensionalSpatialMethod: def test_identity_ops(self): test_mesh = np.array([1, 2, 3]) spatial_method = pybamm.ZeroDimensionalSpatialMethod() @@ -16,14 +14,14 @@ def test_identity_ops(self): np.testing.assert_array_equal(spatial_method._mesh, test_mesh) a = pybamm.Symbol("a") - self.assertEqual(a, spatial_method.integral(None, a, "primary")) - self.assertEqual(a, spatial_method.indefinite_integral(None, a, "forward")) - self.assertEqual(a, spatial_method.boundary_value_or_flux(None, a)) - self.assertEqual((-a), spatial_method.indefinite_integral(None, a, "backward")) + assert a == spatial_method.integral(None, a, "primary") + assert a == spatial_method.indefinite_integral(None, a, "forward") + assert a == spatial_method.boundary_value_or_flux(None, a) + assert (-a) == spatial_method.indefinite_integral(None, a, "backward") mass_matrix = spatial_method.mass_matrix(None, None) - self.assertIsInstance(mass_matrix, pybamm.Matrix) - self.assertEqual(mass_matrix.shape, (1, 1)) + assert isinstance(mass_matrix, pybamm.Matrix) + assert mass_matrix.shape == (1, 1) np.testing.assert_array_equal(mass_matrix.entries, 1) def test_discretise_spatial_variable(self): @@ -38,7 +36,7 @@ def test_discretise_spatial_variable(self): r = pybamm.SpatialVariable("r", ["negative particle"]) for var in [x1, x2, r]: var_disc = spatial_method.spatial_variable(var) - self.assertIsInstance(var_disc, pybamm.Vector) + assert isinstance(var_disc, pybamm.Vector) np.testing.assert_array_equal( var_disc.evaluate()[:, 0], mesh[var.domain].nodes ) @@ -49,7 +47,7 @@ def test_discretise_spatial_variable(self): r_edge = pybamm.SpatialVariableEdge("r", ["negative particle"]) for var in [x1_edge, x2_edge, r_edge]: var_disc = spatial_method.spatial_variable(var) - self.assertIsInstance(var_disc, pybamm.Vector) + assert isinstance(var_disc, pybamm.Vector) np.testing.assert_array_equal( var_disc.evaluate()[:, 0], mesh[var.domain].edges ) @@ -70,13 +68,3 @@ def test_averages(self): np.testing.assert_array_equal( var_disc.evaluate(y=y), expr_disc.evaluate(y=y) ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_timer.py b/tests/unit/test_timer.py index 228cdd5dce..6ef62f791e 100644 --- a/tests/unit/test_timer.py +++ b/tests/unit/test_timer.py @@ -5,11 +5,9 @@ # (see https://github.com/pints-team/pints) # import pybamm -import unittest -from tests import TestCase -class TestTimer(TestCase): +class TestTimer: """ Tests the basic methods of the Timer class. """ @@ -20,64 +18,54 @@ def __init__(self, name): def test_timing(self): t = pybamm.Timer() a = t.time().value - self.assertGreaterEqual(a, 0) + assert a >= 0 for _ in range(100): - self.assertGreater(t.time().value, a) + assert t.time().value > a a = t.time().value t.reset() b = t.time().value - self.assertGreaterEqual(b, 0) - self.assertLess(b, a) + assert b >= 0 + assert b < a def test_timer_format(self): - self.assertEqual(str(pybamm.TimerTime(1e-9)), "1.000 ns") - self.assertEqual(str(pybamm.TimerTime(0.000000123456789)), "123.457 ns") - self.assertEqual(str(pybamm.TimerTime(1e-6)), "1.000 us") - self.assertEqual(str(pybamm.TimerTime(0.000123456789)), "123.457 us") - self.assertEqual(str(pybamm.TimerTime(0.999e-3)), "999.000 us") - self.assertEqual(str(pybamm.TimerTime(1e-3)), "1.000 ms") - self.assertEqual(str(pybamm.TimerTime(0.123456789)), "123.457 ms") - self.assertEqual(str(pybamm.TimerTime(2)), "2.000 s") - self.assertEqual(str(pybamm.TimerTime(2.5)), "2.500 s") - self.assertEqual(str(pybamm.TimerTime(12.5)), "12.500 s") - self.assertEqual(str(pybamm.TimerTime(59.41)), "59.410 s") - self.assertEqual(str(pybamm.TimerTime(59.4126347547)), "59.413 s") - self.assertEqual(str(pybamm.TimerTime(60.2)), "1 minute, 0 seconds") - self.assertEqual(str(pybamm.TimerTime(61)), "1 minute, 1 second") - self.assertEqual(str(pybamm.TimerTime(121)), "2 minutes, 1 second") - self.assertEqual( - str(pybamm.TimerTime(604800)), - "1 week, 0 days, 0 hours, 0 minutes, 0 seconds", + assert str(pybamm.TimerTime(1e-9)) == "1.000 ns" + assert str(pybamm.TimerTime(0.000000123456789)) == "123.457 ns" + assert str(pybamm.TimerTime(1e-6)) == "1.000 us" + assert str(pybamm.TimerTime(0.000123456789)) == "123.457 us" + assert str(pybamm.TimerTime(0.999e-3)) == "999.000 us" + assert str(pybamm.TimerTime(1e-3)) == "1.000 ms" + assert str(pybamm.TimerTime(0.123456789)) == "123.457 ms" + assert str(pybamm.TimerTime(2)) == "2.000 s" + assert str(pybamm.TimerTime(2.5)) == "2.500 s" + assert str(pybamm.TimerTime(12.5)) == "12.500 s" + assert str(pybamm.TimerTime(59.41)) == "59.410 s" + assert str(pybamm.TimerTime(59.4126347547)) == "59.413 s" + assert str(pybamm.TimerTime(60.2)) == "1 minute, 0 seconds" + assert str(pybamm.TimerTime(61)) == "1 minute, 1 second" + assert str(pybamm.TimerTime(121)) == "2 minutes, 1 second" + assert ( + str(pybamm.TimerTime(604800)) + == "1 week, 0 days, 0 hours, 0 minutes, 0 seconds" ) - self.assertEqual( - str(pybamm.TimerTime(2 * 604800 + 3 * 3600 + 60 + 4)), - "2 weeks, 0 days, 3 hours, 1 minute, 4 seconds", + assert ( + str(pybamm.TimerTime(2 * 604800 + 3 * 3600 + 60 + 4)) + == "2 weeks, 0 days, 3 hours, 1 minute, 4 seconds" ) - self.assertEqual(repr(pybamm.TimerTime(1.5)), "pybamm.TimerTime(1.5)") + assert repr(pybamm.TimerTime(1.5)) == "pybamm.TimerTime(1.5)" def test_timer_operations(self): - self.assertEqual((pybamm.TimerTime(1) + 2).value, 3) - self.assertEqual((1 + pybamm.TimerTime(1)).value, 2) - self.assertEqual((pybamm.TimerTime(1) - 2).value, -1) - self.assertEqual((pybamm.TimerTime(1) - pybamm.TimerTime(2)).value, -1) - self.assertEqual((1 - pybamm.TimerTime(1)).value, 0) - self.assertEqual((pybamm.TimerTime(4) * 2).value, 8) - self.assertEqual((pybamm.TimerTime(4) * pybamm.TimerTime(2)).value, 8) - self.assertEqual((2 * pybamm.TimerTime(5)).value, 10) - self.assertEqual((pybamm.TimerTime(4) / 2).value, 2) - self.assertEqual((pybamm.TimerTime(4) / pybamm.TimerTime(2)).value, 2) - self.assertEqual((2 / pybamm.TimerTime(5)).value, 2 / 5) + assert (pybamm.TimerTime(1) + 2).value == 3 + assert (1 + pybamm.TimerTime(1)).value == 2 + assert (pybamm.TimerTime(1) - 2).value == -1 + assert (pybamm.TimerTime(1) - pybamm.TimerTime(2)).value == -1 + assert (1 - pybamm.TimerTime(1)).value == 0 + assert (pybamm.TimerTime(4) * 2).value == 8 + assert (pybamm.TimerTime(4) * pybamm.TimerTime(2)).value == 8 + assert (2 * pybamm.TimerTime(5)).value == 10 + assert (pybamm.TimerTime(4) / 2).value == 2 + assert (pybamm.TimerTime(4) / pybamm.TimerTime(2)).value == 2 + assert (2 / pybamm.TimerTime(5)).value == 2 / 5 - self.assertTrue(pybamm.TimerTime(1) == pybamm.TimerTime(1)) - self.assertTrue(pybamm.TimerTime(1) != pybamm.TimerTime(2)) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert pybamm.TimerTime(1) == pybamm.TimerTime(1) + assert pybamm.TimerTime(1) != pybamm.TimerTime(2) diff --git a/vcpkg.json b/vcpkg.json index 4e2fb4fe7e..9134ac3fd9 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -1,6 +1,6 @@ { "name": "pybamm", - "version-string": "24.5rc0", + "version-string": "24.5rc2", "dependencies": [ "casadi", {