diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh index 8351e046c..9e6130f40 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh @@ -59,9 +59,13 @@ transpose_embedding_input( int end_bit = sizeof(KeyT) * 8, \ cudaStream_t stream = 0) +DECL_RADIX_SORT_PAIRS_FN(int64_t, int32_t); +DECL_RADIX_SORT_PAIRS_FN(int64_t, int64_t); DECL_RADIX_SORT_PAIRS_FN(int64_t, float); DECL_RADIX_SORT_PAIRS_FN(int64_t, double); -DECL_RADIX_SORT_PAIRS_FN(int64_t, int64_t); -DECL_RADIX_SORT_PAIRS_FN(int64_t, int32_t); +DECL_RADIX_SORT_PAIRS_FN(int32_t, int32_t); +DECL_RADIX_SORT_PAIRS_FN(int32_t, int64_t); +DECL_RADIX_SORT_PAIRS_FN(int32_t, float); +DECL_RADIX_SORT_PAIRS_FN(int32_t, double); #undef DECL_RADIX_SORT_PAIRS_FN diff --git a/fbgemm_gpu/src/split_embeddings_utils/radix_sort_pairs.cu b/fbgemm_gpu/src/split_embeddings_utils/radix_sort_pairs.cu index 50d9757d2..93dab81a4 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/radix_sort_pairs.cu +++ b/fbgemm_gpu/src/split_embeddings_utils/radix_sort_pairs.cu @@ -77,7 +77,11 @@ using namespace fbgemm_gpu; } #endif +DEF_RADIX_SORT_PAIRS_FN(int64_t, int32_t); +DEF_RADIX_SORT_PAIRS_FN(int64_t, int64_t); DEF_RADIX_SORT_PAIRS_FN(int64_t, float); DEF_RADIX_SORT_PAIRS_FN(int64_t, double); -DEF_RADIX_SORT_PAIRS_FN(int64_t, int64_t); -DEF_RADIX_SORT_PAIRS_FN(int64_t, int32_t); +DEF_RADIX_SORT_PAIRS_FN(int32_t, int32_t); +DEF_RADIX_SORT_PAIRS_FN(int32_t, int64_t); +DEF_RADIX_SORT_PAIRS_FN(int32_t, float); +DEF_RADIX_SORT_PAIRS_FN(int32_t, double); diff --git a/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu b/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu index 83a06d78a..ea2ea7d66 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu +++ b/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu @@ -63,7 +63,7 @@ inline at::Tensor asynchronous_complete_cumsum(at::Tensor t_in) { template __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel( - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 hash_size_cumsum, const pta::PackedTensorAccessor32 indices, @@ -79,7 +79,7 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel( // Use a raw pointer to avoid creating dummy PackedTensorAccessor const uint32_t* const __restrict__ vbe_b_t_map, FixedDivisor fd) { - const int32_t T = hash_size_cumsum.size(0) - 1; + const auto T = hash_size_cumsum.size(0) - 1; auto b_t = blockIdx.x * blockDim.x + threadIdx.x; int32_t b; int32_t t; @@ -97,21 +97,20 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel( } const index_t hash_offset = valid ? hash_size_cumsum[t] : -1; - const index_t indices_start = valid ? offsets[b_t] : -1; - const int32_t L = valid ? offsets[b_t + 1] - indices_start : 0; + const auto indices_start = valid ? offsets[b_t] : -1; + const auto L = valid ? offsets[b_t + 1] - indices_start : 0; const int32_t lane_id = threadIdx.x % fbgemm_gpu::kWarpSize; // Compile-time conditional if (nobag) { for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) { - const index_t indices_start_warp = - fbgemm_gpu::shfl_sync(indices_start, j); - const int32_t t_warp = fbgemm_gpu::shfl_sync(t, j); - const int32_t L_warp = fbgemm_gpu::shfl_sync(L, j); + const auto indices_start_warp = fbgemm_gpu::shfl_sync(indices_start, j); + const auto t_warp = fbgemm_gpu::shfl_sync(t, j); + const auto L_warp = fbgemm_gpu::shfl_sync(L, j); const index_t hash_offset_warp = fbgemm_gpu::shfl_sync(hash_offset, j); for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) { const index_t idx = __ldg(&indices[indices_start_warp + i]); - const int64_t l_t = (indices_start_warp + i) * T + t_warp; + const auto l_t = (indices_start_warp + i) * T + t_warp; infos[indices_start_warp + i] = l_t; linear_indices[indices_start_warp + i] = hash_offset_warp + idx; } @@ -124,10 +123,9 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel( reinterpret_cast(&b)[0]; } for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) { - const index_t indices_start_warp = - fbgemm_gpu::shfl_sync(indices_start, j); - const uint32_t info_warp = fbgemm_gpu::shfl_sync(info, j); - const int32_t L_warp = fbgemm_gpu::shfl_sync(L, j); + const auto indices_start_warp = fbgemm_gpu::shfl_sync(indices_start, j); + const auto info_warp = fbgemm_gpu::shfl_sync(info, j); + const auto L_warp = fbgemm_gpu::shfl_sync(L, j); const index_t hash_offset_warp = fbgemm_gpu::shfl_sync(hash_offset, j); for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) { const index_t idx = __ldg(&indices[indices_start_warp + i]); @@ -142,7 +140,7 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel( template __global__ __launch_bounds__(kMaxThreads) void linearize_index_index_select_kernel( - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 hash_size_cumsum, const pta::PackedTensorAccessor32 indices, @@ -153,7 +151,7 @@ __launch_bounds__(kMaxThreads) void linearize_index_index_select_kernel( linear_indices, FixedDivisor fd, int32_t fixed_L_per_warp) { - const int32_t T = hash_size_cumsum.size(0) - 1; + const auto T = hash_size_cumsum.size(0) - 1; auto b_t = blockIdx.x * blockDim.x + threadIdx.x; int32_t b; int32_t t; @@ -258,7 +256,7 @@ transpose_embedding_input( kMaxThreads, \ 0, \ at::cuda::getCurrentCUDAStream()>>>( \ - MAKE_PTA_WITH_NAME(func_name, hash_size_cumsum, index_t, 1, 32), \ + MAKE_PTA_WITH_NAME(func_name, hash_size_cumsum, int64_t, 1, 32), \ MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), \ MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), \ MAKE_PTA_WITH_NAME(func_name, infos, INFO_ACC_T, 1, 32), \ @@ -296,7 +294,7 @@ transpose_embedding_input( 0, at::cuda::getCurrentCUDAStream()>>>( MAKE_PTA_WITH_NAME( - func_name, hash_size_cumsum, index_t, 1, 32), + func_name, hash_size_cumsum, int64_t, 1, 32), MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), MAKE_PTA_WITH_NAME( func_name, total_L_offsets.value(), index_t, 1, 32),