diff --git a/test/data/test_batch_manipulation.py b/test/data/test_batch_manipulation.py index 43dacc7fb7c74..7a3fbb2ae7b34 100644 --- a/test/data/test_batch_manipulation.py +++ b/test/data/test_batch_manipulation.py @@ -150,8 +150,8 @@ def test_from_batch_list(): def test_batch_slice(): - batch_size = 12 - node_range = (2, 5) + batch_size = 9 + node_range = (3, 8) nodes_v = torch.randint(*node_range, (batch_size, )).to(device) edges_per_graph = torch.cat([ torch.randint(1, num_nodes, size=(1, )).to(device) @@ -174,9 +174,10 @@ def test_batch_slice(): edge_attr=ea, ga=ga) for x, edges, ea, ga in zip( x_list, edges_list, edge_attr_list, graph_attr_list) ] - bslice = slice(0, 6, 2) + bslice = torch.FloatTensor(batch_size).uniform_() > 0.4 batch_full = Batch.from_data_list(batch_list) - batch_truth = Batch.from_data_list(batch_list[bslice]) + batch_truth = Batch.from_data_list( + [batch_list[e] for e in bslice.nonzero().squeeze()]) batch_new = batch_full[bslice] compare(batch_truth, batch_new) diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index 4e32d183a95c7..449d0f305b447 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -6,6 +6,7 @@ import numpy as np import torch from torch import Tensor +from torch.nn.functional import pad from typing_extensions import Self from torch_geometric.data.collate import collate @@ -223,82 +224,99 @@ def index_select(self, idx: IndexType) -> Self: f"'{type(idx).__name__}')") dev = self.ptr.device + idx = torch.tensor(idx).long().to(dev) + subbatch = separate( + cls=self.__class__.__bases__[1], + batch=self, + idx=idx, + slice_dict=self._slice_dict, + inc_dict=self._inc_dict, + decrement=True, + ) + nodes_per_graph = self.ptr.diff() new_nodes_per_graph = nodes_per_graph[idx] - subbatch: Batch = Batch.from_batch_index( - torch.arange(len(idx)).to(dev).long().repeat_interleave( - new_nodes_per_graph)) - - sel_graphs_bool = torch.isin( - torch.arange(self.num_graphs).to(dev), - torch.tensor(idx).to(dev)) - # Create a bool vector to select the tensors belonging - # to the selected nodes so we can index eg. batch.x - sel_nodes_bool = torch.isin(self.batch, torch.tensor(idx).to(dev)) - - # apply this filter to all stored tensors with the right dimensions - for k in set( - self.keys()) - {"ptr", "batch", 'edge_index', 'edge_attr'}: - v = self[k] - if not isinstance(v, torch.Tensor): - subbatch[k] = v - # check for node attrs - elif v.shape[0] == self.num_nodes: - subbatch.add_node_attr(k, v[sel_nodes_bool]) - # Check for graph attrs - elif v.shape[0] == self.num_graphs: - subbatch.add_graph_attr(k, v[sel_graphs_bool]) - else: - msg = f""" - Undefinied slicing for attribute {k} with dimensions - {v.shape}. Make sure that one of the following is true: - * the first dimension corresponds to the number of nodes - or graphs - * the attribute is 'edge_index' or 'edge_attr' - Otherwise, try `to_data_list` instead. - """ - raise NotImplementedError(msg) - - if 'edge_index' in self: - # to select the surviving edges, we create a tensor with - # the allowed node indeces and then filter the edge_index - sel_nodes_idx = torch.arange(len( - self["x"])).to(dev)[sel_nodes_bool] - sel_edges_bool = torch.isin(self.edge_index[0], sel_nodes_idx) - new_edge_index = self.edge_index.T[sel_edges_bool].T - - # Now we have the edges, but we still need to remove the offset - # from the previous graphs in the batch, so that we can - # recalulate the edge index with the set_edge_index method. - # This offset is saved in `ptr`: - edge_offset_per_graph = self.ptr[:-1] - - # Now that we have the offset for each graph, we need repeat this - # for the number of edges for each graph in the new batch. - # For this we need to calculate the number of edges per graph - # This can be computed by checking, when the edge_index - # becomes large then ptr - edges_per_graph = torch.max((self.edge_index[0] - > self.ptr.unsqueeze(1) - 1).long(), - dim=1).indices.diff() - edges_per_graph[ - -1] = self.edge_index.shape[1] - edges_per_graph[:-1].sum() - - shift_edges = edge_offset_per_graph[ - sel_graphs_bool].repeat_interleave( - edges_per_graph[sel_graphs_bool]) - # Now the shift is applied - new_edge_index -= shift_edges - # Finally we need to provide the batch, - # that the edges belong to - batchidx_per_edge = torch.arange( - subbatch.num_graphs).to(dev).repeat_interleave( - edges_per_graph[sel_graphs_bool]) - - subbatch.set_edge_index(new_edge_index, batchidx_per_edge) - - if 'edge_attr' in self: - subbatch.set_edge_attr(self.edge_attr[sel_edges_bool]) + subbatch.batch = torch.arange( + len(idx)).to(dev).long().repeat_interleave(new_nodes_per_graph) + subbatch.ptr = self.__ptr_from_batchidx(subbatch.batch) + + for k in set(self.keys()) - {"ptr", "batch"}: + subbatch._slice_dict[k] = pad(self._slice_dict[k].diff()[idx], + (1, 0)).cumsum(0) + # subbatch._inc_dict[k]=pad(self._inc_dict[k].diff(),(1,0))[idx].cumsum(0) + subbatch._inc_dict[k] = pad(self._inc_dict[k].diff()[idx[:-1]], + (1, 0)).cumsum(0) + + # sel_graphs_bool = torch.isin( + # torch.arange(self.num_graphs).to(dev), + # idx) + # # Create a bool vector to select the tensors belonging + # # to the selected nodes so we can index eg. batch.x + # sel_nodes_bool = torch.isin(self.batch, idx) + + # # apply this filter to all stored tensors with the right dimensions + # for k in set( + # self.keys()) - {"ptr", "batch", 'edge_index', 'edge_attr'}: + # v = self[k] + # if not isinstance(v, torch.Tensor): + # subbatch[k] = v + # # check for node attrs + # elif v.shape[0] == self.num_nodes: + # subbatch.add_node_attr(k, v[sel_nodes_bool]) + # # Check for graph attrs + # elif v.shape[0] == self.num_graphs: + # subbatch.add_graph_attr(k, v[sel_graphs_bool]) + # else: + # msg = f""" + # Undefinied slicing for attribute {k} with dimensions + # {v.shape}. Make sure that one of the following is true: + # * the first dimension corresponds to the number of nodes + # or graphs + # * the attribute is 'edge_index' or 'edge_attr' + # Otherwise, try `to_data_list` instead. + # """ + # raise NotImplementedError(msg) + + # if 'edge_index' in self: + # # to select the surviving edges, we create a tensor with + # # the allowed node indeces and then filter the edge_index + # sel_nodes_idx = torch.arange(len( + # self["x"])).to(dev)[sel_nodes_bool] + # sel_edges_bool = torch.isin(self.edge_index[0], sel_nodes_idx) + # new_edge_index = self.edge_index.T[sel_edges_bool].T + + # # Now we have the edges, but we still need to remove the offset + # # from the previous graphs in the batch, so that we can + # # recalulate the edge index with the set_edge_index method. + # # This offset is saved in `ptr`: + # edge_offset_per_graph = self.ptr[:-1] + + # # Now that we have the offset for each graph, we need repeat this + # # for the number of edges for each graph in the new batch. + # # For this we need to calculate the number of edges per graph + # # This can be computed by checking, when the edge_index + # # becomes large then ptr + # edges_per_graph = torch.max((self.edge_index[0] + # > self.ptr.unsqueeze(1) - 1).long(), + # dim=1).indices.diff() + # edges_per_graph[ + # -1] = self.edge_index.shape[1] - edges_per_graph[:-1].sum() + + # shift_edges = edge_offset_per_graph[ + # sel_graphs_bool].repeat_interleave( + # edges_per_graph[sel_graphs_bool]) + # # Now the shift is applied + # new_edge_index -= shift_edges + # # Finally we need to provide the batch, + # # that the edges belong to + # batchidx_per_edge = torch.arange( + # subbatch.num_graphs).to(dev).repeat_interleave( + # edges_per_graph[sel_graphs_bool]) + + # subbatch.set_edge_index(new_edge_index, batchidx_per_edge) + + # if 'edge_attr' in self: + # subbatch.set_edge_attr(self.edge_attr[sel_edges_bool]) return subbatch diff --git a/torch_geometric/data/separate.py b/torch_geometric/data/separate.py index 5412e7c133d6f..cd27b5bb9e823 100644 --- a/torch_geometric/data/separate.py +++ b/torch_geometric/data/separate.py @@ -1,15 +1,17 @@ from collections.abc import Mapping, Sequence from typing import Any +import torch from torch import Tensor +from torch.nn.functional import pad from torch_geometric.data.data import BaseData +from torch_geometric.data.dataset import IndexType from torch_geometric.data.storage import BaseStorage from torch_geometric.typing import SparseTensor, TensorFrame -from torch_geometric.utils import narrow -def separate(cls, batch: BaseData, idx: int, slice_dict: Any, +def separate(cls, batch: BaseData, idx: IndexType, slice_dict: Any, inc_dict: Any = None, decrement: bool = True) -> BaseData: # Separates the individual element from a `batch` at index `idx`. # `separate` can handle both homogeneous and heterogeneous data objects by @@ -17,7 +19,8 @@ def separate(cls, batch: BaseData, idx: int, slice_dict: Any, # In addition, `separate` can handle nested data structures such as # dictionaries and lists. - data = cls().stores_as(batch) + # data = cls().stores_as(batch) + data = batch # We iterate over each storage object and recursively separate all its # attributes: @@ -49,7 +52,7 @@ def separate(cls, batch: BaseData, idx: int, slice_dict: Any, def _separate( key: str, value: Any, - idx: int, + idx: IndexType, slices: Any, incs: Any, batch: BaseData, @@ -62,12 +65,25 @@ def _separate( # NOTE: We need to take care of decrementing elements appropriately. key = str(key) cat_dim = batch.__cat_dim__(key, value, store) - start, end = int(slices[idx]), int(slices[idx + 1]) - value = narrow(value, cat_dim or 0, start, end - start) + value = value.swapaxes(0, cat_dim or 0) + graph_slice = torch.concat([ + torch.arange(slices[istart], slices[istart + 1]) for istart in idx + ]).to(value.device) + value = value[graph_slice].swapaxes(0, cat_dim or 0) value = value.squeeze(0) if cat_dim is None else value if (decrement and incs is not None - and (incs.dim() > 1 or int(incs[idx]) != 0)): - value = value - incs[idx].to(value.device) + and (incs.dim() > 1 or any(incs[idx] != 0))): + # remove the old offset + nelem_new = slices.diff()[idx] + old_offset = incs[idx] + + # add the new offset + # for this we compute the number of nodes in the batch before + nelem_previous = pad(incs.diff()[idx[:-1]], (1, 0)).cumsum(0) + # nelem_previous = (pad(incs.diff(),(1,0))[idx]).cumsum(0) + new_offset = nelem_previous + value = value + (-old_offset + + new_offset).repeat_interleave(nelem_new) return value elif isinstance(value, SparseTensor) and decrement: diff --git a/torch_geometric/data/storage.py b/torch_geometric/data/storage.py index c620cbedeb9ec..0517b7fae8f37 100644 --- a/torch_geometric/data/storage.py +++ b/torch_geometric/data/storage.py @@ -313,7 +313,7 @@ def can_infer_num_nodes(self): @property def num_nodes(self) -> Optional[int]: # We sequentially access attributes that reveal the number of nodes. - if 'num_nodes' in self: + if 'num_nodes' in self._mapping: return self['num_nodes'] for key, value in self.items(): if isinstance(value, Tensor) and key in N_KEYS: