diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc index fba5a73cf2e9..8ca1078a4d2c 100644 --- a/runtime/bindings/python/hal.cc +++ b/runtime/bindings/python/hal.cc @@ -57,6 +57,18 @@ static const char kHalDeviceQueueExecute[] = signal_semaphores: Semaphores/Fence to signal. )"; +static const char kHalDeviceQueueCopy[] = + R"(Copy data from a source buffer to destination buffer. + +Args: + source_buffer: `HalBuffer` that holds src data. + target_buffer: `HalBuffer` that will receive data. + wait_semaphores: `List[Tuple[HalSemaphore, int]]` of semaphore values or + a HalFence. The allocation will be made once these semaphores are + satisfied. + signal_semaphores: Semaphores/Fence to signal. +)"; + static const char kHalFenceWait[] = R"(Waits until the fence is signalled or errored. @@ -524,6 +536,69 @@ void HalDevice::QueueExecute(py::handle command_buffers, "executing command buffers"); } +void HalDevice::QueueCopy(HalBuffer& source_buffer, HalBuffer& target_buffer, + py::handle wait_semaphores, + py::handle signal_semaphores) { + iree_hal_semaphore_list_t wait_list; + iree_hal_semaphore_list_t signal_list; + + // Wait list. + if (py::isinstance(wait_semaphores)) { + wait_list = iree_hal_fence_semaphore_list( + py::cast(wait_semaphores)->raw_ptr()); + } else { + size_t wait_count = py::len(wait_semaphores); + wait_list = { + wait_count, + /*semaphores=*/ + static_cast( + alloca(sizeof(iree_hal_semaphore_t*) * wait_count)), + /*payload_values=*/ + static_cast(alloca(sizeof(uint64_t) * wait_count)), + }; + for (size_t i = 0; i < wait_count; ++i) { + py::tuple pair = wait_semaphores[i]; + wait_list.semaphores[i] = py::cast(pair[0])->raw_ptr(); + wait_list.payload_values[i] = py::cast(pair[1]); + } + } + + // Signal list. + if (py::isinstance(signal_semaphores)) { + signal_list = iree_hal_fence_semaphore_list( + py::cast(signal_semaphores)->raw_ptr()); + } else { + size_t signal_count = py::len(signal_semaphores); + signal_list = { + signal_count, + /*semaphores=*/ + static_cast( + alloca(sizeof(iree_hal_semaphore_t*) * signal_count)), + /*payload_values=*/ + static_cast(alloca(sizeof(uint64_t) * signal_count)), + }; + for (size_t i = 0; i < signal_count; ++i) { + py::tuple pair = signal_semaphores[i]; + signal_list.semaphores[i] = py::cast(pair[0])->raw_ptr(); + signal_list.payload_values[i] = py::cast(pair[1]); + } + } + + // TODO: Accept params for src_offset and target_offset. + iree_device_size_t source_length = + iree_hal_buffer_byte_length(source_buffer.raw_ptr()); + if (source_length != iree_hal_buffer_byte_length(target_buffer.raw_ptr())) { + throw std::invalid_argument( + "Source and target buffer length must match and it does not. Please " + "check allocations"); + } + CheckApiStatus(iree_hal_device_queue_copy( + raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY, wait_list, + signal_list, source_buffer.raw_ptr(), 0, + target_buffer.raw_ptr(), 0, source_length), + "Copying buffer on queue"); +} + //------------------------------------------------------------------------------ // HalDriver //------------------------------------------------------------------------------ @@ -861,6 +936,9 @@ void SetupHalBindings(nanobind::module_ m) { .def("queue_execute", &HalDevice::QueueExecute, py::arg("command_buffers"), py::arg("wait_semaphores"), py::arg("signal_semaphores"), kHalDeviceQueueExecute) + .def("queue_copy", &HalDevice::QueueCopy, py::arg("source_buffer"), + py::arg("target_buffer"), py::arg("wait_semaphores"), + py::arg("signal_semaphores"), kHalDeviceQueueCopy) .def("__repr__", [](HalDevice& self) { auto id_sv = iree_hal_device_id(self.raw_ptr()); return std::string(id_sv.data, id_sv.size); @@ -963,6 +1041,9 @@ void SetupHalBindings(nanobind::module_ m) { py::class_(m, "HalBuffer") .def("fill_zero", &HalBuffer::FillZero, py::arg("byte_offset"), py::arg("byte_length")) + .def("byte_length", &HalBuffer::byte_length) + .def("memory_type", &HalBuffer::memory_type) + .def("allowed_usage", &HalBuffer::allowed_usage) .def("create_view", &HalBuffer::CreateView, py::arg("shape"), py::arg("element_size"), py::keep_alive<0, 1>()) .def("map", HalMappedMemory::CreateFromBuffer, py::keep_alive<0, 1>()) @@ -994,6 +1075,8 @@ void SetupHalBindings(nanobind::module_ m) { py::arg("buffer"), py::arg("shape"), py::arg("element_type")); hal_buffer_view .def("map", HalMappedMemory::CreateFromBufferView, py::keep_alive<0, 1>()) + .def("get_buffer", HalBuffer::CreateFromBufferView, + py::keep_alive<0, 1>()) .def_prop_ro("shape", [](HalBufferView& self) { iree_host_size_t rank = diff --git a/runtime/bindings/python/hal.h b/runtime/bindings/python/hal.h index 0c3bc63ea64e..5e18cfa71151 100644 --- a/runtime/bindings/python/hal.h +++ b/runtime/bindings/python/hal.h @@ -128,6 +128,8 @@ class HalDevice : public ApiRefCounted { py::handle signal_semaphores); void QueueExecute(py::handle command_buffers, py::handle wait_semaphores, py::handle signal_semaphores); + void QueueCopy(HalBuffer& src_buffer, HalBuffer& dst_buffer, + py::handle wait_semaphores, py::handle signal_semaphores); }; class HalDriver : public ApiRefCounted { @@ -176,6 +178,10 @@ class HalBuffer : public ApiRefCounted { return iree_hal_buffer_byte_length(raw_ptr()); } + int memory_type() const { return iree_hal_buffer_memory_type(raw_ptr()); } + + int allowed_usage() const { return iree_hal_buffer_allowed_usage(raw_ptr()); } + void FillZero(iree_device_size_t byte_offset, iree_device_size_t byte_length) { CheckApiStatus( @@ -197,6 +203,11 @@ class HalBuffer : public ApiRefCounted { return HalBufferView::StealFromRawPtr(bv); } + static HalBuffer CreateFromBufferView(HalBufferView& bv) { + return HalBuffer::BorrowFromRawPtr( + iree_hal_buffer_view_buffer(bv.raw_ptr())); + } + py::str Repr(); }; diff --git a/runtime/bindings/python/iree/runtime/array_interop.py b/runtime/bindings/python/iree/runtime/array_interop.py index 096fc9b04dda..fb67b21c7080 100644 --- a/runtime/bindings/python/iree/runtime/array_interop.py +++ b/runtime/bindings/python/iree/runtime/array_interop.py @@ -17,6 +17,7 @@ HalElementType, MappedMemory, MemoryType, + HalFence, ) __all__ = [ @@ -106,6 +107,20 @@ def to_host(self) -> np.ndarray: self._transfer_to_host(False) return self._host_array + def _is_mappable(self) -> bool: + buffer = self._buffer_view.get_buffer() + if ( + buffer.memory_type() & int(MemoryType.HOST_VISIBLE) + != MemoryType.HOST_VISIBLE + ): + return False + if ( + buffer.allowed_usage() & int(BufferUsage.MAPPING_SCOPED) + != BufferUsage.MAPPING_SCOPED + ): + return False + return True + def _transfer_to_host(self, implicit): if self._host_array is not None: return @@ -114,7 +129,10 @@ def _transfer_to_host(self, implicit): "DeviceArray cannot be implicitly transferred to the host: " "if necessary, do an explicit transfer via .to_host()" ) - self._mapped_memory, self._host_array = self._map_to_host() + if self._is_mappable(): + self._mapped_memory, self._host_array = self._map_to_host() + else: + self._host_array = self._copy_to_host() def _map_to_host(self) -> Tuple[MappedMemory, np.ndarray]: # TODO: When synchronization is enabled, need to block here. @@ -129,6 +147,35 @@ def _map_to_host(self) -> Tuple[MappedMemory, np.ndarray]: host_array = host_array.astype(self._override_dtype) return mapped_memory, host_array + def _copy_to_host(self) -> np.ndarray: + # TODO: When synchronization is enabled, need to block here. + source_buffer = self._buffer_view.get_buffer() + host_buffer = self._device.allocator.allocate_buffer( + memory_type=(MemoryType.HOST_LOCAL | MemoryType.DEVICE_VISIBLE), + allowed_usage=(BufferUsage.TRANSFER_TARGET | BufferUsage.MAPPING_SCOPED), + allocation_size=source_buffer.byte_length(), + ) + # Copy and wait for buffer to be copied from source buffer. + sem = self._device.create_semaphore(0) + self._device.queue_copy( + source_buffer, + host_buffer, + wait_semaphores=HalFence.create_at(sem, 0), + signal_semaphores=HalFence.create_at(sem, 1), + ) + HalFence.create_at(sem, 1).wait() + # Map and reformat buffer as np.array. + raw_dtype = self._get_raw_dtype() + mapped_memory = host_buffer.map() + host_array = mapped_memory.asarray(self._buffer_view.shape, raw_dtype) + # Detect if we need to force an explicit conversion. This happens when + # we were requested to pretend that the array is in a specific dtype, + # even if that is not representable on the device. You guessed it: + # this is to support bools. + if self._override_dtype is not None and self._override_dtype != raw_dtype: + host_array = host_array.astype(self._override_dtype) + return host_array + def _get_raw_dtype(self): return HalElementType.map_to_dtype(self._buffer_view.element_type) diff --git a/runtime/bindings/python/tests/hal_test.py b/runtime/bindings/python/tests/hal_test.py index 00076cf02a99..c7407bc58a94 100644 --- a/runtime/bindings/python/tests/hal_test.py +++ b/runtime/bindings/python/tests/hal_test.py @@ -265,6 +265,52 @@ def testFenceExtend(self): fence.extend(iree.runtime.HalFence.create_at(sem2, 2)) self.assertEqual(fence.timepoint_count, 2) + def testRoundTripQueueCopy(self): + original_ary = np.zeros([3, 4], dtype=np.int32) + 2 + source_bv = self.allocator.allocate_buffer_copy( + memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, + allowed_usage=iree.runtime.BufferUsage.DEFAULT, + device=self.device, + buffer=original_ary, + element_type=iree.runtime.HalElementType.SINT_32, + ) + source_buffer = source_bv.get_buffer() + target_buffer = self.allocator.allocate_buffer( + memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, + allowed_usage=iree.runtime.BufferUsage.DEFAULT, + allocation_size=source_buffer.byte_length(), + ) + sem = self.device.create_semaphore(0) + self.device.queue_copy( + source_buffer, + target_buffer, + wait_semaphores=iree.runtime.HalFence.create_at(sem, 0), + signal_semaphores=iree.runtime.HalFence.create_at(sem, 1), + ) + iree.runtime.HalFence.create_at(sem, 1).wait() + copy_ary = target_buffer.map().asarray(original_ary.shape, original_ary.dtype) + np.testing.assert_array_equal(original_ary, copy_ary) + + def testDifferentSizeQueueCopy(self): + source_buffer = self.allocator.allocate_buffer( + memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, + allowed_usage=iree.runtime.BufferUsage.DEFAULT, + allocation_size=12, + ) + target_buffer = self.allocator.allocate_buffer( + memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, + allowed_usage=iree.runtime.BufferUsage.DEFAULT, + allocation_size=13, + ) + sem = self.device.create_semaphore(0) + with self.assertRaisesRegex(ValueError, "length must match"): + self.device.queue_copy( + source_buffer, + target_buffer, + wait_semaphores=iree.runtime.HalFence.create_at(sem, 0), + signal_semaphores=iree.runtime.HalFence.create_at(sem, 1), + ) + def testCommandBufferStartsByDefault(self): cb = iree.runtime.HalCommandBuffer(self.device) with self.assertRaisesRegex(RuntimeError, "FAILED_PRECONDITION"): diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_allocator.c b/runtime/src/iree/hal/drivers/cuda/cuda_allocator.c index d010be080e65..fcbb2e464f15 100644 --- a/runtime/src/iree/hal/drivers/cuda/cuda_allocator.c +++ b/runtime/src/iree/hal/drivers/cuda/cuda_allocator.c @@ -236,12 +236,12 @@ iree_hal_cuda_allocator_query_buffer_compatibility( if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) { compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_IMPORTABLE; } + if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER; + } // Buffers can only be used on the queue if they are device visible. if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) { - if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) { - compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER; - } if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE)) { compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH;