Skip to content

Commit

Permalink
[GraphBolt] Hetero CPU sampling bug fix. (#7369)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Apr 29, 2024
1 parent 0d9a09d commit 6b140f2
Showing 1 changed file with 20 additions and 36 deletions.
56 changes: 20 additions & 36 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
// it equals to `num_seeds`.
const int64_t num_rows = etype_id_to_num_picked_offset[num_etypes];
torch::Tensor num_picked_neighbors_per_node =
torch::empty({num_rows}, indptr_options);
// Need to use zeros because all nodes don't have all etypes.
torch::zeros({num_rows}, indptr_options);

AT_DISPATCH_INDEX_TYPES(
indptr_.scalar_type(), "SampleNeighborsImplWrappedWithIndptr", ([&] {
Expand All @@ -571,14 +572,6 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
num_picked_neighbors_data_ptr[0] = 0;
const auto seeds_data_ptr = seeds.data_ptr<seeds_t>();

// Initialize the empty spots in `num_picked_neighbors_per_node`.
if (hetero_with_seed_offsets) {
for (auto i = 0; i < num_etypes; ++i) {
num_picked_neighbors_data_ptr
[etype_id_to_num_picked_offset[i]] = 0;
}
}

// Step 1. Calculate pick number of each node.
torch::parallel_for(
0, num_seeds, grain_size, [&](int64_t begin, int64_t end) {
Expand Down Expand Up @@ -612,40 +605,36 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
}
});

// Step 2. Calculate prefix sum to get total length and offsets of
// each node. It's also the indptr of the generated subgraph.
subgraph_indptr = num_picked_neighbors_per_node.cumsum(
0, indptr_.scalar_type());
auto subgraph_indptr_data_ptr =
subgraph_indptr.data_ptr<indptr_t>();

if (hetero_with_seed_offsets) {
torch::Tensor num_picked_offset_tensor =
torch::zeros({num_etypes + 1}, indptr_options);
torch::empty({num_etypes + 1}, indptr_options);
const auto num_picked_offset_data_ptr =
num_picked_offset_tensor.data_ptr<indptr_t>();
std::copy(
etype_id_to_num_picked_offset.begin(),
etype_id_to_num_picked_offset.end(),
num_picked_offset_data_ptr);
torch::Tensor substract_offset =
torch::zeros({num_etypes}, indptr_options);
torch::empty({num_etypes}, indptr_options);
const auto substract_offset_data_ptr =
substract_offset.data_ptr<indptr_t>();
const auto num_picked_offset_data_ptr =
num_picked_offset_tensor.data_ptr<indptr_t>();
for (auto i = 0; i < num_etypes; ++i) {
num_picked_offset_data_ptr[i + 1] =
etype_id_to_num_picked_offset[i + 1];
// Collect the total pick number for each edge type.
if (i + 1 < num_etypes)
substract_offset_data_ptr[i + 1] =
num_picked_neighbors_data_ptr
[etype_id_to_num_picked_offset[i]];
num_picked_neighbors_data_ptr
[etype_id_to_num_picked_offset[i]] = 0;
// Collect the total pick number subtract offsets.
substract_offset_data_ptr[i] = subgraph_indptr_data_ptr
[etype_id_to_num_picked_offset[i]];
}
substract_offset =
substract_offset.cumsum(0, indptr_.scalar_type());
subgraph_indptr_substract = ops::ExpandIndptr(
num_picked_offset_tensor, indptr_.scalar_type(),
substract_offset);
}

// Step 2. Calculate prefix sum to get total length and offsets of
// each node. It's also the indptr of the generated subgraph.
subgraph_indptr = num_picked_neighbors_per_node.cumsum(
0, indptr_.scalar_type());
auto subgraph_indptr_data_ptr =
subgraph_indptr.data_ptr<indptr_t>();

// When doing non-temporal hetero sampling, we generate an
// edge_offsets tensor.
if (hetero_with_seed_offsets) {
Expand Down Expand Up @@ -1277,11 +1266,6 @@ void NumPickByEtype(
NumPick(
fanouts[etype], replace, probs_or_mask, etype_begin,
etype_end - etype_begin, num_picked_ptr + offset);
// Use the skipped position of each edge type in the
// num_picked_tensor to sum up the total pick number for each edge
// type.
num_picked_ptr[etype_id_to_num_picked_offset[etype] - 1] +=
num_picked_ptr[offset];
} else {
PickedNumType picked_count = 0;
NumPick(
Expand Down

0 comments on commit 6b140f2

Please sign in to comment.