Skip to content

Commit

Permalink
Merge branch 'GH-335' into 'main'
Browse files Browse the repository at this point in the history
GH-335: Fix paddle backend bug

Closes GH-335

See merge request omniverse/warp!824
  • Loading branch information
shi-eric committed Oct 31, 2024
2 parents 085d857 + 49acc63 commit f26347d
Show file tree
Hide file tree
Showing 5 changed files with 301 additions and 99 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

### Fixed

- Fix place setting of paddle backend and add 2 missing benchmark files.
- Fix to relax the integer types expected when indexing arrays (regression in 1.3.0).
- Fix printing vector and matrix adjoints in backward kernels.
- Fix kernel compile error when printing structs.
Expand Down
86 changes: 86 additions & 0 deletions warp/examples/benchmarks/benchmark_cloth_paddle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import paddle


def eval_springs(x, v, indices, rest, ke, kd, f):
i = indices[:, 0]
j = indices[:, 1]

xi = x[i]
xj = x[j]

vi = v[i]
vj = v[j]

xij = xi - xj
vij = vi - vj

l = paddle.linalg.norm(xij, axis=1)
l_inv = 1.0 / l

# normalized spring direction
dir = (xij.T * l_inv).T

c = l - rest
dcdt = paddle.sum(dir * vij, axis=1)

# damping based on relative velocity.
fs = dir.T * (ke * c + kd * dcdt)

f.index_add_(axis=0, index=i, value=-fs.T)
f.index_add_(axis=0, index=j, value=fs.T)


def integrate_particles(x, v, f, g, w, dt):
s = w > 0.0

a_ext = g * s[:, None].astype(g.dtype)

# simple semi-implicit Euler. v1 = v0 + a dt, x1 = x0 + v1 dt
v += ((f.T * w).T + a_ext) * dt
x += v * dt

# clear forces
f *= 0.0


class TrIntegrator:
def __init__(self, cloth, device):
self.cloth = cloth

self.positions = paddle.to_tensor(self.cloth.positions, place=device)
self.velocities = paddle.to_tensor(self.cloth.velocities, place=device)
self.inv_mass = paddle.to_tensor(self.cloth.inv_masses, place=device)

self.spring_indices = paddle.to_tensor(self.cloth.spring_indices, dtype=paddle.int64, place=device)
self.spring_lengths = paddle.to_tensor(self.cloth.spring_lengths, place=device)
self.spring_stiffness = paddle.to_tensor(self.cloth.spring_stiffness, place=device)
self.spring_damping = paddle.to_tensor(self.cloth.spring_damping, place=device)

self.forces = paddle.zeros((self.cloth.num_particles, 3), dtype=paddle.float32).to(device=device)
self.gravity = paddle.to_tensor((0.0, 0.0 - 9.8, 0.0), dtype=paddle.float32, place=device)

def simulate(self, dt, substeps):
sim_dt = dt / substeps

for _s in range(substeps):
eval_springs(
self.positions,
self.velocities,
self.spring_indices.reshape((self.cloth.num_springs, 2)),
self.spring_lengths,
self.spring_stiffness,
self.spring_damping,
self.forces,
)

# integrate
integrate_particles(self.positions, self.velocities, self.forces, self.gravity, self.inv_mass, sim_dt)

return self.positions.cpu().numpy()
158 changes: 158 additions & 0 deletions warp/examples/benchmarks/benchmark_interop_paddle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import time

import paddle

import warp as wp


def create_simple_kernel(dtype):
def simple_kernel(
a: wp.array(dtype=dtype),
b: wp.array(dtype=dtype),
c: wp.array(dtype=dtype),
d: wp.array(dtype=dtype),
e: wp.array(dtype=dtype),
):
pass

return wp.Kernel(simple_kernel)


def test_from_paddle(kernel, num_iters, array_size, device, warp_dtype=None):
warp_device = wp.get_device(device)
paddle_device = wp.device_to_paddle(warp_device)

