From 9fb3ea1cc53cacd41fb1d51eb94e30ab3dcc5ce8 Mon Sep 17 00:00:00 2001 From: mova Date: Tue, 21 Nov 2023 14:19:26 +0100 Subject: [PATCH] Implemenation of Batch.{from_batch_list,from_batch_index,add_graph_attr,set_edge_attr,set_edges} #8414 --- CHANGELOG.md | 4 +- test/data/test_batch.py | 64 +++--- test/data/test_batch_manipulation.py | 220 +++++++++++++++++++ torch_geometric/data/batch.py | 249 ++++++++++++++++++++-- torch_geometric/data/in_memory_dataset.py | 12 +- torch_geometric/data/separate.py | 106 ++++++--- torch_geometric/data/storage.py | 2 +- 7 files changed, 571 insertions(+), 86 deletions(-) create mode 100644 test/data/test_batch_manipulation.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e6e757d53f6a..d9c75964c4ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## \[2.6.0\] - 2024-MM-DD ### Added - +- Added implemenation of `Batch.{from_batch_list,from_batch_index,add_graph_attr,set_edge_attr,set_edges}` ([#8414](https://github.com/pyg-team/pytorch_geometric/pull/8414)) - Added the `GRetriever` model ([#9480](https://github.com/pyg-team/pytorch_geometric/pull/9480)) - Added the `ClusterPooling` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627)) - Added the `LinkPredMRR` metric ([#9632](https://github.com/pyg-team/pytorch_geometric/pull/9632)) @@ -44,7 +44,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added documentation on environment setup on XPU device ([#9407](https://github.com/pyg-team/pytorch_geometric/pull/9407)) ### Changed - +- Convert `Batch.index_select` to a full slicing operation returning a new batch instead of a list of `Data` ([#8414](https://github.com/pyg-team/pytorch_geometric/pull/8414)) - Use `torch.load(weights_only=True)` by default ([#9618](https://github.com/pyg-team/pytorch_geometric/pull/9618)) - Adapt `cugraph` examples to its new API ([#9541](https://github.com/pyg-team/pytorch_geometric/pull/9541)) - Allow optional but untyped tensors in `MessagePassing` ([#9494](https://github.com/pyg-team/pytorch_geometric/pull/9494)) diff --git a/test/data/test_batch.py b/test/data/test_batch.py index b893c1e4dd11..75c71ecaa327 100644 --- a/test/data/test_batch.py +++ b/test/data/test_batch.py @@ -63,24 +63,27 @@ def test_batch_basic(): assert batch.batch.tolist() == [0, 0, 0, 1, 1, 2, 2, 2, 2] assert batch.ptr.tolist() == [0, 3, 5, 9] - assert str(batch[0]) == ("Data(x=[3], edge_index=[2, 4], y=[1], " - "string='1', array=[2], num_nodes=3)") - assert str(batch[1]) == ("Data(x=[2], edge_index=[2, 2], y=[1], " - "string='2', array=[3], num_nodes=2)") - assert str(batch[2]) == ("Data(x=[4], edge_index=[2, 6], y=[1], " - "string='3', array=[4], num_nodes=4)") - - assert len(batch.index_select([1, 0])) == 2 - assert len(batch.index_select(torch.tensor([1, 0]))) == 2 - assert len(batch.index_select(torch.tensor([True, False]))) == 1 - assert len(batch.index_select(np.array([1, 0], dtype=np.int64))) == 2 - assert len(batch.index_select(np.array([True, False]))) == 1 - assert len(batch[:2]) == 2 + graphs = batch[0], batch[1], batch[2] + assert all([isinstance(e, Data) for e in graphs]) + assert [e.x.shape[0] for e in graphs] == [3, 2, 4] + assert [e.y.shape[0] for e in graphs] == [1, 1, 1] + assert [list(e.edge_index.shape) for e in graphs] == [[2, 4], [2, 2], + [2, 6]] + assert [e.string for e in graphs] == ["1", "2", "3"] + assert [e.num_nodes for e in graphs] == [3, 2, 4] + + assert batch.index_select([1, 0]).num_graphs == 2 + assert batch.index_select(torch.tensor([1, 0])).num_graphs == 2 + assert batch.index_select(torch.tensor([True, False, + False])).num_graphs == 1 + assert batch.index_select(np.array([1, 0], dtype=np.int64)).num_graphs == 2 + assert batch.index_select(np.array([True, False, False])).num_graphs == 1 + assert batch[:2].num_graphs == 2 data_list = batch.to_data_list() assert len(data_list) == 3 - assert len(data_list[0]) == 6 + assert set(data_list[0].keys()) == set(data1.keys()) assert data_list[0].x.tolist() == [1, 2, 3] assert data_list[0].y.tolist() == [1] assert data_list[0].edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] @@ -111,21 +114,24 @@ def test_batch_basic(): def test_index(): index1 = Index([0, 1, 1, 2], dim_size=3, is_sorted=True) index2 = Index([0, 1, 1, 2, 2, 3], dim_size=4, is_sorted=True) + index3 = Index([0, 1, 1, 2], dim_size=3, is_sorted=True) data1 = Data(index=index1, num_nodes=3) data2 = Data(index=index2, num_nodes=4) + data3 = Data(index=index1, num_nodes=3) - batch = Batch.from_data_list([data1, data2]) + batch = Batch.from_data_list([data1, data2, data3]) - assert len(batch) == 2 - assert batch.batch.equal(torch.tensor([0, 0, 0, 1, 1, 1, 1])) - assert batch.ptr.equal(torch.tensor([0, 3, 7])) + assert len(batch) == 3 + assert batch.batch.equal(torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2])) + assert batch.ptr.equal(torch.tensor([0, 3, 7, 10])) assert isinstance(batch.index, Index) - assert batch.index.equal(torch.tensor([0, 1, 1, 2, 3, 4, 4, 5, 5, 6])) - assert batch.index.dim_size == 7 + assert batch.index.equal( + torch.tensor([0, 1, 1, 2, 3, 4, 4, 5, 5, 6, 7, 8, 8, 9])) + assert batch.index.dim_size == 10 assert batch.index.is_sorted - for i, index in enumerate([index1, index2]): + for i, index in enumerate([index1, index2, index3]): data = batch[i] assert isinstance(data.index, Index) assert data.index.equal(index) @@ -149,18 +155,16 @@ def test_edge_index(): data1 = Data(edge_index=edge_index1) data2 = Data(edge_index=edge_index2) - batch = Batch.from_data_list([data1, data2]) + batch = Batch.from_data_list([data1, data2, data1.clone()]) - assert len(batch) == 2 - assert batch.batch.equal(torch.tensor([0, 0, 0, 1, 1, 1, 1])) - assert batch.ptr.equal(torch.tensor([0, 3, 7])) + assert len(batch) == 3 + assert batch.batch.equal(torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2])) + assert batch.ptr.equal(torch.tensor([0, 3, 7, 10])) assert isinstance(batch.edge_index, EdgeIndex) assert batch.edge_index.equal( - torch.tensor([ - [0, 1, 1, 2, 4, 3, 5, 4, 6, 5], - [1, 0, 2, 1, 3, 4, 4, 5, 5, 6], - ])) - assert batch.edge_index.sparse_size() == (7, 7) + torch.tensor([[0, 1, 1, 2, 4, 3, 5, 4, 6, 5, 7, 8, 8, 9], + [1, 0, 2, 1, 3, 4, 4, 5, 5, 6, 8, 7, 9, 8]])) + assert batch.edge_index.sparse_size() == (10, 10) assert batch.edge_index.sort_order is None assert not batch.edge_index.is_undirected diff --git a/test/data/test_batch_manipulation.py b/test/data/test_batch_manipulation.py new file mode 100644 index 000000000000..eee6e720a7bf --- /dev/null +++ b/test/data/test_batch_manipulation.py @@ -0,0 +1,220 @@ +import torch + +from torch_geometric.data import Batch, Data + +device = torch.device("cpu") + + +def test_batch_add_node_attr(): + batch_size = 3 + node_range = (1, 10) + nodes_v = torch.randint(*node_range, (batch_size, )) + x_list = [torch.rand(n, 3).to(device) for n in nodes_v] + + batch_list = [Data(x=x) for x in x_list] + batch_truth = Batch.from_data_list(batch_list) + + batch_idx = batch_truth.batch + + batch = Batch.from_batch_index(batch_idx) + + batch.add_node_attr("x", torch.vstack(x_list)) + + compare(batch_truth, batch) + + +def test_batch_set_edge_index(): + batch_size = 4 + node_range = (2, 5) + nodes_v = torch.randint(*node_range, (batch_size, )).to(device) + edges_per_graph = torch.cat([ + torch.randint(1, num_nodes, size=(1, )).to(device) + for num_nodes in nodes_v + ]).to(device) + x_list = [torch.rand(num_nodes, 3).to(device) for num_nodes in nodes_v] + edges_list = [ + torch.vstack([ + torch.randint(0, num_nodes, size=(num_edges, )), + torch.randint(0, num_nodes, size=(num_edges, )), + ]).to(device) for num_nodes, num_edges in zip(nodes_v, edges_per_graph) + ] + + batch_list = [ + Data(x=x, edge_index=edges) for x, edges in zip(x_list, edges_list) + ] + batch_truth = Batch.from_data_list(batch_list) + + batch_idx = batch_truth.batch + batch = Batch.from_batch_index(batch_idx) + batch.add_node_attr("x", torch.vstack(x_list)) + + batchidx_per_edge = torch.cat([ + torch.ones(num_edges).long().to(device) * igraph + for igraph, num_edges in enumerate(edges_per_graph) + ]) + batch.set_edge_index(torch.hstack(edges_list), batchidx_per_edge) + compare(batch_truth, batch) + + +def test_batch_set_edge_attr(): + batch_size = 4 + node_range = (2, 5) + nodes_v = torch.randint(*node_range, (batch_size, )).to(device) + edges_per_graph = torch.cat([ + torch.randint(1, num_nodes, size=(1, )).to(device) + for num_nodes in nodes_v + ]) + x_list = [torch.rand(num_nodes, 3).to(device) for num_nodes in nodes_v] + edges_list = [ + torch.vstack([ + torch.randint(0, num_nodes, size=(num_edges, )), + torch.randint(0, num_nodes, size=(num_edges, )), + ]).to(device) for num_nodes, num_edges in zip(nodes_v, edges_per_graph) + ] + edge_attr_list = [ + torch.rand(num_edges).to(device) for num_edges in edges_per_graph + ] + + batch_list = [ + Data(x=x, edge_index=edges, edge_attr=ea) + for x, edges, ea in zip(x_list, edges_list, edge_attr_list) + ] + batch_truth = Batch.from_data_list(batch_list) + + batch_idx = batch_truth.batch + batch = Batch.from_batch_index(batch_idx) + batch.add_node_attr("x", torch.vstack(x_list)) + + batchidx_per_edge = torch.cat([ + torch.ones(num_edges).to(device).long() * igraph + for igraph, num_edges in enumerate(edges_per_graph) + ]) + batch.set_edge_index(torch.hstack(edges_list), batchidx_per_edge) + batch.set_edge_attr(torch.hstack(edge_attr_list)) + compare(batch_truth, batch) + + +def test_batch_add_graph_attr(): + batch_size = 3 + node_range = (1, 10) + nodes_v = torch.randint(*node_range, (batch_size, )).to(device) + x_list = [torch.rand(n, 3).to(device) for n in nodes_v] + graph_attr_list = torch.rand(batch_size).to(device) + + batch_list = [Data(x=x, ga=ga) for x, ga in zip(x_list, graph_attr_list)] + batch_truth = Batch.from_data_list(batch_list) + + batch_idx = batch_truth.batch + + batch = Batch.from_batch_index(batch_idx) + + batch.add_node_attr("x", torch.vstack(x_list)) + batch.add_graph_attr("ga", graph_attr_list) + compare(batch_truth, batch) + + +def test_from_batch_list(): + batch_size = 12 + node_range = (2, 5) + nodes_v = torch.randint(*node_range, (batch_size, )).to(device) + edges_per_graph = torch.cat([ + torch.randint(1, num_nodes, size=(1, )).to(device) + for num_nodes in nodes_v + ]) + x_list = [torch.rand(num_nodes, 3).to(device) for num_nodes in nodes_v] + edges_list = [ + torch.vstack([ + torch.randint(0, num_nodes, size=(num_edges, )), + torch.randint(0, num_nodes, size=(num_edges, )), + ]).to(device) for num_nodes, num_edges in zip(nodes_v, edges_per_graph) + ] + edge_attr_list = [ + torch.rand(num_edges).to(device) for num_edges in edges_per_graph + ] + graph_attr_list = torch.rand(batch_size).to(device) + + batch_list = [ + Data(x=x, edge_index=edges, + edge_attr=ea, ga=ga) for x, edges, ea, ga in zip( + x_list, edges_list, edge_attr_list, graph_attr_list) + ] + batch_truth = Batch.from_data_list(batch_list) + batch = Batch.from_batch_list([ + Batch.from_data_list(batch_list[:3]), + Batch.from_data_list(batch_list[3:5]), + Batch.from_data_list(batch_list[5:7]), + Batch.from_data_list(batch_list[7:]), + ]) + + compare(batch_truth, batch) + + +def test_batch_slice(): + batch_size = 9 + node_range = (3, 8) + nodes_v = torch.randint(*node_range, (batch_size, )).to(device) + edges_per_graph = torch.cat([ + torch.randint(1, num_nodes, size=(1, )).to(device) + for num_nodes in nodes_v + ]) + x_list = [torch.rand(num_nodes, 3).to(device) for num_nodes in nodes_v] + edges_list = [ + torch.vstack([ + torch.randint(0, num_nodes, size=(num_edges, )), + torch.randint(0, num_nodes, size=(num_edges, )), + ]).to(device) for num_nodes, num_edges in zip(nodes_v, edges_per_graph) + ] + edge_attr_list = [ + torch.rand(num_edges).to(device) for num_edges in edges_per_graph + ] + graph_attr_list = torch.rand(batch_size, 1, 5).to(device) + + batch_list = [ + Data(x=x, edge_index=edges, + edge_attr=ea, ga=ga) for x, edges, ea, ga in zip( + x_list, edges_list, edge_attr_list, graph_attr_list) + ] + bslice = torch.FloatTensor(batch_size).uniform_() > 0.4 + batch_full = Batch.from_data_list(batch_list) + batch_truth = Batch.from_data_list( + [batch_list[e] for e in bslice.nonzero().squeeze()]) + batch_new = batch_full[bslice] + compare(batch_truth, batch_new) + + +def compare(ba: Batch, bb: Batch): + if set(ba.keys()) != set(bb.keys()): + raise Exception() + assert (ba.batch == bb.batch).all() + assert (ba.ptr == bb.ptr).all() + for k in ba.keys(): + try: + rec_comp(ba[k], bb[k]) + except Exception as e: + raise Exception( + f"Batch comparison failed tensor for key {k}.") from e + if k in ba._slice_dict or k in bb._slice_dict: + try: + rec_comp(ba._slice_dict[k], bb._slice_dict[k]) + except Exception as e: + raise Exception( + f"Batch comparison failed _slice_dict for key {k}.") from e + if k in ba._inc_dict or k in bb._inc_dict: + try: + rec_comp(ba._inc_dict[k], bb._inc_dict[k]) + except Exception as e: + raise Exception( + f"Batch comparison failed _inc_dict for key {k}.") from e + + +def rec_comp(a, b): + if not type(a) is type(b): + raise Exception() + if isinstance(a, dict): + if not set(a.keys()) == set(b.keys()): + raise Exception() + for k in a: + rec_comp(a[k], b[k]) + if isinstance(a, torch.Tensor): + if not (a == b).all(): + raise Exception() diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index 411639a228d2..51e0e775fbdd 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -1,10 +1,12 @@ import inspect +from collections import defaultdict from collections.abc import Sequence from typing import Any, List, Optional, Type, Union import numpy as np import torch from torch import Tensor +from torch.nn.functional import pad from typing_extensions import Self from torch_geometric.data.collate import collate @@ -79,6 +81,25 @@ class Batch(metaclass=DynamicInheritance): Furthermore, :meth:`~Data.__cat_dim__` defines in which dimension graph tensors of the same attribute should be concatenated together. """ + @classmethod + def from_batch_index(cls, batch_idx: Tensor) -> Self: + batch = Batch() + + if not batch_idx.dtype == torch.long: + raise Exception("Batch index dtype must be torch.long") + if not (batch_idx.diff() >= 0).all(): + raise Exception("Batch index must be increasing") + if not batch_idx.dim() == 1: + raise Exception() + + batch.batch = batch_idx + batch.ptr = batch.__ptr_from_batchidx(batch_idx) + batch._num_graphs = int(batch.batch.max() + 1) + + batch._slice_dict = defaultdict(dict) + batch._inc_dict = defaultdict(dict) + return batch + @classmethod def from_data_list( cls, @@ -89,7 +110,7 @@ def from_data_list( r"""Constructs a :class:`~torch_geometric.data.Batch` object from a list of :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects. - The assignment vector :obj:`batch` is created on the fly. + The assignment vector :obj:`batch` is adjusted on the fly. In addition, creates assignment vectors for each key in :obj:`follow_batch`. Will exclude any keys given in :obj:`exclude_keys`. @@ -109,6 +130,48 @@ def from_data_list( return batch + @classmethod + def from_batch_list( + cls, + batches: List[Self], + follow_batch: Optional[List[str]] = None, + exclude_keys: Optional[List[str]] = None, + ) -> Self: + r"""Same as :meth:`~Batch.from_data_list```, + but for concatenating existing batches. + Constructs a :class:`~torch_geometric.data.Batch` object from a + list of :class:`~torch_geometric.data.Batch` objects. + The assignment vector :obj:`batch` is created on the fly. + In addition, creates assignment vectors for each key in + :obj:`follow_batch`. + Will exclude any keys given in :obj:`exclude_keys`. + """ + batch = cls.from_data_list(batches, follow_batch, exclude_keys) + + del batch._slice_dict["batch"], batch._inc_dict["batch"] + + batch.ptr = cls.__ptr_from_batchidx(cls, batch.batch) + batch._num_graphs = batch.ptr.numel() - 1 + + for k in set(batch.keys()) - {"batch", "ptr"}: + # slice_shift = [0] + [be._slice_dict[k][-1] for be in batches ] + batch._slice_dict[k] = batch._pad_zero( + torch.concat([be._slice_dict[k].diff() + for be in batches]).cumsum(0)) + if k != "edge_index": + inc_shift = batch._pad_zero( + torch.tensor([sum(be._inc_dict[k]) + for be in batches])).cumsum(0) + else: + inc_shift = batch._pad_zero( + torch.tensor([be.num_nodes for be in batches])).cumsum(0) + + batch._inc_dict[k] = torch.cat([ + be._inc_dict[k] + inc_shift[ibatch] + for ibatch, be in enumerate(batches) + ]) + return batch + def get_example(self, idx: int) -> BaseData: r"""Gets the :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object at index :obj:`idx`. @@ -124,15 +187,16 @@ def get_example(self, idx: int) -> BaseData: data = separate( cls=self.__class__.__bases__[-1], batch=self, - idx=idx, + idx=torch.tensor([idx]).long(), slice_dict=getattr(self, '_slice_dict'), inc_dict=getattr(self, '_inc_dict'), decrement=True, + return_batch=False, ) return data - def index_select(self, idx: IndexType) -> List[BaseData]: + def index_select(self, idx: IndexType) -> Self: r"""Creates a subset of :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects from specified indices :obj:`idx`. @@ -143,42 +207,90 @@ def index_select(self, idx: IndexType) -> List[BaseData]: via :meth:`from_data_list` in order to be able to reconstruct the initial objects. """ - index: Sequence[int] - if isinstance(idx, slice): - index = list(range(self.num_graphs)[idx]) + if not isinstance(idx, Union[slice, Sequence, torch.Tensor, + np.ndarray]): + raise IndexError( + f"Only slices (':'), list, tuples, torch.tensor and " + f"np.ndarray of dtype long or bool are valid indices (got " + f"'{type(idx).__name__}')") - elif isinstance(idx, Tensor) and idx.dtype == torch.long: - index = idx.flatten().tolist() + index: torch.Tensor + + def _to_tidx(o): + return torch.tensor(o, device=self.ptr.device) + # convert numpt to torch tensors + if isinstance(idx, np.ndarray): + idx = _to_tidx(idx) + if isinstance(idx, slice): + index = _to_tidx(range(self.num_graphs)[idx]).long() + elif isinstance(idx, Sequence): + index = _to_tidx(idx).long() + elif isinstance(idx, Tensor) and idx.dtype == torch.long: + index = idx elif isinstance(idx, Tensor) and idx.dtype == torch.bool: - index = idx.flatten().nonzero(as_tuple=False).flatten().tolist() + if not len(idx) == self.num_graphs: + IndexError( + f"Boolen vector length does not match number of graphs" + f" (got {len(idx)} vector size " + f"vs. {self.num_graphs} graphs).") + index = idx.nonzero().flatten() + else: + raise IndexError( + f"Could not convert index (got '{type(idx).__name__}')") - elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: - index = idx.flatten().tolist() + if index.dim() != 1: + raise IndexError( + f"Index must have a single dimension (got {index.dim()})") - elif isinstance(idx, np.ndarray) and idx.dtype == bool: - index = idx.flatten().nonzero()[0].flatten().tolist() + dev = self.ptr.device - elif isinstance(idx, Sequence) and not isinstance(idx, str): - index = idx + subbatch = separate( + cls=self.__class__, + batch=self, + idx=index, + slice_dict=self._slice_dict, + inc_dict=self._inc_dict, + decrement=True, + ) - else: - raise IndexError( - f"Only slices (':'), list, tuples, torch.tensor and " - f"np.ndarray of dtype long or bool are valid indices (got " - f"'{type(idx).__name__}')") + nodes_per_graph = self.ptr.diff() + new_nodes_per_graph = nodes_per_graph[index] + + # Construct batch index and ptr + subbatch.batch = torch.arange( + len(index)).to(dev).long().repeat_interleave(new_nodes_per_graph) + subbatch.ptr = self.__ptr_from_batchidx(subbatch.batch) - return [self.get_example(i) for i in index] + # fix the _slice_dict and _inc_dict + subbatch._slice_dict = defaultdict(dict) + subbatch._inc_dict = defaultdict(dict) + for k in set(self.keys()) - {"ptr", "batch"}: + if k not in self._slice_dict: + continue + subbatch._slice_dict[k] = pad(self._slice_dict[k].diff()[index], + (1, 0)).cumsum(0) + if k not in self._inc_dict: + continue + if self._inc_dict[k] is None: + subbatch._inc_dict[k] = None + continue + subbatch._inc_dict[k] = pad(self._inc_dict[k].diff()[index[:-1]], + (1, 0)).cumsum(0) + return subbatch def __getitem__(self, idx: Union[int, np.integer, str, IndexType]) -> Any: + # Return single Graph if (isinstance(idx, (int, np.integer)) or (isinstance(idx, Tensor) and idx.dim() == 0) or (isinstance(idx, np.ndarray) and np.isscalar(idx))): return self.get_example(idx) # type: ignore + # Return stored objects elif isinstance(idx, str) or (isinstance(idx, tuple) and isinstance(idx[0], str)): # Accessing attributes or node/edge types: return super().__getitem__(idx) # type: ignore + # Return subset of the batch else: return self.index_select(idx) @@ -215,3 +327,98 @@ def __len__(self) -> int: def __reduce__(self) -> Any: state = self.__dict__.copy() return DynamicInheritanceGetter(), self.__class__.__bases__, state + + def add_node_attr(self, attrname: str, attr: Tensor) -> None: + r"""Adds an attribute to the nodes in an existing batch. + The first dimension of the :obj:`attr` must match the number + of nodes in the batch. The exisiting + :obj:`~torch_geometric.data.Batch.batch` will be used to + assign the elements to the correct graph. + """ + assert attr.device == self.batch.device + batch_idxs = self.batch + + self[attrname] = attr + out = batch_idxs.unique(return_counts=True)[1] + out = out.cumsum(dim=0) + self._slice_dict[attrname] = self._pad_zero(out).cpu() + + self._inc_dict[attrname] = torch.zeros(self._num_graphs, + dtype=torch.long) + + def add_graph_attr(self, attrname: str, attr: Tensor) -> None: + r"""Adds an attribute to the graphs in an existing batch. + The first dimension of the :obj:`attr` must match the + number of nodes in the batch. The exisiting :obj:`~Batch.batch` + will be used to assign the elements to the correct graph. + """ + assert attr.device == self.batch.device + + self[attrname] = attr + self._slice_dict[attrname] = torch.arange(self.num_graphs + 1, + dtype=torch.long) + + self._inc_dict[attrname] = torch.zeros(self.num_graphs, + dtype=torch.long) + + def set_edge_index(self, edge_index: Tensor, + batchidx_per_edge: Tensor) -> None: + r"""Sets or overwrites :obj:`edge_index` in an existing batch. + For this, :obj:`batchidx_per_edge` should contain, to which of + the graphs each of the pair of nodes belongs. :obj:`edge_index` + must have the shape :obj:`[2:num_edges]`; :obj:`batchidx_per_edge` + must have the shape :obj:`[num_edges]`. Both must be instances of + :class:`torch.LongTensor`. The exisiting :obj:`~Batch.ptr` will + be used to assign the elements to the correct graph. + """ + assert edge_index.dtype == batchidx_per_edge.dtype == torch.long + assert (edge_index.device == batchidx_per_edge.device == + self.batch.device) + assert (batchidx_per_edge.diff() + >= 0).all(), "Edges must be ordered by batch" + if 'edge_index' not in self._store.keys(): + self.edge_index = torch.empty(2, 0, dtype=torch.long, + device=self.batch.device) + # Edges must be shifted by the number sum of the nodes + # in the previous graphs + edge_index += self.ptr[batchidx_per_edge] + self.edge_index = torch.hstack((self.edge_index.clone(), edge_index)) + # Fix _slice_dict + edges_per_graph = batchidx_per_edge.unique(return_counts=True)[1] + self._slice_dict["edge_index"] = self._pad_zero( + edges_per_graph.cumsum(0)).cpu() + self._inc_dict["edge_index"] = self.ptr[:-1].cpu() + + def set_edge_attr(self, edge_attr: Tensor) -> None: + r"""Sets or overwrites :obj:`edge_attr` in an existing batch. + The first dimension of :obj:`edge_attr` must match the number + of edges in the batch. The exisiting :obj:`~Batch.edge_index` + will be used to assign the elements to the correct graph. + """ + assert (hasattr(self, "edge_index") + and self["edge_index"].dtype == torch.long) + self.edge_attr = edge_attr + self._slice_dict["edge_attr"] = self._slice_dict["edge_index"] + self._inc_dict["edge_attr"] = torch.zeros(self.num_graphs) + + def _pad_zero(self, arr: torch.Tensor) -> torch.Tensor: + return torch.cat([ + torch.tensor(0, dtype=arr.dtype, device=arr.device).unsqueeze(0), + arr + ]) + + def __ptr_from_batchidx(self, batch_idx: Tensor) -> torch.Tensor: + # Construct the ptr to adress single graphs + assert batch_idx.dtype == torch.long + # graph[idx].x== batch.x[batch.ptr[idx]:batch.ptr[idx]+1] + # Get delta with diff + # Get idx of diff >0 with nonzero + # shift by -1 + # add the batch size -1 as last element and add 0 in front + dev = batch_idx.device + ptr = torch.concatenate(( + torch.tensor(0).long().to(dev).unsqueeze(0), + (batch_idx.diff()).nonzero().reshape(-1) + 1, + torch.tensor(len(batch_idx)).long().to(dev).unsqueeze(0), + )) + return ptr diff --git a/torch_geometric/data/in_memory_dataset.py b/torch_geometric/data/in_memory_dataset.py index 9f307bf14d5d..3e54d71af89b 100644 --- a/torch_geometric/data/in_memory_dataset.py +++ b/torch_geometric/data/in_memory_dataset.py @@ -107,14 +107,11 @@ def get(self, idx: int) -> BaseData: self._data_list = self.len() * [None] elif self._data_list[idx] is not None: return copy.copy(self._data_list[idx]) + self._data._num_graphs = self.len() - data = separate( - cls=self._data.__class__, - batch=self._data, - idx=idx, - slice_dict=self.slices, - decrement=False, - ) + data = separate(cls=self._data.__class__, batch=self._data, + idx=torch.tensor([idx]).long(), slice_dict=self.slices, + decrement=False, return_batch=False) self._data_list[idx] = copy.copy(data) @@ -158,6 +155,7 @@ def collate( increment=False, add_batch=False, ) + data._num_graphs = len(data_list) return data, slices diff --git a/torch_geometric/data/separate.py b/torch_geometric/data/separate.py index 2910b6679f60..d3522941605e 100644 --- a/torch_geometric/data/separate.py +++ b/torch_geometric/data/separate.py @@ -1,24 +1,26 @@ from collections.abc import Mapping, Sequence from typing import Any, Type, TypeVar +import torch from torch import Tensor +from torch.nn.functional import pad from torch_geometric import EdgeIndex, Index from torch_geometric.data.data import BaseData from torch_geometric.data.storage import BaseStorage from torch_geometric.typing import SparseTensor, TensorFrame -from torch_geometric.utils import narrow T = TypeVar('T') def separate( cls: Type[T], - batch: Any, - idx: int, + batch: BaseData, + idx: torch.Tensor, slice_dict: Any, inc_dict: Any = None, decrement: bool = True, + return_batch: bool = True, ) -> T: # Separates the individual element from a `batch` at index `idx`. # `separate` can handle both homogeneous and heterogeneous data objects by @@ -50,8 +52,16 @@ def separate( # The `num_nodes` attribute needs special treatment, as we cannot infer # the real number of nodes from the total number of nodes alone: + # TODO make less ugly if hasattr(batch_store, '_num_nodes'): - data_store.num_nodes = batch_store._num_nodes[idx] + if return_batch: + data_store._num_nodes = [ + batch_store._num_nodes[i] for i in idx + ] + else: + data_store._num_nodes = batch_store._num_nodes[idx[0]] + if hasattr(batch_store, 'num_nodes'): + data_store.num_nodes = data_store._num_nodes return data @@ -59,43 +69,82 @@ def separate( def _separate( key: str, values: Any, - idx: int, + idx: torch.Tensor, slices: Any, incs: Any, batch: BaseData, store: BaseStorage, decrement: bool, ) -> Any: - if isinstance(values, Tensor): + idx = idx.to(values.device) + graph_slice = torch.concat([ + torch.arange(int(slices[i]), int(slices[i + 1])) for i in idx + ]).to(values.device) + valid_inc = incs is not None and (incs.dim() > 1 + or any(incs[idx] != 0)) + # Narrow a `torch.Tensor` based on `slices`. # NOTE: We need to take care of decrementing elements appropriately. key = str(key) cat_dim = batch.__cat_dim__(key, values, store) - start, end = int(slices[idx]), int(slices[idx + 1]) - value = narrow(values, cat_dim or 0, start, end - start) + value = torch.index_select(values, cat_dim or 0, graph_slice) value = value.squeeze(0) if cat_dim is None else value - if isinstance(values, Index) and values._cat_metadata is not None: - # Reconstruct original `Index` metadata: - value._dim_size = values._cat_metadata.dim_size[idx] - value._is_sorted = values._cat_metadata.is_sorted[idx] - - if isinstance(values, EdgeIndex) and values._cat_metadata is not None: - # Reconstruct original `EdgeIndex` metadata: - value._sparse_size = values._cat_metadata.sparse_size[idx] - value._sort_order = values._cat_metadata.sort_order[idx] - value._is_undirected = values._cat_metadata.is_undirected[idx] - - if (decrement and incs is not None - and (incs.dim() > 1 or int(incs[idx]) != 0)): - value = value - incs[idx].to(value.device) + if (decrement and incs is not None and valid_inc): + # remove the old offset + nelem_new = slices.diff()[idx] + if len(idx) == 1: + old_offset = incs[idx[0]] + new_offset = torch.zeros_like(old_offset) + shift = torch.ones_like(value) * (-old_offset + new_offset) + else: + old_offset = incs[idx] + # add the new offset + # for this we compute the number of nodes in the batch before + new_offset = pad(incs.diff()[idx[:-1]], (1, 0)).cumsum(0) + shift = (-old_offset + new_offset).repeat_interleave( + nelem_new, dim=cat_dim or 0) + value = value + shift + + if hasattr(values, + "_cat_metadata") and values._cat_metadata is not None: + + def _pad_diff(a): + a = torch.tensor(a) + return torch.concat([a[:1], a.diff(dim=0)]) + + if isinstance(values, Index): + # Reconstruct original `Index` metadata: + if decrement: + value._dim_size = _pad_diff( + values._cat_metadata.dim_size)[idx].squeeze() + else: + value._dim_size = values._cat_metadata.dim_size[idx] + value._is_sorted = values._cat_metadata.is_sorted[idx] + + if isinstance(values, EdgeIndex): + # Reconstruct original `EdgeIndex` metadata: + def _to_tup(a): + return tuple(a.flatten().tolist()) + + if decrement: + value._sparse_size = _to_tup( + _pad_diff(values._cat_metadata.sparse_size)[idx]) + else: + value._sparse_size = values._cat_metadata.sparse_size[idx] + value._sort_order = values._cat_metadata.sort_order[idx] + value._is_undirected = values._cat_metadata.is_undirected[idx] return value elif isinstance(values, SparseTensor) and decrement: # Narrow a `SparseTensor` based on `slices`. # NOTE: `cat_dim` may return a tuple to allow for diagonal stacking. + if len(idx) > 1: + raise NotImplementedError + idx = idx[0] + key = str(key) cat_dim = batch.__cat_dim__(key, values, store) cat_dims = (cat_dim, ) if isinstance(cat_dim, int) else cat_dim @@ -107,8 +156,8 @@ def _separate( elif isinstance(values, TensorFrame): key = str(key) start, end = int(slices[idx]), int(slices[idx + 1]) - value = values[start:end] - return value + values = values[start:end] + return values elif isinstance(values, Mapping): # Recursively separate elements of dictionaries. @@ -131,6 +180,9 @@ def _separate( and not isinstance(values[0], str) and len(values[0]) > 0 and isinstance(values[0][0], (Tensor, SparseTensor)) and isinstance(slices, Sequence)): + if len(idx) > 1: + raise NotImplementedError + idx = idx[0] # Recursively separate elements of lists of lists. return [value[idx] for value in values] @@ -150,6 +202,10 @@ def _separate( decrement=decrement, ) for i, value in enumerate(values) ] - + elif isinstance(values, list) and batch._num_graphs == len(values): + if len(idx) == 1: + return values[idx[0]] + else: + return [values[i] for i in idx] else: return values[idx] diff --git a/torch_geometric/data/storage.py b/torch_geometric/data/storage.py index 07a52fdfc21a..c71ba627ffdb 100644 --- a/torch_geometric/data/storage.py +++ b/torch_geometric/data/storage.py @@ -420,7 +420,7 @@ def can_infer_num_nodes(self) -> bool: @property def num_nodes(self) -> Optional[int]: # We sequentially access attributes that reveal the number of nodes. - if 'num_nodes' in self: + if 'num_nodes' in self._mapping: return self['num_nodes'] for key, value in self.items(): if isinstance(value, Tensor) and key in N_KEYS: