Skip to content

Commit

Permalink
Merge branch 'ccrouzet/vec-ops' into 'main'
Browse files Browse the repository at this point in the history
Add `wp.abs()`, `wp.clamp()`, and `wp.sign()` for Vectors

See merge request omniverse/warp!589
  • Loading branch information
christophercrouzet committed Jul 18, 2024
2 parents 2d4209c + ddaaa04 commit 3d9c55d
Show file tree
Hide file tree
Showing 7 changed files with 384 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
- Add documentation for dynamic loop autograd limitations
- Conform to Python's syntax for function arguments when calling built-ins inside of kernels, thus extending support for keyword arguments
- Implement the assignment operator for `wp.quat`
- Add `wp.abs()`, `wp.clamp()`, and `wp.sign()` for vectors

## [1.2.1] - 2024-06-14

Expand Down
21 changes: 21 additions & 0 deletions docs/modules/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,37 @@ Scalar Math
Clamp the value of ``x`` to the range [a, b].


.. py:function:: clamp(a: Vector[Any,Scalar], low: Vector[Any,Scalar], high: Vector[Any,Scalar]) -> Vector[Any,Scalar]
:noindex:
:nocontentsentry:

Clamp the elements of ``a`` to the elements from the range [low, high].


.. py:function:: abs(x: Scalar) -> Scalar
Return the absolute value of ``x``.


.. py:function:: abs(a: Vector[Any,Scalar]) -> Vector[Any,Scalar]
:noindex:
:nocontentsentry:

Return the absolute values of the elements of ``a``.


.. py:function:: sign(x: Scalar) -> Scalar
Return -1 if ``x`` < 0, return 1 otherwise.


.. py:function:: sign(a: Vector[Any,Scalar]) -> Scalar
:noindex:
:nocontentsentry:

Return -1 for the negative elements of ``a``, and 1 otherwise.


.. py:function:: step(x: Scalar) -> Scalar
Return 1.0 if ``x`` < 0.0, return 0.0 otherwise.
Expand Down
31 changes: 31 additions & 0 deletions warp/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,37 @@ def float_sametypes_value_func(arg_types: Mapping[str, type], arg_values: Mappin
missing_grad=True,
)

add_builtin(
"clamp",
input_types={
"a": vector(length=Any, dtype=Scalar),
"low": vector(length=Any, dtype=Scalar),
"high": vector(length=Any, dtype=Scalar),
},
constraint=sametypes,
value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
doc="Clamp the elements of ``a`` to the elements from the range [low, high].",
group="Vector Math",
)

add_builtin(
"abs",
input_types={"a": vector(length=Any, dtype=Scalar)},
constraint=sametypes,
value_func=sametypes_create_value_func(vector(length=Any, dtype=Scalar)),
doc="Return the absolute values of the elements of ``a``.",
group="Vector Math",
)

add_builtin(
"sign",
input_types={"a": vector(length=Any, dtype=Scalar)},
constraint=sametypes,
value_func=sametypes_create_value_func(Scalar),
doc="Return -1 for the negative elements of ``a``, and 1 otherwise.",
group="Vector Math",
)


