Skip to content

Commit

Permalink
fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
jnke2016 committed Nov 27, 2024
1 parent 88b35ba commit f8c576a
Showing 1 changed file with 102 additions and 109 deletions.
211 changes: 102 additions & 109 deletions cpp/src/sampling/neighbor_sampling_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,30 +106,27 @@ neighbor_sample_impl(raft::handle_t const& handle,
edge_masks_vector{};
graph_view_t<vertex_t, edge_t, false, multi_gpu> modified_graph_view = graph_view;
edge_masks_vector.reserve(num_edge_types);

label_t num_unique_labels = 0;

if (starting_vertex_labels) {
// Find the number of unique lables
std::optional<rmm::device_uvector<label_t>> cp_starting_vertex_labels{std::nullopt};
cp_starting_vertex_labels = rmm::device_uvector<label_t>(starting_vertex_labels->size(), handle.get_stream());

thrust::copy(
handle.get_thrust_policy(),
starting_vertex_labels->begin(),
starting_vertex_labels->end(),
cp_starting_vertex_labels->begin());

thrust::sort(
handle.get_thrust_policy(),
cp_starting_vertex_labels->begin(),
cp_starting_vertex_labels->end());

cp_starting_vertex_labels =
rmm::device_uvector<label_t>(starting_vertex_labels->size(), handle.get_stream());

thrust::copy(handle.get_thrust_policy(),
starting_vertex_labels->begin(),
starting_vertex_labels->end(),
cp_starting_vertex_labels->begin());

thrust::sort(handle.get_thrust_policy(),
cp_starting_vertex_labels->begin(),
cp_starting_vertex_labels->end());

num_unique_labels = thrust::unique_count(handle.get_thrust_policy(),
cp_starting_vertex_labels->begin(),
cp_starting_vertex_labels->end());


}

if (num_edge_types > 1) {
Expand Down Expand Up @@ -171,7 +168,6 @@ neighbor_sample_impl(raft::handle_t const& handle,
? (fan_out.size() / num_edge_types)
: ((fan_out.size() / num_edge_types) + 1);


auto level_result_weight_vectors =
edge_weight_view ? std::make_optional(std::vector<rmm::device_uvector<weight_t>>{})
: std::nullopt;
Expand All @@ -192,13 +188,14 @@ neighbor_sample_impl(raft::handle_t const& handle,
: std::nullopt;
auto level_result_edge_id =
edge_id_view ? std::make_optional(rmm::device_uvector<edge_t>(0, handle.get_stream()))
: std::nullopt;
: std::nullopt;
auto level_result_edge_type =
edge_type_view ? std::make_optional(rmm::device_uvector<edge_type_t>(0, handle.get_stream()))
: std::nullopt;
: std::nullopt;
auto level_result_label =
starting_vertex_labels ? std::make_optional(rmm::device_uvector<label_t>(0, handle.get_stream()))
: std::nullopt;
starting_vertex_labels
? std::make_optional(rmm::device_uvector<label_t>(0, handle.get_stream()))
: std::nullopt;

if (level_result_weight_vectors) { (*level_result_weight_vectors).reserve(num_hops); }
if (level_result_edge_id_vectors) { (*level_result_edge_id_vectors).reserve(num_hops); }
Expand Down Expand Up @@ -267,52 +264,44 @@ neighbor_sample_impl(raft::handle_t const& handle,
level_result_src.resize(old_size + srcs.size(), handle.get_stream());
level_result_dst.resize(old_size + srcs.size(), handle.get_stream());

raft::copy(
level_result_src.begin() + old_size, srcs.begin(), srcs.size(), handle.get_stream());

raft::copy(level_result_src.begin() + old_size,
srcs.begin(),
srcs.size(),
handle.get_stream());

raft::copy(level_result_dst.begin() + old_size,
dsts.begin(),
srcs.size(),
handle.get_stream());
raft::copy(
level_result_dst.begin() + old_size, dsts.begin(), srcs.size(), handle.get_stream());

if (weights) {
(*level_result_weight).resize(old_size + srcs.size(), handle.get_stream());

raft::copy(level_result_weight->begin() + old_size,
weights->begin(),
srcs.size(),
handle.get_stream());
weights->begin(),
srcs.size(),
handle.get_stream());
}




if (edge_ids) {
(*level_result_edge_id).resize(old_size + srcs.size(), handle.get_stream());
raft::copy(level_result_edge_id->begin() + old_size,
edge_ids->begin(),
srcs.size(),
handle.get_stream());
edge_ids->begin(),
srcs.size(),
handle.get_stream());
}
if (edge_types) {
(*level_result_edge_type).resize(old_size + srcs.size(), handle.get_stream());


raft::copy(level_result_edge_type->begin() + old_size,
edge_types->begin(),
srcs.size(),
handle.get_stream());
edge_types->begin(),
srcs.size(),
handle.get_stream());
}

if (labels) {
(*level_result_label).resize(old_size + srcs.size(), handle.get_stream());

raft::copy(level_result_label->begin() + old_size,
labels->begin(),
srcs.size(),
handle.get_stream());
labels->begin(),
srcs.size(),
handle.get_stream());
}

if (num_edge_types > 1) { modified_graph_view.clear_edge_mask(); }
Expand All @@ -322,10 +311,18 @@ neighbor_sample_impl(raft::handle_t const& handle,
level_result_src_vectors.push_back(std::move(level_result_src));
level_result_dst_vectors.push_back(std::move(level_result_dst));

if (level_result_weight) { (*level_result_weight_vectors).push_back(std::move(*level_result_weight)); }
if (level_result_edge_id) { (*level_result_edge_id_vectors).push_back(std::move(*level_result_edge_id)); }
if (level_result_edge_type) { (*level_result_edge_type_vectors).push_back(std::move(*level_result_edge_type)); }
if (level_result_label) { (*level_result_label_vectors).push_back(std::move(*level_result_label)); }
if (level_result_weight) {
(*level_result_weight_vectors).push_back(std::move(*level_result_weight));
}
if (level_result_edge_id) {
(*level_result_edge_id_vectors).push_back(std::move(*level_result_edge_id));
}
if (level_result_edge_type) {
(*level_result_edge_type_vectors).push_back(std::move(*level_result_edge_type));
}
if (level_result_label) {
(*level_result_label_vectors).push_back(std::move(*level_result_label));
}

// FIXME: We should modify vertex_partition_range_lasts to return a raft::host_span
// rather than making a copy.
Expand All @@ -337,10 +334,9 @@ neighbor_sample_impl(raft::handle_t const& handle,
starting_vertex_labels,
raft::device_span<vertex_t const>{level_result_dst_vectors.back().data(),
level_result_dst_vectors.back().size()},
frontier_vertex_labels
? std::make_optional(raft::device_span<label_t const>(
level_result_label->data(), level_result_label->size()))
: std::nullopt,
frontier_vertex_labels ? std::make_optional(raft::device_span<label_t const>(
level_result_label->data(), level_result_label->size()))
: std::nullopt,
std::move(vertex_used_as_source),
modified_graph_view.local_vertex_partition_view(),
vertex_partition_range_lasts,
Expand Down Expand Up @@ -465,62 +461,59 @@ neighbor_sample_impl(raft::handle_t const& handle,
if (result_labels) {
cp_result_labels = rmm::device_uvector<label_t>(result_labels->size(), handle.get_stream());

thrust::copy(
handle.get_thrust_policy(),
result_labels->begin(),
result_labels->end(),
cp_result_labels->begin());
thrust::copy(handle.get_thrust_policy(),
result_labels->begin(),
result_labels->end(),
cp_result_labels->begin());
}

std::tie(result_srcs,
result_dsts,
result_weights,
result_edge_ids,
result_edge_types,
result_hops,
result_labels,
result_offsets) = detail::shuffle_and_organize_output(handle,
std::move(result_srcs),
std::move(result_dsts),
std::move(result_weights),
std::move(result_edge_ids),
std::move(result_edge_types),
std::move(result_hops),
std::move(result_labels),
label_to_output_comm_rank);

if (result_labels && (result_offsets->size() != num_unique_labels + 1)) {
result_offsets = rmm::device_uvector<size_t>(num_unique_labels + 1, handle.get_stream());

// Sort labels
thrust::sort(handle.get_thrust_policy(), cp_result_labels->begin(), cp_result_labels->end());

thrust::transform(handle.get_thrust_policy(),
thrust::make_counting_iterator<edge_t>(0),
thrust::make_counting_iterator<edge_t>(result_offsets->size() - 1),
result_offsets->begin() + 1,
[result_labels = raft::device_span<label_t const>(
cp_result_labels->data(), cp_result_labels->size())] __device__(auto idx) {
auto itr_lower = thrust::lower_bound(
thrust::seq, result_labels.begin(), result_labels.end(), idx);

auto itr_upper = thrust::upper_bound(
thrust::seq, result_labels.begin(), result_labels.end(), idx);

auto sampled_label_size = thrust::distance(itr_lower, itr_upper);

return sampled_label_size;
});

// Run inclusive scan
thrust::inclusive_scan(handle.get_thrust_policy(),
result_offsets->begin() + 1,
result_offsets->end(),
result_offsets->begin() + 1);
}

std::tie(result_srcs, result_dsts, result_weights, result_edge_ids,
result_edge_types, result_hops, result_labels, result_offsets)
= detail::shuffle_and_organize_output(handle,
std::move(result_srcs),
std::move(result_dsts),
std::move(result_weights),
std::move(result_edge_ids),
std::move(result_edge_types),
std::move(result_hops),
std::move(result_labels),
label_to_output_comm_rank);

if (result_labels && (result_offsets->size() != num_unique_labels + 1)) {
result_offsets = rmm::device_uvector<size_t>(num_unique_labels + 1, handle.get_stream());

// Sort labels
thrust::sort(
handle.get_thrust_policy(),
cp_result_labels->begin(),
cp_result_labels->end());

thrust::transform(
handle.get_thrust_policy(),
thrust::make_counting_iterator<edge_t>(0),
thrust::make_counting_iterator<edge_t>(result_offsets->size() - 1),
result_offsets->begin() + 1,
[
result_labels = raft::device_span<label_t const>(
cp_result_labels->data(),
cp_result_labels->size())
] __device__(auto idx) {
auto itr_lower = thrust::lower_bound(
thrust::seq, result_labels.begin(), result_labels.end(), idx);

auto itr_upper = thrust::upper_bound(
thrust::seq, result_labels.begin(), result_labels.end(), idx);

auto sampled_label_size = thrust::distance(itr_lower, itr_upper);

return sampled_label_size;
});

// Run inclusive scan
thrust::inclusive_scan(handle.get_thrust_policy(),
result_offsets->begin() + 1,
result_offsets->end(),
result_offsets->begin() + 1);
}

return std::make_tuple(std::move(result_srcs),
std::move(result_dsts),
std::move(result_weights),
Expand Down

0 comments on commit f8c576a

Please sign in to comment.