Skip to content

Commit

Permalink
Add test_data_manipulations.py
Browse files Browse the repository at this point in the history
  • Loading branch information
vapavlo committed May 26, 2024
1 parent 539a76a commit c8f6373
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 6 deletions.
1 change: 0 additions & 1 deletion test/data/test_Dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def check_shape(batch, elems, key):
rows += elems[i][f'x_{n}'].shape[0]
assert batch[key].shape[0] == rows
elif key in elems[0].keys():
#assert 0
for i in range(len(batch[key].shape)):
i_elems = 0
for j in range(len(elems)):
Expand Down
95 changes: 95 additions & 0 deletions test/transforms/data_manipulations/test_data_manipulations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Test the collate function."""
import hydra
from hydra import compose, initialize
from omegaconf import OmegaConf

import torch
import torch_geometric


from topobenchmarkx.transforms.data_manipulations import (
InfereKNNConnectivity,
InfereRadiusConnectivity,
KeepSelectedDataFields
)

from topobenchmarkx.utils.config_resolvers import (
get_default_transform,
get_monitor_metric,
get_monitor_mode,
infer_in_channels,
)

import rootutils

rootutils.setup_root("./", indicator=".project-root", pythonpath=True)

class TestCollateFunction:
"""Test collate_fn."""

def setup_method(self):
"""Setup the test.
For this test we load the MUTAG dataset.
Parameters
----------
None
"""
"""
OmegaConf.register_new_resolver("get_default_transform", get_default_transform)
OmegaConf.register_new_resolver("get_monitor_metric", get_monitor_metric)
OmegaConf.register_new_resolver("get_monitor_mode", get_monitor_mode)
OmegaConf.register_new_resolver("infer_in_channels", infer_in_channels)
OmegaConf.register_new_resolver(
"parameter_multiplication", lambda x, y: int(int(x) * int(y))
)
initialize(version_base="1.3", config_path="../../configs", job_name="job")
cfg = compose(config_name="train.yaml", overrides=["dataset=PROTEINS_TU"])
graph_loader = hydra.utils.instantiate(cfg.dataset, _recursive_=False)
datasets = graph_loader.load()
self.batch_size = 2
datamodule = DefaultDataModule(
dataset_train=datasets[0],
dataset_val=datasets[1],
dataset_test=datasets[2],
batch_size=self.batch_size
)
self.val_dataloader = datamodule.val_dataloader()
self.val_dataset = datasets[1]
"""
x = torch.tensor([
[2, 2], [2.2, 2], [2.1, 1.5],
[-3, 2], [-2.7, 2], [-2.5, 1.5],
[-3, -2], [-2.7, -2], [-2.5, -1.5],
])
self.data = torch_geometric.data.Data(
x=x,
num_nodes=len(x),
field_1 = "some text",
field_2 = x.clone(),
preserve_1 = 123,
preserve_2 = torch.tensor((1, 2, 3))
)
# Data Manipulations
self.infere_by_knn = InfereKNNConnectivity(args={"k":3})
self.infere_by_radius = InfereRadiusConnectivity(args={"r":1.})
self.keep_selected_fields = KeepSelectedDataFields(base_fields=["x", "num_nodes"], preserved_fields=["preserve_1", "preserve_2"])


def test_infere_connectivity(self):
data = self.infere_by_knn(self.data.clone())
assert "edge_index" in data, "No edges in Data object"


def test_radius_connectivity(self):
data = self.infere_by_radius(self.data.clone())
assert "edge_index" in data, "No edges in Data object"

#def test_keep_selected_data_fields(self):
# orig_data = self.data.clone()
# data = self.keep_selected_fields(orig_data)
# assert 0
5 changes: 0 additions & 5 deletions test/transforms/feature_liftings/test_ConcatenationLifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,3 @@ def test_lift_features(self):
expected_x3 == lifted_data.x_3
).all(), "Something is wrong with the lifted features x_3."


if __name__ == "__main__":
t = TestConcatentionLifting()
t.setup_method()
t.test_lift_features()

0 comments on commit c8f6373

Please sign in to comment.