Skip to content

Commit

Permalink
Merge pull request #252 from pyt-team/simplicial_checks
Browse files Browse the repository at this point in the history
Simplicial checks
  • Loading branch information
ninamiolane authored Nov 15, 2023
2 parents c75ddd1 + d25a35b commit 9586e2d
Show file tree
Hide file tree
Showing 10 changed files with 1,119 additions and 624 deletions.
8 changes: 4 additions & 4 deletions test/nn/simplicial/test_sca_cmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,8 @@ def test_forward(self):
channels_list = [x_0.shape[-1], x_1.shape[-1], x_2.shape[-1]]
complex_dim = 3
model = SCACMPS(
channels_list=channels_list,
in_channels_all=channels_list,
complex_dim=complex_dim,
n_classes=1,
n_layers=3,
att=False,
)
Expand All @@ -57,15 +56,16 @@ def test_forward(self):
incidence_t_list = [incidence_1t, incidence_2t]
forward_pass = model(x_list, down_lap_list, incidence_t_list)
assert torch.any(
torch.isclose(forward_pass, torch.tensor([-4.8042]), rtol=1e-02)
torch.isclose(
forward_pass[0][0], torch.tensor([1.9269, 1.4873]), rtol=1e-02
)
)

def test_reset_parameters(self):
"""Test the reset_parameters method of SCA."""
model = SCACMPS(
[2, 2, 2],
2,
n_classes=1,
n_layers=3,
att=False,
)
Expand Down
90 changes: 90 additions & 0 deletions test/nn/simplicial/test_scconv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Unit tests for SCCNN Model."""
import itertools
import random

import torch
from toponetx.classes import SimplicialComplex

from topomodelx.nn.simplicial.scconv import SCConv
from topomodelx.utils.sparse import from_sparse


class TestSCConv:
"""Unit tests for the SCConv model class."""

def test_forward(self):
"""Test the forward method of SCConv."""
faces = 14
node_creation = 17
nodes_per_face = 3
seed_value = 42
random.seed(seed_value)
torch.manual_seed(seed_value)
# Create a random cell complex of cells with length 3
all_combinations = list(
itertools.combinations(
[x for x in range(1, node_creation + 1)], nodes_per_face
)
)
random.shuffle(all_combinations)
selected_combinations = all_combinations[:faces]
simplicial_complex = SimplicialComplex()
for simplex in selected_combinations:
simplicial_complex.add_simplex(simplex)
# Some nodes might not be selected at all in the combinations above
x_0 = torch.randn(simplicial_complex.shape[0], 2)
x_1 = torch.randn(simplicial_complex.shape[1], 2)
x_2 = torch.randn(faces, 2)

incidence_1_norm = from_sparse(simplicial_complex.incidence_matrix(1))
incidence_1 = from_sparse(simplicial_complex.coincidence_matrix(1))
incidence_2_norm = from_sparse(simplicial_complex.incidence_matrix(2))
incidence_2 = from_sparse(simplicial_complex.coincidence_matrix(2))
adjacency_up_0_norm = from_sparse(simplicial_complex.up_laplacian_matrix(0))
adjacency_up_1_norm = from_sparse(simplicial_complex.up_laplacian_matrix(1))
adjacency_down_1_norm = from_sparse(simplicial_complex.down_laplacian_matrix(1))
adjacency_down_2_norm = from_sparse(simplicial_complex.down_laplacian_matrix(2))

in_channels = x_0.shape[1]
n_layers = 2
model = SCConv(
node_channels=in_channels,
n_layers=n_layers,
)

with torch.no_grad():
forward_pass = model(
x_0,
x_1,
x_2,
incidence_1,
incidence_1_norm,
incidence_2,
incidence_2_norm,
adjacency_up_0_norm,
adjacency_up_1_norm,
adjacency_down_1_norm,
adjacency_down_2_norm,
)
assert torch.any(
torch.isclose(
forward_pass[0][0],
torch.tensor(
[
0.8847,
0.9963,
]
),
rtol=1e-02,
)
)

def test_reset_parameters(self):
"""Test the reset_parameters method of SCConv."""
model = SCConv(4, 2)
for layer in model.children():
if hasattr(layer, "reset_parameters"):
layer.reset_parameters()
for module in model.modules():
if hasattr(module, "reset_parameters"):
module.reset_parameters()
8 changes: 4 additions & 4 deletions test/nn/simplicial/test_scconv_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ def test_forward(self):
x_0 = torch.randn(n_nodes, node_channels)
x_1 = torch.randn(n_edges, edge_channels)
x_2 = torch.randn(n_faces, face_channels)
incidence_1 = torch.randint(0, 2, (n_nodes, n_edges)).float()
incidence_1 = torch.randint(0, 2, (n_nodes, n_edges)).float().T
# incidence_1_norm = torch.randint(0, 2, (n_nodes, n_edges)).float()
incidence_1_norm = torch.randint(0, 2, (n_edges, n_nodes)).float()
incidence_1_norm = torch.randint(0, 2, (n_edges, n_nodes)).float().T

incidence_2 = torch.randint(0, 2, (n_edges, n_faces)).float()
incidence_2 = torch.randint(0, 2, (n_edges, n_faces)).float().T
# incidence_2_norm = torch.randint(0, 2, (n_edges, n_faces)).float()
incidence_2_norm = torch.randint(0, 2, (n_faces, n_edges)).float()
incidence_2_norm = torch.randint(0, 2, (n_faces, n_edges)).float().T

adjacency_up_0_norm = torch.randint(0, 2, (n_nodes, n_nodes)).float()
adjacency_up_1_norm = torch.randint(0, 2, (n_edges, n_edges)).float()
Expand Down
45 changes: 9 additions & 36 deletions topomodelx/nn/simplicial/sca_cmps.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""SCA with CMPS."""
import torch