def outer_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
if arg_types is None:
Expand Down
108 changes: 108 additions & 0 deletions warp/native/exports.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,42 @@ WP_API void builtin_clamp_uint16_uint16_uint16(uint16 x, uint16 a, uint16 b, uin
WP_API void builtin_clamp_uint32_uint32_uint32(uint32 x, uint32 a, uint32 b, uint32* ret) { *ret = wp::clamp(x, a, b); }
WP_API void builtin_clamp_uint64_uint64_uint64(uint64 x, uint64 a, uint64 b, uint64* ret) { *ret = wp::clamp(x, a, b); }
WP_API void builtin_clamp_uint8_uint8_uint8(uint8 x, uint8 a, uint8 b, uint8* ret) { *ret = wp::clamp(x, a, b); }
WP_API void builtin_clamp_vec2h_vec2h_vec2h(vec2h& a, vec2h& low, vec2h& high, vec2h* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec3h_vec3h_vec3h(vec3h& a, vec3h& low, vec3h& high, vec3h* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec4h_vec4h_vec4h(vec4h& a, vec4h& low, vec4h& high, vec4h* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_spatial_vectorh_spatial_vectorh_spatial_vectorh(spatial_vectorh& a, spatial_vectorh& low, spatial_vectorh& high, spatial_vectorh* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec2f_vec2f_vec2f(vec2f& a, vec2f& low, vec2f& high, vec2f* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec3f_vec3f_vec3f(vec3f& a, vec3f& low, vec3f& high, vec3f* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec4f_vec4f_vec4f(vec4f& a, vec4f& low, vec4f& high, vec4f* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_spatial_vectorf_spatial_vectorf_spatial_vectorf(spatial_vectorf& a, spatial_vectorf& low, spatial_vectorf& high, spatial_vectorf* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec2d_vec2d_vec2d(vec2d& a, vec2d& low, vec2d& high, vec2d* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec3d_vec3d_vec3d(vec3d& a, vec3d& low, vec3d& high, vec3d* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec4d_vec4d_vec4d(vec4d& a, vec4d& low, vec4d& high, vec4d* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_spatial_vectord_spatial_vectord_spatial_vectord(spatial_vectord& a, spatial_vectord& low, spatial_vectord& high, spatial_vectord* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec2s_vec2s_vec2s(vec2s& a, vec2s& low, vec2s& high, vec2s* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec3s_vec3s_vec3s(vec3s& a, vec3s& low, vec3s& high, vec3s* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec4s_vec4s_vec4s(vec4s& a, vec4s& low, vec4s& high, vec4s* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec2i_vec2i_vec2i(vec2i& a, vec2i& low, vec2i& high, vec2i* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec3i_vec3i_vec3i(vec3i& a, vec3i& low, vec3i& high, vec3i* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec4i_vec4i_vec4i(vec4i& a, vec4i& low, vec4i& high, vec4i* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec2l_vec2l_vec2l(vec2l& a, vec2l& low, vec2l& high, vec2l* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec3l_vec3l_vec3l(vec3l& a, vec3l& low, vec3l& high, vec3l* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec4l_vec4l_vec4l(vec4l& a, vec4l& low, vec4l& high, vec4l* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec2b_vec2b_vec2b(vec2b& a, vec2b& low, vec2b& high, vec2b* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec3b_vec3b_vec3b(vec3b& a, vec3b& low, vec3b& high, vec3b* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec4b_vec4b_vec4b(vec4b& a, vec4b& low, vec4b& high, vec4b* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec2us_vec2us_vec2us(vec2us& a, vec2us& low, vec2us& high, vec2us* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec3us_vec3us_vec3us(vec3us& a, vec3us& low, vec3us& high, vec3us* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec4us_vec4us_vec4us(vec4us& a, vec4us& low, vec4us& high, vec4us* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec2ui_vec2ui_vec2ui(vec2ui& a, vec2ui& low, vec2ui& high, vec2ui* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec3ui_vec3ui_vec3ui(vec3ui& a, vec3ui& low, vec3ui& high, vec3ui* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec4ui_vec4ui_vec4ui(vec4ui& a, vec4ui& low, vec4ui& high, vec4ui* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec2ul_vec2ul_vec2ul(vec2ul& a, vec2ul& low, vec2ul& high, vec2ul* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec3ul_vec3ul_vec3ul(vec3ul& a, vec3ul& low, vec3ul& high, vec3ul* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec4ul_vec4ul_vec4ul(vec4ul& a, vec4ul& low, vec4ul& high, vec4ul* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec2ub_vec2ub_vec2ub(vec2ub& a, vec2ub& low, vec2ub& high, vec2ub* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec3ub_vec3ub_vec3ub(vec3ub& a, vec3ub& low, vec3ub& high, vec3ub* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_clamp_vec4ub_vec4ub_vec4ub(vec4ub& a, vec4ub& low, vec4ub& high, vec4ub* ret) { *ret = wp::clamp(a, low, high); }
WP_API void builtin_abs_float16(float16 x, float16* ret) { *ret = wp::abs(x); }
WP_API void builtin_abs_float32(float32 x, float32* ret) { *ret = wp::abs(x); }
WP_API void builtin_abs_float64(float64 x, float64* ret) { *ret = wp::abs(x); }
Expand All @@ -190,6 +226,42 @@ WP_API void builtin_abs_uint16(uint16 x, uint16* ret) { *ret = wp::abs(x); }
WP_API void builtin_abs_uint32(uint32 x, uint32* ret) { *ret = wp::abs(x); }
WP_API void builtin_abs_uint64(uint64 x, uint64* ret) { *ret = wp::abs(x); }
WP_API void builtin_abs_uint8(uint8 x, uint8* ret) { *ret = wp::abs(x); }
WP_API void builtin_abs_vec2h(vec2h& a, vec2h* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec3h(vec3h& a, vec3h* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec4h(vec4h& a, vec4h* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_spatial_vectorh(spatial_vectorh& a, spatial_vectorh* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec2f(vec2f& a, vec2f* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec3f(vec3f& a, vec3f* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec4f(vec4f& a, vec4f* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_spatial_vectorf(spatial_vectorf& a, spatial_vectorf* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec2d(vec2d& a, vec2d* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec3d(vec3d& a, vec3d* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec4d(vec4d& a, vec4d* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_spatial_vectord(spatial_vectord& a, spatial_vectord* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec2s(vec2s& a, vec2s* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec3s(vec3s& a, vec3s* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec4s(vec4s& a, vec4s* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec2i(vec2i& a, vec2i* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec3i(vec3i& a, vec3i* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec4i(vec4i& a, vec4i* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec2l(vec2l& a, vec2l* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec3l(vec3l& a, vec3l* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec4l(vec4l& a, vec4l* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec2b(vec2b& a, vec2b* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec3b(vec3b& a, vec3b* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec4b(vec4b& a, vec4b* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec2us(vec2us& a, vec2us* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec3us(vec3us& a, vec3us* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec4us(vec4us& a, vec4us* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec2ui(vec2ui& a, vec2ui* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec3ui(vec3ui& a, vec3ui* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec4ui(vec4ui& a, vec4ui* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec2ul(vec2ul& a, vec2ul* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec3ul(vec3ul& a, vec3ul* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec4ul(vec4ul& a, vec4ul* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec2ub(vec2ub& a, vec2ub* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec3ub(vec3ub& a, vec3ub* ret) { *ret = wp::abs(a); }
WP_API void builtin_abs_vec4ub(vec4ub& a, vec4ub* ret) { *ret = wp::abs(a); }
WP_API void builtin_sign_float16(float16 x, float16* ret) { *ret = wp::sign(x); }
WP_API void builtin_sign_float32(float32 x, float32* ret) { *ret = wp::sign(x); }
WP_API void builtin_sign_float64(float64 x, float64* ret) { *ret = wp::sign(x); }
Expand All @@ -201,6 +273,42 @@ WP_API void builtin_sign_uint16(uint16 x, uint16* ret) { *ret = wp::sign(x); }
WP_API void builtin_sign_uint32(uint32 x, uint32* ret) { *ret = wp::sign(x); }
WP_API void builtin_sign_uint64(uint64 x, uint64* ret) { *ret = wp::sign(x); }
WP_API void builtin_sign_uint8(uint8 x, uint8* ret) { *ret = wp::sign(x); }
WP_API void builtin_sign_vec2h(vec2h& a, vec2h* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec3h(vec3h& a, vec3h* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec4h(vec4h& a, vec4h* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_spatial_vectorh(spatial_vectorh& a, spatial_vectorh* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec2f(vec2f& a, vec2f* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec3f(vec3f& a, vec3f* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec4f(vec4f& a, vec4f* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_spatial_vectorf(spatial_vectorf& a, spatial_vectorf* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec2d(vec2d& a, vec2d* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec3d(vec3d& a, vec3d* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec4d(vec4d& a, vec4d* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_spatial_vectord(spatial_vectord& a, spatial_vectord* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec2s(vec2s& a, vec2s* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec3s(vec3s& a, vec3s* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec4s(vec4s& a, vec4s* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec2i(vec2i& a, vec2i* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec3i(vec3i& a, vec3i* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec4i(vec4i& a, vec4i* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec2l(vec2l& a, vec2l* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec3l(vec3l& a, vec3l* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec4l(vec4l& a, vec4l* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec2b(vec2b& a, vec2b* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec3b(vec3b& a, vec3b* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec4b(vec4b& a, vec4b* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec2us(vec2us& a, vec2us* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec3us(vec3us& a, vec3us* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec4us(vec4us& a, vec4us* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec2ui(vec2ui& a, vec2ui* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec3ui(vec3ui& a, vec3ui* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec4ui(vec4ui& a, vec4ui* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec2ul(vec2ul& a, vec2ul* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec3ul(vec3ul& a, vec3ul* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec4ul(vec4ul& a, vec4ul* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec2ub(vec2ub& a, vec2ub* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec3ub(vec3ub& a, vec3ub* ret) { *ret = wp::sign(a); }
WP_API void builtin_sign_vec4ub(vec4ub& a, vec4ub* ret) { *ret = wp::sign(a); }
WP_API void builtin_step_float16(float16 x, float16* ret) { *ret = wp::step(x); }
WP_API void builtin_step_float32(float32 x, float32* ret) { *ret = wp::step(x); }
WP_API void builtin_step_float64(float64 x, float64* ret) { *ret = wp::step(x); }
Expand Down
100 changes: 100 additions & 0 deletions warp/native/vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,42 @@ inline CUDA_CALLABLE unsigned argmax(vec_t<Length,Type> v)
return ret;
}

template<unsigned Length, typename Type>
inline CUDA_CALLABLE vec_t<Length,Type> clamp(vec_t<Length,Type> v, vec_t<Length,Type> a, vec_t<Length,Type> b)
{
vec_t<Length,Type> ret;
for (unsigned i=0; i < Length; ++i)
{
ret[i] = v[i] < a[i] ? a[i] : v[i] > b[i] ? b[i] : v[i];
}

return ret;
}

template<unsigned Length, typename Type>
inline CUDA_CALLABLE vec_t<Length,Type> abs(vec_t<Length,Type> v)
{
vec_t<Length,Type> ret;
for (unsigned i=0; i < Length; ++i)
{
ret[i] = abs(v[i]);
}

return ret;
}

template<unsigned Length, typename Type>
inline CUDA_CALLABLE vec_t<Length,Type> sign(vec_t<Length,Type> v)
{
vec_t<Length,Type> ret;
for (unsigned i=0; i < Length; ++i)
{
ret[i] = v[i] < Type(0) ? Type(-1) : Type(1);
}

return ret;
}

template<unsigned Length, typename Type>
inline CUDA_CALLABLE void expect_near(const vec_t<Length, Type>& actual, const vec_t<Length, Type>& expected, const Type& tolerance)
{
Expand Down Expand Up @@ -1046,6 +1082,70 @@ inline CUDA_CALLABLE void adj_max(const vec_t<Length,Type> &v, vec_t<Length,Type
adj_v[i] += adj_ret;
}

template<unsigned Length, typename Type>
inline CUDA_CALLABLE void adj_clamp(
const vec_t<Length,Type>& v, const vec_t<Length,Type>& a, const vec_t<Length,Type>& b,
vec_t<Length,Type>& adj_v, vec_t<Length,Type>& adj_a, vec_t<Length,Type>& adj_b,
const vec_t<Length,Type>& adj_ret
)
{
for (unsigned i=0; i < Length; ++i)
{
if (v[i] < a[i])
{
adj_a[i] += adj_ret[i];
}
else if (v[i] > b[i])
{
adj_b[i] += adj_ret[i];
}
else
{
adj_v[i] += adj_ret[i];
}
}
}

template<unsigned Length, typename Type>
inline CUDA_CALLABLE void adj_abs(
const vec_t<Length,Type>& v,
vec_t<Length,Type>& adj_v,
const vec_t<Length,Type>& adj_ret
)
{
for (unsigned i=0; i < Length; ++i)
{
if (v[i] < Type(0))
{
adj_v[i] -= adj_ret[i];
}
else
{
adj_v[i] += adj_ret[i];
}
}
}

template<unsigned Length, typename Type>
inline CUDA_CALLABLE void adj_sign(
const vec_t<Length,Type>& v,
vec_t<Length,Type>& adj_v,
const vec_t<Length,Type>& adj_ret
)
{
for (unsigned i=0; i < Length; ++i)
{
if (v[i] < Type(0))
{
adj_v[i] -= adj_ret[i];
}
else
{
adj_v[i] += adj_ret[i];
}
}
}

// Do I need to specialize these for different lengths?
template<unsigned Length, typename Type>
inline CUDA_CALLABLE vec_t<Length, Type> atomic_add(vec_t<Length, Type> * addr, vec_t<Length, Type> value)
Expand Down
18 changes: 18 additions & 0 deletions warp/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,36 @@ def clamp(x: Scalar, a: Scalar, b: Scalar) -> Scalar:
...


@over
def clamp(a: Vector[Any, Scalar], low: Vector[Any, Scalar], high: Vector[Any, Scalar]) -> Vector[Any, Scalar]:
"""Clamp the elements of ``a`` to the elements from the range [low, high]."""
...


@over
def abs(x: Scalar) -> Scalar:
"""Return the absolute value of ``x``."""
...


@over
def abs(a: Vector[Any, Scalar]) -> Vector[Any, Scalar]:
"""Return the absolute values of the elements of ``a``."""
...


@over
def sign(x: Scalar) -> Scalar:
"""Return -1 if ``x`` < 0, return 1 otherwise."""
...


@over
def sign(a: Vector[Any, Scalar]) -> Scalar:
"""Return -1 for the negative elements of ``a``, and 1 otherwise."""
...


@over
def step(x: Scalar) -> Scalar:
"""Return 1.0 if ``x`` < 0.0, return 0.0 otherwise."""
Expand Down
Loading

0 comments on commit 3d9c55d

Please sign in to comment.