diff --git a/include/merlin_hashtable.cuh b/include/merlin_hashtable.cuh index 77b8867cc..9967a3863 100644 --- a/include/merlin_hashtable.cuh +++ b/include/merlin_hashtable.cuh @@ -153,12 +153,6 @@ class HashTable { using DeviceMemoryPool = MemoryPool>; using HostMemoryPool = MemoryPool>; -#if THRUST_VERSION >= 101600 - static constexpr auto thrust_par = thrust::cuda::par_nosync; -#else - static constexpr auto thrust_par = thrust::cuda::par; -#endif - public: /** * @brief Default constructor for the hash table class. @@ -323,7 +317,7 @@ class HashTable { reinterpret_cast(d_dst)); thrust::device_ptr d_src_offset_ptr(d_src_offset); - thrust::sort_by_key(thrust_par.on(stream), d_dst_ptr, d_dst_ptr + n, + thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), d_dst_ptr, d_dst_ptr + n, d_src_offset_ptr, thrust::less()); } @@ -557,7 +551,7 @@ class HashTable { thrust::device_ptr dst_ptr(reinterpret_cast(dst)); thrust::device_ptr src_offset_ptr(src_offset); - thrust::sort_by_key(thrust_par.on(stream), dst_ptr, dst_ptr + n, + thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), dst_ptr, dst_ptr + n, src_offset_ptr, thrust::less()); } @@ -651,7 +645,7 @@ class HashTable { reinterpret_cast(d_table_value_addrs)); thrust::device_ptr param_key_index_ptr(param_key_index); - thrust::sort_by_key(thrust_par.on(stream), table_value_ptr, + thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), table_value_ptr, table_value_ptr + n, param_key_index_ptr, thrust::less()); } @@ -821,7 +815,7 @@ class HashTable { reinterpret_cast(d_dst)); thrust::device_ptr d_src_offset_ptr(d_src_offset); - thrust::sort_by_key(thrust_par.on(stream), d_dst_ptr, d_dst_ptr + n, + thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), d_dst_ptr, d_dst_ptr + n, d_src_offset_ptr, thrust::less()); } @@ -922,7 +916,7 @@ class HashTable { reinterpret_cast(src)); thrust::device_ptr dst_offset_ptr(dst_offset); - thrust::sort_by_key(thrust_par.on(stream), src_ptr, src_ptr + n, + thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), src_ptr, src_ptr + n, dst_offset_ptr, thrust::less()); } @@ -1275,7 +1269,7 @@ class HashTable { for (size_type start_i = 0; start_i < N; start_i += step) { size_type end_i = std::min(start_i + step, N); - h_size += thrust::reduce(thrust_par.on(stream), size_ptr + start_i, + h_size += thrust::reduce(thrust::cuda::par_nosync.on(stream), size_ptr + start_i, size_ptr + end_i, 0, thrust::plus()); } @@ -1589,7 +1583,7 @@ class HashTable { thrust::device_ptr size_ptr(table_->buckets_size); - int size = thrust::reduce(thrust_par.on(stream), size_ptr, size_ptr + N, 0, + int size = thrust::reduce(thrust::cuda::par_nosync.on(stream), size_ptr, size_ptr + N, 0, thrust::plus()); CudaCheckError();