Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neighbourhood Complex Lifting (Graph to Simplicial) #41

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions configs/transforms/feature_liftings/element_wise_mean.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
transform_name: "ElementwiseMean"
transform_type: null
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
transform_type: 'lifting'
transform_name: "NeighborhoodComplexLifting"
preserve_edge_attr: False
signed: True
feature_lifting: ElementwiseMean
complex_dim: 5
10 changes: 9 additions & 1 deletion modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,32 @@
NodeFeaturesToFloat,
OneHotDegreeFeatures,
)
from modules.transforms.feature_liftings.feature_liftings import ProjectionSum
from modules.transforms.feature_liftings.feature_liftings import (
ElementwiseMean,
ProjectionSum,
)
from modules.transforms.liftings.graph2cell.cycle_lifting import CellCycleLifting
from modules.transforms.liftings.graph2hypergraph.knn_lifting import (
HypergraphKNNLifting,
)
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
)
from modules.transforms.liftings.graph2simplicial.neighborhood_complex_lifting import (
NeighborhoodComplexLifting,
)

TRANSFORMS = {
# Graph -> Hypergraph
"HypergraphKNNLifting": HypergraphKNNLifting,
# Graph -> Simplicial Complex
"SimplicialCliqueLifting": SimplicialCliqueLifting,
"NeighborhoodComplexLifting": NeighborhoodComplexLifting,
# Graph -> Cell Complex
"CellCycleLifting": CellCycleLifting,
# Feature Liftings
"ProjectionSum": ProjectionSum,
"ElementwiseMean": ElementwiseMean,
# Data Manipulations
"Identity": IdentityTransform,
"NodeDegrees": NodeDegrees,
Expand Down
89 changes: 89 additions & 0 deletions modules/transforms/feature_liftings/feature_liftings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch.nn.functional as F
import torch_geometric


Expand Down Expand Up @@ -54,3 +55,91 @@ def forward(
The lifted data.
"""
return self.lift_features(data)

class ElementwiseMean(torch_geometric.transforms.BaseTransform):
r"""Lifts r-cell features to r+1-cells by taking the mean of the lower
dimensional features.

Parameters
----------
**kwargs : optional
Additional arguments for the class.
"""

def __init__(self, **kwargs):
super().__init__()

def lift_features(
self, data: torch_geometric.data.Data | dict
) -> torch_geometric.data.Data | dict:
r"""Projects r-cell features of a graph to r+1-cell structures using the incidence matrix.

Parameters
----------
data : torch_geometric.data.Data | dict
The input data to be lifted.

Returns
-------
torch_geometric.data.Data | dict
The lifted data."""

# Find the maximum dimension of the input data
max_dim = max([int(key.split("_")[-1]) for key in data if "x_idx" in key])

# Create a list of all x_idx tensors
x_idx_tensors = [data[f"x_idx_{i}"] for i in range(max_dim + 1)]

# Find the maximum sizes
max_simplices = max(tensor.size(0) for tensor in x_idx_tensors)
max_nodes = max(tensor.size(1) for tensor in x_idx_tensors)

# Pad tensors to have the same size
padded_tensors = [F.pad(tensor, (0, max_nodes - tensor.size(1), 0, max_simplices - tensor.size(0)))
for tensor in x_idx_tensors]

# Stack all x_idx tensors
all_indices = torch.stack(padded_tensors)

# Create a mask for valid indices
mask = all_indices != 0

# Replace 0s with a valid index (e.g., 0) to avoid indexing errors
all_indices = all_indices.clamp(min=0)

# Get all embeddings at once
all_embeddings = data["x_0"][all_indices]

# Apply mask to set padded embeddings to 0
all_embeddings = all_embeddings * mask.unsqueeze(-1).float()

# Compute sum and count of non-zero elements
embedding_sum = all_embeddings.sum(dim=2)
count = mask.sum(dim=2).clamp(min=1) # Avoid division by zero

# Compute mean
mean_embeddings = embedding_sum / count.unsqueeze(-1)

# Assign results back to data dictionary
for i in range(1, max_dim + 1):
original_size = x_idx_tensors[i].size(0)
data[f"x_{i}"] = mean_embeddings[i, :original_size]

return data

def forward(
self, data: torch_geometric.data.Data | dict
) -> torch_geometric.data.Data | dict:
r"""Applies the lifting to the input data.

Parameters
----------
data : torch_geometric.data.Data | dict
The input data to be lifted.

Returns
-------
torch_geometric.data.Data | dict
The lifted data.
"""
return self.lift_features(data)
2 changes: 1 addition & 1 deletion modules/transforms/liftings/graph2simplicial/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _get_lifted_topology(
list(simplicial_complex.get_simplex_attributes("features", 0).values())
)
# If new edges have been added during the lifting process, we discard the edge attributes
if self.contains_edge_attr and simplicial_complex.shape[1] == (
if self.preserve_edge_attr and simplicial_complex.shape[1] == (
graph.number_of_edges()
):
lifted_topology["x_1"] = torch.stack(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import networkx as nx
import torch
from toponetx.classes import SimplicialComplex
from torch_geometric.data import Data
from torch_geometric.utils.convert import to_networkx

from modules.transforms.liftings.graph2simplicial.base import Graph2SimplicialLifting


class NeighborhoodComplexLifting(Graph2SimplicialLifting):
""" Lifts graphs to a simplicial complex domain by identifying the neighborhood complex as k-simplices.
The neighborhood complex of a node u is the set of nodes that share a neighbor with u.

"""
def __init__(self, **kwargs):
super().__init__(**kwargs)

def lift_topology(self, data: Data) -> dict:
graph: nx.Graph = to_networkx(data, to_undirected=True)
simplicial_complex = SimplicialComplex(simplices=graph)

# For every node u
for u in graph.nodes:
neighbourhood_complex = set()
neighbourhood_complex.add(u)
# Check it's neighbours
for v in graph.neighbors(u):
# For every other node w != u ^ w != v
for w in graph.nodes:
# w == u
if w == u:
continue
# w == v
if w == v:
continue

# w and u share v as it's neighbour
if v in graph.neighbors(w):
neighbourhood_complex.add(w)
# Do not add 0-simplices
if len(neighbourhood_complex) < 2:
continue
# Do not add i-simplices if the maximum dimension is lower
if len(neighbourhood_complex) > self.complex_dim + 1:
continue
simplicial_complex.add_simplex(neighbourhood_complex)

feature_dict = {
i: f for i, f in enumerate(data["x"])
}

simplicial_complex.set_simplex_attributes(feature_dict, name="features")

return self._get_lifted_topology(simplicial_complex, graph)

def _get_lifted_topology(self, simplicial_complex: SimplicialComplex, graph: nx.Graph) -> dict:
data = super()._get_lifted_topology(simplicial_complex, graph)


for r in range(simplicial_complex.dim+1):
data[f"x_idx_{r}"] = torch.tensor(simplicial_complex.skeleton(r))

return data
6 changes: 5 additions & 1 deletion modules/transforms/liftings/lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
from torch_geometric.utils.undirected import is_undirected, to_undirected

from modules.transforms.data_manipulations.manipulations import IdentityTransform
from modules.transforms.feature_liftings.feature_liftings import ProjectionSum
from modules.transforms.feature_liftings.feature_liftings import (
ElementwiseMean,
ProjectionSum,
)

# Implemented Feature Liftings
FEATURE_LIFTINGS = {
"ProjectionSum": ProjectionSum,
"ElementwiseMean": ElementwiseMean,
None: IdentityTransform,
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@

import networkx as nx
import torch
from torch_geometric.utils.convert import from_networkx

from modules.data.utils.utils import load_manual_graph
from modules.transforms.liftings.graph2simplicial.neighborhood_complex_lifting import (
NeighborhoodComplexLifting,
)


class TestNeighborhoodComplexLifting:
"""Test the NeighborhoodComplexLifting class."""

def setup_method(self):
# Load the graph
self.data = load_manual_graph()

# Initialize the NeighborhoodComplexLifting class for dim=3
self.lifting_signed = NeighborhoodComplexLifting(complex_dim=3, signed=True)
self.lifting_unsigned = NeighborhoodComplexLifting(complex_dim=3, signed=False)
self.lifting_high = NeighborhoodComplexLifting(complex_dim=7, signed=False)

# Intialize an empty graph for testing purpouses
self.empty_graph = nx.empty_graph(10)
self.empty_data = from_networkx(self.empty_graph)
self.empty_data["x"] = torch.rand((10, 10))

# Intialize a start graph for testing
self.star_graph = nx.star_graph(5)
self.star_data = from_networkx(self.star_graph)
self.star_data["x"] = torch.rand((6, 1))

# Intialize a random graph for testing purpouses
self.random_graph = nx.fast_gnp_random_graph(5, 0.5)
self.random_data = from_networkx(self.random_graph)
self.random_data["x"] = torch.rand((5, 1))


def has_neighbour(self, simplex_points: list[set]) -> tuple[bool, set[int]]:
""" Verifies that the maximal simplices
of Data representation of a simplicial complex
share a neighbour.
"""
for simplex_point_a in simplex_points:
for simplex_point_b in simplex_points:
# Same point
if simplex_point_a == simplex_point_b:
continue
# Search all nodes to check if they are c such that a and b share c as a neighbour
for node in self.random_graph.nodes:
# They share a neighbour
if self.random_graph.has_edge(simplex_point_a.item(), node) and self.random_graph.has_edge(simplex_point_b.item(), node):
return True
return False

def test_lift_topology_random_graph(self):
""" Verifies that the lifting procedure works on
a random graph, that is, checks that the simplices
generated share a neighbour.
"""
lifted_data = self.lifting_high.forward(self.random_data)
# For each set of simplices
r = max(int(key.split("_")[-1]) for key in list(lifted_data.keys()) if "x_idx_" in key)
idx_str = f"x_idx_{r}"

# Go over each (max_dim)-simplex
for simplex_points in lifted_data[idx_str]:
share_neighbour = self.has_neighbour(simplex_points)
assert share_neighbour, f"The simplex {simplex_points} does not have a common neighbour with all the nodes."

def test_lift_topology_star_graph(self):
""" Verifies that the lifting procedure works on
a small star graph, that is, checks that the simplices
generated share a neighbour.
"""
lifted_data = self.lifting_high.forward(self.star_data)
# For each set of simplices
r = max(int(key.split("_")[-1]) for key in list(lifted_data.keys()) if "x_idx_" in key)
idx_str = f"x_idx_{r}"

# Go over each (max_dim)-simplex
for simplex_points in lifted_data[idx_str]:
share_neighbour = self.has_neighbour(simplex_points)
assert share_neighbour, f"The simplex {simplex_points} does not have a common neighbour with all the nodes."



def test_lift_topology_empty_graph(self):
""" Test the lift_topology method with an empty graph.
"""

lifted_data_signed = self.lifting_signed.forward(self.empty_data)

assert lifted_data_signed.incidence_1.shape[1] == 0, "Something is wrong with signed incidence_1 (nodes to edges)."

assert lifted_data_signed.incidence_2.shape[1] == 0, "Something is wrong with signed incidence_2 (edges to triangles)."

def test_lift_topology(self):
"""Test the lift_topology method."""

# Test the lift_topology method
lifted_data_signed = self.lifting_signed.forward(self.data.clone())
lifted_data_unsigned = self.lifting_unsigned.forward(self.data.clone())

expected_incidence_1 = torch.tensor(
[
[-1., -1., -1., -1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 1., 0., 0., 0., -1., -1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 1., 0., -1., -1., -1., -1., -1., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., -1., 0., 0., 0.],
[ 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., -1., -1., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 0., -1.],
[ 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 1.]
]
)
assert (
abs(expected_incidence_1) == lifted_data_unsigned.incidence_1.to_dense()
).all(), "Something is wrong with unsigned incidence_1 (nodes to edges)."
assert (
expected_incidence_1 == lifted_data_signed.incidence_1.to_dense()
).all(), "Something is wrong with signed incidence_1 (nodes to edges)."

expected_incidence_2 = torch.tensor(
[
[ 0.],
[ 0.],
[ 0.],
[ 0.],
[ 0.],
[ 0.],
[ 0.],
[ 0.],
[ 0.],
[ 1.],
[-1.],
[ 0.],
[ 0.],
[ 0.],
[ 1.]
]
)

assert (
abs(expected_incidence_2) == lifted_data_unsigned.incidence_2.to_dense()
).all(), "Something is wrong with unsigned incidence_2 (edges to triangles)."
assert (
expected_incidence_2 == lifted_data_signed.incidence_2.to_dense()
).all(), "Something is wrong with signed incidence_2 (edges to triangles)."
Loading
Loading