Skip to content

Commit

Permalink
move splitting into the separate method
Browse files Browse the repository at this point in the history
  • Loading branch information
mova committed Nov 24, 2023
1 parent 8854100 commit 039c074
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 87 deletions.
9 changes: 5 additions & 4 deletions test/data/test_batch_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def test_from_batch_list():


def test_batch_slice():
batch_size = 12
node_range = (2, 5)
batch_size = 9
node_range = (3, 8)
nodes_v = torch.randint(*node_range, (batch_size, )).to(device)
edges_per_graph = torch.cat([
torch.randint(1, num_nodes, size=(1, )).to(device)
Expand All @@ -174,9 +174,10 @@ def test_batch_slice():
edge_attr=ea, ga=ga) for x, edges, ea, ga in zip(
x_list, edges_list, edge_attr_list, graph_attr_list)
]
bslice = slice(0, 6, 2)
bslice = torch.FloatTensor(batch_size).uniform_() > 0.4
batch_full = Batch.from_data_list(batch_list)
batch_truth = Batch.from_data_list(batch_list[bslice])
batch_truth = Batch.from_data_list(
[batch_list[e] for e in bslice.nonzero().squeeze()])
batch_new = batch_full[bslice]
compare(batch_truth, batch_new)

Expand Down
166 changes: 92 additions & 74 deletions torch_geometric/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import torch
from torch import Tensor
from torch.nn.functional import pad
from typing_extensions import Self

from torch_geometric.data.collate import collate
Expand Down Expand Up @@ -223,82 +224,99 @@ def index_select(self, idx: IndexType) -> Self:
f"'{type(idx).__name__}')")

dev = self.ptr.device
idx = torch.tensor(idx).long().to(dev)
subbatch = separate(
cls=self.__class__.__bases__[1],
batch=self,
idx=idx,
slice_dict=self._slice_dict,
inc_dict=self._inc_dict,
decrement=True,
)

nodes_per_graph = self.ptr.diff()
new_nodes_per_graph = nodes_per_graph[idx]
subbatch: Batch = Batch.from_batch_index(
torch.arange(len(idx)).to(dev).long().repeat_interleave(
new_nodes_per_graph))

sel_graphs_bool = torch.isin(
torch.arange(self.num_graphs).to(dev),
torch.tensor(idx).to(dev))
# Create a bool vector to select the tensors belonging
# to the selected nodes so we can index eg. batch.x
sel_nodes_bool = torch.isin(self.batch, torch.tensor(idx).to(dev))

# apply this filter to all stored tensors with the right dimensions
for k in set(
self.keys()) - {"ptr", "batch", 'edge_index', 'edge_attr'}:
v = self[k]
if not isinstance(v, torch.Tensor):
subbatch[k] = v
# check for node attrs
elif v.shape[0] == self.num_nodes:
subbatch.add_node_attr(k, v[sel_nodes_bool])
# Check for graph attrs
elif v.shape[0] == self.num_graphs:
subbatch.add_graph_attr(k, v[sel_graphs_bool])
else:
msg = f"""
Undefinied slicing for attribute {k} with dimensions
{v.shape}. Make sure that one of the following is true:
* the first dimension corresponds to the number of nodes
or graphs
* the attribute is 'edge_index' or 'edge_attr'
Otherwise, try `to_data_list` instead.
"""
raise NotImplementedError(msg)

if 'edge_index' in self:
# to select the surviving edges, we create a tensor with
# the allowed node indeces and then filter the edge_index
sel_nodes_idx = torch.arange(len(
self["x"])).to(dev)[sel_nodes_bool]
sel_edges_bool = torch.isin(self.edge_index[0], sel_nodes_idx)
new_edge_index = self.edge_index.T[sel_edges_bool].T

# Now we have the edges, but we still need to remove the offset
# from the previous graphs in the batch, so that we can
# recalulate the edge index with the set_edge_index method.
# This offset is saved in `ptr`:
edge_offset_per_graph = self.ptr[:-1]

