From d3fe8be6637360054c1968a0457c57e0adfbedd0 Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Mon, 25 Nov 2024 22:50:23 +0100 Subject: [PATCH] added test --- test/nn/backbones/graph/test_graph_dgm.py | 37 +++++ test/nn/encoders/test_dgm.py | 172 ++++++++++++++++++++++ topobenchmarkx/nn/encoders/alpha_dgm.py | 118 --------------- topobenchmarkx/nn/encoders/dgm_encoder.py | 4 +- 4 files changed, 211 insertions(+), 120 deletions(-) create mode 100644 test/nn/backbones/graph/test_graph_dgm.py create mode 100644 test/nn/encoders/test_dgm.py delete mode 100644 topobenchmarkx/nn/encoders/alpha_dgm.py diff --git a/test/nn/backbones/graph/test_graph_dgm.py b/test/nn/backbones/graph/test_graph_dgm.py new file mode 100644 index 00000000..5810414d --- /dev/null +++ b/test/nn/backbones/graph/test_graph_dgm.py @@ -0,0 +1,37 @@ +"""Unit tests for GraphMLP.""" + +import torch +import torch_geometric +from topobenchmarkx.nn.backbones.graph import GraphMLP +from topobenchmarkx.nn.wrappers.graph import GraphMLPWrapper +from topobenchmarkx.loss.model import GraphMLPLoss + +def testGraphMLP(random_graph_input): + """ Unit test for GraphMLP. + + Parameters + ---------- + random_graph_input : Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]] + A tuple of input tensors for testing EDGNN. + """ + x, x_1, x_2, edges_1, edges_2 = random_graph_input + batch = torch_geometric.data.Data(x_0=x, y=x, edge_index=edges_1, batch_0=torch.zeros(x.shape[0], dtype=torch.long)) + model = GraphMLP(x.shape[1], x.shape[1]) + wrapper = GraphMLPWrapper(model, **{"out_channels": x.shape[1], "num_cell_dimensions": 1}) + loss_fn = GraphMLPLoss() + + _ = wrapper.__repr__() + _ = loss_fn.__repr__() + + model_out = wrapper(batch) + assert model_out["x_0"].shape == x.shape + assert list(model_out["x_dis"].shape) == [8,8] + + loss = loss_fn(model_out, batch) + assert loss.item() >= 0 + + model_out["x_dis"] = None + loss = loss_fn(model_out, batch) + assert loss == torch.tensor(0.0) + + diff --git a/test/nn/encoders/test_dgm.py b/test/nn/encoders/test_dgm.py new file mode 100644 index 00000000..4e03d046 --- /dev/null +++ b/test/nn/encoders/test_dgm.py @@ -0,0 +1,172 @@ +"""Unit tests for the DGMStructureFeatureEncoder module.""" + +import pytest +import torch +import torch_geometric +import numpy as np + +from topobenchmarkx.nn.encoders import DGMStructureFeatureEncoder +from topobenchmarkx.nn.encoders.kdgm import DGM_d + +class TestDGMStructureFeatureEncoder: + """Test suite for the DGMStructureFeatureEncoder class. + + This test class covers various aspects of the DGMStructureFeatureEncoder, + including initialization, forward pass, selective encoding, and + configuration settings. + """ + + @pytest.fixture + def sample_data(self): + """Create a sample PyG Data object for testing. + + Returns + ------- + torch_geometric.data.Data + A data object with simulated multi-dimensional features and batch information. + """ + data = torch_geometric.data.Data() + + # Simulate multi-dimensional features + data.x_0 = torch.randn(10, 5) # 10 nodes, 5 features + data.x_1 = torch.randn(10, 7) # 10 nodes, 7 features + data.x_2 = torch.randn(10, 9) # 10 nodes, 9 features + + # Add batch information + data.batch_0 = torch.zeros(10, dtype=torch.long) + data.batch_1 = torch.zeros(10, dtype=torch.long) + data.batch_2 = torch.zeros(10, dtype=torch.long) + + return data + + def test_initialization(self, sample_data): + """Test encoder initialization with different configurations. + + Parameters + ---------- + sample_data : torch_geometric.data.Data + Fixture providing sample graph data for testing. + """ + # Test with default settings + encoder = DGMStructureFeatureEncoder( + in_channels=[5, 7, 9], + out_channels=64 + ) + + # Test __repr__ method + repr_str = encoder.__repr__() + + # Check basic attributes + assert encoder.in_channels == [5, 7, 9] + assert encoder.out_channels == 64 + assert len(encoder.dimensions) == 3 + + def test_forward_pass(self, sample_data): + """Test forward pass of the encoder. + + Parameters + ---------- + sample_data : torch_geometric.data.Data + Fixture providing sample graph data for testing. + """ + encoder = DGMStructureFeatureEncoder( + in_channels=[5, 7, 9], + out_channels=64, + selected_dimensions=[0, 1, 2] + ) + + # Perform forward pass + output_data = encoder(sample_data) + + # Check output attributes + for i in [0, 1, 2]: + # Check encoded features exist + assert hasattr(output_data, f'x_{i}') + assert output_data[f'x_{i}'].shape[1] == 64 + + # Check auxiliary attributes + assert hasattr(output_data, f'x_aux_{i}') + assert hasattr(output_data, f'logprobs_{i}') + + # Check edges index exists + assert 'edges_index' in output_data + + def test_selective_encoding(self, sample_data): + """Test encoding only specific dimensions. + + Parameters + ---------- + sample_data : torch_geometric.data.Data + Fixture providing sample graph data for testing. + """ + encoder = DGMStructureFeatureEncoder( + in_channels=[5, 7, 9], + out_channels=64, + selected_dimensions=[0, 1] # Only encode the first two dimensions + ) + + # Perform forward pass + output_data = encoder(sample_data) + + # Verify encoding for selected dimensions + assert hasattr(output_data, 'x_1') + assert output_data['x_0'].shape[1] == 64 + assert output_data['x_1'].shape[1] == 64 + assert output_data['x_2'].shape[1] == 9 + + def test_dropout_configuration(self): + """Test dropout configuration for the encoder.""" + # Test with non-zero dropout + encoder = DGMStructureFeatureEncoder( + in_channels=[5, 7, 9], + out_channels=64, + proj_dropout=0.5 + ) + + # Check dropout value + for i in encoder.dimensions: + encoder_module = getattr(encoder, f'encoder_{i}') + assert encoder_module.base_enc.dropout.p == 0.5 + assert encoder_module.embed_f.dropout.p == 0.5 + + @pytest.mark.parametrize("in_channels", [ + [5], # Single dimension + [5, 7, 9], # Multiple dimensions + [10, 20, 30, 40] # More dimensions + ]) + def test_variable_input_dimensions(self, sample_data, in_channels): + """Test encoder with varying input dimensions. + + Parameters + ---------- + sample_data : torch_geometric.data.Data + Fixture providing sample graph data for testing. + in_channels : list + List of input channel dimensions to test. + """ + encoder = DGMStructureFeatureEncoder( + in_channels=in_channels, + out_channels=64 + ) + + # Prepare data dynamically + data = torch_geometric.data.Data() + for i, channel in enumerate(in_channels): + setattr(data, f'x_{i}', torch.randn(10, channel)) + setattr(data, f'batch_{i}', torch.zeros(10, dtype=torch.long)) + + # Perform forward pass + output_data = encoder(data) + + # Verify encoding for each dimension + for i in range(len(in_channels)): + assert hasattr(output_data, f'x_{i}') + assert output_data[f'x_{i}'].shape[1] == 64 + +def pytest_configure(): + """Custom pytest configuration. + + Sets up default configuration values for testing. + """ + pytest.in_channels = [5, 7, 9] + pytest.out_channels = 64 \ No newline at end of file diff --git a/topobenchmarkx/nn/encoders/alpha_dgm.py b/topobenchmarkx/nn/encoders/alpha_dgm.py deleted file mode 100644 index 67d1544e..00000000 --- a/topobenchmarkx/nn/encoders/alpha_dgm.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Encoder class to apply BaseEncoder.""" - -import torch -import torch.nn as nn -from entmax import entmax15 - - -class AlphaDGM(nn.Module): - """DGM. - - Parameters - ---------- - base_enc : nn.Module - Base encoder. - embed_f : nn.Module - Embedding function. - gamma : float, optional - Gamma parameter for the LayerNorm. - std : float, optional - Standard deviation for the normal distribution. - """ - - def __init__( - self, base_enc: nn.Module, embed_f: nn.Module, gamma=10, std=0 - ): - super().__init__() - self.ln = LayerNorm(gamma) - self.std = std - self.base_enc = base_enc - self.embed_f = embed_f - - def forward(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor: - """Forward pass. - - Parameters - ---------- - x : torch.Tensor - Input tensor. - - batch : torch.Tensor - Batch tensor. - - Returns - ------- - torch.Tensor - Output tensor. - """ - # Input embedding - x_ = self.base_enc(x, batch) - x_aux = self.embed_f(x, batch) - edges_hat, logprobs = entmax(x=x_aux, ln=self.ln, std=self.std) - - return x_, x_aux, edges_hat, logprobs - - -class LayerNorm(nn.Module): - """LayerNorm with gamma and beta parameters. - - Parameters - ---------- - gamma : torch.tensor - Gamma parameter for the LayerNorm. - """ - - def __init__(self, gamma): - super().__init__() - self.gamma = nn.Parameter(gamma * torch.ones(1)) - self.beta = nn.Parameter(torch.zeros(1)) - self.eps = 1e-6 - - def forward(self, x): - """LayerNorm with gamma and beta parameters. - - Parameters - ---------- - x : torch.tensor - Input tensor. - - Returns - ------- - torch.tensor - Output tensor. - """ - mean = x.mean(-1, keepdim=True) - std = x.std(-1, keepdim=True) - if x.size(-1) == 1: - std = 1 - return self.gamma * (x - mean) / (std + self.eps) + self.beta - - -def entmax(x: torch.tensor, ln, std=0): - """Entmax function. - - Parameters - ---------- - x : torch.tensor - Input tensor. - - ln : torch.tensor - Layer normalization. - - std : float, optional - Standard deviation for the normal distribution. - - Returns - ------- - torch.tensor - Output tensor. - """ - probs = -torch.cdist(x, x) - probs = probs + torch.empty(probs.size(), device=probs.device).normal_( - mean=0, std=std - ) - vprobs = entmax15(ln(probs).fill_diagonal_(-1e-6), dim=-1) - res = (((vprobs + vprobs.t()) / 2) > 0) * 1 - edges = res.nonzero().t_() - logprobs = res.sum(dim=1) - return edges, logprobs diff --git a/topobenchmarkx/nn/encoders/dgm_encoder.py b/topobenchmarkx/nn/encoders/dgm_encoder.py index aad40378..85e7101c 100644 --- a/topobenchmarkx/nn/encoders/dgm_encoder.py +++ b/topobenchmarkx/nn/encoders/dgm_encoder.py @@ -5,7 +5,7 @@ from topobenchmarkx.nn.encoders.all_cell_encoder import BaseEncoder from topobenchmarkx.nn.encoders.base import AbstractFeatureEncoder -from .kdgm import DGM_d # . AlphaDGM +from .kdgm import DGM_d class DGMStructureFeatureEncoder(AbstractFeatureEncoder): @@ -42,6 +42,7 @@ def __init__( self.in_channels = in_channels self.out_channels = out_channels + self.dimensions = ( selected_dimensions if ( @@ -64,7 +65,6 @@ def __init__( setattr( self, f"encoder_{i}", - # AlphaDGM(base_enc=base_enc, embed_f=embed_f), DGM_d(base_enc=base_enc, embed_f=embed_f), )