Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 6, 2024
1 parent 01defeb commit ae907b7
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## \[2.6.0\] - 2024-MM-DD

### Added

- Added implemenation of `Batch.{from_batch_list,from_batch_index,add_graph_attr,set_edge_attr,set_edges}` ([#8414](https://github.com/pyg-team/pytorch_geometric/pull/8414))
- Added the `LinkPredMRR` metric ([#9632](https://github.com/pyg-team/pytorch_geometric/pull/9632))
- Added PyTorch 2.4 support ([#9594](https://github.com/pyg-team/pytorch_geometric/pull/9594))
Expand Down Expand Up @@ -42,6 +43,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added documentation on environment setup on XPU device ([#9407](https://github.com/pyg-team/pytorch_geometric/pull/9407))

### Changed

- Convert `Batch.index_select` to a full slicing operation returning a new batch instead of a list of `Data` ([#8414](https://github.com/pyg-team/pytorch_geometric/pull/8414))
- Use `torch.load(weights_only=True)` by default ([#9618](https://github.com/pyg-team/pytorch_geometric/pull/9618))
- Adapt `cugraph` examples to its new API ([#9541](https://github.com/pyg-team/pytorch_geometric/pull/9541))
Expand Down
6 changes: 3 additions & 3 deletions torch_geometric/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,13 @@ def index_select(self, idx: IndexType) -> Self:
index = idx.flatten().nonzero(as_tuple=False).flatten().tolist()

elif isinstance(idx, np.ndarray) and idx.dtype == np.int64:
index = idx.flatten().tolist()
idx.flatten().tolist()

elif isinstance(idx, np.ndarray) and idx.dtype == bool:
index = idx.flatten().nonzero()[0].flatten().tolist()
idx.flatten().nonzero()[0].flatten().tolist()

elif isinstance(idx, Sequence) and not isinstance(idx, str):
index = idx
pass

else:
raise IndexError(
Expand Down
1 change: 0 additions & 1 deletion torch_geometric/data/separate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch import Tensor
from torch.nn.functional import pad

from torch_geometric import EdgeIndex, Index
from torch_geometric.data.data import BaseData
from torch_geometric.data.storage import BaseStorage
from torch_geometric.typing import SparseTensor, TensorFrame
Expand Down

0 comments on commit ae907b7

Please sign in to comment.