Skip to content

Commit

Permalink
Simplify binary operation definition for fp16 and bf16
Browse files Browse the repository at this point in the history
  • Loading branch information
stijnh committed Oct 12, 2023
1 parent 283edce commit 3da5ba0
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 38 deletions.
29 changes: 1 addition & 28 deletions include/kernel_float/bf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(fast_sin, ::hsin, ::h2sin)
}; \
}
#else
#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \
namespace ops { \
template<> \
struct NAME<__nv_bfloat16> { \
KERNEL_FLOAT_INLINE __nv_bfloat16 \
operator()(__nv_bfloat16 left, __nv_bfloat16 right) const { \
return __nv_bfloat16(ops::NAME<float> {}(float(left), float(right))); \
} \
}; \
}
#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2)
#endif

KERNEL_FLOAT_BF16_BINARY_FUN(add, __hadd, __hadd2)
Expand All @@ -205,20 +196,6 @@ KERNEL_FLOAT_BF16_BINARY_FUN(greater, __hgt, __hgt2)
KERNEL_FLOAT_BF16_BINARY_FUN(greater_equal, __hge, __hgt2)

namespace ops {
template<typename T>
struct cast<T, __nv_bfloat16> {
KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(T input) {
return __float2bfloat16(ops::cast<T, float> {}(input));
};
};

template<typename T>
struct cast<__nv_bfloat16, T> {
KERNEL_FLOAT_INLINE T operator()(__nv_bfloat16 input) {
return ops::cast<float, T> {}(__bfloat162float(input));
};
};

template<>
struct cast<double, __nv_bfloat16> {
KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(double input) {
Expand Down Expand Up @@ -340,10 +317,6 @@ struct dot_impl<__nv_bfloat16, N> {
#include "fp16.h"

namespace kernel_float {
#if KERNEL_FLOAT_CUDA_ARCH >= 800
KERNEL_FLOAT_BF16_CAST(__half, __float2bfloat16(input), __bfloat162float(input));
#endif

template<>
struct promote_type<__nv_bfloat16, __half> {
using type = float;
Expand Down
2 changes: 1 addition & 1 deletion include/kernel_float/binops.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
template<typename T> \
struct NAME { \
KERNEL_FLOAT_INLINE T operator()(T left, T right) { \
return T(EXPR); \
return ops::cast<decltype(EXPR), T> {}(EXPR); \
} \
}; \
} \
Expand Down
10 changes: 1 addition & 9 deletions include/kernel_float/fp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,7 @@ KERNEL_FLOAT_FP16_UNARY_FUN(fast_sin, ::hsin, ::h2sin)
}; \
}
#else
#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \
namespace ops { \
template<> \
struct NAME<__half> { \
KERNEL_FLOAT_INLINE __half operator()(__half left, __half right) const { \
return __half(ops::NAME<float> {}(float(left), float(right))); \
} \
}; \
}
#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2)
#endif

KERNEL_FLOAT_FP16_BINARY_FUN(add, __hadd, __hadd2)
Expand Down

0 comments on commit 3da5ba0

Please sign in to comment.