if hasattr(warp_dtype, "_shape_"):
paddle_shape = (array_size, *warp_dtype._shape_)
paddle_dtype = wp.dtype_to_paddle(warp_dtype._wp_scalar_type_)
else:
paddle_shape = (array_size,)
paddle_dtype = paddle.float32 if warp_dtype is None else wp.dtype_to_paddle(warp_dtype)

_a = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
_b = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
_c = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
_d = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
_e = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)

wp.synchronize()

# profiler = Profiler(interval=0.000001)
# profiler.start()

t1 = time.time_ns()

for _ in range(num_iters):
a = wp.from_paddle(_a, dtype=warp_dtype)
b = wp.from_paddle(_b, dtype=warp_dtype)
c = wp.from_paddle(_c, dtype=warp_dtype)
d = wp.from_paddle(_d, dtype=warp_dtype)
e = wp.from_paddle(_e, dtype=warp_dtype)
wp.launch(kernel, dim=array_size, inputs=[a, b, c, d, e])

t2 = time.time_ns()
print(f"{(t2 - t1) / 1_000_000 :8.0f} ms from_paddle(...)")

# profiler.stop()
# profiler.print()


def test_array_ctype_from_paddle(kernel, num_iters, array_size, device, warp_dtype=None):
warp_device = wp.get_device(device)
paddle_device = wp.device_to_paddle(warp_device)

if hasattr(warp_dtype, "_shape_"):
paddle_shape = (array_size, *warp_dtype._shape_)
paddle_dtype = wp.dtype_to_paddle(warp_dtype._wp_scalar_type_)
else:
paddle_shape = (array_size,)
paddle_dtype = paddle.float32 if warp_dtype is None else wp.dtype_to_paddle(warp_dtype)

_a = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
_b = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
_c = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
_d = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
_e = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)

wp.synchronize()

# profiler = Profiler(interval=0.000001)
# profiler.start()

t1 = time.time_ns()

for _ in range(num_iters):
a = wp.from_paddle(_a, dtype=warp_dtype, return_ctype=True)
b = wp.from_paddle(_b, dtype=warp_dtype, return_ctype=True)
c = wp.from_paddle(_c, dtype=warp_dtype, return_ctype=True)
d = wp.from_paddle(_d, dtype=warp_dtype, return_ctype=True)
e = wp.from_paddle(_e, dtype=warp_dtype, return_ctype=True)
wp.launch(kernel, dim=array_size, inputs=[a, b, c, d, e])

t2 = time.time_ns()
print(f"{(t2 - t1) / 1_000_000 :8.0f} ms from_paddle(..., return_ctype=True)")

# profiler.stop()
# profiler.print()


def test_direct_from_paddle(kernel, num_iters, array_size, device, warp_dtype=None):
warp_device = wp.get_device(device)
paddle_device = wp.device_to_paddle(warp_device)

if hasattr(warp_dtype, "_shape_"):
paddle_shape = (array_size, *warp_dtype._shape_)
paddle_dtype = wp.dtype_to_paddle(warp_dtype._wp_scalar_type_)
else:
paddle_shape = (array_size,)
paddle_dtype = paddle.float32 if warp_dtype is None else wp.dtype_to_paddle(warp_dtype)

_a = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
_b = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
_c = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
_d = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
_e = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)

wp.synchronize()

# profiler = Profiler(interval=0.000001)
# profiler.start()

t1 = time.time_ns()

for _ in range(num_iters):
wp.launch(kernel, dim=array_size, inputs=[_a, _b, _c, _d, _e])

t2 = time.time_ns()
print(f"{(t2 - t1) / 1_000_000 :8.0f} ms direct from paddle")

# profiler.stop()
# profiler.print()


wp.init()

params = [
# (warp_dtype arg, kernel)
(None, create_simple_kernel(wp.float32)),
(wp.float32, create_simple_kernel(wp.float32)),
(wp.vec3f, create_simple_kernel(wp.vec3f)),
(wp.mat22f, create_simple_kernel(wp.mat22f)),
]

