From 6b140f28ed847b1426560ded26a4245627e0a528 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Mon, 29 Apr 2024 00:29:41 -0400 Subject: [PATCH] [GraphBolt] Hetero CPU sampling bug fix. (#7369) --- graphbolt/src/fused_csc_sampling_graph.cc | 56 ++++++++--------------- 1 file changed, 20 insertions(+), 36 deletions(-) diff --git a/graphbolt/src/fused_csc_sampling_graph.cc b/graphbolt/src/fused_csc_sampling_graph.cc index 18629b05e76a..b5c2587f693a 100644 --- a/graphbolt/src/fused_csc_sampling_graph.cc +++ b/graphbolt/src/fused_csc_sampling_graph.cc @@ -557,7 +557,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( // it equals to `num_seeds`. const int64_t num_rows = etype_id_to_num_picked_offset[num_etypes]; torch::Tensor num_picked_neighbors_per_node = - torch::empty({num_rows}, indptr_options); + // Need to use zeros because all nodes don't have all etypes. + torch::zeros({num_rows}, indptr_options); AT_DISPATCH_INDEX_TYPES( indptr_.scalar_type(), "SampleNeighborsImplWrappedWithIndptr", ([&] { @@ -571,14 +572,6 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( num_picked_neighbors_data_ptr[0] = 0; const auto seeds_data_ptr = seeds.data_ptr(); - // Initialize the empty spots in `num_picked_neighbors_per_node`. - if (hetero_with_seed_offsets) { - for (auto i = 0; i < num_etypes; ++i) { - num_picked_neighbors_data_ptr - [etype_id_to_num_picked_offset[i]] = 0; - } - } - // Step 1. Calculate pick number of each node. torch::parallel_for( 0, num_seeds, grain_size, [&](int64_t begin, int64_t end) { @@ -612,40 +605,36 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( } }); + // Step 2. Calculate prefix sum to get total length and offsets of + // each node. It's also the indptr of the generated subgraph. + subgraph_indptr = num_picked_neighbors_per_node.cumsum( + 0, indptr_.scalar_type()); + auto subgraph_indptr_data_ptr = + subgraph_indptr.data_ptr(); + if (hetero_with_seed_offsets) { torch::Tensor num_picked_offset_tensor = - torch::zeros({num_etypes + 1}, indptr_options); + torch::empty({num_etypes + 1}, indptr_options); + const auto num_picked_offset_data_ptr = + num_picked_offset_tensor.data_ptr(); + std::copy( + etype_id_to_num_picked_offset.begin(), + etype_id_to_num_picked_offset.end(), + num_picked_offset_data_ptr); torch::Tensor substract_offset = - torch::zeros({num_etypes}, indptr_options); + torch::empty({num_etypes}, indptr_options); const auto substract_offset_data_ptr = substract_offset.data_ptr(); - const auto num_picked_offset_data_ptr = - num_picked_offset_tensor.data_ptr(); for (auto i = 0; i < num_etypes; ++i) { - num_picked_offset_data_ptr[i + 1] = - etype_id_to_num_picked_offset[i + 1]; - // Collect the total pick number for each edge type. - if (i + 1 < num_etypes) - substract_offset_data_ptr[i + 1] = - num_picked_neighbors_data_ptr - [etype_id_to_num_picked_offset[i]]; - num_picked_neighbors_data_ptr - [etype_id_to_num_picked_offset[i]] = 0; + // Collect the total pick number subtract offsets. + substract_offset_data_ptr[i] = subgraph_indptr_data_ptr + [etype_id_to_num_picked_offset[i]]; } - substract_offset = - substract_offset.cumsum(0, indptr_.scalar_type()); subgraph_indptr_substract = ops::ExpandIndptr( num_picked_offset_tensor, indptr_.scalar_type(), substract_offset); } - // Step 2. Calculate prefix sum to get total length and offsets of - // each node. It's also the indptr of the generated subgraph. - subgraph_indptr = num_picked_neighbors_per_node.cumsum( - 0, indptr_.scalar_type()); - auto subgraph_indptr_data_ptr = - subgraph_indptr.data_ptr(); - // When doing non-temporal hetero sampling, we generate an // edge_offsets tensor. if (hetero_with_seed_offsets) { @@ -1277,11 +1266,6 @@ void NumPickByEtype( NumPick( fanouts[etype], replace, probs_or_mask, etype_begin, etype_end - etype_begin, num_picked_ptr + offset); - // Use the skipped position of each edge type in the - // num_picked_tensor to sum up the total pick number for each edge - // type. - num_picked_ptr[etype_id_to_num_picked_offset[etype] - 1] += - num_picked_ptr[offset]; } else { PickedNumType picked_count = 0; NumPick(