from topomodelx.base.aggregation import Aggregation
from topomodelx.nn.simplicial.sca_cmps_layer import SCACMPSLayer


Expand All @@ -10,12 +9,10 @@ class SCACMPS(torch.nn.Module):
Parameters
----------
channels_list : list[int]
in_channels_all : list[int]
Dimension of features on each node, edge, simplex, tetahedron,... respectively
complex_dim : int
Highest dimension of simplicial complex feature being trained on.
n_classes : int
Dimension to which the complex embeddings will be projected.
n_layers : int, default = 2
Amount of message passing layers.
att : bool
Expand All @@ -24,37 +21,28 @@ class SCACMPS(torch.nn.Module):

def __init__(
self,
channels_list,
in_channels_all,
complex_dim,
n_classes,
n_layers=2,
att=False,
):
super().__init__()
self.n_layers = n_layers
self.channels_list = channels_list
self.n_classes = n_classes
self.in_channels_all = in_channels_all

layers = []
for _ in range(n_layers):
layers.append(SCACMPSLayer(channels_list, complex_dim, att))
layers.append(SCACMPSLayer(in_channels_all, complex_dim, att))

self.layers = torch.nn.ModuleList(layers)
self.lin0 = torch.nn.Linear(channels_list[0], n_classes)
self.lin1 = torch.nn.Linear(channels_list[1], n_classes)
self.lin2 = torch.nn.Linear(channels_list[2], n_classes)
self.aggr = Aggregation(
aggr_func="mean",
update_func="sigmoid",
)

def forward(self, x_list, laplacian_down_list, incidence_t_list):
def forward(self, x, laplacian_down_list, incidence_t_list):
"""Forward computation through layers, then linear layers, then avg pooling.
Parameters
----------
x_list : list[torch.Tensor]
List of tensor inputs for each dimension of the complex (nodes, edges, etc.).
x : list[torch.Tensor]
Tensor inputs for each dimension of the complex (nodes, edges, etc.).
laplacian_down_list : list[torch.Tensor]
List of the down laplacian matrix for each dimension in the complex starting at edges.
incidence_t_list : list[torch.Tensor]
Expand All @@ -66,21 +54,6 @@ def forward(self, x_list, laplacian_down_list, incidence_t_list):
Label assigned to whole complex.
"""
for i in range(self.n_layers):
x_list = self.layers[i](x_list, laplacian_down_list, incidence_t_list)
x = self.layers[i](x, laplacian_down_list, incidence_t_list)

x_0 = self.lin0(x_list[0])
x_1 = self.lin1(x_list[1])
x_2 = self.lin2(x_list[2])

two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)
two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0
one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)
one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0
zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)
zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0

x_2f = torch.flatten(two_dimensional_cells_mean)
x_1f = torch.flatten(one_dimensional_cells_mean)
x_0f = torch.flatten(zero_dimensional_cells_mean)

