Skip to content

Commit

Permalink
Fix several issues related to HIP compilation for bfloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
stijnh committed Dec 2, 2024
1 parent 846de1f commit f94bd10
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 14 deletions.
30 changes: 28 additions & 2 deletions include/kernel_float/bf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ struct allow_float_fallback<bfloat16_t> {
};
}; // namespace detail

#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \
namespace ops { \
template<> \
Expand All @@ -81,6 +80,7 @@ struct allow_float_fallback<bfloat16_t> {
}; \
}

#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin)
KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos)

Expand All @@ -101,9 +101,34 @@ KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil)
KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint)
KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc)
KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)

// For some reason, HIP struggles with the functions `__habs` and `__hneg`. We define them here using bitwise ops.
// For CUDA, we can just use the regular bfloat16 functions (see above).
#elif KERNEL_FLOAT_IS_HIP
KERNEL_FLOAT_INLINE __hip_bfloat16 hip_habs(const __hip_bfloat16 a) {
__hip_bfloat16 res = a;
res.data &= 0x7FFF;
return res;
}

KERNEL_FLOAT_INLINE __hip_bfloat16 hip_hneg(const __hip_bfloat16 a) {
__hip_bfloat16 res = a;
res.data ^= 0x8000;
return res;
}

KERNEL_FLOAT_INLINE __hip_bfloat162 hip_habs2(const __hip_bfloat162 a) {
return {hip_habs(a.x), hip_habs(a.y)};
}

KERNEL_FLOAT_INLINE __hip_bfloat162 hip_hneg2(const __hip_bfloat162 a) {
return {hip_hneg(a.x), hip_hneg(a.y)};
}

KERNEL_FLOAT_BF16_UNARY_FUN(abs, hip_habs, hip_habs2)
KERNEL_FLOAT_BF16_UNARY_FUN(negate, hip_hneg, hip_hneg2)
#endif

#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \
namespace ops { \
template<> \
Expand Down Expand Up @@ -133,6 +158,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)
}; \
}

#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
KERNEL_FLOAT_BF16_BINARY_FUN(add, __hadd, __hadd2)
KERNEL_FLOAT_BF16_BINARY_FUN(subtract, __hsub, __hsub2)
KERNEL_FLOAT_BF16_BINARY_FUN(multiply, __hmul, __hmul2)
Expand Down
6 changes: 4 additions & 2 deletions include/kernel_float/binops.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,16 @@ namespace ops {
template<typename T>
struct min {
KERNEL_FLOAT_INLINE T operator()(T left, T right) {
return left < right ? left : right;
auto cond = less<T> {}(left, right);
return cast<decltype(cond), bool> {}(cond) ? left : right;
}
};

template<typename T>
struct max {
KERNEL_FLOAT_INLINE T operator()(T left, T right) {
return left > right ? left : right;
auto cond = greater<T> {}(left, right);
return cast<decltype(cond), bool> {}(cond) ? left : right;
}
};

Expand Down
3 changes: 1 addition & 2 deletions include/kernel_float/macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@
#elif defined(__HIPCC__)
#define KERNEL_FLOAT_IS_HIP (1)
#define KERNEL_FLOAT_DEVICE __attribute__((always_inline)) __device__
#define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__

#ifdef __HIP_DEVICE_COMPILE__
#define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__
#define KERNEL_FLOAT_IS_DEVICE (1)
#else
#define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__
#define KERNEL_FLOAT_IS_HOST (1)
#endif

Expand Down
43 changes: 35 additions & 8 deletions single_include/kernel_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

//================================================================================
// this file has been auto-generated, do not modify its contents!
// date: 2024-12-02 10:59:19.296684
// git hash: a2b08a56e31d1c9a6302c8a49c740cf56fcc1607
// date: 2024-12-02 18:48:50.243676
// git hash: 846de1f9aefaef76da15ebb5474080d531efaf38
//================================================================================

#ifndef KERNEL_FLOAT_MACROS_H
Expand All @@ -42,12 +42,11 @@
#elif defined(__HIPCC__)
#define KERNEL_FLOAT_IS_HIP (1)
#define KERNEL_FLOAT_DEVICE __attribute__((always_inline)) __device__
#define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__

#ifdef __HIP_DEVICE_COMPILE__
#define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__
#define KERNEL_FLOAT_IS_DEVICE (1)
#else
#define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__
#define KERNEL_FLOAT_IS_HOST (1)
#endif

Expand Down Expand Up @@ -1875,14 +1874,16 @@ namespace ops {
template<typename T>
struct min {
KERNEL_FLOAT_INLINE T operator()(T left, T right) {
return left < right ? left : right;
auto cond = less<T> {}(left, right);
return cast<decltype(cond), bool> {}(cond) ? left : right;
}
};

template<typename T>
struct max {
KERNEL_FLOAT_INLINE T operator()(T left, T right) {
return left > right ? left : right;
auto cond = greater<T> {}(left, right);
return cast<decltype(cond), bool> {}(cond) ? left : right;
}
};

Expand Down Expand Up @@ -4307,7 +4308,6 @@ struct allow_float_fallback<bfloat16_t> {
};
}; // namespace detail

#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \
namespace ops { \
template<> \
Expand All @@ -4328,6 +4328,7 @@ struct allow_float_fallback<bfloat16_t> {
}; \
}

#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin)
KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos)

Expand All @@ -4348,9 +4349,34 @@ KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil)
KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint)
KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc)
KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)

// For some reason, HIP struggles with the functions `__habs` and `__hneg`. We define them here using bitwise ops.
// For CUDA, we can just use the regular bfloat16 functions (see above).
#elif KERNEL_FLOAT_IS_HIP
KERNEL_FLOAT_INLINE __hip_bfloat16 hip_habs(const __hip_bfloat16 a) {
__hip_bfloat16 res = a;
res.data &= 0x7FFF;
return res;
}

KERNEL_FLOAT_INLINE __hip_bfloat16 hip_hneg(const __hip_bfloat16 a) {
__hip_bfloat16 res = a;
res.data ^= 0x8000;
return res;
}

KERNEL_FLOAT_INLINE __hip_bfloat162 hip_habs2(const __hip_bfloat162 a) {
return {hip_habs(a.x), hip_habs(a.y)};
}

KERNEL_FLOAT_INLINE __hip_bfloat162 hip_hneg2(const __hip_bfloat162 a) {
return {hip_hneg(a.x), hip_hneg(a.y)};
}

KERNEL_FLOAT_BF16_UNARY_FUN(abs, hip_habs, hip_habs2)
KERNEL_FLOAT_BF16_UNARY_FUN(negate, hip_hneg, hip_hneg2)
#endif

#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \
namespace ops { \
template<> \
Expand Down Expand Up @@ -4380,6 +4406,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)
}; \
}

#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
KERNEL_FLOAT_BF16_BINARY_FUN(add, __hadd, __hadd2)
KERNEL_FLOAT_BF16_BINARY_FUN(subtract, __hsub, __hsub2)
KERNEL_FLOAT_BF16_BINARY_FUN(multiply, __hmul, __hmul2)
Expand Down

0 comments on commit f94bd10

Please sign in to comment.