Skip to content

Commit

Permalink
Implemenation of Batch.{from_batch_list,from_batch_index,add_graph_at…
Browse files Browse the repository at this point in the history
…tr,set_edge_attr,set_edges} pyg-team#8414
  • Loading branch information
mova committed Sep 12, 2024
1 parent 6d9e850 commit 9fb3ea1
Show file tree
Hide file tree
Showing 7 changed files with 571 additions and 86 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## \[2.6.0\] - 2024-MM-DD

### Added

- Added implemenation of `Batch.{from_batch_list,from_batch_index,add_graph_attr,set_edge_attr,set_edges}` ([#8414](https://github.com/pyg-team/pytorch_geometric/pull/8414))
- Added the `GRetriever` model ([#9480](https://github.com/pyg-team/pytorch_geometric/pull/9480))
- Added the `ClusterPooling` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627))
- Added the `LinkPredMRR` metric ([#9632](https://github.com/pyg-team/pytorch_geometric/pull/9632))
Expand Down Expand Up @@ -44,7 +44,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added documentation on environment setup on XPU device ([#9407](https://github.com/pyg-team/pytorch_geometric/pull/9407))

### Changed

- Convert `Batch.index_select` to a full slicing operation returning a new batch instead of a list of `Data` ([#8414](https://github.com/pyg-team/pytorch_geometric/pull/8414))
- Use `torch.load(weights_only=True)` by default ([#9618](https://github.com/pyg-team/pytorch_geometric/pull/9618))
- Adapt `cugraph` examples to its new API ([#9541](https://github.com/pyg-team/pytorch_geometric/pull/9541))
- Allow optional but untyped tensors in `MessagePassing` ([#9494](https://github.com/pyg-team/pytorch_geometric/pull/9494))
Expand Down
64 changes: 34 additions & 30 deletions test/data/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,24 +63,27 @@ def test_batch_basic():
assert batch.batch.tolist() == [0, 0, 0, 1, 1, 2, 2, 2, 2]
assert batch.ptr.tolist() == [0, 3, 5, 9]

assert str(batch[0]) == ("Data(x=[3], edge_index=[2, 4], y=[1], "
"string='1', array=[2], num_nodes=3)")
assert str(batch[1]) == ("Data(x=[2], edge_index=[2, 2], y=[1], "
"string='2', array=[3], num_nodes=2)")
assert str(batch[2]) == ("Data(x=[4], edge_index=[2, 6], y=[1], "
"string='3', array=[4], num_nodes=4)")

assert len(batch.index_select([1, 0])) == 2
assert len(batch.index_select(torch.tensor([1, 0]))) == 2
assert len(batch.index_select(torch.tensor([True, False]))) == 1
assert len(batch.index_select(np.array([1, 0], dtype=np.int64))) == 2
assert len(batch.index_select(np.array([True, False]))) == 1
assert len(batch[:2]) == 2
graphs = batch[0], batch[1], batch[2]
assert all([isinstance(e, Data) for e in graphs])
assert [e.x.shape[0] for e in graphs] == [3, 2, 4]
assert [e.y.shape[0] for e in graphs] == [1, 1, 1]
assert [list(e.edge_index.shape) for e in graphs] == [[2, 4], [2, 2],
[2, 6]]
assert [e.string for e in graphs] == ["1", "2", "3"]
assert [e.num_nodes for e in graphs] == [3, 2, 4]

assert batch.index_select([1, 0]).num_graphs == 2
assert batch.index_select(torch.tensor([1, 0])).num_graphs == 2
assert batch.index_select(torch.tensor([True, False,
False])).num_graphs == 1
assert batch.index_select(np.array([1, 0], dtype=np.int64)).num_graphs == 2
assert batch.index_select(np.array([True, False, False])).num_graphs == 1
assert batch[:2].num_graphs == 2

data_list = batch.to_data_list()
assert len(data_list) == 3

assert len(data_list[0]) == 6
assert set(data_list[0].keys()) == set(data1.keys())
assert data_list[0].x.tolist() == [1, 2, 3]
assert data_list[0].y.tolist() == [1]
assert data_list[0].edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]
Expand Down Expand Up @@ -111,21 +114,24 @@ def test_batch_basic():
def test_index():
index1 = Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
index2 = Index([0, 1, 1, 2, 2, 3], dim_size=4, is_sorted=True)
index3 = Index([0, 1, 1, 2], dim_size=3, is_sorted=True)

data1 = Data(index=index1, num_nodes=3)
data2 = Data(index=index2, num_nodes=4)
data3 = Data(index=index1, num_nodes=3)

batch = Batch.from_data_list([data1, data2])
batch = Batch.from_data_list([data1, data2, data3])

assert len(batch) == 2
assert batch.batch.equal(torch.tensor([0, 0, 0, 1, 1, 1, 1]))
assert batch.ptr.equal(torch.tensor([0, 3, 7]))
assert len(batch) == 3
assert batch.batch.equal(torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2]))
assert batch.ptr.equal(torch.tensor([0, 3, 7, 10]))
assert isinstance(batch.index, Index)
assert batch.index.equal(torch.tensor([0, 1, 1, 2, 3, 4, 4, 5, 5, 6]))
assert batch.index.dim_size == 7
assert batch.index.equal(
torch.tensor([0, 1, 1, 2, 3, 4, 4, 5, 5, 6, 7, 8, 8, 9]))
assert batch.index.dim_size == 10
assert batch.index.is_sorted

for i, index in enumerate([index1, index2]):
for i, index in enumerate([index1, index2, index3]):
data = batch[i]
assert isinstance(data.index, Index)
assert data.index.equal(index)
Expand All @@ -149,18 +155,16 @@ def test_edge_index():
data1 = Data(edge_index=edge_index1)
data2 = Data(edge_index=edge_index2)

batch = Batch.from_data_list([data1, data2])
batch = Batch.from_data_list([data1, data2, data1.clone()])

assert len(batch) == 2
assert batch.batch.equal(torch.tensor([0, 0, 0, 1, 1, 1, 1]))
assert batch.ptr.equal(torch.tensor([0, 3, 7]))
assert len(batch) == 3
assert batch.batch.equal(torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2]))
assert batch.ptr.equal(torch.tensor([0, 3, 7, 10]))
assert isinstance(batch.edge_index, EdgeIndex)
assert batch.edge_index.equal(
torch.tensor([
[0, 1, 1, 2, 4, 3, 5, 4, 6, 5],
[1, 0, 2, 1, 3, 4, 4, 5, 5, 6],
]))
assert batch.edge_index.sparse_size() == (7, 7)
torch.tensor([[0, 1, 1, 2, 4, 3, 5, 4, 6, 5, 7, 8, 8, 9],
[1, 0, 2, 1, 3, 4, 4, 5, 5, 6, 8, 7, 9, 8]]))
assert batch.edge_index.sparse_size() == (10, 10)
assert batch.edge_index.sort_order is None
assert not batch.edge_index.is_undirected

Expand Down
220 changes: 220 additions & 0 deletions test/data/test_batch_manipulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import torch

from torch_geometric.data import Batch, Data

device = torch.device("cpu")


def test_batch_add_node_attr():
batch_size = 3
node_range = (1, 10)
nodes_v = torch.randint(*node_range, (batch_size, ))
x_list = [torch.rand(n, 3).to(device) for n in nodes_v]

batch_list = [Data(x=x) for x in x_list]
batch_truth = Batch.from_data_list(batch_list)

batch_idx = batch_truth.batch

batch = Batch.from_batch_index(batch_idx)

batch.add_node_attr("x", torch.vstack(x_list))

compare(batch_truth, batch)


def test_batch_set_edge_index():
batch_size = 4
node_range = (2, 5)
nodes_v = torch.randint(*node_range, (batch_size, )).to(device)
edges_per_graph = torch.cat([
torch.randint(1, num_nodes, size=(1, )).to(device)
for num_nodes in nodes_v
]).to(device)
x_list = [torch.rand(num_nodes, 3).to(device) for num_nodes in nodes_v]
edges_list = [
torch.vstack([
torch.randint(0, num_nodes, size=(num_edges, )),
torch.randint(0, num_nodes, size=(num_edges, )),
]).to(device) for num_nodes, num_edges in zip(nodes_v, edges_per_graph)
]

batch_list = [
Data(x=x, edge_index=edges) for x, edges in zip(x_list, edges_list)
]
batch_truth = Batch.from_data_list(batch_list)

batch_idx = batch_truth.batch
batch = Batch.from_batch_index(batch_idx)
batch.add_node_attr("x", torch.vstack(x_list))

batchidx_per_edge = torch.cat([
torch.ones(num_edges).long().to(device) * igraph
for igraph, num_edges in enumerate(edges_per_graph)
])
batch.set_edge_index(torch.hstack(edges_list), batchidx_per_edge)
compare(batch_truth, batch)


def test_batch_set_edge_attr():
batch_size = 4
node_range = (2, 5)
nodes_v = torch.randint(*node_range, (batch_size, )).to(device)
edges_per_graph = torch.cat([
torch.randint(1, num_nodes, size=(1, )).to(device)
for num_nodes in nodes_v
])
x_list = [torch.rand(num_nodes, 3).to(device) for num_nodes in nodes_v]
edges_list = [
torch.vstack([
torch.randint(0, num_nodes, size=(num_edges, )),
torch.randint(0, num_nodes, size=(num_edges, )),
]).to(device) for num_nodes, num_edges in zip(nodes_v, edges_per_graph)
]
edge_attr_list = [
torch.rand(num_edges).to(device) for num_edges in edges_per_graph
]

batch_list = [
Data(x=x, edge_index=edges, edge_attr=ea)
for x, edges, ea in zip(x_list, edges_list, edge_attr_list)
]
batch_truth = Batch.from_data_list(batch_list)

batch_idx = batch_truth.batch
batch = Batch.from_batch_index(batch_idx)
batch.add_node_attr("x", torch.vstack(x_list))

batchidx_per_edge = torch.cat([
torch.ones(num_edges).to(device).long() * igraph
for igraph, num_edges in enumerate(edges_per_graph)
])
batch.set_edge_index(torch.hstack(edges_list), batchidx_per_edge)
batch.set_edge_attr(torch.hstack(edge_attr_list))
compare(batch_truth, batch)


def test_batch_add_graph_attr():
batch_size = 3
node_range = (1, 10)
nodes_v = torch.randint(*node_range, (batch_size, )).to(device)
x_list = [torch.rand(n, 3).to(device) for n in nodes_v]
graph_attr_list = torch.rand(batch_size).to(device)

batch_list = [Data(x=x, ga=ga) for x, ga in zip(x_list, graph_attr_list)]
batch_truth = Batch.from_data_list(batch_list)

batch_idx = batch_truth.batch

batch = Batch.from_batch_index(batch_idx)

batch.add_node_attr("x", torch.vstack(x_list))
batch.add_graph_attr("ga", graph_attr_list)
compare(batch_truth, batch)


def test_from_batch_list():
batch_size = 12
node_range = (2, 5)
nodes_v = torch.randint(*node_range, (batch_size, )).to(device)
edges_per_graph = torch.cat([
torch.randint(1, num_nodes, size=(1, )).to(device)
for num_nodes in nodes_v
])
x_list = [torch.rand(num_nodes, 3).to(device) for num_nodes in nodes_v]
edges_list = [
torch.vstack([
torch.randint(0, num_nodes, size=(num_edges, )),
torch.randint(0, num_nodes, size=(num_edges, )),
]).to(device) for num_nodes, num_edges in zip(nodes_v, edges_per_graph)
]
edge_attr_list = [
torch.rand(num_edges).to(device) for num_edges in edges_per_graph
]
graph_attr_list = torch.rand(batch_size).to(device)

batch_list = [
Data(x=x, edge_index=edges,
edge_attr=ea, ga=ga) for x, edges, ea, ga in zip(
x_list, edges_list, edge_attr_list, graph_attr_list)
]
batch_truth = Batch.from_data_list(batch_list)
batch = Batch.from_batch_list([
Batch.from_data_list(batch_list[:3]),
Batch.from_data_list(batch_list[3:5]),
Batch.from_data_list(batch_list[5:7]),
Batch.from_data_list(batch_list[7:]),
])

compare(batch_truth, batch)


def test_batch_slice():
batch_size = 9
node_range = (3, 8)
nodes_v = torch.randint(*node_range, (batch_size, )).to(device)
edges_per_graph = torch.cat([
torch.randint(1, num_nodes, size=(1, )).to(device)
for num_nodes in nodes_v
])
x_list = [torch.rand(num_nodes, 3).to(device) for num_nodes in nodes_v]
edges_list = [
torch.vstack([
torch.randint(0, num_nodes, size=(num_edges, )),
torch.randint(0, num_nodes, size=(num_edges, )),
]).to(device) for num_nodes, num_edges in zip(nodes_v, edges_per_graph)
]
edge_attr_list = [
torch.rand(num_edges).to(device) for num_edges in edges_per_graph
]
graph_attr_list = torch.rand(batch_size, 1, 5).to(device)

batch_list = [
Data(x=x, edge_index=edges,
edge_attr=ea, ga=ga) for x, edges, ea, ga in zip(
x_list, edges_list, edge_attr_list, graph_attr_list)
]
bslice = torch.FloatTensor(batch_size).uniform_() > 0.4
batch_full = Batch.from_data_list(batch_list)
batch_truth = Batch.from_data_list(
[batch_list[e] for e in bslice.nonzero().squeeze()])
batch_new = batch_full[bslice]
compare(batch_truth, batch_new)


def compare(ba: Batch, bb: Batch):
if set(ba.keys()) != set(bb.keys()):
raise Exception()
assert (ba.batch == bb.batch).all()
assert (ba.ptr == bb.ptr).all()
for k in ba.keys():
try:
rec_comp(ba[k], bb[k])
except Exception as e:
raise Exception(
f"Batch comparison failed tensor for key {k}.") from e
if k in ba._slice_dict or k in bb._slice_dict:
try:
rec_comp(ba._slice_dict[k], bb._slice_dict[k])
except Exception as e:
raise Exception(
f"Batch comparison failed _slice_dict for key {k}.") from e
if k in ba._inc_dict or k in bb._inc_dict:
try:
rec_comp(ba._inc_dict[k], bb._inc_dict[k])
except Exception as e:
raise Exception(
f"Batch comparison failed _inc_dict for key {k}.") from e


def rec_comp(a, b):
if not type(a) is type(b):
raise Exception()
if isinstance(a, dict):
if not set(a.keys()) == set(b.keys()):
raise Exception()
for k in a:
rec_comp(a[k], b[k])
if isinstance(a, torch.Tensor):
if not (a == b).all():
raise Exception()
Loading

0 comments on commit 9fb3ea1

Please sign in to comment.