From f94bd1068ba605130043a96f395084e168906826 Mon Sep 17 00:00:00 2001 From: stijn Date: Mon, 2 Dec 2024 18:49:12 +0100 Subject: [PATCH] Fix several issues related to HIP compilation for bfloat16 --- include/kernel_float/bf16.h | 30 ++++++++++++++++++++++-- include/kernel_float/binops.h | 6 +++-- include/kernel_float/macros.h | 3 +-- single_include/kernel_float.h | 43 ++++++++++++++++++++++++++++------- 4 files changed, 68 insertions(+), 14 deletions(-) diff --git a/include/kernel_float/bf16.h b/include/kernel_float/bf16.h index 22ea8b8..f89fc3c 100644 --- a/include/kernel_float/bf16.h +++ b/include/kernel_float/bf16.h @@ -60,7 +60,6 @@ struct allow_float_fallback { }; }; // namespace detail -#if KERNEL_FLOAT_BF16_OPS_SUPPORTED #define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ @@ -81,6 +80,7 @@ struct allow_float_fallback { }; \ } +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin) KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos) @@ -101,9 +101,34 @@ KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil) KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint) KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) + +// For some reason, HIP struggles with the functions `__habs` and `__hneg`. We define them here using bitwise ops. +// For CUDA, we can just use the regular bfloat16 functions (see above). +#elif KERNEL_FLOAT_IS_HIP +KERNEL_FLOAT_INLINE __hip_bfloat16 hip_habs(const __hip_bfloat16 a) { + __hip_bfloat16 res = a; + res.data &= 0x7FFF; + return res; +} + +KERNEL_FLOAT_INLINE __hip_bfloat16 hip_hneg(const __hip_bfloat16 a) { + __hip_bfloat16 res = a; + res.data ^= 0x8000; + return res; +} + +KERNEL_FLOAT_INLINE __hip_bfloat162 hip_habs2(const __hip_bfloat162 a) { + return {hip_habs(a.x), hip_habs(a.y)}; +} + +KERNEL_FLOAT_INLINE __hip_bfloat162 hip_hneg2(const __hip_bfloat162 a) { + return {hip_hneg(a.x), hip_hneg(a.y)}; +} + +KERNEL_FLOAT_BF16_UNARY_FUN(abs, hip_habs, hip_habs2) +KERNEL_FLOAT_BF16_UNARY_FUN(negate, hip_hneg, hip_hneg2) #endif -#if KERNEL_FLOAT_BF16_OPS_SUPPORTED #define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ @@ -133,6 +158,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) }; \ } +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED KERNEL_FLOAT_BF16_BINARY_FUN(add, __hadd, __hadd2) KERNEL_FLOAT_BF16_BINARY_FUN(subtract, __hsub, __hsub2) KERNEL_FLOAT_BF16_BINARY_FUN(multiply, __hmul, __hmul2) diff --git a/include/kernel_float/binops.h b/include/kernel_float/binops.h index 75de26a..2e4c149 100644 --- a/include/kernel_float/binops.h +++ b/include/kernel_float/binops.h @@ -189,14 +189,16 @@ namespace ops { template struct min { KERNEL_FLOAT_INLINE T operator()(T left, T right) { - return left < right ? left : right; + auto cond = less {}(left, right); + return cast {}(cond) ? left : right; } }; template struct max { KERNEL_FLOAT_INLINE T operator()(T left, T right) { - return left > right ? left : right; + auto cond = greater {}(left, right); + return cast {}(cond) ? left : right; } }; diff --git a/include/kernel_float/macros.h b/include/kernel_float/macros.h index 68be6e5..88bbdbc 100644 --- a/include/kernel_float/macros.h +++ b/include/kernel_float/macros.h @@ -20,12 +20,11 @@ #elif defined(__HIPCC__) #define KERNEL_FLOAT_IS_HIP (1) #define KERNEL_FLOAT_DEVICE __attribute__((always_inline)) __device__ + #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ #ifdef __HIP_DEVICE_COMPILE__ - #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ #define KERNEL_FLOAT_IS_DEVICE (1) #else - #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ #define KERNEL_FLOAT_IS_HOST (1) #endif diff --git a/single_include/kernel_float.h b/single_include/kernel_float.h index f6b2493..c77c7e6 100644 --- a/single_include/kernel_float.h +++ b/single_include/kernel_float.h @@ -16,8 +16,8 @@ //================================================================================ // this file has been auto-generated, do not modify its contents! -// date: 2024-12-02 10:59:19.296684 -// git hash: a2b08a56e31d1c9a6302c8a49c740cf56fcc1607 +// date: 2024-12-02 18:48:50.243676 +// git hash: 846de1f9aefaef76da15ebb5474080d531efaf38 //================================================================================ #ifndef KERNEL_FLOAT_MACROS_H @@ -42,12 +42,11 @@ #elif defined(__HIPCC__) #define KERNEL_FLOAT_IS_HIP (1) #define KERNEL_FLOAT_DEVICE __attribute__((always_inline)) __device__ + #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ #ifdef __HIP_DEVICE_COMPILE__ - #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ #define KERNEL_FLOAT_IS_DEVICE (1) #else - #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ #define KERNEL_FLOAT_IS_HOST (1) #endif @@ -1875,14 +1874,16 @@ namespace ops { template struct min { KERNEL_FLOAT_INLINE T operator()(T left, T right) { - return left < right ? left : right; + auto cond = less {}(left, right); + return cast {}(cond) ? left : right; } }; template struct max { KERNEL_FLOAT_INLINE T operator()(T left, T right) { - return left > right ? left : right; + auto cond = greater {}(left, right); + return cast {}(cond) ? left : right; } }; @@ -4307,7 +4308,6 @@ struct allow_float_fallback { }; }; // namespace detail -#if KERNEL_FLOAT_BF16_OPS_SUPPORTED #define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ @@ -4328,6 +4328,7 @@ struct allow_float_fallback { }; \ } +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin) KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos) @@ -4348,9 +4349,34 @@ KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil) KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint) KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) + +// For some reason, HIP struggles with the functions `__habs` and `__hneg`. We define them here using bitwise ops. +// For CUDA, we can just use the regular bfloat16 functions (see above). +#elif KERNEL_FLOAT_IS_HIP +KERNEL_FLOAT_INLINE __hip_bfloat16 hip_habs(const __hip_bfloat16 a) { + __hip_bfloat16 res = a; + res.data &= 0x7FFF; + return res; +} + +KERNEL_FLOAT_INLINE __hip_bfloat16 hip_hneg(const __hip_bfloat16 a) { + __hip_bfloat16 res = a; + res.data ^= 0x8000; + return res; +} + +KERNEL_FLOAT_INLINE __hip_bfloat162 hip_habs2(const __hip_bfloat162 a) { + return {hip_habs(a.x), hip_habs(a.y)}; +} + +KERNEL_FLOAT_INLINE __hip_bfloat162 hip_hneg2(const __hip_bfloat162 a) { + return {hip_hneg(a.x), hip_hneg(a.y)}; +} + +KERNEL_FLOAT_BF16_UNARY_FUN(abs, hip_habs, hip_habs2) +KERNEL_FLOAT_BF16_UNARY_FUN(negate, hip_hneg, hip_hneg2) #endif -#if KERNEL_FLOAT_BF16_OPS_SUPPORTED #define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ @@ -4380,6 +4406,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) }; \ } +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED KERNEL_FLOAT_BF16_BINARY_FUN(add, __hadd, __hadd2) KERNEL_FLOAT_BF16_BINARY_FUN(subtract, __hsub, __hsub2) KERNEL_FLOAT_BF16_BINARY_FUN(multiply, __hmul, __hmul2)