Skip to content

Commit

Permalink
Enable deepcopy for ExpData(View)
Browse files Browse the repository at this point in the history
Fixes a bug in SwigPtrView.__deepcopy__ which did not produce a deep copy.

Add SwigPtrView.__eq__ to allow for comparison. The view objects are considered
equal if the underlying viewed objects are equal.

Fixes #2189.
  • Loading branch information
dweindl committed Nov 15, 2023
1 parent dcbd4a3 commit 68ee924
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
18 changes: 17 additions & 1 deletion python/sdist/amici/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def __deepcopy__(self, memo):
:returns: SwigPtrView deep copy
"""
other = SwigPtrView(self._swigptr)
# We assume we have a copy-ctor for the swigptr object
other = self.__class__(copy.deepcopy(self._swigptr))
other._field_names = copy.deepcopy(self._field_names)
other._field_dimensions = copy.deepcopy(self._field_dimensions)
other._cache = copy.deepcopy(self._cache)
Expand All @@ -151,6 +152,18 @@ def __repr__(self):
"""
return f"<{self.__class__.__name__}({self._swigptr})>"

def __eq__(self, other):
"""
Equality check
:param other: other object
:returns: whether other object is equal to this object
"""
if not isinstance(other, self.__class__):
return False
return self._swigptr == other._swigptr


class ReturnDataView(SwigPtrView):
"""
Expand Down Expand Up @@ -340,6 +353,9 @@ class ExpDataView(SwigPtrView):
"""
Interface class for C++ Exp Data objects that avoids possibly costly
copies of member data.
NOTE: This currently assumes that the underlying :class:`ExpData`
does not change after instantiating an :class:`ExpDataView`.
"""

_field_names = [
Expand Down
19 changes: 19 additions & 0 deletions python/tests/test_swig_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numbers

import amici
import numpy as np


def test_version_number(pysb_example_presimulation_module):
Expand Down Expand Up @@ -451,3 +452,21 @@ def test_edata_equality_operator():
# check that comparison with other types works
# this is not implemented by swig by default
assert e1 != 1


def test_expdata_and_expdataview_are_deepcopyable():
edata1 = amici.ExpData(3, 2, 3, range(4))
edata1.setObservedData(np.zeros((3, 4)).flatten())

# ExpData
edata2 = copy.deepcopy(edata1)
assert edata1 == edata2
assert edata1.this != edata2.this
edata2.setTimepoints([0])
assert edata1 != edata2

# ExpDataView
ev1 = amici.ExpDataView(edata1)
ev2 = copy.deepcopy(ev1)
assert ev2._swigptr.this != ev1._swigptr.this
assert ev1 == ev2

0 comments on commit 68ee924

Please sign in to comment.