# Now that we have the offset for each graph, we need repeat this
# for the number of edges for each graph in the new batch.
# For this we need to calculate the number of edges per graph
# This can be computed by checking, when the edge_index
# becomes large then ptr
edges_per_graph = torch.max((self.edge_index[0]
> self.ptr.unsqueeze(1) - 1).long(),
dim=1).indices.diff()
edges_per_graph[
-1] = self.edge_index.shape[1] - edges_per_graph[:-1].sum()

shift_edges = edge_offset_per_graph[
sel_graphs_bool].repeat_interleave(
edges_per_graph[sel_graphs_bool])
# Now the shift is applied
new_edge_index -= shift_edges
# Finally we need to provide the batch,
# that the edges belong to
batchidx_per_edge = torch.arange(
subbatch.num_graphs).to(dev).repeat_interleave(
edges_per_graph[sel_graphs_bool])

subbatch.set_edge_index(new_edge_index, batchidx_per_edge)

if 'edge_attr' in self:
subbatch.set_edge_attr(self.edge_attr[sel_edges_bool])
subbatch.batch = torch.arange(
len(idx)).to(dev).long().repeat_interleave(new_nodes_per_graph)
subbatch.ptr = self.__ptr_from_batchidx(subbatch.batch)

for k in set(self.keys()) - {"ptr", "batch"}:
subbatch._slice_dict[k] = pad(self._slice_dict[k].diff()[idx],
(1, 0)).cumsum(0)
# subbatch._inc_dict[k]=pad(self._inc_dict[k].diff(),(1,0))[idx].cumsum(0)
subbatch._inc_dict[k] = pad(self._inc_dict[k].diff()[idx[:-1]],
(1, 0)).cumsum(0)

# sel_graphs_bool = torch.isin(
# torch.arange(self.num_graphs).to(dev),
# idx)
# # Create a bool vector to select the tensors belonging
# # to the selected nodes so we can index eg. batch.x
# sel_nodes_bool = torch.isin(self.batch, idx)

# # apply this filter to all stored tensors with the right dimensions
# for k in set(
# self.keys()) - {"ptr", "batch", 'edge_index', 'edge_attr'}:
# v = self[k]
# if not isinstance(v, torch.Tensor):
# subbatch[k] = v
# # check for node attrs
# elif v.shape[0] == self.num_nodes:
# subbatch.add_node_attr(k, v[sel_nodes_bool])
# # Check for graph attrs
# elif v.shape[0] == self.num_graphs:
# subbatch.add_graph_attr(k, v[sel_graphs_bool])
# else:
# msg = f"""
# Undefinied slicing for attribute {k} with dimensions
# {v.shape}. Make sure that one of the following is true:
# * the first dimension corresponds to the number of nodes
# or graphs
# * the attribute is 'edge_index' or 'edge_attr'
# Otherwise, try `to_data_list` instead.
# """
# raise NotImplementedError(msg)

# if 'edge_index' in self:
# # to select the surviving edges, we create a tensor with
# # the allowed node indeces and then filter the edge_index
# sel_nodes_idx = torch.arange(len(
# self["x"])).to(dev)[sel_nodes_bool]
# sel_edges_bool = torch.isin(self.edge_index[0], sel_nodes_idx)
# new_edge_index = self.edge_index.T[sel_edges_bool].T

# # Now we have the edges, but we still need to remove the offset
# # from the previous graphs in the batch, so that we can
# # recalulate the edge index with the set_edge_index method.
# # This offset is saved in `ptr`:
# edge_offset_per_graph = self.ptr[:-1]

# # Now that we have the offset for each graph, we need repeat this
# # for the number of edges for each graph in the new batch.
# # For this we need to calculate the number of edges per graph
# # This can be computed by checking, when the edge_index
# # becomes large then ptr
# edges_per_graph = torch.max((self.edge_index[0]
# > self.ptr.unsqueeze(1) - 1).long(),
# dim=1).indices.diff()
# edges_per_graph[
# -1] = self.edge_index.shape[1] - edges_per_graph[:-1].sum()

