Skip to content

Commit 4962be5

Browse files
Update rocPrim usage for ROCm7 (#3100)
Co-authored-by: Jason Furmanek <[email protected]>
1 parent 98fbb59 commit 4962be5

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

tensorflow/core/kernels/gpu_prim.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ namespace gpuprim = ::hipcub;
8282

8383
// Required for sorting Eigen::half and bfloat16.
8484
namespace rocprim {
85+
#if (TF_ROCM_VERSION >= 50200 && TF_ROCM_VERSION < 70000)
8586
namespace detail {
86-
#if (TF_ROCM_VERSION >= 50200)
8787
template <>
8888
struct float_bit_mask<Eigen::half> {
8989
static constexpr uint16_t sign_bit = 0x8000;
@@ -99,14 +99,35 @@ struct float_bit_mask<Eigen::bfloat16> {
9999
static constexpr uint16_t mantissa = 0x007F;
100100
using bit_type = uint16_t;
101101
};
102+
}; // namespace detail
103+
104+
#else
105+
namespace traits {
106+
template<>
107+
struct rocprim::traits::define<Eigen::half> {
108+
using float_bit_mask = rocprim::traits::float_bit_mask::values<uint16_t, 0x8000, 0x7C00, 0x03FF>;
109+
using is_arithmetic = rocprim::traits::is_arithmetic::values<true>;
110+
using number_format = rocprim::traits::number_format::values<traits::number_format::kind::floating_point_type>;
111+
};
112+
113+
template<>
114+
struct rocprim::traits::define<tsl::bfloat16> {
115+
using float_bit_mask = rocprim::traits::float_bit_mask::values<uint16_t, 0x8000, 0x7F80, 0x007F>;
116+
using is_arithmetic = rocprim::traits::is_arithmetic::values<true>;
117+
using number_format = rocprim::traits::number_format::values<traits::number_format::kind::floating_point_type>;
118+
};
119+
}; // namespace traits
102120
#endif
121+
#if (TF_ROCM_VERSION < 70000)
122+
namespace detail {
103123
template <>
104124
struct radix_key_codec_base<Eigen::half>
105125
: radix_key_codec_floating<Eigen::half, uint16_t> {};
106126
template <>
107127
struct radix_key_codec_base<tensorflow::bfloat16>
108128
: radix_key_codec_floating<tensorflow::bfloat16, uint16_t> {};
109129
}; // namespace detail
130+
#endif
110131
}; // namespace rocprim
111132

112133
#endif // TENSORFLOW_USE_ROCM

0 commit comments

Comments
 (0)