diff --git a/src/torch_ppr/utils.py b/src/torch_ppr/utils.py index 398d5cf..ca8cd75 100644 --- a/src/torch_ppr/utils.py +++ b/src/torch_ppr/utils.py @@ -95,7 +95,7 @@ def edge_index_to_sparse_matrix( ) -def validate_adjacency(adj: torch.Tensor, n: Optional[int] = None): +def validate_adjacency(adj: torch.Tensor, n: Optional[int] = None, rtol: float = 1.0e-04): """ Validate the page-rank adjacency matrix. @@ -108,6 +108,8 @@ def validate_adjacency(adj: torch.Tensor, n: Optional[int] = None): the adjacency matrix :param n: the number of nodes + :param rtol: + the tolerance for checking the sum is close to 1.0 :raises ValueError: if the adjacency matrix is invalid @@ -143,8 +145,16 @@ def validate_adjacency(adj: torch.Tensor, n: Optional[int] = None): else: # hotfix until torch.sparse.sum is implemented adj_sum = adj.t() @ torch.ones(adj.shape[0]) - if not torch.allclose(adj_sum, torch.ones_like(adj_sum), rtol=1.0e-04): - raise ValueError(f"Invalid column sum: {adj_sum}. expected 1.0") + exp_sum = torch.ones_like(adj_sum) + mask = adj_sum == 0 + if mask.any(): + logger.warning(f"Adjacency contains {mask.sum().item()} isolated nodes.") + exp_sum[mask] = 0.0 + if not torch.allclose(adj_sum, exp_sum, rtol=rtol): + raise ValueError( + f"Invalid column sum: {adj_sum} (min: {adj_sum.min().item()}, max: {adj_sum.max().item()}). " + f"Expected 1.0 with a relative tolerance of {rtol}.", + ) def prepare_page_rank_adjacency( @@ -187,8 +197,11 @@ def prepare_page_rank_adjacency( adj = edge_index_to_sparse_matrix(edge_index=edge_index, num_nodes=num_nodes) # symmetrize adj = adj + adj.t() + # TODO: should we add an identity matrix here? # adjacency normalization: normalize to col-sum = 1 - degree_inv = torch.reciprocal(torch.sparse.sum(adj, dim=0).to_dense()) + degree_inv = torch.reciprocal( + torch.sparse.sum(adj, dim=0).to_dense().clamp_min(min=torch.finfo(adj.dtype).eps) + ) degree_inv = torch.sparse_coo_tensor( indices=torch.arange(degree_inv.shape[0], device=adj.device).unsqueeze(dim=0).repeat(2, 1), values=degree_inv,