diff --git a/README.md b/README.md index 7ccb3e61..4f5be963 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ regular Python functions and JIT compiles them to efficient kernel code that can Warp is designed for [spatial computing](https://en.wikipedia.org/wiki/Spatial_computing) and comes with a rich set of primitives that make it easy to write programs for physics simulation, perception, robotics, and geometry processing. In addition, Warp kernels -are differentiable and can be used as part of machine-learning pipelines with frameworks such as PyTorch and JAX. +are differentiable and can be used as part of machine-learning pipelines with frameworks such as PyTorch, JAX and Paddle. Please refer to the project [Documentation](https://nvidia.github.io/warp/) for API and language reference and [CHANGELOG.md](./CHANGELOG.md) for release history. diff --git a/docs/index.rst b/docs/index.rst index dd03d06f..5ffef7f2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,7 +7,7 @@ regular Python functions and JIT compiles them to efficient kernel code that can Warp is designed for `spatial computing `_ and comes with a rich set of primitives that make it easy to write programs for physics simulation, perception, robotics, and geometry processing. In addition, Warp kernels -are differentiable and can be used as part of machine-learning pipelines with frameworks such as PyTorch and JAX. +are differentiable and can be used as part of machine-learning pipelines with frameworks such as PyTorch, JAX and Paddle. Below are some examples of simulations implemented using Warp: diff --git a/docs/installation.rst b/docs/installation.rst index 016109f0..f218a64e 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -76,6 +76,7 @@ The following optional dependencies are required to support certain features: * `usd-core `_: Required for some Warp examples, ``warp.sim.parse_usd()``, and ``warp.render.UsdRenderer``. * `JAX `_: Required for JAX interoperability (see :ref:`jax-interop`). * `PyTorch `_: Required for PyTorch interoperability (see :ref:`pytorch-interop`). +* `Paddle `_: Required for Paddle interoperability (see :ref:`paddle-interop`). * `NVTX for Python `_: Required to use :class:`wp.ScopedTimer(use_nvtx=True) `. Building the Warp documentation requires: diff --git a/docs/modules/interoperability.rst b/docs/modules/interoperability.rst index 800d4e79..ef215f7c 100644 --- a/docs/modules/interoperability.rst +++ b/docs/modules/interoperability.rst @@ -709,6 +709,7 @@ The canonical way to export a Warp array to an external framework is to use the jax_array = jax.dlpack.from_dlpack(warp_array) torch_tensor = torch.utils.dlpack.from_dlpack(warp_array) + paddle_tensor = paddle.utils.dlpack.from_dlpack(warp_array) For CUDA arrays, this will synchronize the current stream of the consumer framework with the current Warp stream on the array's device. Thus it should be safe to use the wrapped array in the consumer framework, even if the array was previously used in a Warp kernel @@ -719,9 +720,11 @@ This approach may be used for older versions of frameworks that do not support t warp_array1 = wp.from_dlpack(jax.dlpack.to_dlpack(jax_array)) warp_array2 = wp.from_dlpack(torch.utils.dlpack.to_dlpack(torch_tensor)) + warp_array3 = wp.from_dlpack(paddle.utils.dlpack.to_dlpack(paddle_tensor)) jax_array = jax.dlpack.from_dlpack(wp.to_dlpack(warp_array)) torch_tensor = torch.utils.dlpack.from_dlpack(wp.to_dlpack(warp_array)) + paddle_tensor = paddle.utils.dlpack.from_dlpack(wp.to_dlpack(warp_array)) This approach is generally faster because it skips any stream synchronization, but another solution must be used to ensure correct ordering of operations. In situations where no synchronization is required, using this approach can yield better performance. @@ -733,3 +736,181 @@ This may be a good choice in situations like these: .. autofunction:: warp.from_dlpack .. autofunction:: warp.to_dlpack + +.. _paddle-interop: + +Paddle +------ + +Warp provides helper functions to convert arrays to/from Paddle:: + + w = wp.array([1.0, 2.0, 3.0], dtype=float, device="cpu") + + # convert to Paddle tensor + t = wp.to_paddle(w) + + # convert from Paddle tensor + w = wp.from_paddle(t) + +These helper functions allow the conversion of Warp arrays to/from Paddle tensors without copying the underlying data. +At the same time, if available, gradient arrays and tensors are converted to/from Paddle autograd tensors, allowing the use of Warp arrays +in Paddle autograd computations. + +.. autofunction:: warp.from_paddle +.. autofunction:: warp.to_paddle +.. autofunction:: warp.device_from_paddle +.. autofunction:: warp.device_to_paddle +.. autofunction:: warp.dtype_from_paddle +.. autofunction:: warp.dtype_to_paddle + +To convert a Paddle CUDA stream to a Warp CUDA stream and vice versa, Warp provides the following functions: + +.. autofunction:: warp.stream_from_paddle + +Example: Optimization using ``warp.from_paddle()`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +An example usage of minimizing a loss function over an array of 2D points written in Warp via Paddle's Adam optimizer +using :func:`warp.from_paddle` is as follows:: + + import warp as wp + import paddle + + # init warp context at beginning + wp.context.init() + + @wp.kernel() + def loss(xs: wp.array(dtype=float, ndim=2), l: wp.array(dtype=float)): + tid = wp.tid() + wp.atomic_add(l, 0, xs[tid, 0] ** 2.0 + xs[tid, 1] ** 2.0) + + # indicate requires_grad so that Warp can accumulate gradients in the grad buffers + xs = paddle.randn([100, 2]) + xs.stop_gradient = False + l = paddle.zeros([1]) + l.stop_gradient = False + opt = paddle.optimizer.Adam(learning_rate=0.1, parameters=[xs]) + + wp_xs = wp.from_paddle(xs) + wp_l = wp.from_paddle(l) + + tape = wp.Tape() + with tape: + # record the loss function kernel launch on the tape + wp.launch(loss, dim=len(xs), inputs=[wp_xs], outputs=[wp_l], device=wp_xs.device) + + for i in range(500): + tape.zero() + tape.backward(loss=wp_l) # compute gradients + # now xs.grad will be populated with the gradients computed by Warp + opt.step() # update xs (and thereby wp_xs) + + # these lines are only needed for evaluating the loss + # (the optimization just needs the gradient, not the loss value) + wp_l.zero_() + wp.launch(loss, dim=len(xs), inputs=[wp_xs], outputs=[wp_l], device=wp_xs.device) + print(f"{i}\tloss: {l.item()}") + +Example: Optimization using ``warp.to_paddle`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Less code is needed when we declare the optimization variables directly in Warp and use :func:`warp.to_paddle` to convert them to Paddle tensors. +Here, we revisit the same example from above where now only a single conversion to a paddle tensor is needed to supply Adam with the optimization variables:: + + import warp as wp + import numpy as np + import paddle + + # init warp context at beginning + wp.context.init() + + @wp.kernel() + def loss(xs: wp.array(dtype=float, ndim=2), l: wp.array(dtype=float)): + tid = wp.tid() + wp.atomic_add(l, 0, xs[tid, 0] ** 2.0 + xs[tid, 1] ** 2.0) + + # initialize the optimization variables in Warp + xs = wp.array(np.random.randn(100, 2), dtype=wp.float32, requires_grad=True) + l = wp.zeros(1, dtype=wp.float32, requires_grad=True) + # just a single wp.to_paddle call is needed, Adam optimizes using the Warp array gradients + opt = paddle.optimizer.Adam(learning_rate=0.1, parameters=[wp.to_paddle(xs)]) + + tape = wp.Tape() + with tape: + wp.launch(loss, dim=len(xs), inputs=[xs], outputs=[l], device=xs.device) + + for i in range(500): + tape.zero() + tape.backward(loss=l) + opt.step() + + l.zero_() + wp.launch(loss, dim=len(xs), inputs=[xs], outputs=[l], device=xs.device) + print(f"{i}\tloss: {l.numpy()[0]}") + +Performance Notes +^^^^^^^^^^^^^^^^^ + +The ``wp.from_paddle()`` function creates a Warp array object that shares data with a Paddle tensor. Although this function does not copy the data, there is always some CPU overhead during the conversion. If these conversions happen frequently, the overall program performance may suffer. As a general rule, it's good to avoid repeated conversions of the same tensor. Instead of: + +.. code:: python + + x_t = paddle.arange(n, dtype=paddle.float32).to(device=wp.device_to_paddle(device)) + y_t = paddle.ones([n], dtype=paddle.float32).to(device=wp.device_to_paddle(device)) + + for i in range(10): + x_w = wp.from_paddle(x_t) + y_w = wp.from_paddle(y_t) + wp.launch(saxpy, dim=n, inputs=[x_w, y_w, 1.0], device=device) + +Try converting the arrays only once and reuse them: + +.. code:: python + + x_t = paddle.arange(n, dtype=paddle.float32).to(device=wp.device_to_paddle(device)) + y_t = paddle.ones([n], dtype=paddle.float32).to(device=wp.device_to_paddle(device)) + + x_w = wp.from_paddle(x_t) + y_w = wp.from_paddle(y_t) + + for i in range(10): + wp.launch(saxpy, dim=n, inputs=[x_w, y_w, 1.0], device=device) + +If reusing arrays is not possible (e.g., a new Paddle tensor is constructed on every iteration), passing ``return_ctype=True`` to ``wp.from_paddle()`` should yield faster performance. Setting this argument to True avoids constructing a ``wp.array`` object and instead returns a low-level array descriptor. This descriptor is a simple C structure that can be passed to Warp kernels instead of a ``wp.array``, but cannot be used in other places that require a ``wp.array``. + +.. code:: python + + for n in range(1, 10): + # get Paddle tensors for this iteration + x_t = paddle.arange(n, dtype=paddle.float32).to(device=wp.device_to_paddle(device)) + y_t = paddle.ones([n], dtype=paddle.float32).to(device=wp.device_to_paddle(device)) + + # get Warp array descriptors + x_ctype = wp.from_paddle(x_t, return_ctype=True) + y_ctype = wp.from_paddle(y_t, return_ctype=True) + + wp.launch(saxpy, dim=n, inputs=[x_ctype, y_ctype, 1.0], device=device) + +An alternative approach is to pass the Paddle tensors to Warp kernels directly. This avoids constructing temporary Warp arrays by leveraging standard array interfaces (like ``__cuda_array_interface__``) supported by both Paddle and Warp. The main advantage of this approach is convenience, since there is no need to call any conversion functions. The main limitation is that it does not handle gradients, because gradient information is not included in the standard array interfaces. This technique is therefore most suitable for algorithms that do not involve differentiation. + +.. code:: python + + x = paddle.arange(n, dtype=paddle.float32).to(device=wp.device_to_paddle(device)) + y = paddle.ones([n], dtype=paddle.float32).to(device=wp.device_to_paddle(device)) + + for i in range(10): + wp.launch(saxpy, dim=n, inputs=[x, y, 1.0], device=device) + +.. code:: shell + + python -m warp.examples.benchmarks.benchmark_interop_paddle + +Sample output: + +.. code:: + + 13990 ms from_paddle(...) + 5990 ms from_paddle(..., return_ctype=True) + 35167 ms direct from paddle + +The default ``wp.from_paddle()`` conversion is the slowest. Passing ``return_ctype=True`` is the fastest, because it skips creating temporary Warp array objects. Passing Paddle tensors to Warp kernels directly falls somewhere in between. It skips creating temporary Warp arrays, but accessing the ``__cuda_array_interface__`` attributes of Paddle tensors adds overhead because they are initialized on-demand. diff --git a/exts/omni.warp.core/config/extension.toml b/exts/omni.warp.core/config/extension.toml index 8b80e91f..e39e4dd2 100644 --- a/exts/omni.warp.core/config/extension.toml +++ b/exts/omni.warp.core/config/extension.toml @@ -38,6 +38,7 @@ pyCoverageOmit = [ "warp/stubs.py", "warp/jax.py", "warp/torch.py", + "warp/paddle.py", "warp/build.py", "warp/build_dll.py", "warp/sim/**", diff --git a/warp/__init__.py b/warp/__init__.py index 28df7a37..dc77075c 100644 --- a/warp/__init__.py +++ b/warp/__init__.py @@ -99,6 +99,11 @@ from warp.dlpack import from_dlpack, to_dlpack +from warp.paddle import from_paddle, to_paddle +from warp.paddle import dtype_from_paddle, dtype_to_paddle +from warp.paddle import device_from_paddle, device_to_paddle +from warp.paddle import stream_from_paddle + from warp.build import clear_kernel_cache from warp.constants import * diff --git a/warp/dlpack.py b/warp/dlpack.py index 34de4264..20860c6e 100644 --- a/warp/dlpack.py +++ b/warp/dlpack.py @@ -124,6 +124,8 @@ def device_to_dlpack(wp_device) -> DLDevice: def dtype_to_dlpack(wp_dtype) -> DLDataType: + if wp_dtype == warp.bool: + return (DLDataTypeCode.kDLBool, 8, 1) if wp_dtype == warp.int8: return (DLDataTypeCode.kDLInt, 8, 1) elif wp_dtype == warp.uint8: diff --git a/warp/examples/benchmarks/benchmark.bat b/warp/examples/benchmarks/benchmark.bat index 9edec17d..66a5dab3 100644 --- a/warp/examples/benchmarks/benchmark.bat +++ b/warp/examples/benchmarks/benchmark.bat @@ -11,3 +11,5 @@ python benchmark_cloth.py numpy @REM python benchmark_cloth.py numba @REM python benchmark_cloth.py jax_cpu @REM python benchmark_cloth.py jax_gpu +@REM python benchmark_cloth.py paddle_cpu +@REM python benchmark_cloth.py paddle_gpu diff --git a/warp/examples/benchmarks/benchmark.sh b/warp/examples/benchmarks/benchmark.sh index f82289a6..a4d54386 100755 --- a/warp/examples/benchmarks/benchmark.sh +++ b/warp/examples/benchmarks/benchmark.sh @@ -11,3 +11,5 @@ python3 benchmark_cloth.py numpy # python3 benchmark_cloth.py jax_cpu # python3 benchmark_cloth.py jax_gpu # python3 benchmark_cloth.py numba +# python3 benchmark_cloth.py paddle_cpu +# python3 benchmark_cloth.py paddle_gpu diff --git a/warp/examples/benchmarks/benchmark_cloth.py b/warp/examples/benchmarks/benchmark_cloth.py index d28213da..3fc6a740 100644 --- a/warp/examples/benchmarks/benchmark_cloth.py +++ b/warp/examples/benchmarks/benchmark_cloth.py @@ -219,6 +219,16 @@ def run_benchmark(mode, dim, timers, render=False): integrator = benchmark_cloth_jax.JxIntegrator(cloth) + elif mode == "paddle_cpu": + import benchmark_cloth_paddle + + integrator = benchmark_cloth_paddle.TrIntegrator(cloth, "cpu") + + elif mode == "paddle_gpu": + import benchmark_cloth_paddle + + integrator = benchmark_cloth_paddle.TrIntegrator(cloth, "gpu") + else: raise RuntimeError("Unknown simulation backend") diff --git a/warp/paddle.py b/warp/paddle.py new file mode 100644 index 00000000..65dcf17f --- /dev/null +++ b/warp/paddle.py @@ -0,0 +1,382 @@ +# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +from __future__ import annotations + +import ctypes +from typing import TYPE_CHECKING, Optional, Union + +import numpy + +import warp +import warp.context + +if TYPE_CHECKING: + import paddle + + +# return the warp device corresponding to a paddle device +def device_from_paddle(paddle_device: Union[paddle.base.libpaddle.Place, str]) -> warp.context.Device: + """Return the Warp device corresponding to a Paddle device. + + Args: + paddle_device (`paddle.base.libpaddle.Place` or `str`): Paddle device identifier + + Raises: + RuntimeError: Paddle device does not have a corresponding Warp device + """ + if type(paddle_device) is str: + warp_device = warp.context.runtime.device_map.get(paddle_device) + if warp_device is not None: + return warp_device + elif paddle_device.startswith("gpu"): + return warp.context.runtime.get_current_cuda_device() + else: + raise RuntimeError(f"Unsupported Paddle device {paddle_device}") + else: + import paddle + + try: + if paddle_device.is_gpu_place(): + return warp.context.runtime.cuda_devices[paddle_device.gpu_device_id()] + elif paddle_device.is_cpu_place(): + return warp.context.runtime.cpu_device + else: + raise RuntimeError(f"Unsupported Paddle device type {paddle_device}") + except Exception as e: + import paddle + + if not isinstance(paddle_device, paddle.base.libpaddle.Place): + raise ValueError("Argument must be a paddle.base.libpaddle.Place object or a string") from e + raise + + +def device_to_paddle(warp_device: warp.context.Devicelike) -> str: + """Return the Paddle device string corresponding to a Warp device. + + Args: + warp_device: An identifier that can be resolved to a :class:`warp.context.Device`. + + Raises: + RuntimeError: The Warp device is not compatible with PyPaddle. + """ + device = warp.get_device(warp_device) + if device.is_cpu or device.is_primary: + return str(device).replace("cuda", "gpu") + elif device.is_cuda and device.is_uva: + # it's not a primary context, but paddle can access the data ptr directly thanks to UVA + return f"gpu:{device.ordinal}" + raise RuntimeError(f"Warp device {device} is not compatible with paddle") + + +def dtype_to_paddle(warp_dtype): + """Return the Paddle dtype corresponding to a Warp dtype. + + Args: + warp_dtype: A Warp data type that has a corresponding ``paddle.dtype``. + ``warp.uint16``, ``warp.uint32``, and ``warp.uint64`` are mapped + to the signed integer ``paddle.dtype`` of the same width. + Raises: + TypeError: Unable to find a corresponding PyPaddle data type. + """ + # initialize lookup table on first call to defer paddle import + if dtype_to_paddle.type_map is None: + import paddle + + dtype_to_paddle.type_map = { + warp.float16: paddle.float16, + warp.float32: paddle.float32, + warp.float64: paddle.float64, + warp.int8: paddle.int8, + warp.int16: paddle.int16, + warp.int32: paddle.int32, + warp.int64: paddle.int64, + warp.uint8: paddle.uint8, + warp.bool: paddle.bool, + # paddle doesn't support unsigned ints bigger than 8 bits + warp.uint16: paddle.int16, + warp.uint32: paddle.int32, + warp.uint64: paddle.int64, + } + + paddle_dtype = dtype_to_paddle.type_map.get(warp_dtype) + if paddle_dtype is not None: + return paddle_dtype + else: + raise TypeError(f"Cannot convert {warp_dtype} to a Paddle type") + + +def dtype_from_paddle(paddle_dtype): + """Return the Warp dtype corresponding to a Paddle dtype. + + Args: + paddle_dtype: A ``paddle.dtype`` that has a corresponding Warp data type. + Currently ``paddle.bfloat16``, ``paddle.complex64``, and + ``paddle.complex128`` are not supported. + + Raises: + TypeError: Unable to find a corresponding Warp data type. + """ + # initialize lookup table on first call to defer paddle import + if dtype_from_paddle.type_map is None: + import paddle + + dtype_from_paddle.type_map = { + paddle.float16: warp.float16, + paddle.float32: warp.float32, + paddle.float64: warp.float64, + paddle.int8: warp.int8, + paddle.int16: warp.int16, + paddle.int32: warp.int32, + paddle.int64: warp.int64, + paddle.uint8: warp.uint8, + paddle.bool: warp.bool, + # currently unsupported by Warp + # paddle.bfloat16: + # paddle.complex64: + # paddle.complex128: + } + + warp_dtype = dtype_from_paddle.type_map.get(paddle_dtype) + + if warp_dtype is not None: + return warp_dtype + else: + raise TypeError(f"Cannot convert {paddle_dtype} to a Warp type") + + +def dtype_is_compatible(paddle_dtype: paddle.dtype, warp_dtype) -> bool: + """Evaluates whether the given paddle dtype is compatible with the given Warp dtype.""" + # initialize lookup table on first call to defer paddle import + if dtype_is_compatible.compatible_sets is None: + import paddle + + dtype_is_compatible.compatible_sets = { + paddle.float64: {warp.float64}, + paddle.float32: {warp.float32}, + paddle.float16: {warp.float16}, + # allow aliasing integer tensors as signed or unsigned integer arrays + paddle.int64: {warp.int64, warp.uint64}, + paddle.int32: {warp.int32, warp.uint32}, + paddle.int16: {warp.int16, warp.uint16}, + paddle.int8: {warp.int8, warp.uint8}, + paddle.uint8: {warp.uint8, warp.int8}, + paddle.bool: {warp.bool, warp.uint8, warp.int8}, + # currently unsupported by Warp + # paddle.bfloat16: + # paddle.complex64: + # paddle.complex128: + } + + compatible_set = dtype_is_compatible.compatible_sets.get(paddle_dtype) + + if compatible_set is not None: + if warp_dtype in compatible_set: + return True + # check if it's a vector or matrix type + if hasattr(warp_dtype, "_wp_scalar_type_"): + return warp_dtype._wp_scalar_type_ in compatible_set + + return False + + +# lookup tables initialized when needed +dtype_from_paddle.type_map = None +dtype_to_paddle.type_map = None +dtype_is_compatible.compatible_sets = None + + +# wrap a paddle tensor to a wp array, data is not copied +def from_paddle( + t: paddle.Tensor, + dtype: Optional[paddle.dtype] = None, + requires_grad: Optional[bool] = None, + grad: Optional[paddle.Tensor] = None, + return_ctype: bool = False, +) -> warp.array: + """Convert a Paddle tensor to a Warp array without copying the data. + + Args: + t (paddle.Tensor): The paddle tensor to wrap. + dtype (warp.dtype, optional): The target data type of the resulting Warp array. Defaults to the tensor value type mapped to a Warp array value type. + requires_grad (bool, optional): Whether the resulting array should wrap the tensor's gradient, if it exists (the grad tensor will be allocated otherwise). Defaults to the tensor's `requires_grad` value. + grad (paddle.Tensor, optional): The grad attached to given tensor. Defaults to None. + return_ctype (bool, optional): Whether to return a low-level array descriptor instead of a ``wp.array`` object (faster). The descriptor can be passed to Warp kernels. + + Returns: + warp.array: The wrapped array or array descriptor. + """ + if dtype is None: + dtype = dtype_from_paddle(t.dtype) + elif not dtype_is_compatible(t.dtype, dtype): + raise RuntimeError(f"Cannot convert Paddle type {t.dtype} to Warp type {dtype}") + + # get size of underlying data type to compute strides + ctype_size = ctypes.sizeof(dtype._type_) + + shape = tuple(t.shape) + strides = tuple(s * ctype_size for s in t.strides) + + # if target is a vector or matrix type + # then check if trailing dimensions match + # the target type and update the shape + if hasattr(dtype, "_shape_"): + dtype_shape = dtype._shape_ + dtype_dims = len(dtype._shape_) + # ensure inner shape matches + if dtype_dims > len(shape) or dtype_shape != shape[-dtype_dims:]: + raise RuntimeError( + f"Could not convert Paddle tensor with shape {shape} to Warp array with dtype={dtype}, ensure that source inner shape is {dtype_shape}" + ) + # ensure inner strides are contiguous + if strides[-1] != ctype_size or (dtype_dims > 1 and strides[-2] != ctype_size * dtype_shape[-1]): + raise RuntimeError( + f"Could not convert Paddle tensor with shape {shape} to Warp array with dtype={dtype}, because the source inner strides are not contiguous" + ) + # trim shape and strides + shape = tuple(shape[:-dtype_dims]) or (1,) + strides = tuple(strides[:-dtype_dims]) or (ctype_size,) + + # gradient + # - if return_ctype is False, we set `grad` to a wp.array or None + # - if return_ctype is True, we set `grad_ptr` and set `grad` as the owner (wp.array or paddle.Tensor) + requires_grad = (not t.stop_gradient) if requires_grad is None else requires_grad + grad_ptr = 0 + if grad is not None: + if isinstance(grad, warp.array): + if return_ctype: + if grad.strides != strides: + raise RuntimeError( + f"Gradient strides must match array strides, expected {strides} but got {grad.strides}" + ) + grad_ptr = grad.ptr + else: + # assume grad is a paddle.Tensor + if return_ctype: + if t.strides != grad.strides: + raise RuntimeError( + f"Gradient strides must match array strides, expected {t.strides} but got {grad.strides}" + ) + grad_ptr = grad.data_ptr() + else: + grad = from_paddle(grad, dtype=dtype, requires_grad=False) + elif requires_grad: + # wrap the tensor gradient, allocate if necessary + if t.grad is not None: + if return_ctype: + grad = t.grad + if t.strides != grad.strides: + raise RuntimeError( + f"Gradient strides must match array strides, expected {t.strides} but got {grad.strides}" + ) + grad_ptr = grad.data_ptr() + else: + grad = from_paddle(t.grad, dtype=dtype, requires_grad=False) + else: + # allocate a zero-filled gradient if it doesn't exist + # Note: we use Warp to allocate the shared gradient with compatible strides + grad = warp.zeros(dtype=dtype, shape=shape, strides=strides, device=device_from_paddle(t.place)) + # use .grad_ for zero-copy + t.grad_ = to_paddle(grad, requires_grad=False) + grad_ptr = grad.ptr + + if return_ctype: + ptr = t.data_ptr() + + # create array descriptor + array_ctype = warp.types.array_t(ptr, grad_ptr, len(shape), shape, strides) + + # keep data and gradient alive + array_ctype._ref = t + array_ctype._gradref = grad + + return array_ctype + + else: + a = warp.array( + ptr=t.data_ptr(), + dtype=dtype, + shape=shape, + strides=strides, + device=device_from_paddle(t.place), + copy=False, + grad=grad, + requires_grad=requires_grad, + ) + + # save a reference to the source tensor, otherwise it may get deallocated + a._tensor = t + + return a + + +def to_paddle(a: warp.array, requires_grad: bool = None) -> paddle.Tensor: + """ + Convert a Warp array to a Paddle tensor without copying the data. + + Args: + a (warp.array): The Warp array to convert. + requires_grad (bool, optional): Whether the resulting tensor should convert the array's gradient, if it exists, to a grad tensor. Defaults to the array's `requires_grad` value. + + Returns: + paddle.Tensor: The converted tensor. + """ + import paddle + import paddle.utils.dlpack + + if requires_grad is None: + requires_grad = a.requires_grad + + # Paddle does not support structured arrays + if isinstance(a.dtype, warp.codegen.Struct): + raise RuntimeError("Cannot convert structured Warp arrays to Paddle.") + + if a.device.is_cpu: + # Paddle has an issue wrapping CPU objects + # that support the __array_interface__ protocol + # in this case we need to workaround by going + # to an ndarray first, see https://pearu.github.io/array_interface_pypaddle.html + t = paddle.to_tensor(numpy.asarray(a), place="cpu") + t.stop_gradient = not requires_grad + if requires_grad and a.requires_grad: + # use .grad_ for zero-copy + t.grad_ = paddle.to_tensor(numpy.asarray(a.grad), place="cpu") + return t + + elif a.device.is_cuda: + # Paddle does support the __cuda_array_interface__ + # correctly, but we must be sure to maintain a reference + # to the owning object to prevent memory allocs going out of scope + t = paddle.utils.dlpack.from_dlpack(warp.to_dlpack(a)).to(device=device_to_paddle(a.device)) + t.stop_gradient = not requires_grad + if requires_grad and a.requires_grad: + # use .grad_ for zero-copy + t.grad_ = paddle.utils.dlpack.from_dlpack(warp.to_dlpack(a.grad)).to(device=device_to_paddle(a.device)) + return t + + else: + raise RuntimeError("Unsupported device") + + +def stream_from_paddle(stream_or_device=None): + """Convert from a Paddle CUDA stream to a Warp CUDA stream.""" + import paddle + + if isinstance(stream_or_device, paddle.device.Stream): + stream = stream_or_device + else: + # assume arg is a paddle device + stream = paddle.device.current_stream(stream_or_device) + + device = device_from_paddle(stream.device) + + warp_stream = warp.Stream(device, cuda_stream=stream.stream_base.cuda_stream) + + # save a reference to the source stream, otherwise it may be destroyed + warp_stream._paddle_stream = stream + + return warp_stream diff --git a/warp/stubs.py b/warp/stubs.py index 834e7328..2810e7c6 100644 --- a/warp/stubs.py +++ b/warp/stubs.py @@ -108,6 +108,11 @@ from warp.dlpack import from_dlpack, to_dlpack +from warp.paddle import from_paddle, to_paddle +from warp.paddle import dtype_from_paddle, dtype_to_paddle +from warp.paddle import device_from_paddle, device_to_paddle +from warp.paddle import stream_from_paddle + from warp.build import clear_kernel_cache from warp.constants import * diff --git a/warp/tests/test_dlpack.py b/warp/tests/test_dlpack.py index 45fbef13..30ef693a 100644 --- a/warp/tests/test_dlpack.py +++ b/warp/tests/test_dlpack.py @@ -350,6 +350,34 @@ def test_dlpack_torch_to_warp_v2(test, device): assert_np_equal(a.numpy(), t.cpu().numpy()) +def test_dlpack_paddle_to_warp(test, device): + import paddle + import paddle.utils.dlpack + + t = paddle.arange(N, dtype=paddle.float32).to(device=wp.device_to_paddle(device)) + + # paddle do not implement __dlpack__ yet, so only test to_dlpack here + a = wp.from_dlpack(paddle.utils.dlpack.to_dlpack(t)) + + item_size = wp.types.type_size_in_bytes(a.dtype) + + test.assertEqual(a.ptr, t.data_ptr()) + test.assertEqual(a.device, wp.device_from_paddle(t.place)) + test.assertEqual(a.dtype, wp.dtype_from_paddle(t.dtype)) + test.assertEqual(a.shape, tuple(t.shape)) + test.assertEqual(a.strides, tuple(s * item_size for s in t.strides)) + + assert_np_equal(a.numpy(), t.numpy()) + + wp.launch(inc, dim=a.size, inputs=[a], device=device) + + assert_np_equal(a.numpy(), t.numpy()) + + paddle.assign(t + 1, t) + + assert_np_equal(a.numpy(), t.numpy()) + + def test_dlpack_warp_to_jax(test, device): import jax import jax.dlpack @@ -421,6 +449,61 @@ def test_dlpack_warp_to_jax_v2(test, device): assert_np_equal(a.numpy(), np.asarray(j2)) +def test_dlpack_warp_to_paddle(test, device): + import paddle.utils.dlpack + + a = wp.array(data=np.arange(N, dtype=np.float32), device=device) + + t = paddle.utils.dlpack.from_dlpack(wp.to_dlpack(a)) + + item_size = wp.types.type_size_in_bytes(a.dtype) + + test.assertEqual(a.ptr, t.data_ptr()) + test.assertEqual(a.device, wp.device_from_paddle(t.place)) + test.assertEqual(a.dtype, wp.dtype_from_paddle(t.dtype)) + test.assertEqual(a.shape, tuple(t.shape)) + test.assertEqual(a.strides, tuple(s * item_size for s in t.strides)) + + assert_np_equal(a.numpy(), t.cpu().numpy()) + + wp.launch(inc, dim=a.size, inputs=[a], device=device) + + assert_np_equal(a.numpy(), t.cpu().numpy()) + + paddle.assign(t + 1, t) + + assert_np_equal(a.numpy(), t.cpu().numpy()) + + +def test_dlpack_warp_to_paddle_v2(test, device): + # same as original test, but uses newer __dlpack__() method + + import paddle.utils.dlpack + + a = wp.array(data=np.arange(N, dtype=np.float32), device=device) + + # pass the array directly + t = paddle.utils.dlpack.from_dlpack(a) + + item_size = wp.types.type_size_in_bytes(a.dtype) + + test.assertEqual(a.ptr, t.data_ptr()) + test.assertEqual(a.device, wp.device_from_paddle(t.place)) + test.assertEqual(a.dtype, wp.dtype_from_paddle(t.dtype)) + test.assertEqual(a.shape, tuple(t.shape)) + test.assertEqual(a.strides, tuple(s * item_size for s in t.strides)) + + assert_np_equal(a.numpy(), t.numpy()) + + wp.launch(inc, dim=a.size, inputs=[a], device=device) + + assert_np_equal(a.numpy(), t.numpy()) + + paddle.assign(t + 1, t) + + assert_np_equal(a.numpy(), t.numpy()) + + def test_dlpack_jax_to_warp(test, device): import jax import jax.dlpack @@ -575,6 +658,41 @@ class TestDLPack(unittest.TestCase): print(f"Skipping Jax DLPack tests due to exception: {e}") +# paddle interop via dlpack +try: + import paddle + import paddle.utils.dlpack + + # check which Warp devices work with paddle + # CUDA devices may fail if paddle was not compiled with CUDA support + test_devices = get_test_devices() + paddle_compatible_devices = [] + for d in test_devices: + try: + t = paddle.arange(10).to(device=wp.device_to_paddle(d)) + paddle.assign(t + 1, t) + paddle_compatible_devices.append(d) + except Exception as e: + print(f"Skipping paddle DLPack tests on device '{d}' due to exception: {e}") + + if paddle_compatible_devices: + add_function_test( + TestDLPack, "test_dlpack_warp_to_paddle", test_dlpack_warp_to_paddle, devices=paddle_compatible_devices + ) + add_function_test( + TestDLPack, + "test_dlpack_warp_to_paddle_v2", + test_dlpack_warp_to_paddle_v2, + devices=paddle_compatible_devices, + ) + add_function_test( + TestDLPack, "test_dlpack_paddle_to_warp", test_dlpack_paddle_to_warp, devices=paddle_compatible_devices + ) + +except Exception as e: + print(f"Skipping Paddle DLPack tests due to exception: {e}") + + if __name__ == "__main__": wp.clear_kernel_cache() unittest.main(verbosity=2) diff --git a/warp/tests/test_paddle.py b/warp/tests/test_paddle.py new file mode 100644 index 00000000..53db028e --- /dev/null +++ b/warp/tests/test_paddle.py @@ -0,0 +1,852 @@ +# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved. +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import unittest + +import numpy as np + +import warp as wp +from warp.tests.unittest_utils import * + + +@wp.kernel +def op_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)): + tid = wp.tid() + y[tid] = 0.5 - x[tid] * 2.0 + + +@wp.kernel +def inc(a: wp.array(dtype=float)): + tid = wp.tid() + a[tid] = a[tid] + 1.0 + + +@wp.kernel +def inc_vector(a: wp.array(dtype=wp.vec3f)): + tid = wp.tid() + a[tid] = a[tid] + wp.vec3f(1.0) + + +@wp.kernel +def inc_matrix(a: wp.array(dtype=wp.mat22f)): + tid = wp.tid() + a[tid] = a[tid] + wp.mat22f(1.0) + + +@wp.kernel +def arange(start: int, step: int, a: wp.array(dtype=int)): + tid = wp.tid() + a[tid] = start + step * tid + + +# copy elements between non-contiguous 1d arrays of float +@wp.kernel +def copy1d_float_kernel(dst: wp.array(dtype=float), src: wp.array(dtype=float)): + i = wp.tid() + dst[i] = src[i] + + +# copy elements between non-contiguous 2d arrays of float +@wp.kernel +def copy2d_float_kernel(dst: wp.array2d(dtype=float), src: wp.array2d(dtype=float)): + i, j = wp.tid() + dst[i, j] = src[i, j] + + +# copy elements between non-contiguous 3d arrays of float +@wp.kernel +def copy3d_float_kernel(dst: wp.array3d(dtype=float), src: wp.array3d(dtype=float)): + i, j, k = wp.tid() + dst[i, j, k] = src[i, j, k] + + +# copy elements between non-contiguous 2d arrays of vec3 +@wp.kernel +def copy2d_vec3_kernel(dst: wp.array2d(dtype=wp.vec3), src: wp.array2d(dtype=wp.vec3)): + i, j = wp.tid() + dst[i, j] = src[i, j] + + +# copy elements between non-contiguous 2d arrays of mat22 +@wp.kernel +def copy2d_mat22_kernel(dst: wp.array2d(dtype=wp.mat22), src: wp.array2d(dtype=wp.mat22)): + i, j = wp.tid() + dst[i, j] = src[i, j] + + +def test_dtype_from_paddle(test, device): + import paddle + + def test_conversions(paddle_type, warp_type): + test.assertEqual(wp.dtype_from_paddle(paddle_type), warp_type) + + test_conversions(paddle.float16, wp.float16) + test_conversions(paddle.float32, wp.float32) + test_conversions(paddle.float64, wp.float64) + test_conversions(paddle.int8, wp.int8) + test_conversions(paddle.int16, wp.int16) + test_conversions(paddle.int32, wp.int32) + test_conversions(paddle.int64, wp.int64) + test_conversions(paddle.uint8, wp.uint8) + test_conversions(paddle.bool, wp.bool) + + +def test_dtype_to_paddle(test, device): + import paddle + + def test_conversions(warp_type, paddle_type): + test.assertEqual(wp.dtype_to_paddle(warp_type), paddle_type) + + test_conversions(wp.float16, paddle.float16) + test_conversions(wp.float32, paddle.float32) + test_conversions(wp.float64, paddle.float64) + test_conversions(wp.int8, paddle.int8) + test_conversions(wp.int16, paddle.int16) + test_conversions(wp.int32, paddle.int32) + test_conversions(wp.int64, paddle.int64) + test_conversions(wp.uint8, paddle.uint8) + test_conversions(wp.uint16, paddle.int16) + test_conversions(wp.uint32, paddle.int32) + test_conversions(wp.uint64, paddle.int64) + test_conversions(wp.bool, paddle.bool) + + +def test_device_conversion(test, device): + paddle_device = wp.device_to_paddle(device) + warp_device = wp.device_from_paddle(paddle_device) + test.assertEqual(warp_device, device) + + +def test_paddle_zerocopy(test, device): + import paddle + + a = wp.zeros(10, dtype=wp.float32, device=device) + t = wp.to_paddle(a) + assert a.ptr == t.data_ptr() + + paddle_device = wp.device_to_paddle(device) + + t = paddle.zeros([10], dtype=paddle.float32).to(device=paddle_device) + a = wp.from_paddle(t) + assert a.ptr == t.data_ptr() + + +def test_from_paddle(test, device): + import paddle + + paddle_device = wp.device_to_paddle(device) + + # automatically determine warp dtype + def wrap_scalar_tensor_implicit(paddle_dtype, expected_warp_dtype): + t = paddle.zeros([10], dtype=paddle_dtype).to(device=paddle_device) + a = wp.from_paddle(t) + assert a.dtype == expected_warp_dtype + assert a.shape == tuple(t.shape) + + wrap_scalar_tensor_implicit(paddle.float64, wp.float64) + wrap_scalar_tensor_implicit(paddle.float32, wp.float32) + wrap_scalar_tensor_implicit(paddle.float16, wp.float16) + wrap_scalar_tensor_implicit(paddle.int64, wp.int64) + wrap_scalar_tensor_implicit(paddle.int32, wp.int32) + wrap_scalar_tensor_implicit(paddle.int16, wp.int16) + wrap_scalar_tensor_implicit(paddle.int8, wp.int8) + wrap_scalar_tensor_implicit(paddle.uint8, wp.uint8) + wrap_scalar_tensor_implicit(paddle.bool, wp.bool) + + # explicitly specify warp dtype + def wrap_scalar_tensor_explicit(paddle_dtype, expected_warp_dtype): + t = paddle.zeros([10], dtype=paddle_dtype).to(device=paddle_device) + a = wp.from_paddle(t, expected_warp_dtype) + assert a.dtype == expected_warp_dtype + assert a.shape == tuple(t.shape) + + wrap_scalar_tensor_explicit(paddle.float64, wp.float64) + wrap_scalar_tensor_explicit(paddle.float32, wp.float32) + wrap_scalar_tensor_explicit(paddle.float16, wp.float16) + wrap_scalar_tensor_explicit(paddle.int64, wp.int64) + wrap_scalar_tensor_explicit(paddle.int64, wp.uint64) + wrap_scalar_tensor_explicit(paddle.int32, wp.int32) + wrap_scalar_tensor_explicit(paddle.int32, wp.uint32) + wrap_scalar_tensor_explicit(paddle.int16, wp.int16) + wrap_scalar_tensor_explicit(paddle.int16, wp.uint16) + wrap_scalar_tensor_explicit(paddle.int8, wp.int8) + wrap_scalar_tensor_explicit(paddle.int8, wp.uint8) + wrap_scalar_tensor_explicit(paddle.uint8, wp.uint8) + wrap_scalar_tensor_explicit(paddle.uint8, wp.int8) + wrap_scalar_tensor_explicit(paddle.bool, wp.uint8) + wrap_scalar_tensor_explicit(paddle.bool, wp.int8) + wrap_scalar_tensor_explicit(paddle.bool, wp.bool) + + def wrap_vec_tensor(n, desired_warp_dtype): + t = paddle.zeros((10, n), dtype=paddle.float32).to(device=paddle_device) + a = wp.from_paddle(t, desired_warp_dtype) + assert a.dtype == desired_warp_dtype + assert a.shape == (10,) + + wrap_vec_tensor(2, wp.vec2) + wrap_vec_tensor(3, wp.vec3) + wrap_vec_tensor(4, wp.vec4) + wrap_vec_tensor(6, wp.spatial_vector) + wrap_vec_tensor(7, wp.transform) + + def wrap_mat_tensor(n, m, desired_warp_dtype): + t = paddle.zeros((10, n, m), dtype=paddle.float32).to(device=paddle_device) + a = wp.from_paddle(t, desired_warp_dtype) + assert a.dtype == desired_warp_dtype + assert a.shape == (10,) + + wrap_mat_tensor(2, 2, wp.mat22) + wrap_mat_tensor(3, 3, wp.mat33) + wrap_mat_tensor(4, 4, wp.mat44) + wrap_mat_tensor(6, 6, wp.spatial_matrix) + + def wrap_vec_tensor_with_grad(n, desired_warp_dtype): + t = paddle.zeros((10, n), dtype=paddle.float32).to(device=paddle_device) + a = wp.from_paddle(t, desired_warp_dtype) + a.reuqires_grad = True + assert a.dtype == desired_warp_dtype + assert a.shape == (10,) + + wrap_vec_tensor_with_grad(2, wp.vec2) + wrap_vec_tensor_with_grad(3, wp.vec3) + wrap_vec_tensor_with_grad(4, wp.vec4) + wrap_vec_tensor_with_grad(6, wp.spatial_vector) + wrap_vec_tensor_with_grad(7, wp.transform) + + def wrap_mat_tensor_with_grad(n, m, desired_warp_dtype): + t = paddle.zeros((10, n, m), dtype=paddle.float32).to(device=paddle_device) + a = wp.from_paddle(t, desired_warp_dtype, requires_grad=True) + assert a.dtype == desired_warp_dtype + assert a.shape == (10,) + + wrap_mat_tensor_with_grad(2, 2, wp.mat22) + wrap_mat_tensor_with_grad(3, 3, wp.mat33) + wrap_mat_tensor_with_grad(4, 4, wp.mat44) + wrap_mat_tensor_with_grad(6, 6, wp.spatial_matrix) + + +def test_array_ctype_from_paddle(test, device): + import paddle + + paddle_device = wp.device_to_paddle(device) + + # automatically determine warp dtype + def wrap_scalar_tensor_implicit(paddle_dtype): + t = paddle.zeros([10], dtype=paddle_dtype).to(device=paddle_device) + a = wp.from_paddle(t, return_ctype=True) + warp_dtype = wp.dtype_from_paddle(paddle_dtype) + ctype_size = ctypes.sizeof(warp_dtype._type_) + assert a.data == t.data_ptr() + assert a.grad == 0 + assert a.ndim == 1 + assert a.shape[0] == t.shape[0] + assert a.strides[0] == t.strides[0] * ctype_size + + wrap_scalar_tensor_implicit(paddle.float64) + wrap_scalar_tensor_implicit(paddle.float32) + wrap_scalar_tensor_implicit(paddle.float16) + wrap_scalar_tensor_implicit(paddle.int64) + wrap_scalar_tensor_implicit(paddle.int32) + wrap_scalar_tensor_implicit(paddle.int16) + wrap_scalar_tensor_implicit(paddle.int8) + wrap_scalar_tensor_implicit(paddle.uint8) + wrap_scalar_tensor_implicit(paddle.bool) + + # explicitly specify warp dtype + def wrap_scalar_tensor_explicit(paddle_dtype, warp_dtype): + t = paddle.zeros([10], dtype=paddle_dtype).to(device=paddle_device) + a = wp.from_paddle(t, dtype=warp_dtype, return_ctype=True) + ctype_size = ctypes.sizeof(warp_dtype._type_) + assert a.data == t.data_ptr() + assert a.grad == 0 + assert a.ndim == 1 + assert a.shape[0] == t.shape[0] + assert a.strides[0] == t.strides[0] * ctype_size + + wrap_scalar_tensor_explicit(paddle.float64, wp.float64) + wrap_scalar_tensor_explicit(paddle.float32, wp.float32) + wrap_scalar_tensor_explicit(paddle.float16, wp.float16) + wrap_scalar_tensor_explicit(paddle.int64, wp.int64) + wrap_scalar_tensor_explicit(paddle.int64, wp.uint64) + wrap_scalar_tensor_explicit(paddle.int32, wp.int32) + wrap_scalar_tensor_explicit(paddle.int32, wp.uint32) + wrap_scalar_tensor_explicit(paddle.int16, wp.int16) + wrap_scalar_tensor_explicit(paddle.int16, wp.uint16) + wrap_scalar_tensor_explicit(paddle.int8, wp.int8) + wrap_scalar_tensor_explicit(paddle.int8, wp.uint8) + wrap_scalar_tensor_explicit(paddle.uint8, wp.uint8) + wrap_scalar_tensor_explicit(paddle.uint8, wp.int8) + wrap_scalar_tensor_explicit(paddle.bool, wp.uint8) + wrap_scalar_tensor_explicit(paddle.bool, wp.int8) + wrap_scalar_tensor_explicit(paddle.bool, wp.bool) + + def wrap_vec_tensor(vec_dtype): + t = paddle.zeros((10, vec_dtype._length_), dtype=paddle.float32).to(device=paddle_device) + a = wp.from_paddle(t, dtype=vec_dtype, return_ctype=True) + ctype_size = ctypes.sizeof(vec_dtype._type_) + assert a.data == t.data_ptr() + assert a.grad == 0 + assert a.ndim == 1 + assert a.shape[0] == t.shape[0] + assert a.strides[0] == t.strides[0] * ctype_size + + wrap_vec_tensor(wp.vec2) + wrap_vec_tensor(wp.vec3) + wrap_vec_tensor(wp.vec4) + wrap_vec_tensor(wp.spatial_vector) + wrap_vec_tensor(wp.transform) + + def wrap_mat_tensor(mat_dtype): + t = paddle.zeros((10, *mat_dtype._shape_), dtype=paddle.float32).to(device=paddle_device) + a = wp.from_paddle(t, dtype=mat_dtype, return_ctype=True) + ctype_size = ctypes.sizeof(mat_dtype._type_) + assert a.data == t.data_ptr() + assert a.grad == 0 + assert a.ndim == 1 + assert a.shape[0] == t.shape[0] + assert a.strides[0] == t.strides[0] * ctype_size + + wrap_mat_tensor(wp.mat22) + wrap_mat_tensor(wp.mat33) + wrap_mat_tensor(wp.mat44) + wrap_mat_tensor(wp.spatial_matrix) + + def wrap_vec_tensor_with_existing_grad(vec_dtype): + t = paddle.zeros((10, vec_dtype._length_), dtype=paddle.float32).to(device=paddle_device) + t.stop_gradient = False + t.grad_ = paddle.zeros((10, vec_dtype._length_), dtype=paddle.float32).to(device=paddle_device) + a = wp.from_paddle(t, dtype=vec_dtype, return_ctype=True) + ctype_size = ctypes.sizeof(vec_dtype._type_) + assert a.data == t.data_ptr() + assert a.grad == t.grad.data_ptr() + assert a.ndim == 1 + assert a.shape[0] == t.shape[0] + assert a.strides[0] == t.strides[0] * ctype_size + + wrap_vec_tensor_with_existing_grad(wp.vec2) + wrap_vec_tensor_with_existing_grad(wp.vec3) + wrap_vec_tensor_with_existing_grad(wp.vec4) + wrap_vec_tensor_with_existing_grad(wp.spatial_vector) + wrap_vec_tensor_with_existing_grad(wp.transform) + + def wrap_vec_tensor_with_new_grad(vec_dtype): + t = paddle.zeros((10, vec_dtype._length_), dtype=paddle.float32).to(device=paddle_device) + a = wp.from_paddle(t, dtype=vec_dtype, requires_grad=True, return_ctype=True) + ctype_size = ctypes.sizeof(vec_dtype._type_) + assert a.data == t.data_ptr() + assert a.grad == t.grad.data_ptr() + assert a.ndim == 1 + assert a.shape[0] == t.shape[0] + assert a.strides[0] == t.strides[0] * ctype_size + + wrap_vec_tensor_with_new_grad(wp.vec2) + wrap_vec_tensor_with_new_grad(wp.vec3) + wrap_vec_tensor_with_new_grad(wp.vec4) + wrap_vec_tensor_with_new_grad(wp.spatial_vector) + wrap_vec_tensor_with_new_grad(wp.transform) + + def wrap_vec_tensor_with_paddle_grad(vec_dtype): + t = paddle.zeros((10, vec_dtype._length_), dtype=paddle.float32).to(device=paddle_device) + grad = paddle.zeros((10, vec_dtype._length_), dtype=paddle.float32).to(device=paddle_device) + a = wp.from_paddle(t, dtype=vec_dtype, grad=grad, return_ctype=True) + ctype_size = ctypes.sizeof(vec_dtype._type_) + assert a.data == t.data_ptr() + assert a.grad == grad.data_ptr() + assert a.ndim == 1 + assert a.shape[0] == t.shape[0] + assert a.strides[0] == t.strides[0] * ctype_size + + wrap_vec_tensor_with_paddle_grad(wp.vec2) + wrap_vec_tensor_with_paddle_grad(wp.vec3) + wrap_vec_tensor_with_paddle_grad(wp.vec4) + wrap_vec_tensor_with_paddle_grad(wp.spatial_vector) + wrap_vec_tensor_with_paddle_grad(wp.transform) + + def wrap_vec_tensor_with_warp_grad(vec_dtype): + t = paddle.zeros((10, vec_dtype._length_), dtype=paddle.float32).to(device=paddle_device) + grad = wp.zeros(10, dtype=vec_dtype, device=device) + a = wp.from_paddle(t, dtype=vec_dtype, grad=grad, return_ctype=True) + ctype_size = ctypes.sizeof(vec_dtype._type_) + assert a.data == t.data_ptr() + assert a.grad == grad.ptr + assert a.ndim == 1 + assert a.shape[0] == t.shape[0] + assert a.strides[0] == t.strides[0] * ctype_size + + wrap_vec_tensor_with_warp_grad(wp.vec2) + wrap_vec_tensor_with_warp_grad(wp.vec3) + wrap_vec_tensor_with_warp_grad(wp.vec4) + wrap_vec_tensor_with_warp_grad(wp.spatial_vector) + wrap_vec_tensor_with_warp_grad(wp.transform) + + +def test_to_paddle(test, device): + import paddle + + def wrap_scalar_array(warp_dtype, expected_paddle_dtype): + a = wp.zeros(10, dtype=warp_dtype, device=device) + t = wp.to_paddle(a) + assert t.dtype == expected_paddle_dtype + assert tuple(t.shape) == a.shape + + wrap_scalar_array(wp.float64, paddle.float64) + wrap_scalar_array(wp.float32, paddle.float32) + wrap_scalar_array(wp.float16, paddle.float16) + wrap_scalar_array(wp.int64, paddle.int64) + wrap_scalar_array(wp.int32, paddle.int32) + wrap_scalar_array(wp.int16, paddle.int16) + wrap_scalar_array(wp.int8, paddle.int8) + wrap_scalar_array(wp.uint8, paddle.uint8) + wrap_scalar_array(wp.bool, paddle.bool) + + # not supported by paddle + # wrap_scalar_array(wp.uint64, paddle.int64) + # wrap_scalar_array(wp.uint32, paddle.int32) + # wrap_scalar_array(wp.uint16, paddle.int16) + + def wrap_vec_array(n, warp_dtype): + a = wp.zeros(10, dtype=warp_dtype, device=device) + t = wp.to_paddle(a) + assert t.dtype == paddle.float32 + assert tuple(t.shape) == (10, n) + + wrap_vec_array(2, wp.vec2) + wrap_vec_array(3, wp.vec3) + wrap_vec_array(4, wp.vec4) + wrap_vec_array(6, wp.spatial_vector) + wrap_vec_array(7, wp.transform) + + def wrap_mat_array(n, m, warp_dtype): + a = wp.zeros(10, dtype=warp_dtype, device=device) + t = wp.to_paddle(a) + assert t.dtype == paddle.float32 + assert tuple(t.shape) == (10, n, m) + + wrap_mat_array(2, 2, wp.mat22) + wrap_mat_array(3, 3, wp.mat33) + wrap_mat_array(4, 4, wp.mat44) + wrap_mat_array(6, 6, wp.spatial_matrix) + + +def test_from_paddle_slices(test, device): + import paddle + + paddle_device = wp.device_to_paddle(device) + + # 1D slice, contiguous + t_base = paddle.arange(10, dtype=paddle.float32).to(device=paddle_device) + t = t_base[2:9] + a = wp.from_paddle(t) + assert a.ptr == t.data_ptr() + assert a.is_contiguous + assert a.shape == tuple(t.shape) + assert_np_equal(a.numpy(), t.cpu().numpy()) + + # 1D slice with non-contiguous stride + t_base = paddle.arange(10, dtype=paddle.float32).to(device=paddle_device) + t = t_base[2:9:2] + a = wp.from_paddle(t) + assert a.ptr == t.data_ptr() + assert not a.is_contiguous + assert a.shape == tuple(t.shape) + # copy contents to contiguous array + a_contiguous = wp.empty_like(a) + wp.launch(copy1d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device) + assert_np_equal(a_contiguous.numpy(), t.cpu().numpy()) + + # 2D slices (non-contiguous) + t_base = paddle.arange(24, dtype=paddle.float32).to(device=paddle_device).reshape((4, 6)) + t = t_base[1:3, 2:5] + a = wp.from_paddle(t) + assert a.ptr == t.data_ptr() + assert not a.is_contiguous + assert a.shape == tuple(t.shape) + # copy contents to contiguous array + a_contiguous = wp.empty_like(a) + wp.launch(copy2d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device) + assert_np_equal(a_contiguous.numpy(), t.cpu().numpy()) + + # 3D slices (non-contiguous) + t_base = paddle.arange(36, dtype=paddle.float32).to(device=paddle_device).reshape((4, 3, 3)) + t = t_base[::2, 0:1, 1:2] + a = wp.from_paddle(t) + assert a.ptr == t.data_ptr() + assert not a.is_contiguous + assert a.shape == tuple(t.shape) + # copy contents to contiguous array + a_contiguous = wp.empty_like(a) + wp.launch(copy3d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device) + assert_np_equal(a_contiguous.numpy(), t.cpu().numpy()) + + # 2D slices of vec3 (inner contiguous, outer non-contiguous) + t_base = paddle.arange(150, dtype=paddle.float32).to(device=paddle_device).reshape((10, 5, 3)) + t = t_base[1:7:2, 2:5] + a = wp.from_paddle(t, dtype=wp.vec3) + assert a.ptr == t.data_ptr() + assert not a.is_contiguous + assert a.shape == tuple(t.shape[:-1]) + # copy contents to contiguous array + a_contiguous = wp.empty_like(a) + wp.launch(copy2d_vec3_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device) + assert_np_equal(a_contiguous.numpy(), t.cpu().numpy()) + + # 2D slices of mat22 (inner contiguous, outer non-contiguous) + t_base = paddle.arange(200, dtype=paddle.float32).to(device=paddle_device).reshape((10, 5, 2, 2)) + t = t_base[1:7:2, 2:5] + a = wp.from_paddle(t, dtype=wp.mat22) + assert a.ptr == t.data_ptr() + assert not a.is_contiguous + assert a.shape == tuple(t.shape[:-2]) + # copy contents to contiguous array + a_contiguous = wp.empty_like(a) + wp.launch(copy2d_mat22_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device) + assert_np_equal(a_contiguous.numpy(), t.cpu().numpy()) + + +def test_from_paddle_zero_strides(test, device): + import paddle + + paddle_device = wp.device_to_paddle(device) + + t_base = paddle.arange(9, dtype=paddle.float32).to(device=paddle_device).reshape((3, 3)) + + # expand outermost dimension + t = t_base.unsqueeze(0).expand([3, -1, -1]) + a = wp.from_paddle(t) + assert a.ptr == t.data_ptr() + assert a.is_contiguous + assert a.shape == tuple(t.shape) + a_contiguous = wp.empty_like(a) + wp.launch(copy3d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device) + assert_np_equal(a_contiguous.numpy(), t.cpu().numpy()) + + # expand middle dimension + t = t_base.unsqueeze(1).expand([-1, 3, -1]) + a = wp.from_paddle(t) + assert a.ptr == t.data_ptr() + assert a.is_contiguous + assert a.shape == tuple(t.shape) + a_contiguous = wp.empty_like(a) + wp.launch(copy3d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device) + assert_np_equal(a_contiguous.numpy(), t.cpu().numpy()) + + # expand innermost dimension + t = t_base.unsqueeze(2).expand([-1, -1, 3]) + a = wp.from_paddle(t) + assert a.ptr == t.data_ptr() + assert a.is_contiguous + assert a.shape == tuple(t.shape) + a_contiguous = wp.empty_like(a) + wp.launch(copy3d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device) + assert_np_equal(a_contiguous.numpy(), t.cpu().numpy()) + + +def test_paddle_mgpu_from_paddle(test, device): + import paddle + + n = 32 + + t0 = paddle.arange(0, n, 1, dtype=paddle.int32).to(device="gpu:0") + t1 = paddle.arange(0, n * 2, 2, dtype=paddle.int32).to(device="gpu:1") + + a0 = wp.from_paddle(t0, dtype=wp.int32) + a1 = wp.from_paddle(t1, dtype=wp.int32) + + assert a0.device == "gpu:0" + assert a1.device == "gpu:1" + + expected0 = np.arange(0, n, 1) + expected1 = np.arange(0, n * 2, 2) + + assert_np_equal(a0.numpy(), expected0) + assert_np_equal(a1.numpy(), expected1) + + +def test_paddle_mgpu_to_paddle(test, device): + n = 32 + + with wp.ScopedDevice("gpu:0"): + a0 = wp.empty(n, dtype=wp.int32) + wp.launch(arange, dim=a0.size, inputs=[0, 1, a0]) + + with wp.ScopedDevice("gpu:1"): + a1 = wp.empty(n, dtype=wp.int32) + wp.launch(arange, dim=a1.size, inputs=[0, 2, a1]) + + t0 = wp.to_paddle(a0) + t1 = wp.to_paddle(a1) + + assert str(t0.device) == "gpu:0" + assert str(t1.device) == "gpu:1" + + expected0 = np.arange(0, n, 1, dtype=np.int32) + expected1 = np.arange(0, n * 2, 2, dtype=np.int32) + + assert_np_equal(t0.cpu().numpy(), expected0) + assert_np_equal(t1.cpu().numpy(), expected1) + + +def test_paddle_mgpu_interop(test, device): + import paddle + + n = 1024 * 1024 + + with paddle.cuda.device(0): + t0 = paddle.arange(n, dtype=paddle.float32).to(device="gpu") + a0 = wp.from_paddle(t0) + wp.launch(inc, dim=a0.size, inputs=[a0], stream=wp.stream_from_paddle()) + + with paddle.cuda.device(1): + t1 = paddle.arange(n, dtype=paddle.float32).to(device="gpu") + a1 = wp.from_paddle(t1) + wp.launch(inc, dim=a1.size, inputs=[a1], stream=wp.stream_from_paddle()) + + assert a0.device == "gpu:0" + assert a1.device == "gpu:1" + + expected = np.arange(n, dtype=int) + 1 + + # ensure the paddle tensors were modified by warp + assert_np_equal(t0.cpu().numpy(), expected) + assert_np_equal(t1.cpu().numpy(), expected) + + +def test_paddle_autograd(test, device): + """Test paddle autograd with a custom Warp op""" + + import paddle + + # custom autograd op + class TestFunc(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x): + # allocate output array + y = paddle.empty_like(x) + + ctx.x = x + ctx.y = y + + wp.launch(kernel=op_kernel, dim=len(x), inputs=[wp.from_paddle(x)], outputs=[wp.from_paddle(y)]) + + return y + + @staticmethod + def backward(ctx, adj_y): + # adjoints should be allocated as zero initialized + adj_x = paddle.zeros_like(ctx.x).contiguous() + adj_y = adj_y.contiguous() + + wp_x = wp.from_paddle(ctx.x, grad=adj_x) + wp_y = wp.from_paddle(ctx.y, grad=adj_y) + + wp.launch( + kernel=op_kernel, + dim=len(ctx.x), + # fwd inputs + inputs=[wp_x], + outputs=[wp_y], + # adj inputs (already stored in input/output arrays, passing null pointers) + adj_inputs=[None], + adj_outputs=[None], + adjoint=True, + ) + + return adj_x + + # run autograd on given device + with wp.ScopedDevice(device): + paddle_device = wp.device_to_paddle(device) + + # input data + x = paddle.ones(16, dtype=paddle.float32).to(device=paddle_device) + x.stop_gradient = False + + # execute op + y = TestFunc.apply(x) + + # compute grads + l = y.sum() + l.backward() + + passed = (x.grad == -2.0).all() + assert passed.item() + + +def test_warp_graph_warp_stream(test, device): + """Capture Warp graph on Warp stream""" + + import paddle + + paddle_device = wp.device_to_paddle(device) + + n = 1024 * 1024 + t = paddle.zeros(n, dtype=paddle.float32).to(device=paddle_device) + a = wp.from_paddle(t) + + # make paddle use the warp stream from the given device + paddle_stream = wp.stream_to_paddle(device) + + # capture graph + with wp.ScopedDevice(device), paddle.device.stream(paddle_stream): + wp.capture_begin(force_module_load=False) + try: + t += 1.0 + wp.launch(inc, dim=n, inputs=[a]) + t += 1.0 + wp.launch(inc, dim=n, inputs=[a]) + finally: + g = wp.capture_end() + + # replay graph + num_iters = 10 + for _i in range(num_iters): + wp.capture_launch(g) + + passed = (t == num_iters * 4.0).all() + assert passed.item() + + +def test_warp_graph_paddle_stream(test, device): + """Capture Warp graph on Paddle stream""" + + wp.load_module(device=device) + + import paddle + + paddle_device = wp.device_to_paddle(device) + + n = 1024 * 1024 + t = paddle.zeros(n, dtype=paddle.float32).to(device=paddle_device) + a = wp.from_paddle(t) + + # create a device-specific paddle stream to use for capture + # (the default paddle stream is not suitable for graph capture) + paddle_stream = paddle.device.Stream(device=paddle_device) + + # make warp use the same stream + warp_stream = wp.stream_from_paddle(paddle_stream) + + # capture graph + with wp.ScopedStream(warp_stream): + wp.capture_begin(force_module_load=False) + try: + t += 1.0 + wp.launch(inc, dim=n, inputs=[a]) + t += 1.0 + wp.launch(inc, dim=n, inputs=[a]) + finally: + g = wp.capture_end() + + # replay graph + num_iters = 10 + for _i in range(num_iters): + wp.capture_launch(g) + + passed = (t == num_iters * 4.0).all() + assert passed.item() + + +def test_direct(test, device): + """Pass Paddle tensors to Warp kernels directly""" + + import paddle + + paddle_device = wp.device_to_paddle(device) + n = 12 + + s = paddle.arange(n, dtype=paddle.float32).to(device=paddle_device) + v = paddle.arange(n, dtype=paddle.float32).to(device=paddle_device).reshape((n // 3, 3)) + m = paddle.arange(n, dtype=paddle.float32).to(device=paddle_device).reshape((n // 4, 2, 2)) + + wp.launch(inc, dim=n, inputs=[s], device=device) + wp.launch(inc_vector, dim=n // 3, inputs=[v], device=device) + wp.launch(inc_matrix, dim=n // 4, inputs=[m], device=device) + + expected = paddle.arange(1, n + 1, dtype=paddle.float32).to(device=paddle_device) + + assert paddle.equal_all(s, expected).item() + assert paddle.equal_all(v.reshape([n]), expected).item() + assert paddle.equal_all(m.reshape([n]), expected).item() + + +class TestPaddle(unittest.TestCase): + pass + + +test_devices = get_test_devices() + +try: + import paddle + + # check which Warp devices work with Paddle + # CUDA devices may fail if Paddle was not compiled with CUDA support + paddle_compatible_devices = [] + paddle_compatible_cuda_devices = [] + + for d in test_devices: + try: + t = paddle.arange(10).to(device=wp.device_to_paddle(d)) + t += 1 + paddle_compatible_devices.append(d) + if d.is_cuda: + paddle_compatible_cuda_devices.append(d) + except Exception as e: + print(f"Skipping Paddle tests on device '{d}' due to exception: {e}") + + add_function_test(TestPaddle, "test_dtype_from_paddle", test_dtype_from_paddle, devices=None) + add_function_test(TestPaddle, "test_dtype_to_paddle", test_dtype_to_paddle, devices=None) + + if paddle_compatible_devices: + add_function_test( + TestPaddle, "test_device_conversion", test_device_conversion, devices=paddle_compatible_devices + ) + add_function_test(TestPaddle, "test_from_paddle", test_from_paddle, devices=paddle_compatible_devices) + add_function_test( + TestPaddle, "test_from_paddle_slices", test_from_paddle_slices, devices=paddle_compatible_devices + ) + add_function_test( + TestPaddle, "test_array_ctype_from_paddle", test_array_ctype_from_paddle, devices=paddle_compatible_devices + ) + add_function_test( + TestPaddle, + "test_from_paddle_zero_strides", + test_from_paddle_zero_strides, + devices=paddle_compatible_devices, + ) + add_function_test(TestPaddle, "test_to_paddle", test_to_paddle, devices=paddle_compatible_devices) + add_function_test(TestPaddle, "test_paddle_zerocopy", test_paddle_zerocopy, devices=paddle_compatible_devices) + add_function_test(TestPaddle, "test_paddle_autograd", test_paddle_autograd, devices=paddle_compatible_devices) + add_function_test(TestPaddle, "test_direct", test_direct, devices=paddle_compatible_devices) + + # NOTE: Graph not supported now + # if paddle_compatible_cuda_devices: + # add_function_test( + # TestPaddle, + # "test_warp_graph_warp_stream", + # test_warp_graph_warp_stream, + # devices=paddle_compatible_cuda_devices, + # ) + # add_function_test( + # TestPaddle, + # "test_warp_graph_paddle_stream", + # test_warp_graph_paddle_stream, + # devices=paddle_compatible_cuda_devices, + # ) + + # multi-GPU tests + if len(paddle_compatible_cuda_devices) > 1: + add_function_test(TestPaddle, "test_paddle_mgpu_from_paddle", test_paddle_mgpu_from_paddle) + add_function_test(TestPaddle, "test_paddle_mgpu_to_paddle", test_paddle_mgpu_to_paddle) + add_function_test(TestPaddle, "test_paddle_mgpu_interop", test_paddle_mgpu_interop) + +except Exception as e: + print(f"Skipping Paddle tests due to exception: {e}") + + +if __name__ == "__main__": + wp.clear_kernel_cache() + unittest.main(verbosity=2) diff --git a/warp/thirdparty/dlpack.py b/warp/thirdparty/dlpack.py index 0634474b..399e0002 100644 --- a/warp/thirdparty/dlpack.py +++ b/warp/thirdparty/dlpack.py @@ -58,6 +58,7 @@ class DLDataTypeCode(ctypes.c_uint8): kDLOpaquePointer = 3 kDLBfloat = 4 kDLComplex = 5 + kDLBool = 6 def __str__(self): return { @@ -66,6 +67,7 @@ def __str__(self): self.kDLFloat: "float", self.kDLBfloat: "bfloat", self.kDLComplex: "complex", + self.kDLBool: "bool", self.kDLOpaquePointer: "void_p", }[self.value] @@ -85,7 +87,7 @@ class DLDataType(ctypes.Structure): ("lanes", ctypes.c_uint16), ] TYPE_MAP = { - "bool": (DLDataTypeCode.kDLUInt, 1, 1), + "bool": (DLDataTypeCode.kDLBool, 8, 1), "int8": (DLDataTypeCode.kDLInt, 8, 1), "int16": (DLDataTypeCode.kDLInt, 16, 1), "int32": (DLDataTypeCode.kDLInt, 32, 1),