Skip to content

Commit

Permalink
fix issue when computing new edge index
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardcaquot94 committed Jan 3, 2025
1 parent 7b9ec12 commit b1f3e84
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions torch_geometric/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,13 @@ def filter(self, idx: torch.Tensor) -> Self:
# We assume x node attributes to be changed before edge attributes
# so that mapping_idx_dict is already available.
old_inc = self._inc_dict[key][attr].squeeze(-1).T
old_inc_diff = old_inc.diff(
prepend=torch.zeros((2, 1), dtype=torch.int))[:, mask]
old_inc_diff[:, 0] = 0
new_inc = old_inc_diff.cumsum(1)
shift_inc = new_inc - old_inc[:, mask]
new_inc = old_inc.diff()[:, mask[:-1]].cumsum(1)
new_inc = torch.cat((torch.zeros((2, 1), dtype=torch.int), new_inc), dim=1)

shift = new_inc - old_inc[:, mask]

edge_index_batch = torch.repeat_interleave(sizes_masked)
batch[key].edge_index += shift_inc[:, edge_index_batch]
batch[key].edge_index += shift[:, edge_index_batch]

new_inc = new_inc.T.unsqueeze(-1)

Expand Down

0 comments on commit b1f3e84

Please sign in to comment.