@@ -82,8 +82,8 @@ namespace gpuprim = ::hipcub;
8282
8383// Required for sorting Eigen::half and bfloat16.
8484namespace rocprim {
85+ #if (TF_ROCM_VERSION >= 50200 && TF_ROCM_VERSION < 70000)
8586namespace detail {
86- #if (TF_ROCM_VERSION >= 50200)
8787template <>
8888struct float_bit_mask <Eigen::half> {
8989 static constexpr uint16_t sign_bit = 0x8000 ;
@@ -99,14 +99,35 @@ struct float_bit_mask<Eigen::bfloat16> {
9999 static constexpr uint16_t mantissa = 0x007F ;
100100 using bit_type = uint16_t ;
101101};
102+ }; // namespace detail
103+
104+ #else
105+ namespace traits {
106+ template <>
107+ struct rocprim ::traits::define<Eigen::half> {
108+ using float_bit_mask = rocprim::traits::float_bit_mask::values<uint16_t , 0x8000 , 0x7C00 , 0x03FF >;
109+ using is_arithmetic = rocprim::traits::is_arithmetic::values<true >;
110+ using number_format = rocprim::traits::number_format::values<traits::number_format::kind::floating_point_type>;
111+ };
112+
113+ template <>
114+ struct rocprim ::traits::define<tsl::bfloat16> {
115+ using float_bit_mask = rocprim::traits::float_bit_mask::values<uint16_t , 0x8000 , 0x7F80 , 0x007F >;
116+ using is_arithmetic = rocprim::traits::is_arithmetic::values<true >;
117+ using number_format = rocprim::traits::number_format::values<traits::number_format::kind::floating_point_type>;
118+ };
119+ }; // namespace traits
102120#endif
121+ #if (TF_ROCM_VERSION < 70000)
122+ namespace detail {
103123template <>
104124struct radix_key_codec_base <Eigen::half>
105125 : radix_key_codec_floating<Eigen::half, uint16_t > {};
106126template <>
107127struct radix_key_codec_base <tensorflow::bfloat16>
108128 : radix_key_codec_floating<tensorflow::bfloat16, uint16_t > {};
109129}; // namespace detail
130+ #endif
110131}; // namespace rocprim
111132
112133#endif // TENSORFLOW_USE_ROCM
0 commit comments