# shift_edges = edge_offset_per_graph[
# sel_graphs_bool].repeat_interleave(
# edges_per_graph[sel_graphs_bool])
# # Now the shift is applied
# new_edge_index -= shift_edges
# # Finally we need to provide the batch,
# # that the edges belong to
# batchidx_per_edge = torch.arange(
# subbatch.num_graphs).to(dev).repeat_interleave(
# edges_per_graph[sel_graphs_bool])

# subbatch.set_edge_index(new_edge_index, batchidx_per_edge)

# if 'edge_attr' in self:
# subbatch.set_edge_attr(self.edge_attr[sel_edges_bool])

return subbatch

Expand Down
32 changes: 24 additions & 8 deletions torch_geometric/data/separate.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
from collections.abc import Mapping, Sequence
from typing import Any

import torch
from torch import Tensor
from torch.nn.functional import pad

from torch_geometric.data.data import BaseData
from torch_geometric.data.dataset import IndexType
from torch_geometric.data.storage import BaseStorage
from torch_geometric.typing import SparseTensor, TensorFrame
from torch_geometric.utils import narrow


def separate(cls, batch: BaseData, idx: int, slice_dict: Any,
def separate(cls, batch: BaseData, idx: IndexType, slice_dict: Any,
inc_dict: Any = None, decrement: bool = True) -> BaseData:
# Separates the individual element from a `batch` at index `idx`.
# `separate` can handle both homogeneous and heterogeneous data objects by
# individually separating all their stores.
# In addition, `separate` can handle nested data structures such as
# dictionaries and lists.

data = cls().stores_as(batch)
# data = cls().stores_as(batch)
data = batch

# We iterate over each storage object and recursively separate all its
# attributes:
Expand Down Expand Up @@ -49,7 +52,7 @@ def separate(cls, batch: BaseData, idx: int, slice_dict: Any,
def _separate(
key: str,
value: Any,
idx: int,
idx: IndexType,
slices: Any,
incs: Any,
batch: BaseData,
Expand All @@ -62,12 +65,25 @@ def _separate(
# NOTE: We need to take care of decrementing elements appropriately.
key = str(key)
cat_dim = batch.__cat_dim__(key, value, store)
start, end = int(slices[idx]), int(slices[idx + 1])
value = narrow(value, cat_dim or 0, start, end - start)
value = value.swapaxes(0, cat_dim or 0)
graph_slice = torch.concat([
torch.arange(slices[istart], slices[istart + 1]) for istart in idx
]).to(value.device)
value = value[graph_slice].swapaxes(0, cat_dim or 0)
value = value.squeeze(0) if cat_dim is None else value
if (decrement and incs is not None
and (incs.dim() > 1 or int(incs[idx]) != 0)):
value = value - incs[idx].to(value.device)
and (incs.dim() > 1 or any(incs[idx] != 0))):
# remove the old offset
nelem_new = slices.diff()[idx]
old_offset = incs[idx]

# add the new offset
# for this we compute the number of nodes in the batch before
nelem_previous = pad(incs.diff()[idx[:-1]], (1, 0)).cumsum(0)
# nelem_previous = (pad(incs.diff(),(1,0))[idx]).cumsum(0)
new_offset = nelem_previous
value = value + (-old_offset +
new_offset).repeat_interleave(nelem_new)
return value

elif isinstance(value, SparseTensor) and decrement:
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/data/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def can_infer_num_nodes(self):
@property
def num_nodes(self) -> Optional[int]:
# We sequentially access attributes that reveal the number of nodes.
if 'num_nodes' in self:
if 'num_nodes' in self._mapping:
return self['num_nodes']
for key, value in self.items():
if isinstance(value, Tensor) and key in N_KEYS:
Expand Down

0 comments on commit 039c074

Please sign in to comment.