Skip to content

Commit

Permalink
torch.save and torch.load tests for EdgeIndex (pyg-team#8451)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Nov 27, 2023
1 parent cea885e commit e8e39a2
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions test/data/test_edge_index.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os.path as osp

import torch

from torch_geometric.data.edge_index import EdgeIndex
Expand Down Expand Up @@ -161,3 +163,20 @@ def test_narrow():
out = adj.narrow(dim=0, start=0, length=1)
assert torch.equal(out, torch.tensor([[0, 1, 1, 2]]))
assert not isinstance(out, EdgeIndex)


def test_save_and_load(tmp_path):
adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row')
adj.fill_cache()

assert adj.sort_order == 'row'
assert torch.equal(adj._rowptr, torch.tensor([0, 1, 3, 4]))

path = osp.join(tmp_path, 'edge_index.pt')
torch.save(adj, path)
out = torch.load(path)

assert isinstance(out, EdgeIndex)
assert torch.equal(out, torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]))
assert out.sort_order == 'row'
assert torch.equal(out._rowptr, torch.tensor([0, 1, 3, 4]))

0 comments on commit e8e39a2

Please sign in to comment.