From 68ee924e52dc7ced2eb9ad4f52ea98525ef614f8 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Wed, 15 Nov 2023 14:31:47 +0100 Subject: [PATCH] Enable deepcopy for ExpData(View) Fixes a bug in SwigPtrView.__deepcopy__ which did not produce a deep copy. Add SwigPtrView.__eq__ to allow for comparison. The view objects are considered equal if the underlying viewed objects are equal. Fixes #2189. --- python/sdist/amici/numpy.py | 18 +++++++++++++++++- python/tests/test_swig_interface.py | 19 +++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/python/sdist/amici/numpy.py b/python/sdist/amici/numpy.py index d9b34b6447..1566b7654c 100644 --- a/python/sdist/amici/numpy.py +++ b/python/sdist/amici/numpy.py @@ -137,7 +137,8 @@ def __deepcopy__(self, memo): :returns: SwigPtrView deep copy """ - other = SwigPtrView(self._swigptr) + # We assume we have a copy-ctor for the swigptr object + other = self.__class__(copy.deepcopy(self._swigptr)) other._field_names = copy.deepcopy(self._field_names) other._field_dimensions = copy.deepcopy(self._field_dimensions) other._cache = copy.deepcopy(self._cache) @@ -151,6 +152,18 @@ def __repr__(self): """ return f"<{self.__class__.__name__}({self._swigptr})>" + def __eq__(self, other): + """ + Equality check + + :param other: other object + + :returns: whether other object is equal to this object + """ + if not isinstance(other, self.__class__): + return False + return self._swigptr == other._swigptr + class ReturnDataView(SwigPtrView): """ @@ -340,6 +353,9 @@ class ExpDataView(SwigPtrView): """ Interface class for C++ Exp Data objects that avoids possibly costly copies of member data. + + NOTE: This currently assumes that the underlying :class:`ExpData` + does not change after instantiating an :class:`ExpDataView`. """ _field_names = [ diff --git a/python/tests/test_swig_interface.py b/python/tests/test_swig_interface.py index 8c895eb852..a746552b55 100644 --- a/python/tests/test_swig_interface.py +++ b/python/tests/test_swig_interface.py @@ -7,6 +7,7 @@ import numbers import amici +import numpy as np def test_version_number(pysb_example_presimulation_module): @@ -451,3 +452,21 @@ def test_edata_equality_operator(): # check that comparison with other types works # this is not implemented by swig by default assert e1 != 1 + + +def test_expdata_and_expdataview_are_deepcopyable(): + edata1 = amici.ExpData(3, 2, 3, range(4)) + edata1.setObservedData(np.zeros((3, 4)).flatten()) + + # ExpData + edata2 = copy.deepcopy(edata1) + assert edata1 == edata2 + assert edata1.this != edata2.this + edata2.setTimepoints([0]) + assert edata1 != edata2 + + # ExpDataView + ev1 = amici.ExpDataView(edata1) + ev2 = copy.deepcopy(ev1) + assert ev2._swigptr.this != ev1._swigptr.this + assert ev1 == ev2