Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialisation of models #3397

Merged
merged 32 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a4fedae
Draft a serialisation method
pipliggins Aug 11, 2023
70b765d
Move deserialisation functions to Symbol classes
pipliggins Aug 18, 2023
4ea8108
Create Serialise class
pipliggins Aug 24, 2023
6694cb1
Serialised models can be plotted.
pipliggins Sep 1, 2023
7fadee3
Add unit tests for to_json()
pipliggins Sep 19, 2023
25cb002
Allow saving of geometry where symbols are dict keys
pipliggins Sep 21, 2023
efa7888
Add _from_json tests for symbols without children
pipliggins Sep 21, 2023
fbc8f6f
(wip) testing: add draft de/serialisation tests
pipliggins Sep 22, 2023
4745484
(wip) tests: add _from_json tests with children
pipliggins Sep 22, 2023
80fc250
testing: add unit tests for Serialise() functions
pipliggins Sep 27, 2023
ac928ab
testing: save/load model tests
pipliggins Sep 29, 2023
9e323d9
testing: Add integration tests
pipliggins Sep 29, 2023
2934df4
Add docs for serialisation
pipliggins Oct 2, 2023
66d8045
Increase test coverage
pipliggins Oct 2, 2023
d5dd21d
Fix minor style issues
pipliggins Oct 2, 2023
6d63732
Remove accidental SpatialOperator.diff() addition
pipliggins Oct 5, 2023
0cc0aee
Edits after review
pipliggins Oct 19, 2023
2a72cf8
Merge branch 'develop' into serialisation
pipliggins Oct 19, 2023
616c0d8
Serialisation: fix integration tests
pipliggins Oct 20, 2023
8e32718
Reduce test tolerance of sei_asymmetric_ec_reaction_limited
pipliggins Oct 20, 2023
1e16b92
fix: change serialisation test accuracy
pipliggins Oct 21, 2023
62a46ef
Additional tests for codecov
pipliggins Nov 8, 2023
5211233
More coverage updates to serialise and 1D meshes
pipliggins Nov 10, 2023
a1ac313
Merge branch 'develop' into serialisation
pipliggins Nov 16, 2023
afa187e
Update CHANGELOG
pipliggins Nov 16, 2023
b745317
style: pre-commit fixes
pre-commit-ci[bot] Nov 16, 2023
3fec37e
Add error message for experiment
pipliggins Nov 24, 2023
92e7c90
Update notebook to suggest build() not solve()
pipliggins Nov 24, 2023
95935a0
Merge branch 'develop' into serialisation
pipliggins Nov 27, 2023
04f4230
Add outputs to example notebook
pipliggins Nov 28, 2023
ca63509
style: pre-commit fixes
pre-commit-ci[bot] Nov 28, 2023
df35b91
Fix ruff errors
pipliggins Nov 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# [Unreleased](https://github.com/pybamm-team/PyBaMM/)

## Features

