Skip to content

Commit

Permalink
Add option to add identity matrix (#20)
Browse files Browse the repository at this point in the history
* add test for isolated vertex handling
* fix seed for reproducible tests

Co-authored-by: Charles Tapley Hoyt <[email protected]>
  • Loading branch information
mberr and cthoyt committed Jul 20, 2022
1 parent 4399b34 commit 946faa6
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 15 deletions.
16 changes: 12 additions & 4 deletions src/torch_ppr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def page_rank(
adj: Optional[torch.Tensor] = None,
edge_index: Optional[torch.LongTensor] = None,
num_nodes: Optional[int] = None,
add_identity: bool = False,
max_iter: int = 1_000,
alpha: float = 0.05,
epsilon: float = 1.0e-04,
Expand All @@ -48,6 +49,8 @@ def page_rank(
:param num_nodes:
the number of nodes used to determine the shape of the adjacency matrix.
If ``None``, and ``adj`` is not already provided, it is inferred from ``edge_index``.
:param add_identity:
whether to add an identity matrix to ``A`` to ensure that each node has a degree of at least one.
:param max_iter: ``max_iter > 0``
the maximum number of iterations
Expand All @@ -69,7 +72,9 @@ def page_rank(
the page-rank vector, i.e., a score between 0 and 1 for each node.
"""
# normalize inputs
adj = prepare_page_rank_adjacency(adj=adj, edge_index=edge_index, num_nodes=num_nodes)
adj = prepare_page_rank_adjacency(
adj=adj, edge_index=edge_index, num_nodes=num_nodes, add_identity=add_identity
)
validate_adjacency(adj=adj)

x0 = prepare_x0(x0=x0, n=adj.shape[0])
Expand All @@ -96,6 +101,7 @@ def personalized_page_rank(
*,
adj: Optional[torch.Tensor] = None,
edge_index: Optional[torch.LongTensor] = None,
add_identity: bool = False,
num_nodes: Optional[int] = None,
indices: Optional[torch.Tensor] = None,
device: DeviceHint = None,
Expand All @@ -115,6 +121,8 @@ def personalized_page_rank(
:param num_nodes:
the number of nodes used to determine the shape of the adjacency matrix.
If ``None``, and ``adj`` is not already provided, it is inferred from ``edge_index``.
:param add_identity:
whether to add an identity matrix to ``A`` to ensure that each node has a degree of at least one.
:param indices: shape: ``(k,)``
the node indices for which to calculate the PPR. Defaults to all nodes.
Expand All @@ -131,9 +139,9 @@ def personalized_page_rank(
# resolve device first
device = resolve_device(device=device)
# prepare adjacency and indices only once
adj = prepare_page_rank_adjacency(adj=adj, edge_index=edge_index, num_nodes=num_nodes).to(
device=device
)
adj = prepare_page_rank_adjacency(
adj=adj, edge_index=edge_index, num_nodes=num_nodes, add_identity=add_identity
).to(device=device)
validate_adjacency(adj=adj)

if indices is None:
Expand Down
28 changes: 23 additions & 5 deletions src/torch_ppr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,26 @@ def validate_adjacency(adj: torch.Tensor, n: Optional[int] = None, rtol: float =
)


def sparse_diagonal(values: torch.Tensor) -> torch.Tensor:
"""Create a sparse diagonal matrix with the given values.
:param values: shape: ``(n,)``
the values
:return: shape: ``(n, n)``
a sparse diagonal matrix
"""
return torch.sparse_coo_tensor(
indices=torch.arange(values.shape[0], device=values.device).unsqueeze(dim=0).repeat(2, 1),
values=values,
)


def prepare_page_rank_adjacency(
adj: Optional[torch.Tensor] = None,
edge_index: Optional[torch.LongTensor] = None,
num_nodes: Optional[int] = None,
add_identity: bool = False,
) -> torch.Tensor:
"""
Prepare the page-rank adjacency matrix.
Expand All @@ -180,6 +196,8 @@ def prepare_page_rank_adjacency(
:param num_nodes:
the number of nodes used to determine the shape of the adjacency matrix.
If ``None``, and ``adj`` is not already provided, it is inferred from ``edge_index``.
:param add_identity:
whether to add an identity matrix to ``A`` to ensure that each node has a degree of at least one.
:raises ValueError:
if neither is provided, or the adjacency matrix is invalid
Expand All @@ -197,15 +215,15 @@ def prepare_page_rank_adjacency(
adj = edge_index_to_sparse_matrix(edge_index=edge_index, num_nodes=num_nodes)
# symmetrize
adj = adj + adj.t()
# TODO: should we add an identity matrix here?
# add identity matrix if requested
if add_identity:
adj = adj + sparse_diagonal(torch.ones(adj.shape[0], dtype=adj.dtype, device=adj.device))

# adjacency normalization: normalize to col-sum = 1
degree_inv = torch.reciprocal(
torch.sparse.sum(adj, dim=0).to_dense().clamp_min(min=torch.finfo(adj.dtype).eps)
)
degree_inv = torch.sparse_coo_tensor(
indices=torch.arange(degree_inv.shape[0], device=adj.device).unsqueeze(dim=0).repeat(2, 1),
values=degree_inv,
)
degree_inv = sparse_diagonal(values=degree_inv)
return torch.sparse.mm(adj, degree_inv)


Expand Down
18 changes: 17 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ class APITest(unittest.TestCase):

def setUp(self) -> None:
"""Prepare data."""
generator = torch.manual_seed(42)
self.edge_index = torch.cat(
[
torch.randint(self.num_nodes, size=(2, self.num_edges - self.num_nodes)),
torch.randint(
self.num_nodes, size=(2, self.num_edges - self.num_nodes), generator=generator
),
# ensure connectivity
torch.arange(self.num_nodes).unsqueeze(0).repeat(2, 1),
],
Expand Down Expand Up @@ -50,3 +53,16 @@ def test_page_rank_manual(self):
x = page_rank(edge_index=edge_index)
# verify that central node has the largest PR value
assert x.argmax() == 1

def test_page_rank_isolated_vertices(self):
"""Test Page-Rank with isolated vertices."""
# create isolated node, ID=0
edge_index = self.edge_index + 1
x = page_rank(edge_index=edge_index, add_identity=True)
# isolated node has only one self-loop -> no change in mass to initial mass
self.assertAlmostEqual(x[0].item(), 1 / (self.num_nodes + 1))
# verify that other nodes are unaffected
x2 = page_rank(edge_index=self.edge_index)
# rescale
x2 = x2 * (self.num_nodes / (self.num_nodes + 1))
assert torch.allclose(x2, x[1:], atol=1.0e-02)
26 changes: 21 additions & 5 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest
from typing import Counter, Optional, Tuple

import pytest
import torch
from torch.nn import functional

Expand Down Expand Up @@ -99,14 +100,18 @@ def test_validate_adjacancy(self):

def test_prepare_page_rank_adjacency(self):
"""Test adjacency preparation."""
for (adj, edge_index) in (
for (adj, edge_index, add_identity) in (
# from edge index
(None, self.edge_index),
(None, self.edge_index, False),
# passing through adjacency matrix
(self.adj, None),
(self.adj, self.edge_index),
(self.adj, None, False),
(self.adj, self.edge_index, False),
# add identity
(None, self.edge_index, True),
):
adj2 = utils.prepare_page_rank_adjacency(adj=adj, edge_index=edge_index)
adj2 = utils.prepare_page_rank_adjacency(
adj=adj, edge_index=edge_index, add_identity=add_identity
)
utils.validate_adjacency(adj=adj2, n=self.num_nodes)
if adj is not None:
assert adj is adj2
Expand Down Expand Up @@ -165,3 +170,14 @@ def test_batched_personalized_page_rank(self):
adj=self.adj, indices=torch.arange(self.num_nodes), batch_size=self.num_nodes // 3
)
utils.validate_x(x)


@pytest.mark.parametrize("n", [8, 16])
def test_sparse_diagonal(n: int):
"""Test for sparse diagonal matrix creation."""
values = torch.rand(n)
matrix = utils.sparse_diagonal(values=values)
assert torch.is_tensor(matrix)
assert matrix.shape == (n, n)
assert matrix.is_sparse
assert torch.allclose(matrix.to_dense(), torch.diag(values))

0 comments on commit 946faa6

Please sign in to comment.