Skip to content

Commit

Permalink
skip useless repeat_interleave computation in some specific but commo…
Browse files Browse the repository at this point in the history
…n scenarios
  • Loading branch information
leonardcaquot94 committed Jan 7, 2025
1 parent bd296a3 commit 06088d5
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions torch_geometric/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,9 @@ def filter(self, idx: torch.Tensor) -> Self:
for attr, slc in slices.items():
slice_diff = slc.diff()

# Reshape mask to align it with attribute shape
attr_mask = mask[torch.repeat_interleave(slice_diff)]
# Reshape mask to align it with attribute shape.
# Since slice_diff often contains only ones, skip useless computation in such cases
attr_mask = mask[torch.repeat_interleave(slice_diff)] if torch.any(slice_diff != 1) else mask

# Apply mask to attribute
if attr == 'edge_index':
Expand Down

0 comments on commit 06088d5

Please sign in to comment.