diff --git a/cpp/src/c_api/neighbor_sampling.cpp b/cpp/src/c_api/neighbor_sampling.cpp index 37982eab82..2cc9646309 100644 --- a/cpp/src/c_api/neighbor_sampling.cpp +++ b/cpp/src/c_api/neighbor_sampling.cpp @@ -880,7 +880,6 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { handle_.get_stream()); std::optional> start_vertex_labels{std::nullopt}; - std::optional> local_label_to_comm_rank{std::nullopt}; std::optional> label_to_comm_rank{ std::nullopt}; // global after allgatherv @@ -932,12 +931,13 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { handle_.get_stream(), raft::device_span{unique_labels.data(), unique_labels.size()}); - (*local_label_to_comm_rank).resize(num_unique_labels, handle_.get_stream()); + rmm::device_uvector local_label_to_comm_rank(num_unique_labels, + handle_.get_stream()); cugraph::detail::scalar_fill( handle_.get_stream(), - (*local_label_to_comm_rank).begin(), // This should be rename to rank - (*local_label_to_comm_rank).size(), + local_label_to_comm_rank.begin(), // This should be rename to rank + local_label_to_comm_rank.size(), label_t{handle_.get_comms().get_rank()}); // Perform allgather to get global_label_to_comm_rank_d_vector @@ -948,11 +948,13 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { std::exclusive_scan( recvcounts.begin(), recvcounts.end(), displacements.begin(), size_t{0}); - (*label_to_comm_rank) - .resize(displacements.back() + recvcounts.back(), handle_.get_stream()); + rmm::device_uvector tmp_label_to_comm_rank( + displacements.back() + recvcounts.back(), handle_.get_stream()); + + label_to_comm_rank = std::move(tmp_label_to_comm_rank); cugraph::device_allgatherv(handle_.get_comms(), - (*local_label_to_comm_rank).begin(), + local_label_to_comm_rank.begin(), (*label_to_comm_rank).begin(), recvcounts, displacements,