Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardcaquot94 committed Jan 7, 2025
2 parents 66b48d2 + 080e23c commit bd296a3
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions torch_geometric/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,9 @@ def filter(self, idx: torch.Tensor) -> Self:
if attr == 'edge_index':
new_store[attr] = old_store[attr][:, attr_mask]
elif isinstance(old_store[attr], list):
new_store[attr] = [x for x, m in zip(old_store[attr], attr_mask) if m]
new_store[attr] = [
x for x, m in zip(old_store[attr], attr_mask) if m
]
else:
new_store[attr] = old_store[attr][attr_mask]

Expand Down Expand Up @@ -281,7 +283,8 @@ def filter(self, idx: torch.Tensor) -> Self:
new_inc = new_inc_tmp.roll(1, dims=1)

# Map each edge_index element to its batch position
edge_index_batch_map = torch.repeat_interleave(sizes_masked)
edge_index_batch_map = torch.repeat_interleave(
sizes_masked)
# Remove old_inc and add new_inc to each edge_index element using shift tensor
shift = new_inc - old_inc[:, mask]
batch[key].edge_index += shift[:, edge_index_batch_map]
Expand Down

0 comments on commit bd296a3

Please sign in to comment.