- Serialisation added so models can be written to/read from JSON ([#3397](https://github.com/pybamm-team/PyBaMM/pull/3397))

## Bug fixes

- Fixed a bug where simulations using the CasADi-based solvers would fail randomly with the half-cell model ([#3494](https://github.com/pybamm-team/PyBaMM/pull/3494))
Expand Down
1 change: 1 addition & 0 deletions docs/source/api/expression_tree/operations/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ Classes and functions that operate on the expression tree
evaluate
jacobian
convert_to_casadi
serialise
unpack_symbol
5 changes: 5 additions & 0 deletions docs/source/api/expression_tree/operations/serialise.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Serialise
=========

.. autoclass:: pybamm.expression_tree.operations.serialise.Serialise
:members:
1 change: 1 addition & 0 deletions docs/source/examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ The notebooks are organised into subfolders, and can be viewed in the galleries
notebooks/models/MSMR.ipynb
notebooks/models/pouch-cell-model.ipynb
notebooks/models/rate-capability.ipynb
notebooks/models/saving_models.ipynb
notebooks/models/SEI-on-cracks.ipynb
notebooks/models/simulating-ORegan-2022-parameter-set.ipynb
notebooks/models/SPM.ipynb
Expand Down
376 changes: 376 additions & 0 deletions docs/source/examples/notebooks/models/saving_models.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@
UserSupplied2DSubMesh,
)

#
# Serialisation
#
from .models.base_model import load_model

#
# Spatial Methods
#
Expand Down
48 changes: 48 additions & 0 deletions pybamm/expression_tree/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,30 @@ def __init__(
name, domain=domain, auxiliary_domains=auxiliary_domains, domains=domains
)

@classmethod
def _from_json(cls, snippet: dict):
Copy link
Member

@Saransh-cpp Saransh-cpp Nov 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that PyBaMM supports Python 3.8+, each file using the newer type hints mus import -

from __future__ import annotations

at the top to ensure backward compatibility. This can also be automated using the isort rules in ruff.

instance = cls.__new__(cls)

if isinstance(snippet["entries"], dict):
matrix = csr_matrix(
(
snippet["entries"]["data"],
snippet["entries"]["row_indices"],
snippet["entries"]["column_pointers"],
),
shape=snippet["entries"]["shape"],
)
else:
matrix = snippet["entries"]

instance.__init__(
matrix,
name=snippet["name"],
domains=snippet["domains"],
martinjrobins marked this conversation as resolved.
Show resolved Hide resolved
)

return instance

@property
def entries(self):
return self._entries
Expand Down Expand Up @@ -129,6 +153,30 @@ def to_equation(self):
entries_list = self.entries.tolist()
return sympy.Array(entries_list)

def to_json(self):
"""
Method to serialise an Array object into JSON.
"""

if isinstance(self.entries, np.ndarray):
matrix = self.entries.tolist()
elif isinstance(self.entries, csr_matrix):
matrix = {
"shape": self.entries.shape,
"data": self.entries.data.tolist(),
"row_indices": self.entries.indices.tolist(),
"column_pointers": self.entries.indptr.tolist(),
}

json_dict = {
"name": self.name,
"id": self.id,
"domains": self.domains,
"entries": matrix,
}

return json_dict


def linspace(start, stop, num=50, **kwargs):
"""
Expand Down
26 changes: 26 additions & 0 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,23 @@ def __init__(self, name, left, right):
self.left = self.children[0]
self.right = self.children[1]

@classmethod
def _from_json(cls, snippet: dict):
"""Use to instantiate when deserialising; discretisation has
already occured so pre-processing of binaries is not necessary."""

instance = cls.__new__(cls)

super(BinaryOperator, instance).__init__(
snippet["name"],
children=[snippet["children"][0], snippet["children"][1]],
domains=snippet["domains"],
)
instance.left = instance.children[0]
instance.right = instance.children[1]

return instance

def __str__(self):
"""See :meth:`pybamm.Symbol.__str__()`."""
# Possibly add brackets for clarity
Expand Down Expand Up @@ -156,6 +173,15 @@ def to_equation(self):
eq2 = child2.to_equation()
return self._sympy_operator(eq1, eq2)

def to_json(self):
"""
Method to serialise a BinaryOperator object into JSON.
"""

json_dict = {"name": self.name, "id": self.id, "domains": self.domains}

return json_dict


class Power(BinaryOperator):
"""
Expand Down
11 changes: 11 additions & 0 deletions pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ def _diff(self, variable):
# Differentiate the child and broadcast the result in the same way
return self._unary_new_copy(self.child.diff(variable))

def to_json(self):
raise NotImplementedError(
"pybamm.Broadcast: Serialisation is only implemented for discretised models"
)

@classmethod
def _from_json(cls, snippet):
raise NotImplementedError(
"pybamm.Broadcast: Please use a discretised model when reading in from JSON"
)


class PrimaryBroadcast(Broadcast):
"""
Expand Down
74 changes: 74 additions & 0 deletions pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ def __init__(self, *children, name=None, check_domain=True, concat_fun=None):

super().__init__(name, children, domains=domains)

@classmethod
def _from_json(cls, *children, name, domains, concat_fun=None):
"""Creates a new Concatenation instance from a json object"""
instance = cls.__new__(cls)

instance.concatenation_function = concat_fun

super(Concatenation, instance).__init__(name, children, domains=domains)

return instance

def __str__(self):
"""See :meth:`pybamm.Symbol.__str__()`."""
out = self.name + "("
Expand Down Expand Up @@ -183,6 +194,18 @@ def __init__(self, *children):
concat_fun=np.concatenate
)

@classmethod
def _from_json(cls, snippet: dict):
"""See :meth:`pybamm.Concatenation._from_json()`."""
instance = super()._from_json(
*snippet["children"],
name="numpy_concatenation",
domains=snippet["domains"],
concat_fun=np.concatenate
)

return instance

def _concatenation_jac(self, children_jacs):
"""See :meth:`pybamm.Concatenation.concatenation_jac()`."""
children = self.children
Expand Down Expand Up @@ -251,6 +274,31 @@ def __init__(self, children, full_mesh, copy_this=None):
self._children_slices = copy.copy(copy_this._children_slices)
self.secondary_dimensions_npts = copy_this.secondary_dimensions_npts

@classmethod
def _from_json(cls, snippet: dict):
"""See :meth:`pybamm.Concatenation._from_json()`."""
instance = super()._from_json(
*snippet["children"],
name="domain_concatenation",
domains=snippet["domains"]
)

def repack_defaultDict(slices):
slices = defaultdict(list, slices)
for domain, sls in slices.items():
sls = [slice(s["start"], s["stop"], s["step"]) for s in sls]
slices[domain] = sls
return slices

instance._size = snippet["size"]
instance._slices = repack_defaultDict(snippet["slices"])
instance._children_slices = [
repack_defaultDict(s) for s in snippet["children_slices"]
]
instance.secondary_dimensions_npts = snippet["secondary_dimensions_npts"]

return instance

def _get_auxiliary_domain_repeats(self, auxiliary_domains):
"""Helper method to read the 'auxiliary_domain' meshes."""
mesh_pts = 1
Expand Down Expand Up @@ -316,6 +364,32 @@ def _concatenation_new_copy(self, children):
)
return new_symbol

def to_json(self):
"""
Method to serialise a DomainConcatenation object into JSON.
"""

def unpack_defaultDict(slices):
slices = dict(slices)
for domain, sls in slices.items():
sls = [{"start": s.start, "stop": s.stop, "step": s.step} for s in sls]
slices[domain] = sls
return slices

json_dict = {
"name": self.name,
"id": self.id,
"domains": self.domains,
"slices": unpack_defaultDict(self._slices),
"size": self._size,
"children_slices": [
unpack_defaultDict(child_slice) for child_slice in self._children_slices
],
"secondary_dimensions_npts": self.secondary_dimensions_npts,
}

return json_dict


class SparseStack(Concatenation):
"""
Expand Down
Loading