Skip to content

Commit

Permalink
[runtime][python] Fix device array deepcopy when not mappable
Browse files Browse the repository at this point in the history
If the device array is not mappable to the host the deepcopy would fail
and result in an exception.

To mitigate this in case it is not mappable we copy.

Implement the __deepcopy__ to avoid double copy when not mappable.
If deepcopy is done through __reduce__ then once we would copy form
device to host and another time inside the deepcopy implementation.

Signed-off-by: Boian Petkantchin <[email protected]>
  • Loading branch information
sogartar committed Nov 19, 2024
1 parent b68c535 commit b9bc221
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
11 changes: 6 additions & 5 deletions runtime/bindings/python/iree/runtime/array_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,13 @@ def __getitem__(self, index):
host_ary = self.to_host()
return host_ary.__getitem__(index)

def __deepcopy__(self, memo):
return self.to_host()

def __reduce__(self):
# Since this is used for making deep copies and pickling, we map
# separately from any interactive state. We just reduce to the actual
# host ndarray, which supports the necessary serialization protocols.
_, host_array = self._map_to_host()
return _restore_reduced_array, (host_array,)
# We just reduce to the actual host ndarray, which supports the necessary
# serialization protocols.
return _restore_reduced_array, (self.to_host(),)


def _restore_reduced_array(ary):
Expand Down
11 changes: 11 additions & 0 deletions runtime/bindings/python/tests/array_interop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import copy
import pickle
import gc
import numpy as np
import unittest
Expand Down Expand Up @@ -137,6 +138,16 @@ def testDeepcopy(self):
self.assertIsNot(orig_ary, copy_ary)
np.testing.assert_array_equal(orig_ary, copy_ary)

def testPickle(self):
init_ary = np.arange(3 * 4, dtype=np.float32).reshape([3, 4])
orig_ary = iree.runtime.asdevicearray(
self.device, init_ary, implicit_host_transfer=True
)
serialized_ary = pickle.dumps(orig_ary)
copy_ary = pickle.loads(serialized_ary)
self.assertIsNot(orig_ary, copy_ary)
np.testing.assert_array_equal(orig_ary, copy_ary)

def testAsType(self):
init_ary = np.zeros([3, 4], dtype=np.int32) + 2
orig_ary = iree.runtime.asdevicearray(
Expand Down

0 comments on commit b9bc221

Please sign in to comment.