diff --git a/include/fp16/fp16.h b/include/fp16/fp16.h index ff8c970..efe0a0e 100644 --- a/include/fp16/fp16.h +++ b/include/fp16/fp16.h @@ -15,6 +15,7 @@ #endif #include +#include /* @@ -106,6 +107,13 @@ static inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) { * floating-point operations and bitcasts between integer and floating-point variables. */ static inline float fp16_ieee_to_fp32_value(uint16_t h) { +#if FP16_USE_FP16_TYPE + union { + uint16_t as_bits; + __fp16 as_value; + } fp16 = { h }; + return (float) fp16.as_value; +#else /* * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: * +---+-----+------------+-------------------+ @@ -211,6 +219,7 @@ static inline float fp16_ieee_to_fp32_value(uint16_t h) { const uint32_t result = sign | (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); return fp32_from_bits(result); +#endif } /* @@ -221,6 +230,13 @@ static inline float fp16_ieee_to_fp32_value(uint16_t h) { * floating-point operations and bitcasts between integer and floating-point variables. */ static inline uint16_t fp16_ieee_from_fp32_value(float f) { +#if FP16_USE_FP16_TYPE + union { + __fp16 as_value; + uint16_t as_bits; + } fp16 = { (__fp16) h }; + return fp16.as_bits; +#else #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) const float scale_to_inf = 0x1.0p+112f; const float scale_to_zero = 0x1.0p-110f; @@ -249,6 +265,7 @@ static inline uint16_t fp16_ieee_from_fp32_value(float f) { const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); const uint32_t nonsign = exp_bits + mantissa_bits; return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); +#endif } /* diff --git a/include/fp16/macros.h b/include/fp16/macros.h new file mode 100644 index 0000000..4be4b0b --- /dev/null +++ b/include/fp16/macros.h @@ -0,0 +1,17 @@ +#pragma once +#ifndef FP16_MACROS_H +#define FP16_MACROS_H + + +#ifndef FP16_USE_FP16_TYPE + #if defined(__clang__) + #if defined(__F16C__) || defined(__aarch64__) + #define FP16_USE_FP16_TYPE 1 + #endif + #endif + #if !defined(FP16_USE_FP16_TYPE) + #define FP16_USE_FP16_TYPE 0 + #endif // !defined(FP16_USE_FP16_TYPE) +#endif // !defined(FP16_USE_FP16_TYPE) + +#endif /* FP16_MACROS_H */