From fb7f07b5dd465ab6be0dbc6d2f9221afaa4bf48f Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 14 Oct 2024 10:41:58 +0800 Subject: [PATCH 1/5] fix device setting bug and upload missing paddle benchmark files --- .../benchmarks/benchmark_cloth_paddle.py | 86 ++++++++++ .../benchmarks/benchmark_interop_paddle.py | 158 ++++++++++++++++++ warp/paddle.py | 38 +++-- warp/tests/test_paddle.py | 115 ++++++++----- 4 files changed, 344 insertions(+), 53 deletions(-) create mode 100644 warp/examples/benchmarks/benchmark_cloth_paddle.py create mode 100644 warp/examples/benchmarks/benchmark_interop_paddle.py diff --git a/warp/examples/benchmarks/benchmark_cloth_paddle.py b/warp/examples/benchmarks/benchmark_cloth_paddle.py new file mode 100644 index 00000000..52987739 --- /dev/null +++ b/warp/examples/benchmarks/benchmark_cloth_paddle.py @@ -0,0 +1,86 @@ +# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import paddle + + +def eval_springs(x, v, indices, rest, ke, kd, f): + i = indices[:, 0] + j = indices[:, 1] + + xi = x[i] + xj = x[j] + + vi = v[i] + vj = v[j] + + xij = xi - xj + vij = vi - vj + + l = paddle.linalg.norm(xij, axis=1) + l_inv = 1.0 / l + + # normalized spring direction + dir = (xij.T * l_inv).T + + c = l - rest + dcdt = paddle.sum(dir * vij, axis=1) + + # damping based on relative velocity. + fs = dir.T * (ke * c + kd * dcdt) + + f.index_add_(axis=0, index=i, value=-fs.T) + f.index_add_(axis=0, index=j, value=fs.T) + + +def integrate_particles(x, v, f, g, w, dt): + s = w > 0.0 + + a_ext = g * s[:, None].astype(g.dtype) + + # simple semi-implicit Euler. v1 = v0 + a dt, x1 = x0 + v1 dt + v += ((f.T * w).T + a_ext) * dt + x += v * dt + + # clear forces + f *= 0.0 + + +class TrIntegrator: + def __init__(self, cloth, device): + self.cloth = cloth + + self.positions = paddle.to_tensor(self.cloth.positions, place=device) + self.velocities = paddle.to_tensor(self.cloth.velocities, place=device) + self.inv_mass = paddle.to_tensor(self.cloth.inv_masses, place=device) + + self.spring_indices = paddle.to_tensor(self.cloth.spring_indices, dtype=paddle.int64, place=device) + self.spring_lengths = paddle.to_tensor(self.cloth.spring_lengths, place=device) + self.spring_stiffness = paddle.to_tensor(self.cloth.spring_stiffness, place=device) + self.spring_damping = paddle.to_tensor(self.cloth.spring_damping, place=device) + + self.forces = paddle.zeros((self.cloth.num_particles, 3), dtype=paddle.float32).to(device=device) + self.gravity = paddle.to_tensor((0.0, 0.0 - 9.8, 0.0), dtype=paddle.float32, place=device) + + def simulate(self, dt, substeps): + sim_dt = dt / substeps + + for _s in range(substeps): + eval_springs( + self.positions, + self.velocities, + self.spring_indices.reshape((self.cloth.num_springs, 2)), + self.spring_lengths, + self.spring_stiffness, + self.spring_damping, + self.forces, + ) + + # integrate + integrate_particles(self.positions, self.velocities, self.forces, self.gravity, self.inv_mass, sim_dt) + + return self.positions.cpu().numpy() diff --git a/warp/examples/benchmarks/benchmark_interop_paddle.py b/warp/examples/benchmarks/benchmark_interop_paddle.py new file mode 100644 index 00000000..5377e67b --- /dev/null +++ b/warp/examples/benchmarks/benchmark_interop_paddle.py @@ -0,0 +1,158 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import time + +import paddle + +import warp as wp + + +def create_simple_kernel(dtype): + def simple_kernel( + a: wp.array(dtype=dtype), + b: wp.array(dtype=dtype), + c: wp.array(dtype=dtype), + d: wp.array(dtype=dtype), + e: wp.array(dtype=dtype), + ): + pass + + return wp.Kernel(simple_kernel) + + +def test_from_paddle(kernel, num_iters, array_size, device, warp_dtype=None): + warp_device = wp.get_device(device) + paddle_device = wp.device_to_paddle(warp_device) + + if hasattr(warp_dtype, "_shape_"): + paddle_shape = (array_size, *warp_dtype._shape_) + paddle_dtype = wp.dtype_to_paddle(warp_dtype._wp_scalar_type_) + else: + paddle_shape = (array_size,) + paddle_dtype = paddle.float32 if warp_dtype is None else wp.dtype_to_paddle(warp_dtype) + + _a = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device) + _b = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device) + _c = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device) + _d = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device) + _e = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device) + + wp.synchronize() + + # profiler = Profiler(interval=0.000001) + # profiler.start() + + t1 = time.time_ns() + + for _ in range(num_iters): + a = wp.from_paddle(_a, dtype=warp_dtype) + b = wp.from_paddle(_b, dtype=warp_dtype) + c = wp.from_paddle(_c, dtype=warp_dtype) + d = wp.from_paddle(_d, dtype=warp_dtype) + e = wp.from_paddle(_e, dtype=warp_dtype) + wp.launch(kernel, dim=array_size, inputs=[a, b, c, d, e]) + + t2 = time.time_ns() + print(f"{(t2 - t1) / 1_000_000 :8.0f} ms from_paddle(...)") + + # profiler.stop() + # profiler.print() + + +def test_array_ctype_from_paddle(kernel, num_iters, array_size, device, warp_dtype=None): + warp_device = wp.get_device(device) + paddle_device = wp.device_to_paddle(warp_device) + + if hasattr(warp_dtype, "_shape_"): + paddle_shape = (array_size, *warp_dtype._shape_) + paddle_dtype = wp.dtype_to_paddle(warp_dtype._wp_scalar_type_) + else: + paddle_shape = (array_size,) + paddle_dtype = paddle.float32 if warp_dtype is None else wp.dtype_to_paddle(warp_dtype) + + _a = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device) + _b = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device) + _c = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device) + _d = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device) + _e = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device) + + wp.synchronize() + + # profiler = Profiler(interval=0.000001) + # profiler.start() + + t1 = time.time_ns() + + for _ in range(num_iters): + a = wp.from_paddle(_a, dtype=warp_dtype, return_ctype=True) + b = wp.from_paddle(_b, dtype=warp_dtype, return_ctype=True) + c = wp.from_paddle(_c, dtype=warp_dtype, return_ctype=True) + d = wp.from_paddle(_d, dtype=warp_dtype, return_ctype=True) + e = wp.from_paddle(_e, dtype=warp_dtype, return_ctype=True) + wp.launch(kernel, dim=array_size, inputs=[a, b, c, d, e]) + + t2 = time.time_ns() + print(f"{(t2 - t1) / 1_000_000 :8.0f} ms from_paddle(..., return_ctype=True)") + + # profiler.stop() + # profiler.print() + + +def test_direct_from_paddle(kernel, num_iters, array_size, device, warp_dtype=None): + warp_device = wp.get_device(device) + paddle_device = wp.device_to_paddle(warp_device) + + if hasattr(warp_dtype, "_shape_"): + paddle_shape = (array_size, *warp_dtype._shape_) + paddle_dtype = wp.dtype_to_paddle(warp_dtype._wp_scalar_type_) + else: + paddle_shape = (array_size,) + paddle_dtype = paddle.float32 if warp_dtype is None else wp.dtype_to_paddle(warp_dtype) + + _a = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device) + _b = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device) + _c = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device) + _d = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device) + _e = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device) + + wp.synchronize() + + # profiler = Profiler(interval=0.000001) + # profiler.start() + + t1 = time.time_ns() + + for _ in range(num_iters): + wp.launch(kernel, dim=array_size, inputs=[_a, _b, _c, _d, _e]) + + t2 = time.time_ns() + print(f"{(t2 - t1) / 1_000_000 :8.0f} ms direct from paddle") + + # profiler.stop() + # profiler.print() + + +wp.init() + +params = [ + # (warp_dtype arg, kernel) + (None, create_simple_kernel(wp.float32)), + (wp.float32, create_simple_kernel(wp.float32)), + (wp.vec3f, create_simple_kernel(wp.vec3f)), + (wp.mat22f, create_simple_kernel(wp.mat22f)), +] + +wp.load_module() + +num_iters = 100000 + +for warp_dtype, kernel in params: + print(f"\ndtype={wp.context.type_str(warp_dtype)}") + test_from_paddle(kernel, num_iters, 10, "cuda:0", warp_dtype=warp_dtype) + test_array_ctype_from_paddle(kernel, num_iters, 10, "cuda:0", warp_dtype=warp_dtype) + test_direct_from_paddle(kernel, num_iters, 10, "cuda:0", warp_dtype=warp_dtype) diff --git a/warp/paddle.py b/warp/paddle.py index 65dcf17f..198bbab4 100644 --- a/warp/paddle.py +++ b/warp/paddle.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Optional, Union import numpy +from paddle.base.libpaddle import CPUPlace, CUDAPinnedPlace, CUDAPlace, Place import warp import warp.context @@ -20,38 +21,51 @@ # return the warp device corresponding to a paddle device -def device_from_paddle(paddle_device: Union[paddle.base.libpaddle.Place, str]) -> warp.context.Device: +def device_from_paddle(paddle_device: Union[Place, CPUPlace, CUDAPinnedPlace, CUDAPlace, str]) -> warp.context.Device: """Return the Warp device corresponding to a Paddle device. Args: - paddle_device (`paddle.base.libpaddle.Place` or `str`): Paddle device identifier + paddle_device (`Place`, `CPUPlace`, `CUDAPinnedPlace`, `CUDAPlace`, or `str`): Paddle device identifier Raises: RuntimeError: Paddle device does not have a corresponding Warp device """ + print(paddle_device) if type(paddle_device) is str: + if paddle_device.startswith("gpu:"): + paddle_device = paddle_device.replace("gpu:", "cuda:") warp_device = warp.context.runtime.device_map.get(paddle_device) if warp_device is not None: return warp_device - elif paddle_device.startswith("gpu"): + elif paddle_device == "gpu": return warp.context.runtime.get_current_cuda_device() else: raise RuntimeError(f"Unsupported Paddle device {paddle_device}") else: - import paddle - try: - if paddle_device.is_gpu_place(): - return warp.context.runtime.cuda_devices[paddle_device.gpu_device_id()] - elif paddle_device.is_cpu_place(): + if isinstance(paddle_device, Place): + if paddle_device.is_gpu_place(): + return warp.context.runtime.cuda_devices[paddle_device.gpu_device_id()] + elif paddle_device.is_cpu_place(): + return warp.context.runtime.cpu_device + else: + raise RuntimeError(f"Unsupported Paddle device type {paddle_device}") + elif isinstance(paddle_device, (CPUPlace, CUDAPinnedPlace)): return warp.context.runtime.cpu_device + elif isinstance(paddle_device, CUDAPlace): + return warp.context.runtime.cuda_devices[paddle_device.get_device_id()] else: raise RuntimeError(f"Unsupported Paddle device type {paddle_device}") except Exception as e: - import paddle - - if not isinstance(paddle_device, paddle.base.libpaddle.Place): - raise ValueError("Argument must be a paddle.base.libpaddle.Place object or a string") from e + if not isinstance(paddle_device, (Place, CPUPlace, CUDAPinnedPlace, CUDAPlace)): + raise TypeError( + "device_from_paddle() received an invalid argument - " + f"got {paddle_device}({type(paddle_device)}), but expected one of:\n" + "* paddle.base.libpaddle.Place\n" + "* paddle.CPUPlace\n" + "* paddle.CUDAPinnedPlace\n" + "* paddle.CUDAPlace or 'gpu' or 'gpu:x'(x means device id)" + ) from e raise diff --git a/warp/tests/test_paddle.py b/warp/tests/test_paddle.py index 53db028e..a187e4c7 100644 --- a/warp/tests/test_paddle.py +++ b/warp/tests/test_paddle.py @@ -8,10 +8,35 @@ import unittest import numpy as np +import paddle import warp as wp from warp.tests.unittest_utils import * +device_stack = [] +# push global device into device stack +device_stack.append(paddle.device.get_device()) + + +class PaddleDevice: + def __init__(self, device: str): + if device == "cpu": + self.device_new = device + elif device.startswith("gpu:"): + self.device_new = device + elif device == "gpu": + self.device_new = "gpu" + else: + raise NotImplementedError(f"Unsupported device type {device}") + + def __enter__(self) -> bool: + device_stack.append(self.device_new) + paddle.device.set_device(self.device_new) + + def __exit__(self, exc_type, exc_value, traceback): + device_stack.pop() + paddle.device.set_device(device_stack[-1]) + @wp.kernel def op_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)): @@ -444,7 +469,7 @@ def test_from_paddle_slices(test, device): assert a.ptr == t.data_ptr() assert a.is_contiguous assert a.shape == tuple(t.shape) - assert_np_equal(a.numpy(), t.cpu().numpy()) + assert_np_equal(a.numpy(), t.numpy()) # 1D slice with non-contiguous stride t_base = paddle.arange(10, dtype=paddle.float32).to(device=paddle_device) @@ -456,7 +481,7 @@ def test_from_paddle_slices(test, device): # copy contents to contiguous array a_contiguous = wp.empty_like(a) wp.launch(copy1d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device) - assert_np_equal(a_contiguous.numpy(), t.cpu().numpy()) + assert_np_equal(a_contiguous.numpy(), t.numpy()) # 2D slices (non-contiguous) t_base = paddle.arange(24, dtype=paddle.float32).to(device=paddle_device).reshape((4, 6)) @@ -468,7 +493,7 @@ def test_from_paddle_slices(test, device): # copy contents to contiguous array a_contiguous = wp.empty_like(a) wp.launch(copy2d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device) - assert_np_equal(a_contiguous.numpy(), t.cpu().numpy()) + assert_np_equal(a_contiguous.numpy(), t.numpy()) # 3D slices (non-contiguous) t_base = paddle.arange(36, dtype=paddle.float32).to(device=paddle_device).reshape((4, 3, 3)) @@ -480,7 +505,7 @@ def test_from_paddle_slices(test, device): # copy contents to contiguous array a_contiguous = wp.empty_like(a) wp.launch(copy3d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device) - assert_np_equal(a_contiguous.numpy(), t.cpu().numpy()) + assert_np_equal(a_contiguous.numpy(), t.numpy()) # 2D slices of vec3 (inner contiguous, outer non-contiguous) t_base = paddle.arange(150, dtype=paddle.float32).to(device=paddle_device).reshape((10, 5, 3)) @@ -492,7 +517,7 @@ def test_from_paddle_slices(test, device): # copy contents to contiguous array a_contiguous = wp.empty_like(a) wp.launch(copy2d_vec3_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device) - assert_np_equal(a_contiguous.numpy(), t.cpu().numpy()) + assert_np_equal(a_contiguous.numpy(), t.numpy()) # 2D slices of mat22 (inner contiguous, outer non-contiguous) t_base = paddle.arange(200, dtype=paddle.float32).to(device=paddle_device).reshape((10, 5, 2, 2)) @@ -504,7 +529,7 @@ def test_from_paddle_slices(test, device): # copy contents to contiguous array a_contiguous = wp.empty_like(a) wp.launch(copy2d_mat22_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device) - assert_np_equal(a_contiguous.numpy(), t.cpu().numpy()) + assert_np_equal(a_contiguous.numpy(), t.numpy()) def test_from_paddle_zero_strides(test, device): @@ -522,7 +547,7 @@ def test_from_paddle_zero_strides(test, device): assert a.shape == tuple(t.shape) a_contiguous = wp.empty_like(a) wp.launch(copy3d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device) - assert_np_equal(a_contiguous.numpy(), t.cpu().numpy()) + assert_np_equal(a_contiguous.numpy(), t.numpy()) # expand middle dimension t = t_base.unsqueeze(1).expand([-1, 3, -1]) @@ -532,7 +557,7 @@ def test_from_paddle_zero_strides(test, device): assert a.shape == tuple(t.shape) a_contiguous = wp.empty_like(a) wp.launch(copy3d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device) - assert_np_equal(a_contiguous.numpy(), t.cpu().numpy()) + assert_np_equal(a_contiguous.numpy(), t.numpy()) # expand innermost dimension t = t_base.unsqueeze(2).expand([-1, -1, 3]) @@ -542,7 +567,7 @@ def test_from_paddle_zero_strides(test, device): assert a.shape == tuple(t.shape) a_contiguous = wp.empty_like(a) wp.launch(copy3d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device) - assert_np_equal(a_contiguous.numpy(), t.cpu().numpy()) + assert_np_equal(a_contiguous.numpy(), t.numpy()) def test_paddle_mgpu_from_paddle(test, device): @@ -550,44 +575,40 @@ def test_paddle_mgpu_from_paddle(test, device): n = 32 - t0 = paddle.arange(0, n, 1, dtype=paddle.int32).to(device="gpu:0") - t1 = paddle.arange(0, n * 2, 2, dtype=paddle.int32).to(device="gpu:1") - - a0 = wp.from_paddle(t0, dtype=wp.int32) - a1 = wp.from_paddle(t1, dtype=wp.int32) - - assert a0.device == "gpu:0" - assert a1.device == "gpu:1" - expected0 = np.arange(0, n, 1) expected1 = np.arange(0, n * 2, 2) + t0 = paddle.arange(0, n, 1, dtype=paddle.int32).to(device="gpu:0") + a0 = wp.from_paddle(t0, dtype=wp.int32) + assert a0.device == "cuda:0" assert_np_equal(a0.numpy(), expected0) + + t1 = paddle.arange(0, n * 2, 2, dtype=paddle.int32).to(device="gpu:1") + a1 = wp.from_paddle(t1, dtype=wp.int32) + assert a1.device == "cuda:1" assert_np_equal(a1.numpy(), expected1) def test_paddle_mgpu_to_paddle(test, device): n = 32 - with wp.ScopedDevice("gpu:0"): + with wp.ScopedDevice("cuda:0"): a0 = wp.empty(n, dtype=wp.int32) wp.launch(arange, dim=a0.size, inputs=[0, 1, a0]) - with wp.ScopedDevice("gpu:1"): + t0 = wp.to_paddle(a0) + assert str(t0.place) == "Place(gpu:0)" + expected0 = np.arange(0, n, 1, dtype=np.int32) + assert_np_equal(t0.numpy(), expected0) + + with wp.ScopedDevice("cuda:1"): a1 = wp.empty(n, dtype=wp.int32) wp.launch(arange, dim=a1.size, inputs=[0, 2, a1]) - t0 = wp.to_paddle(a0) - t1 = wp.to_paddle(a1) - - assert str(t0.device) == "gpu:0" - assert str(t1.device) == "gpu:1" - - expected0 = np.arange(0, n, 1, dtype=np.int32) - expected1 = np.arange(0, n * 2, 2, dtype=np.int32) - - assert_np_equal(t0.cpu().numpy(), expected0) - assert_np_equal(t1.cpu().numpy(), expected1) + t1 = wp.to_paddle(a1) + assert str(t1.place) == "Place(gpu:1)" + expected1 = np.arange(0, n * 2, 2, dtype=np.int32) + assert_np_equal(t1.numpy(), expected1) def test_paddle_mgpu_interop(test, device): @@ -595,24 +616,24 @@ def test_paddle_mgpu_interop(test, device): n = 1024 * 1024 - with paddle.cuda.device(0): - t0 = paddle.arange(n, dtype=paddle.float32).to(device="gpu") + with PaddleDevice("gpu:0"): + t0 = paddle.arange(n, dtype=paddle.float32).to(device="gpu:0") a0 = wp.from_paddle(t0) wp.launch(inc, dim=a0.size, inputs=[a0], stream=wp.stream_from_paddle()) - with paddle.cuda.device(1): - t1 = paddle.arange(n, dtype=paddle.float32).to(device="gpu") + with PaddleDevice("gpu:1"): + t1 = paddle.arange(n, dtype=paddle.float32).to(device="gpu:1") a1 = wp.from_paddle(t1) wp.launch(inc, dim=a1.size, inputs=[a1], stream=wp.stream_from_paddle()) - assert a0.device == "gpu:0" - assert a1.device == "gpu:1" + # ensure the paddle tensors were modified by warp + assert a0.device == "cuda:0" + assert a1.device == "cuda:1" expected = np.arange(n, dtype=int) + 1 - # ensure the paddle tensors were modified by warp - assert_np_equal(t0.cpu().numpy(), expected) - assert_np_equal(t1.cpu().numpy(), expected) + assert_np_equal(t0.numpy(), expected) + assert_np_equal(t1.numpy(), expected) def test_paddle_autograd(test, device): @@ -624,6 +645,9 @@ def test_paddle_autograd(test, device): class TestFunc(paddle.autograd.PyLayer): @staticmethod def forward(ctx, x): + # ensure Paddle operations complete before running Warp + wp.synchronize_device() + # allocate output array y = paddle.empty_like(x) @@ -632,10 +656,16 @@ def forward(ctx, x): wp.launch(kernel=op_kernel, dim=len(x), inputs=[wp.from_paddle(x)], outputs=[wp.from_paddle(y)]) + # ensure Warp operations complete before returning data to Paddle + wp.synchronize_device() + return y @staticmethod def backward(ctx, adj_y): + # ensure Paddle operations complete before running Warp + wp.synchronize_device() + # adjoints should be allocated as zero initialized adj_x = paddle.zeros_like(ctx.x).contiguous() adj_y = adj_y.contiguous() @@ -655,6 +685,9 @@ def backward(ctx, adj_y): adjoint=True, ) + # ensure Warp operations complete before returning data to Paddle + wp.synchronize_device() + return adj_x # run autograd on given device @@ -691,7 +724,7 @@ def test_warp_graph_warp_stream(test, device): paddle_stream = wp.stream_to_paddle(device) # capture graph - with wp.ScopedDevice(device), paddle.device.stream(paddle_stream): + with wp.ScopedDevice(device), paddle.device.stream_guard(paddle.device.Stream(paddle_stream)): wp.capture_begin(force_module_load=False) try: t += 1.0 From ade6b80793aab057fe4e92006ede993577d84f57 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 29 Oct 2024 19:48:42 +0800 Subject: [PATCH 2/5] remove print --- warp/paddle.py | 1 - 1 file changed, 1 deletion(-) diff --git a/warp/paddle.py b/warp/paddle.py index 198bbab4..c02f76e8 100644 --- a/warp/paddle.py +++ b/warp/paddle.py @@ -30,7 +30,6 @@ def device_from_paddle(paddle_device: Union[Place, CPUPlace, CUDAPinnedPlace, CU Raises: RuntimeError: Paddle device does not have a corresponding Warp device """ - print(paddle_device) if type(paddle_device) is str: if paddle_device.startswith("gpu:"): paddle_device = paddle_device.replace("gpu:", "cuda:") From 084d0e10cef1bd613da26af21d0ef2f5f73b6be0 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 29 Oct 2024 20:11:03 +0800 Subject: [PATCH 3/5] update doc and test_paddle.py --- CHANGELOG.md | 1 + warp/tests/test_paddle.py | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 24987e0d..c88e7414 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ ### Fixed +- Fix place setting of paddle backend and add 2 missing benchmark files. - Fix to relax the integer types expected when indexing arrays (regression in 1.3.0). - Fix printing vector and matrix adjoints in backward kernels. - Fix kernel compile error when printing structs. diff --git a/warp/tests/test_paddle.py b/warp/tests/test_paddle.py index a187e4c7..58adea8e 100644 --- a/warp/tests/test_paddle.py +++ b/warp/tests/test_paddle.py @@ -25,7 +25,7 @@ def __init__(self, device: str): elif device.startswith("gpu:"): self.device_new = device elif device == "gpu": - self.device_new = "gpu" + self.device_new = device else: raise NotImplementedError(f"Unsupported device type {device}") @@ -870,11 +870,11 @@ class TestPaddle(unittest.TestCase): # devices=paddle_compatible_cuda_devices, # ) - # multi-GPU tests - if len(paddle_compatible_cuda_devices) > 1: - add_function_test(TestPaddle, "test_paddle_mgpu_from_paddle", test_paddle_mgpu_from_paddle) - add_function_test(TestPaddle, "test_paddle_mgpu_to_paddle", test_paddle_mgpu_to_paddle) - add_function_test(TestPaddle, "test_paddle_mgpu_interop", test_paddle_mgpu_interop) + # multi-GPU not supported yet. + # if len(paddle_compatible_cuda_devices) > 1: + # add_function_test(TestPaddle, "test_paddle_mgpu_from_paddle", test_paddle_mgpu_from_paddle) + # add_function_test(TestPaddle, "test_paddle_mgpu_to_paddle", test_paddle_mgpu_to_paddle) + # add_function_test(TestPaddle, "test_paddle_mgpu_interop", test_paddle_mgpu_interop) except Exception as e: print(f"Skipping Paddle tests due to exception: {e}") From 822adca3223a197ddd7f96ca8a36c638cde71d6b Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 30 Oct 2024 10:47:58 +0800 Subject: [PATCH 4/5] use inline import instead of importing at top --- warp/paddle.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/warp/paddle.py b/warp/paddle.py index c02f76e8..139d0494 100644 --- a/warp/paddle.py +++ b/warp/paddle.py @@ -11,13 +11,13 @@ from typing import TYPE_CHECKING, Optional, Union import numpy -from paddle.base.libpaddle import CPUPlace, CUDAPinnedPlace, CUDAPlace, Place import warp import warp.context if TYPE_CHECKING: import paddle + from paddle.base.libpaddle import CPUPlace, CUDAPinnedPlace, CUDAPlace, Place # return the warp device corresponding to a paddle device @@ -42,6 +42,8 @@ def device_from_paddle(paddle_device: Union[Place, CPUPlace, CUDAPinnedPlace, CU raise RuntimeError(f"Unsupported Paddle device {paddle_device}") else: try: + from paddle.base.libpaddle import CPUPlace, CUDAPinnedPlace, CUDAPlace, Place + if isinstance(paddle_device, Place): if paddle_device.is_gpu_place(): return warp.context.runtime.cuda_devices[paddle_device.gpu_device_id()] @@ -55,6 +57,8 @@ def device_from_paddle(paddle_device: Union[Place, CPUPlace, CUDAPinnedPlace, CU return warp.context.runtime.cuda_devices[paddle_device.get_device_id()] else: raise RuntimeError(f"Unsupported Paddle device type {paddle_device}") + except ModuleNotFoundError as e: + raise ModuleNotFoundError("Please install paddlepaddle first.") from e except Exception as e: if not isinstance(paddle_device, (Place, CPUPlace, CUDAPinnedPlace, CUDAPlace)): raise TypeError( From 49acc63e3e5508dbd120087498ae043badcf72b6 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Thu, 31 Oct 2024 10:51:47 +0800 Subject: [PATCH 5/5] remove mgpu test and top-import of paddle --- warp/tests/test_paddle.py | 93 --------------------------------------- 1 file changed, 93 deletions(-) diff --git a/warp/tests/test_paddle.py b/warp/tests/test_paddle.py index 58adea8e..42eeca24 100644 --- a/warp/tests/test_paddle.py +++ b/warp/tests/test_paddle.py @@ -7,36 +7,9 @@ import unittest -import numpy as np -import paddle - import warp as wp from warp.tests.unittest_utils import * -device_stack = [] -# push global device into device stack -device_stack.append(paddle.device.get_device()) - - -class PaddleDevice: - def __init__(self, device: str): - if device == "cpu": - self.device_new = device - elif device.startswith("gpu:"): - self.device_new = device - elif device == "gpu": - self.device_new = device - else: - raise NotImplementedError(f"Unsupported device type {device}") - - def __enter__(self) -> bool: - device_stack.append(self.device_new) - paddle.device.set_device(self.device_new) - - def __exit__(self, exc_type, exc_value, traceback): - device_stack.pop() - paddle.device.set_device(device_stack[-1]) - @wp.kernel def op_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)): @@ -570,72 +543,6 @@ def test_from_paddle_zero_strides(test, device): assert_np_equal(a_contiguous.numpy(), t.numpy()) -def test_paddle_mgpu_from_paddle(test, device): - import paddle - - n = 32 - - expected0 = np.arange(0, n, 1) - expected1 = np.arange(0, n * 2, 2) - - t0 = paddle.arange(0, n, 1, dtype=paddle.int32).to(device="gpu:0") - a0 = wp.from_paddle(t0, dtype=wp.int32) - assert a0.device == "cuda:0" - assert_np_equal(a0.numpy(), expected0) - - t1 = paddle.arange(0, n * 2, 2, dtype=paddle.int32).to(device="gpu:1") - a1 = wp.from_paddle(t1, dtype=wp.int32) - assert a1.device == "cuda:1" - assert_np_equal(a1.numpy(), expected1) - - -def test_paddle_mgpu_to_paddle(test, device): - n = 32 - - with wp.ScopedDevice("cuda:0"): - a0 = wp.empty(n, dtype=wp.int32) - wp.launch(arange, dim=a0.size, inputs=[0, 1, a0]) - - t0 = wp.to_paddle(a0) - assert str(t0.place) == "Place(gpu:0)" - expected0 = np.arange(0, n, 1, dtype=np.int32) - assert_np_equal(t0.numpy(), expected0) - - with wp.ScopedDevice("cuda:1"): - a1 = wp.empty(n, dtype=wp.int32) - wp.launch(arange, dim=a1.size, inputs=[0, 2, a1]) - - t1 = wp.to_paddle(a1) - assert str(t1.place) == "Place(gpu:1)" - expected1 = np.arange(0, n * 2, 2, dtype=np.int32) - assert_np_equal(t1.numpy(), expected1) - - -def test_paddle_mgpu_interop(test, device): - import paddle - - n = 1024 * 1024 - - with PaddleDevice("gpu:0"): - t0 = paddle.arange(n, dtype=paddle.float32).to(device="gpu:0") - a0 = wp.from_paddle(t0) - wp.launch(inc, dim=a0.size, inputs=[a0], stream=wp.stream_from_paddle()) - - with PaddleDevice("gpu:1"): - t1 = paddle.arange(n, dtype=paddle.float32).to(device="gpu:1") - a1 = wp.from_paddle(t1) - wp.launch(inc, dim=a1.size, inputs=[a1], stream=wp.stream_from_paddle()) - - # ensure the paddle tensors were modified by warp - assert a0.device == "cuda:0" - assert a1.device == "cuda:1" - - expected = np.arange(n, dtype=int) + 1 - - assert_np_equal(t0.numpy(), expected) - assert_np_equal(t1.numpy(), expected) - - def test_paddle_autograd(test, device): """Test paddle autograd with a custom Warp op"""