Skip to content

Commit

Permalink
A few last changes requested
Browse files Browse the repository at this point in the history
  • Loading branch information
ChuckHastings committed Jan 16, 2025
1 parent 1bb002d commit 70d375b
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 16 deletions.
2 changes: 1 addition & 1 deletion cpp/include/cugraph/detail/shuffle_wrappers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace detail {
* @tparam edge_t Type of edge identifiers. Needs to be an integral type.
* @tparam weight_t Type of edge weights. Needs to be a floating point type.
* @tparam edge_type_t Type of edge type identifiers. Needs to be an integral type.
* @tparam edge_time_t The type of the edge time stamp. Needts to be an integral type.
* @tparam edge_time_t The type of the edge time stamp. Needs to be an integral type.
*
* @param[in] handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator,
* and handles to various CUDA libraries) to run graph algorithms.
Expand Down
1 change: 0 additions & 1 deletion cpp/src/detail/groupby_and_count.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ rmm::device_uvector<size_t> groupby_and_count_edgelist_by_local_partition_id(
};

if (edge_property_count == 0) {
// TODO: Consider flipping the outer if and doing edge_property_count test first...
if (groupby_and_count_local_partition_by_minor) {
result = cugraph::groupby_and_count(pair_first,
pair_first + d_edgelist_majors.size(),
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/sampling/neighbor_sampling_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ neighbor_sample_impl(raft::handle_t const& handle,
? std::make_optional(rmm::device_uvector<label_t>(0, handle.get_stream()))
: std::nullopt;

for (edge_type_t edge_type_id = 0; edge_type_id < num_edge_types; edge_type_id++) {
auto k_level = fan_out[(hop * num_edge_types) + edge_type_id];
for (edge_type_t edge_type = 0; edge_type < num_edge_types; edge_type++) {
auto k_level = fan_out[(hop * num_edge_types) + edge_type];
rmm::device_uvector<vertex_t> srcs(0, handle.get_stream());
rmm::device_uvector<vertex_t> dsts(0, handle.get_stream());
std::optional<rmm::device_uvector<weight_t>> weights{std::nullopt};
Expand All @@ -212,7 +212,7 @@ neighbor_sample_impl(raft::handle_t const& handle,
std::optional<rmm::device_uvector<int32_t>> labels{std::nullopt};

if (num_edge_types > 1) {
modified_graph_view.attach_edge_mask(edge_masks_vector[edge_type_id].view());
modified_graph_view.attach_edge_mask(edge_masks_vector[edge_type].view());
}

if (k_level > 0) {
Expand Down
23 changes: 12 additions & 11 deletions cpp/src/structure/remove_multi_edges_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,19 @@ std::tuple<rmm::device_uvector<vertex_t>, rmm::device_uvector<vertex_t>> group_m
return std::make_tuple(std::move(edgelist_srcs), std::move(edgelist_dsts));
}

template <typename vertex_t, typename value_t>
std::
tuple<rmm::device_uvector<vertex_t>, rmm::device_uvector<vertex_t>, rmm::device_uvector<value_t>>
group_multi_edges(raft::handle_t const& handle,
rmm::device_uvector<vertex_t>&& edgelist_srcs,
rmm::device_uvector<vertex_t>&& edgelist_dsts,
rmm::device_uvector<value_t>&& edgelist_values,
size_t mem_frugal_threshold,
bool keep_min_value_edge)
template <typename vertex_t, typename edge_value_t>
std::tuple<rmm::device_uvector<vertex_t>,
rmm::device_uvector<vertex_t>,
dataframe_buffer_type_t<edge_value_t>>
group_multi_edges(raft::handle_t const& handle,
rmm::device_uvector<vertex_t>&& edgelist_srcs,
rmm::device_uvector<vertex_t>&& edgelist_dsts,
dataframe_buffer_type_t<edge_value_t>&& edgelist_values,
size_t mem_frugal_threshold,
bool keep_min_value_edge)
{
auto pair_first = thrust::make_zip_iterator(edgelist_srcs.begin(), edgelist_dsts.begin());
auto value_first = edgelist_values.begin();
auto value_first = get_dataframe_buffer_begin(edgelist_values);
auto edge_first = thrust::make_zip_iterator(pair_first, value_first);

if (edgelist_srcs.size() > mem_frugal_threshold) {
Expand Down Expand Up @@ -152,7 +153,7 @@ std::
thrust::sort_by_key(handle.get_thrust_policy(),
pair_first,
pair_first + edgelist_srcs.size(),
edgelist_values.begin());
get_dataframe_buffer_begin(edgelist_values));
}
}

Expand Down

0 comments on commit 70d375b

Please sign in to comment.