wp.load_module()

num_iters = 100000

for warp_dtype, kernel in params:
print(f"\ndtype={wp.context.type_str(warp_dtype)}")
test_from_paddle(kernel, num_iters, 10, "cuda:0", warp_dtype=warp_dtype)
test_array_ctype_from_paddle(kernel, num_iters, 10, "cuda:0", warp_dtype=warp_dtype)
test_direct_from_paddle(kernel, num_iters, 10, "cuda:0", warp_dtype=warp_dtype)
41 changes: 29 additions & 12 deletions warp/paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,41 +17,58 @@

if TYPE_CHECKING:
import paddle
from paddle.base.libpaddle import CPUPlace, CUDAPinnedPlace, CUDAPlace, Place


# return the warp device corresponding to a paddle device
def device_from_paddle(paddle_device: Union[paddle.base.libpaddle.Place, str]) -> warp.context.Device:
def device_from_paddle(paddle_device: Union[Place, CPUPlace, CUDAPinnedPlace, CUDAPlace, str]) -> warp.context.Device:
"""Return the Warp device corresponding to a Paddle device.
Args:
paddle_device (`paddle.base.libpaddle.Place` or `str`): Paddle device identifier
paddle_device (`Place`, `CPUPlace`, `CUDAPinnedPlace`, `CUDAPlace`, or `str`): Paddle device identifier
Raises:
RuntimeError: Paddle device does not have a corresponding Warp device
"""
if type(paddle_device) is str:
if paddle_device.startswith("gpu:"):
paddle_device = paddle_device.replace("gpu:", "cuda:")
warp_device = warp.context.runtime.device_map.get(paddle_device)
if warp_device is not None:
return warp_device
elif paddle_device.startswith("gpu"):
elif paddle_device == "gpu":
return warp.context.runtime.get_current_cuda_device()
else:
raise RuntimeError(f"Unsupported Paddle device {paddle_device}")
else:
import paddle

try:
if paddle_device.is_gpu_place():
return warp.context.runtime.cuda_devices[paddle_device.gpu_device_id()]
elif paddle_device.is_cpu_place():
from paddle.base.libpaddle import CPUPlace, CUDAPinnedPlace, CUDAPlace, Place

if isinstance(paddle_device, Place):
if paddle_device.is_gpu_place():
return warp.context.runtime.cuda_devices[paddle_device.gpu_device_id()]
elif paddle_device.is_cpu_place():
return warp.context.runtime.cpu_device
else:
raise RuntimeError(f"Unsupported Paddle device type {paddle_device}")
elif isinstance(paddle_device, (CPUPlace, CUDAPinnedPlace)):
return warp.context.runtime.cpu_device
elif isinstance(paddle_device, CUDAPlace):
return warp.context.runtime.cuda_devices[paddle_device.get_device_id()]
else:
raise RuntimeError(f"Unsupported Paddle device type {paddle_device}")
except ModuleNotFoundError as e:
raise ModuleNotFoundError("Please install paddlepaddle first.") from e
except Exception as e:
import paddle

if not isinstance(paddle_device, paddle.base.libpaddle.Place):
raise ValueError("Argument must be a paddle.base.libpaddle.Place object or a string") from e
if not isinstance(paddle_device, (Place, CPUPlace, CUDAPinnedPlace, CUDAPlace)):
raise TypeError(
"device_from_paddle() received an invalid argument - "
f"got {paddle_device}({type(paddle_device)}), but expected one of:\n"
"* paddle.base.libpaddle.Place\n"
"* paddle.CPUPlace\n"
"* paddle.CUDAPinnedPlace\n"
"* paddle.CUDAPlace or 'gpu' or 'gpu:x'(x means device id)"
) from e
raise


Expand Down
Loading

0 comments on commit f26347d

Please sign in to comment.