Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vapavlo committed May 25, 2024
1 parent d57daff commit cfe150a
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 1,601 deletions.
2 changes: 1 addition & 1 deletion test/data/test_Dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def setup_method(self):
"parameter_multiplication", lambda x, y: int(int(x) * int(y))
)

initialize(version_base="1.3", config_path="../../configs", job_name="job")
initialize(version_base="1.3", config_path="../../configs") #, job_name="job")
cfg = compose(config_name="train.yaml")

graph_loader = hydra.utils.instantiate(cfg.dataset, _recursive_=False)
Expand Down
17 changes: 10 additions & 7 deletions test/transforms/feature_liftings/test_ConcatenationLifting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""Test the message passing module."""
import sys
sys.path.insert(0, '/Users/psh_1/Documents/[ Research ]/TopoBenchmarkX')


import torch

Expand All @@ -19,17 +22,11 @@ def setup_method(self):
self.data = manual_simple_graph()

# Initialize a lifting class
self.lifting = SimplicialCliqueLifting(complex_dim=3)
# Initialize the ConcatentionLifting class
self.feature_lifting = ConcatentionLifting()
self.lifting = SimplicialCliqueLifting(feature_lifting="concatenation", complex_dim=3)

def test_lift_features(self):
# Test the lift_features method
lifted_data = self.lifting.forward(self.data.clone())
del lifted_data.x_1
del lifted_data.x_2
del lifted_data.x_3
lifted_data = self.feature_lifting.forward(lifted_data)

expected_x1 = torch.tensor(
[
Expand Down Expand Up @@ -100,3 +97,9 @@ def test_lift_features(self):
assert (
expected_x3 == lifted_data.x_3
).all(), "Something is wrong with the lifted features x_3."


if __name__ == "__main__":
t = TestConcatentionLifting()
t.setup_method()
t.test_lift_features()
9 changes: 2 additions & 7 deletions test/transforms/feature_liftings/test_SetLifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,12 @@ def setup_method(self):
self.data = manual_simple_graph()

# Initialize a lifting class
self.lifting = SimplicialCliqueLifting(complex_dim=3)
# Initialize the SetLifting class
self.feature_lifting = SetLifting()
self.lifting = SimplicialCliqueLifting(feature_lifting="set", complex_dim=3)


def test_lift_features(self):
# Test the lift_features method
lifted_data = self.lifting.forward(self.data.clone())
del lifted_data.x_1
del lifted_data.x_2
del lifted_data.x_3
lifted_data = self.feature_lifting.forward(lifted_data)

expected_x1 = torch.tensor(
[
Expand Down
128 changes: 8 additions & 120 deletions test/transforms/liftings/cell/test_CellCyclesLifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,126 +22,14 @@ def test_lift_topology(self):

expected_incidence_1 = torch.tensor(
[
[
1.0,
1.0,
1.0,
1.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
],
[
1.0,
0.0,
0.0,
0.0,
1.0,
1.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
],
[
0.0,
1.0,
0.0,
0.0,
1.0,
0.0,
1.0,
1.0,
1.0,
1.0,
0.0,
0.0,
0.0,
],
[
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
1.0,
0.0,
0.0,
0.0,
1.0,
0.0,
0.0,
],
[
0.0,
0.0,
1.0,
0.0,
0.0,
1.0,
0.0,
1.0,
0.0,
0.0,
0.0,
0.0,
0.0,
],
[
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
1.0,
0.0,
0.0,
1.0,
1.0,
],
[
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
1.0,
1.0,
0.0,
],
[
0.0,
0.0,
0.0,
1.0,
0.0,
0.0,
0.0,
0.0,
0.0,
1.0,
0.0,
0.0,
1.0,
],
[ 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ],
[ 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ],
[ 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, ],
[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ],
[ 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ],
[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, ],
[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, ],
[ 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, ],
]
)

Expand Down
159 changes: 24 additions & 135 deletions test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,131 +27,20 @@ def test_lift_topology(self):
"""Test the lift_topology method."""

# Test the lift_topology method
print(self.data)
lifted_data_signed = self.lifting_signed.forward(self.data.clone())
lifted_data_unsigned = self.lifting_unsigned.forward(self.data.clone())

expected_incidence_1 = torch.tensor(
[
[
-1.0,
-1.0,
-1.0,
-1.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
],
[
1.0,
0.0,
0.0,
0.0,
-1.0,
-1.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
],
[
0.0,
1.0,
0.0,
0.0,
1.0,
0.0,
-1.0,
-1.0,
-1.0,
-1.0,
0.0,
0.0,
0.0,
],
[
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
1.0,
0.0,
0.0,
0.0,
-1.0,
0.0,
0.0,
],
[
0.0,
0.0,
1.0,
0.0,
0.0,
1.0,
0.0,
1.0,
0.0,
0.0,
0.0,
0.0,
0.0,
],
[
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
1.0,
0.0,
0.0,
-1.0,
-1.0,
],
[
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
1.0,
1.0,
0.0,
],
[
0.0,
0.0,
0.0,
1.0,
0.0,
0.0,
0.0,
0.0,
0.0,
1.0,
0.0,
0.0,
1.0,
],
[ -1.0, -1.0, -1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ],
[ 1.0, 0.0, 0.0, 0.0, -1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ],
[ 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, -1.0, -1.0, -1.0, -1.0, 0.0, 0.0, 0.0, ],
[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, ],
[ 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ],
[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, -1.0, -1.0, ],
[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, ],
[ 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, ],
]
)

Expand Down Expand Up @@ -213,19 +102,19 @@ def test_lifted_features_signed(self):

expected_features_1 = torch.tensor(
[
[4],
[9],
[99],
[4999],
[5],
[95],
[40],
[90],
[490],
[4990],
[950],
[500],
[4500],
[6.0],
[11.0],
[101.0],
[5001.0],
[15.0],
[105.0],
[60.0],
[110.0],
[510.0],
[5010.0],
[1050.0],
[1500.0],
[5500.0],
]
)

Expand All @@ -234,14 +123,14 @@ def test_lifted_features_signed(self):
).all(), "Something is wrong with x_1 features."

expected_features_2 = torch.tensor(
[[0.0], [0.0], [0.0], [0.0], [0.0], [0.0]]
[[32.0], [212.0], [222.0], [10022.0], [230.0], [11020.0]]
)

assert (
expected_features_2 == lifted_data.x_2
).all(), "Something is wrong with x_2 features."

excepted_features_3 = torch.tensor([[0.0]])
excepted_features_3 = torch.tensor([[696.0]])

assert (
excepted_features_3 == lifted_data.x_3
Expand Down
Loading

0 comments on commit cfe150a

Please sign in to comment.