From 70d375b20bed74c950e73154f6e1ed814fdeae52 Mon Sep 17 00:00:00 2001 From: Charles Hastings Date: Thu, 16 Jan 2025 12:55:59 -0800 Subject: [PATCH] A few last changes requested --- .../cugraph/detail/shuffle_wrappers.hpp | 2 +- cpp/src/detail/groupby_and_count.cuh | 1 - cpp/src/sampling/neighbor_sampling_impl.hpp | 6 ++--- cpp/src/structure/remove_multi_edges_impl.cuh | 23 ++++++++++--------- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/cpp/include/cugraph/detail/shuffle_wrappers.hpp b/cpp/include/cugraph/detail/shuffle_wrappers.hpp index b85302d92f7..33ae38b8119 100644 --- a/cpp/include/cugraph/detail/shuffle_wrappers.hpp +++ b/cpp/include/cugraph/detail/shuffle_wrappers.hpp @@ -40,7 +40,7 @@ namespace detail { * @tparam edge_t Type of edge identifiers. Needs to be an integral type. * @tparam weight_t Type of edge weights. Needs to be a floating point type. * @tparam edge_type_t Type of edge type identifiers. Needs to be an integral type. - * @tparam edge_time_t The type of the edge time stamp. Needts to be an integral type. + * @tparam edge_time_t The type of the edge time stamp. Needs to be an integral type. * * @param[in] handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, * and handles to various CUDA libraries) to run graph algorithms. diff --git a/cpp/src/detail/groupby_and_count.cuh b/cpp/src/detail/groupby_and_count.cuh index df2e5553b6f..2117e3142f2 100644 --- a/cpp/src/detail/groupby_and_count.cuh +++ b/cpp/src/detail/groupby_and_count.cuh @@ -126,7 +126,6 @@ rmm::device_uvector groupby_and_count_edgelist_by_local_partition_id( }; if (edge_property_count == 0) { - // TODO: Consider flipping the outer if and doing edge_property_count test first... if (groupby_and_count_local_partition_by_minor) { result = cugraph::groupby_and_count(pair_first, pair_first + d_edgelist_majors.size(), diff --git a/cpp/src/sampling/neighbor_sampling_impl.hpp b/cpp/src/sampling/neighbor_sampling_impl.hpp index c853768f714..bbc0fbc17af 100644 --- a/cpp/src/sampling/neighbor_sampling_impl.hpp +++ b/cpp/src/sampling/neighbor_sampling_impl.hpp @@ -202,8 +202,8 @@ neighbor_sample_impl(raft::handle_t const& handle, ? std::make_optional(rmm::device_uvector(0, handle.get_stream())) : std::nullopt; - for (edge_type_t edge_type_id = 0; edge_type_id < num_edge_types; edge_type_id++) { - auto k_level = fan_out[(hop * num_edge_types) + edge_type_id]; + for (edge_type_t edge_type = 0; edge_type < num_edge_types; edge_type++) { + auto k_level = fan_out[(hop * num_edge_types) + edge_type]; rmm::device_uvector srcs(0, handle.get_stream()); rmm::device_uvector dsts(0, handle.get_stream()); std::optional> weights{std::nullopt}; @@ -212,7 +212,7 @@ neighbor_sample_impl(raft::handle_t const& handle, std::optional> labels{std::nullopt}; if (num_edge_types > 1) { - modified_graph_view.attach_edge_mask(edge_masks_vector[edge_type_id].view()); + modified_graph_view.attach_edge_mask(edge_masks_vector[edge_type].view()); } if (k_level > 0) { diff --git a/cpp/src/structure/remove_multi_edges_impl.cuh b/cpp/src/structure/remove_multi_edges_impl.cuh index c90462343c1..23ce6055348 100644 --- a/cpp/src/structure/remove_multi_edges_impl.cuh +++ b/cpp/src/structure/remove_multi_edges_impl.cuh @@ -99,18 +99,19 @@ std::tuple, rmm::device_uvector> group_m return std::make_tuple(std::move(edgelist_srcs), std::move(edgelist_dsts)); } -template -std:: - tuple, rmm::device_uvector, rmm::device_uvector> - group_multi_edges(raft::handle_t const& handle, - rmm::device_uvector&& edgelist_srcs, - rmm::device_uvector&& edgelist_dsts, - rmm::device_uvector&& edgelist_values, - size_t mem_frugal_threshold, - bool keep_min_value_edge) +template +std::tuple, + rmm::device_uvector, + dataframe_buffer_type_t> +group_multi_edges(raft::handle_t const& handle, + rmm::device_uvector&& edgelist_srcs, + rmm::device_uvector&& edgelist_dsts, + dataframe_buffer_type_t&& edgelist_values, + size_t mem_frugal_threshold, + bool keep_min_value_edge) { auto pair_first = thrust::make_zip_iterator(edgelist_srcs.begin(), edgelist_dsts.begin()); - auto value_first = edgelist_values.begin(); + auto value_first = get_dataframe_buffer_begin(edgelist_values); auto edge_first = thrust::make_zip_iterator(pair_first, value_first); if (edgelist_srcs.size() > mem_frugal_threshold) { @@ -152,7 +153,7 @@ std:: thrust::sort_by_key(handle.get_thrust_policy(), pair_first, pair_first + edgelist_srcs.size(), - edgelist_values.begin()); + get_dataframe_buffer_begin(edgelist_values)); } }