Skip to content

Commit

Permalink
Simplify thrust invocation.
Browse files Browse the repository at this point in the history
  • Loading branch information
bashimao committed May 18, 2023
1 parent a0de3ac commit a3b4ab4
Showing 1 changed file with 7 additions and 13 deletions.
20 changes: 7 additions & 13 deletions include/merlin_hashtable.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,6 @@ class HashTable {
using DeviceMemoryPool = MemoryPool<DeviceAllocator<char>>;
using HostMemoryPool = MemoryPool<HostAllocator<char>>;

#if THRUST_VERSION >= 101600
static constexpr auto thrust_par = thrust::cuda::par_nosync;
#else
static constexpr auto thrust_par = thrust::cuda::par;
#endif

public:
/**
* @brief Default constructor for the hash table class.
Expand Down Expand Up @@ -323,7 +317,7 @@ class HashTable {
reinterpret_cast<uintptr_t*>(d_dst));
thrust::device_ptr<int> d_src_offset_ptr(d_src_offset);

thrust::sort_by_key(thrust_par.on(stream), d_dst_ptr, d_dst_ptr + n,
thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), d_dst_ptr, d_dst_ptr + n,
d_src_offset_ptr, thrust::less<uintptr_t>());
}

Expand Down Expand Up @@ -557,7 +551,7 @@ class HashTable {
thrust::device_ptr<uintptr_t> dst_ptr(reinterpret_cast<uintptr_t*>(dst));
thrust::device_ptr<int> src_offset_ptr(src_offset);

thrust::sort_by_key(thrust_par.on(stream), dst_ptr, dst_ptr + n,
thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), dst_ptr, dst_ptr + n,
src_offset_ptr, thrust::less<uintptr_t>());
}

Expand Down Expand Up @@ -651,7 +645,7 @@ class HashTable {
reinterpret_cast<uintptr_t*>(d_table_value_addrs));
thrust::device_ptr<int> param_key_index_ptr(param_key_index);

thrust::sort_by_key(thrust_par.on(stream), table_value_ptr,
thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), table_value_ptr,
table_value_ptr + n, param_key_index_ptr,
thrust::less<uintptr_t>());
}
Expand Down Expand Up @@ -821,7 +815,7 @@ class HashTable {
reinterpret_cast<uintptr_t*>(d_dst));
thrust::device_ptr<int> d_src_offset_ptr(d_src_offset);

thrust::sort_by_key(thrust_par.on(stream), d_dst_ptr, d_dst_ptr + n,
thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), d_dst_ptr, d_dst_ptr + n,
d_src_offset_ptr, thrust::less<uintptr_t>());
}

Expand Down Expand Up @@ -922,7 +916,7 @@ class HashTable {
reinterpret_cast<uintptr_t*>(src));
thrust::device_ptr<int> dst_offset_ptr(dst_offset);

thrust::sort_by_key(thrust_par.on(stream), src_ptr, src_ptr + n,
thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), src_ptr, src_ptr + n,
dst_offset_ptr, thrust::less<uintptr_t>());
}

Expand Down Expand Up @@ -1275,7 +1269,7 @@ class HashTable {

for (size_type start_i = 0; start_i < N; start_i += step) {
size_type end_i = std::min(start_i + step, N);
h_size += thrust::reduce(thrust_par.on(stream), size_ptr + start_i,
h_size += thrust::reduce(thrust::cuda::par_nosync.on(stream), size_ptr + start_i,
size_ptr + end_i, 0, thrust::plus<int>());
}

Expand Down Expand Up @@ -1589,7 +1583,7 @@ class HashTable {

thrust::device_ptr<int> size_ptr(table_->buckets_size);

int size = thrust::reduce(thrust_par.on(stream), size_ptr, size_ptr + N, 0,
int size = thrust::reduce(thrust::cuda::par_nosync.on(stream), size_ptr, size_ptr + N, 0,
thrust::plus<int>());

CudaCheckError();
Expand Down

0 comments on commit a3b4ab4

Please sign in to comment.