Skip to content

Commit

Permalink
Merge branch 'GH-318' into 'main'
Browse files Browse the repository at this point in the history
Add paddle backend to warp

Closes GH-318

See merge request omniverse/warp!762
  • Loading branch information
shi-eric committed Sep 30, 2024
2 parents 7b96a06 + ab147a0 commit a7ecac0
Show file tree
Hide file tree
Showing 15 changed files with 1,566 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ regular Python functions and JIT compiles them to efficient kernel code that can
Warp is designed for [spatial computing](https://en.wikipedia.org/wiki/Spatial_computing)
and comes with a rich set of primitives that make it easy to write
programs for physics simulation, perception, robotics, and geometry processing. In addition, Warp kernels
are differentiable and can be used as part of machine-learning pipelines with frameworks such as PyTorch and JAX.
are differentiable and can be used as part of machine-learning pipelines with frameworks such as PyTorch, JAX and Paddle.

Please refer to the project [Documentation](https://nvidia.github.io/warp/) for API and language reference and [CHANGELOG.md](./CHANGELOG.md) for release history.

Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ regular Python functions and JIT compiles them to efficient kernel code that can
Warp is designed for `spatial computing <https://en.wikipedia.org/wiki/Spatial_computing>`_
and comes with a rich set of primitives that make it easy to write
programs for physics simulation, perception, robotics, and geometry processing. In addition, Warp kernels
are differentiable and can be used as part of machine-learning pipelines with frameworks such as PyTorch and JAX.
are differentiable and can be used as part of machine-learning pipelines with frameworks such as PyTorch, JAX and Paddle.

Below are some examples of simulations implemented using Warp:

Expand Down
1 change: 1 addition & 0 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ The following optional dependencies are required to support certain features:
* `usd-core <https://pypi.org/project/usd-core>`_: Required for some Warp examples, ``warp.sim.parse_usd()``, and ``warp.render.UsdRenderer``.
* `JAX <https://jax.readthedocs.io/en/latest/installation.html>`_: Required for JAX interoperability (see :ref:`jax-interop`).
* `PyTorch <https://pytorch.org/get-started/locally/>`_: Required for PyTorch interoperability (see :ref:`pytorch-interop`).
* `Paddle <https://github.com/PaddlePaddle/Paddle>`_: Required for Paddle interoperability (see :ref:`paddle-interop`).
* `NVTX for Python <https://github.com/NVIDIA/NVTX#python>`_: Required to use :class:`wp.ScopedTimer(use_nvtx=True) <warp.ScopedTimer>`.

Building the Warp documentation requires:
Expand Down
181 changes: 181 additions & 0 deletions docs/modules/interoperability.rst
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ The canonical way to export a Warp array to an external framework is to use the

jax_array = jax.dlpack.from_dlpack(warp_array)
torch_tensor = torch.utils.dlpack.from_dlpack(warp_array)
paddle_tensor = paddle.utils.dlpack.from_dlpack(warp_array)

For CUDA arrays, this will synchronize the current stream of the consumer framework with the current Warp stream on the array's device.
Thus it should be safe to use the wrapped array in the consumer framework, even if the array was previously used in a Warp kernel
Expand All @@ -719,9 +720,11 @@ This approach may be used for older versions of frameworks that do not support t

warp_array1 = wp.from_dlpack(jax.dlpack.to_dlpack(jax_array))
warp_array2 = wp.from_dlpack(torch.utils.dlpack.to_dlpack(torch_tensor))
warp_array3 = wp.from_dlpack(paddle.utils.dlpack.to_dlpack(paddle_tensor))

jax_array = jax.dlpack.from_dlpack(wp.to_dlpack(warp_array))
torch_tensor = torch.utils.dlpack.from_dlpack(wp.to_dlpack(warp_array))
paddle_tensor = paddle.utils.dlpack.from_dlpack(wp.to_dlpack(warp_array))

This approach is generally faster because it skips any stream synchronization, but another solution must be used to ensure correct
ordering of operations. In situations where no synchronization is required, using this approach can yield better performance.
Expand All @@ -733,3 +736,181 @@ This may be a good choice in situations like these:

.. autofunction:: warp.from_dlpack
.. autofunction:: warp.to_dlpack

.. _paddle-interop:

Paddle
------

Warp provides helper functions to convert arrays to/from Paddle::

w = wp.array([1.0, 2.0, 3.0], dtype=float, device="cpu")

# convert to Paddle tensor
t = wp.to_paddle(w)

# convert from Paddle tensor
w = wp.from_paddle(t)

These helper functions allow the conversion of Warp arrays to/from Paddle tensors without copying the underlying data.
At the same time, if available, gradient arrays and tensors are converted to/from Paddle autograd tensors, allowing the use of Warp arrays
in Paddle autograd computations.

.. autofunction:: warp.from_paddle
.. autofunction:: warp.to_paddle
.. autofunction:: warp.device_from_paddle
.. autofunction:: warp.device_to_paddle
.. autofunction:: warp.dtype_from_paddle
.. autofunction:: warp.dtype_to_paddle

To convert a Paddle CUDA stream to a Warp CUDA stream and vice versa, Warp provides the following functions:

.. autofunction:: warp.stream_from_paddle

Example: Optimization using ``warp.from_paddle()``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

An example usage of minimizing a loss function over an array of 2D points written in Warp via Paddle's Adam optimizer
using :func:`warp.from_paddle` is as follows::

import warp as wp
import paddle

# init warp context at beginning
wp.context.init()

@wp.kernel()
def loss(xs: wp.array(dtype=float, ndim=2), l: wp.array(dtype=float)):
tid = wp.tid()
wp.atomic_add(l, 0, xs[tid, 0] ** 2.0 + xs[tid, 1] ** 2.0)

# indicate requires_grad so that Warp can accumulate gradients in the grad buffers
xs = paddle.randn([100, 2])
xs.stop_gradient = False
l = paddle.zeros([1])
l.stop_gradient = False
opt = paddle.optimizer.Adam(learning_rate=0.1, parameters=[xs])

wp_xs = wp.from_paddle(xs)
wp_l = wp.from_paddle(l)

tape = wp.Tape()
with tape:
# record the loss function kernel launch on the tape
wp.launch(loss, dim=len(xs), inputs=[wp_xs], outputs=[wp_l], device=wp_xs.device)

for i in range(500):
tape.zero()
tape.backward(loss=wp_l) # compute gradients
# now xs.grad will be populated with the gradients computed by Warp
opt.step() # update xs (and thereby wp_xs)

# these lines are only needed for evaluating the loss
# (the optimization just needs the gradient, not the loss value)
wp_l.zero_()
wp.launch(loss, dim=len(xs), inputs=[wp_xs], outputs=[wp_l], device=wp_xs.device)
print(f"{i}\tloss: {l.item()}")

Example: Optimization using ``warp.to_paddle``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Less code is needed when we declare the optimization variables directly in Warp and use :func:`warp.to_paddle` to convert them to Paddle tensors.
Here, we revisit the same example from above where now only a single conversion to a paddle tensor is needed to supply Adam with the optimization variables::

import warp as wp
import numpy as np
import paddle

# init warp context at beginning
wp.context.init()

@wp.kernel()
def loss(xs: wp.array(dtype=float, ndim=2), l: wp.array(dtype=float)):
tid = wp.tid()
wp.atomic_add(l, 0, xs[tid, 0] ** 2.0 + xs[tid, 1] ** 2.0)

# initialize the optimization variables in Warp
xs = wp.array(np.random.randn(100, 2), dtype=wp.float32, requires_grad=True)
l = wp.zeros(1, dtype=wp.float32, requires_grad=True)
# just a single wp.to_paddle call is needed, Adam optimizes using the Warp array gradients
opt = paddle.optimizer.Adam(learning_rate=0.1, parameters=[wp.to_paddle(xs)])

tape = wp.Tape()
with tape:
wp.launch(loss, dim=len(xs), inputs=[xs], outputs=[l], device=xs.device)

for i in range(500):
tape.zero()
tape.backward(loss=l)
opt.step()

l.zero_()
wp.launch(loss, dim=len(xs), inputs=[xs], outputs=[l], device=xs.device)
print(f"{i}\tloss: {l.numpy()[0]}")

Performance Notes
^^^^^^^^^^^^^^^^^

The ``wp.from_paddle()`` function creates a Warp array object that shares data with a Paddle tensor. Although this function does not copy the data, there is always some CPU overhead during the conversion. If these conversions happen frequently, the overall program performance may suffer. As a general rule, it's good to avoid repeated conversions of the same tensor. Instead of:

.. code:: python
x_t = paddle.arange(n, dtype=paddle.float32).to(device=wp.device_to_paddle(device))
y_t = paddle.ones([n], dtype=paddle.float32).to(device=wp.device_to_paddle(device))
for i in range(10):
x_w = wp.from_paddle(x_t)
y_w = wp.from_paddle(y_t)
wp.launch(saxpy, dim=n, inputs=[x_w, y_w, 1.0], device=device)
Try converting the arrays only once and reuse them:

.. code:: python
x_t = paddle.arange(n, dtype=paddle.float32).to(device=wp.device_to_paddle(device))
y_t = paddle.ones([n], dtype=paddle.float32).to(device=wp.device_to_paddle(device))
x_w = wp.from_paddle(x_t)
y_w = wp.from_paddle(y_t)
for i in range(10):
wp.launch(saxpy, dim=n, inputs=[x_w, y_w, 1.0], device=device)
If reusing arrays is not possible (e.g., a new Paddle tensor is constructed on every iteration), passing ``return_ctype=True`` to ``wp.from_paddle()`` should yield faster performance. Setting this argument to True avoids constructing a ``wp.array`` object and instead returns a low-level array descriptor. This descriptor is a simple C structure that can be passed to Warp kernels instead of a ``wp.array``, but cannot be used in other places that require a ``wp.array``.

.. code:: python
for n in range(1, 10):
# get Paddle tensors for this iteration
x_t = paddle.arange(n, dtype=paddle.float32).to(device=wp.device_to_paddle(device))
y_t = paddle.ones([n], dtype=paddle.float32).to(device=wp.device_to_paddle(device))
# get Warp array descriptors
x_ctype = wp.from_paddle(x_t, return_ctype=True)
y_ctype = wp.from_paddle(y_t, return_ctype=True)
wp.launch(saxpy, dim=n, inputs=[x_ctype, y_ctype, 1.0], device=device)
An alternative approach is to pass the Paddle tensors to Warp kernels directly. This avoids constructing temporary Warp arrays by leveraging standard array interfaces (like ``__cuda_array_interface__``) supported by both Paddle and Warp. The main advantage of this approach is convenience, since there is no need to call any conversion functions. The main limitation is that it does not handle gradients, because gradient information is not included in the standard array interfaces. This technique is therefore most suitable for algorithms that do not involve differentiation.

.. code:: python
x = paddle.arange(n, dtype=paddle.float32).to(device=wp.device_to_paddle(device))
y = paddle.ones([n], dtype=paddle.float32).to(device=wp.device_to_paddle(device))
for i in range(10):
wp.launch(saxpy, dim=n, inputs=[x, y, 1.0], device=device)
.. code:: shell
python -m warp.examples.benchmarks.benchmark_interop_paddle
Sample output:

.. code::
13990 ms from_paddle(...)
5990 ms from_paddle(..., return_ctype=True)
35167 ms direct from paddle
The default ``wp.from_paddle()`` conversion is the slowest. Passing ``return_ctype=True`` is the fastest, because it skips creating temporary Warp array objects. Passing Paddle tensors to Warp kernels directly falls somewhere in between. It skips creating temporary Warp arrays, but accessing the ``__cuda_array_interface__`` attributes of Paddle tensors adds overhead because they are initialized on-demand.
1 change: 1 addition & 0 deletions exts/omni.warp.core/config/extension.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pyCoverageOmit = [
"warp/stubs.py",
"warp/jax.py",
"warp/torch.py",
"warp/paddle.py",
"warp/build.py",
"warp/build_dll.py",
"warp/sim/**",
Expand Down
5 changes: 5 additions & 0 deletions warp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@

from warp.dlpack import from_dlpack, to_dlpack

from warp.paddle import from_paddle, to_paddle
from warp.paddle import dtype_from_paddle, dtype_to_paddle
from warp.paddle import device_from_paddle, device_to_paddle
from warp.paddle import stream_from_paddle

from warp.build import clear_kernel_cache

from warp.constants import *
Expand Down
2 changes: 2 additions & 0 deletions warp/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def device_to_dlpack(wp_device) -> DLDevice:


def dtype_to_dlpack(wp_dtype) -> DLDataType:
if wp_dtype == warp.bool:
return (DLDataTypeCode.kDLBool, 8, 1)
if wp_dtype == warp.int8:
return (DLDataTypeCode.kDLInt, 8, 1)
elif wp_dtype == warp.uint8:
Expand Down
2 changes: 2 additions & 0 deletions warp/examples/benchmarks/benchmark.bat
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ python benchmark_cloth.py numpy
@REM python benchmark_cloth.py numba
@REM python benchmark_cloth.py jax_cpu
@REM python benchmark_cloth.py jax_gpu
@REM python benchmark_cloth.py paddle_cpu
@REM python benchmark_cloth.py paddle_gpu
2 changes: 2 additions & 0 deletions warp/examples/benchmarks/benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ python3 benchmark_cloth.py numpy
# python3 benchmark_cloth.py jax_cpu
# python3 benchmark_cloth.py jax_gpu
# python3 benchmark_cloth.py numba
# python3 benchmark_cloth.py paddle_cpu
# python3 benchmark_cloth.py paddle_gpu
10 changes: 10 additions & 0 deletions warp/examples/benchmarks/benchmark_cloth.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,16 @@ def run_benchmark(mode, dim, timers, render=False):

integrator = benchmark_cloth_jax.JxIntegrator(cloth)

elif mode == "paddle_cpu":
import benchmark_cloth_paddle

integrator = benchmark_cloth_paddle.TrIntegrator(cloth, "cpu")

elif mode == "paddle_gpu":
import benchmark_cloth_paddle

integrator = benchmark_cloth_paddle.TrIntegrator(cloth, "gpu")

else:
raise RuntimeError("Unknown simulation backend")

Expand Down
Loading

0 comments on commit a7ecac0

Please sign in to comment.