return x_0f + x_1f + x_2f
return x
41 changes: 8 additions & 33 deletions topomodelx/nn/simplicial/scconv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Simplicial 2-Complex Convolutional Network Implementation for binary node classification."""
import torch

from topomodelx.base.aggregation import Aggregation
from topomodelx.nn.simplicial.scconv_layer import SCConvLayer


Expand All @@ -26,32 +25,25 @@ class SCConv(torch.nn.Module):
"""

def __init__(
self, node_channels, edge_channels, face_channels, n_classes, n_layers=2
self, node_channels, edge_channels=None, face_channels=None, n_layers=2
):
super().__init__()
self.node_channels = node_channels
self.edge_channels = edge_channels
self.face_channels = face_channels
self.n_classes = n_classes
self.edge_channels = node_channels if edge_channels is None else edge_channels
self.face_channels = node_channels if face_channels is None else face_channels
self.n_layers = n_layers

layers = []
for _ in range(n_layers):
layers.append(
SCConvLayer(
node_channels=node_channels,
edge_channels=edge_channels,
face_channels=face_channels,
node_channels=self.node_channels,
edge_channels=self.edge_channels,
face_channels=self.face_channels,
)
)

self.layers = torch.nn.ModuleList(layers)
self.linear_x0 = torch.nn.Linear(node_channels, self.n_classes)
self.linear_x1 = torch.nn.Linear(edge_channels, self.n_classes)
self.linear_x2 = torch.nn.Linear(face_channels, self.n_classes)
self.aggr = Aggregation(
aggr_func="mean",
update_func="sigmoid",
)

def forward(
self,
Expand Down Expand Up @@ -115,21 +107,4 @@ def forward(
adjacency_down_2_norm,
)

x_0 = self.linear_x0(x_0)
x_1 = self.linear_x1(x_1)
x_2 = self.linear_x2(x_2)

node_mean = torch.nanmean(x_0, dim=0)
node_mean[torch.isnan(node_mean)] = 0

edge_mean = torch.nanmean(x_1, dim=0)
edge_mean[torch.isnan(edge_mean)] = 0

face_mean = torch.nanmean(x_2, dim=0)
face_mean[torch.isnan(face_mean)] = 0

return (
torch.flatten(node_mean)
+ torch.flatten(edge_mean)
+ torch.flatten(face_mean)
)
return x_0, x_1, x_2
8 changes: 4 additions & 4 deletions topomodelx/nn/simplicial/scconv_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,16 @@ def forward(
"""
x0_level_0_0 = self.conv_0_to_0(x_0, adjacency_up_0_norm)

x0_level_1_0 = self.conv_1_to_0(x_1, incidence_1)
x0_level_1_0 = self.conv_1_to_0(x_1, incidence_1.T)

x0_level_0_1 = self.conv_0_to_1(x_0, incidence_1_norm)
x0_level_0_1 = self.conv_0_to_1(x_0, incidence_1_norm.T)

adj_norm = adjacency_down_1_norm.add(adjacency_up_1_norm)
x1_level_1_1 = self.conv_1_to_1(x_1, adj_norm)

x2_level_2_1 = self.conv_2_to_1(x_2, incidence_2)
x2_level_2_1 = self.conv_2_to_1(x_2, incidence_2.T)

x1_level_1_2 = self.conv_1_to_2(x_1, incidence_2_norm)
x1_level_1_2 = self.conv_1_to_2(x_1, incidence_2_norm.T)

x_2_level_2_2 = self.conv_2_to_2(x_2, adjacency_down_2_norm)

Expand Down
3 changes: 1 addition & 2 deletions tutorials/simplicial/san_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@
"x_2 = []\n",
"for k, v in dataset.get_simplex_attributes(\"face_feat\").items():\n",
" x_2.append(v)\n",
"x_2 = np.stack(x_2)\n",
"x_2 = torch.tensor(np.stack(x_2))\n",
"print(f\"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.\")"
]
},
Expand Down Expand Up @@ -775,7 +775,6 @@
" y_pred_test = torch.where(\n",
" y_hat_test > 0.5, torch.tensor(1), torch.tensor(0)\n",
" )\n",
" # _pred_test = torch.softmax(y_hat_test,dim=1).ge(0.5).float()\n",
" test_accuracy = (\n",
" torch.eq(y_pred_test[-len(y_test) :], y_test)\n",
" .all(dim=1)\n",
Expand Down
Loading

0 comments on commit 9586e2d

Please sign in to comment.