Skip to content

Commit

Permalink
support non tensor attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardcaquot94 committed Jan 3, 2025
1 parent 007baa8 commit 7b9ec12
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions torch_geometric/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,12 @@ def filter(self, idx: torch.Tensor) -> Self:
attr_mask = mask[torch.repeat_interleave(slice_diff)]

# Apply mask to attribute
new_store[attr] = old_store[
attr][:, attr_mask] if attr == 'edge_index' else old_store[
attr][attr_mask]
if attr == 'edge_index':
new_store[attr] = old_store[attr][:, attr_mask]
elif isinstance(old_store[attr], list):
new_store[attr] = [item for item, m in zip(old_store[attr], attr_mask) if m]
else:
new_store[attr] = old_store[attr][attr_mask]

# Compute masked version of slice tensor
sizes_masked = slice_diff[mask]
Expand Down

0 comments on commit 7b9ec12

Please sign in to comment.