From e1e8c8c600db58d1812a450322e61712dae59fd4 Mon Sep 17 00:00:00 2001 From: pojurer <56473157+vapavlo@users.noreply.github.com> Date: Sat, 21 Dec 2024 16:05:31 +0000 Subject: [PATCH] Add readouts tests --- test/nn/readouts/__init__.py | 0 test/nn/readouts/test_identical.py | 107 ++++++++++++++ .../nn/readouts/test_propagate_signal_down.py | 131 ++++++++++++++++++ 3 files changed, 238 insertions(+) create mode 100644 test/nn/readouts/__init__.py create mode 100644 test/nn/readouts/test_identical.py create mode 100644 test/nn/readouts/test_propagate_signal_down.py diff --git a/test/nn/readouts/__init__.py b/test/nn/readouts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/nn/readouts/test_identical.py b/test/nn/readouts/test_identical.py new file mode 100644 index 00000000..51fc53e5 --- /dev/null +++ b/test/nn/readouts/test_identical.py @@ -0,0 +1,107 @@ +import pytest +import torch +import torch_geometric.data as tg_data +from topobenchmark.nn.readouts.base import AbstractZeroCellReadOut +from topobenchmark.nn.readouts.identical import NoReadOut + + +class TestNoReadOut: + @pytest.fixture + def base_kwargs(self): + """Fixture providing the required base parameters.""" + return { + 'hidden_dim': 64, + 'out_channels': 32, + 'task_level': 'graph' + } + + @pytest.fixture + def readout_layer(self, base_kwargs): + """Fixture to create a NoReadOut instance for testing.""" + return NoReadOut(**base_kwargs) + + @pytest.fixture + def sample_model_output(self): + """Fixture to create a sample model output dictionary.""" + return { + 'x_0': torch.randn(10, 64), # Required key for model output + 'edge_indices': torch.randint(0, 10, (2, 15)), + 'other_data': torch.randn(10, 32) + } + + @pytest.fixture + def sample_batch(self): + """Fixture to create a sample batch of graph data.""" + return tg_data.Data( + x=torch.randn(10, 32), + edge_index=torch.randint(0, 10, (2, 15)), + batch_0=torch.zeros(10, dtype=torch.long) # Required key for batch data + ) + + def test_initialization(self, base_kwargs): + """Test that NoReadOut initializes correctly with required parameters.""" + readout = NoReadOut(**base_kwargs) + assert isinstance(readout, NoReadOut) + assert isinstance(readout, AbstractZeroCellReadOut) + + def test_forward_pass_returns_unchanged_output(self, readout_layer, sample_model_output, sample_batch): + """Test that forward pass returns the model output without modifications.""" + original_output = sample_model_output.copy() + output = readout_layer(sample_model_output, sample_batch) + + # The output should contain the original data plus the computed logits + for key in original_output: + assert key in output + assert torch.equal(output[key], original_output[key]) + assert 'logits' in output + + def test_invalid_task_level(self, base_kwargs): + """Test that initialization fails with invalid task_level.""" + invalid_kwargs = base_kwargs.copy() + invalid_kwargs['task_level'] = 'invalid_level' + with pytest.raises(AssertionError, match="Invalid task_level"): + NoReadOut(**invalid_kwargs) + + def test_repr(self, readout_layer): + """Test the string representation of the NoReadOut layer.""" + assert str(readout_layer) == "NoReadOut()" + assert repr(readout_layer) == "NoReadOut()" + + def test_forward_pass_with_different_batch_sizes(self, readout_layer): + """Test that forward pass works with different batch sizes.""" + # Test with single graph + single_batch = tg_data.Data( + x=torch.randn(5, 32), + edge_index=torch.randint(0, 5, (2, 8)), + batch_0=torch.zeros(5, dtype=torch.long) + ) + single_output = { + 'x_0': torch.randn(5, 64), + 'embeddings': torch.randn(5, 64) + } + result = readout_layer(single_output, single_batch) + assert 'logits' in result + + # Test with multiple graphs + multi_batch = tg_data.Data( + x=torch.randn(15, 32), + edge_index=torch.randint(0, 15, (2, 25)), + batch_0=torch.cat([torch.zeros(5), torch.ones(5), torch.ones(5) * 2]).long() + ) + multi_output = { + 'x_0': torch.randn(15, 64), + 'embeddings': torch.randn(15, 64) + } + result = readout_layer(multi_output, multi_batch) + assert 'logits' in result + + def test_kwargs_handling(self, base_kwargs): + """Test that the layer correctly handles both required and additional keyword arguments.""" + additional_kwargs = { + 'random_param': 42, + 'another_param': 'test', + 'pooling_type': 'mean' # Valid additional parameter + } + kwargs = {**base_kwargs, **additional_kwargs} + readout = NoReadOut(**kwargs) + assert isinstance(readout, NoReadOut) \ No newline at end of file diff --git a/test/nn/readouts/test_propagate_signal_down.py b/test/nn/readouts/test_propagate_signal_down.py new file mode 100644 index 00000000..1f21de58 --- /dev/null +++ b/test/nn/readouts/test_propagate_signal_down.py @@ -0,0 +1,131 @@ +import pytest +import torch +import torch_geometric.data as tg_data +import topomodelx +from topobenchmark.nn.readouts.propagate_signal_down import PropagateSignalDown + + +class TestPropagateSignalDown: + @pytest.fixture + def base_kwargs(self): + """Fixture providing the required base parameters.""" + return { + 'hidden_dim': 64, + 'out_channels': 32, + 'task_level': 'graph', + 'num_cell_dimensions': 2, # Need at least 2 dimensions for signal propagation + 'readout_name': 'test_readout' + } + + @pytest.fixture + def readout_layer(self, base_kwargs): + """Fixture to create a PropagateSignalDown instance for testing.""" + layer = PropagateSignalDown(**base_kwargs) + layer.hidden_dim = base_kwargs['hidden_dim'] + return layer + + @pytest.fixture + def create_sparse_incidence_matrix(self): + """Helper fixture to create sparse incidence matrices.""" + def _create_matrix(num_source, num_target, sparsity=0.3): + num_entries = int(num_source * num_target * sparsity) + indices = torch.zeros((2, num_entries), dtype=torch.long) + values = torch.ones(num_entries) + + for i in range(num_entries): + source = torch.randint(0, num_source, (1,)) + target = torch.randint(0, num_target, (1,)) + indices[0, i] = source + indices[1, i] = target + values[i] = torch.randint(0, 2, (1,)) * 2 - 1 # {-1, 1} values + + sparse_matrix = torch.sparse_coo_tensor( + indices=torch.stack([indices[1], indices[0]]), + values=values, + size=(num_target, num_source) + ).coalesce() + + return sparse_matrix + return _create_matrix + + @pytest.fixture + def sample_batch(self, create_sparse_incidence_matrix): + """Fixture to create a sample batch with required incidence matrices.""" + num_nodes = 10 + num_edges = 15 + + return tg_data.Data( + x=torch.randn(num_nodes, 64), + edge_index=torch.randint(0, num_nodes, (2, num_edges)), + batch_0=torch.zeros(num_nodes, dtype=torch.long), + incidence_1=create_sparse_incidence_matrix(num_edges, num_nodes) + ) + + @pytest.fixture + def sample_model_output(self, sample_batch): + """Fixture to create a sample model output with cell embeddings.""" + hidden_dim = 64 + + num_nodes = sample_batch.x.size(0) + num_edges = sample_batch.edge_index.size(1) + + return { + 'logits': torch.randn(num_nodes, hidden_dim), + 'x_0': torch.randn(num_nodes, hidden_dim), + 'x_1': torch.randn(num_edges, hidden_dim), + } + + def test_forward_propagation(self, readout_layer, sample_model_output, sample_batch): + """Test the forward pass with detailed assertions.""" + initial_output = {k: v.clone() for k, v in sample_model_output.items()} + sample_model_output['x_0'] = sample_model_output['logits'] + + output = readout_layer(sample_model_output, sample_batch) + + assert 'x_0' in output + assert output['x_0'].shape == initial_output['logits'].shape + assert output['x_0'].dtype == torch.float32 + + assert 'x_1' in output + assert output['x_1'].shape == initial_output['x_1'].shape + assert output['x_1'].dtype == torch.float32 + + @pytest.mark.parametrize('missing_key', ['incidence_1']) + def test_missing_incidence_matrix(self, readout_layer, sample_model_output, sample_batch, missing_key): + """Test handling of missing incidence matrices.""" + invalid_batch = tg_data.Data(**{k: v for k, v in sample_batch.items() if k != missing_key}) + sample_model_output['x_0'] = sample_model_output['logits'] + + with pytest.raises(KeyError): + readout_layer(sample_model_output, invalid_batch) + + @pytest.mark.parametrize('missing_key', ['x_1']) # Changed to only test x_1 + def test_missing_cell_features(self, readout_layer, sample_model_output, sample_batch, missing_key): + """Test handling of missing cell features.""" + invalid_output = {k: v for k, v in sample_model_output.items() if k != missing_key} + invalid_output['x_0'] = invalid_output['logits'] # Always map logits to x_0 + + with pytest.raises(KeyError): + readout_layer(invalid_output, sample_batch) + + def test_gradient_flow(self, readout_layer, sample_model_output, sample_batch): + """Test gradient flow through the network.""" + # Create a copy of logits tensor to track gradients properly + logits = sample_model_output['logits'].clone().detach().requires_grad_(True) + x_1 = sample_model_output['x_1'].clone().detach().requires_grad_(True) + + model_output = { + 'logits': logits, + 'x_0': logits, # Share the same tensor + 'x_1': x_1 + } + + output = readout_layer(model_output, sample_batch) + loss = output['x_0'].sum() + loss.backward() + + # Check gradient flow + assert logits.grad is not None + assert not torch.allclose(logits.grad, torch.zeros_like(logits.grad)) + assert x_1.grad is not None + assert not torch.allclose(x_1.grad, torch.zeros_like(x_1.grad)) \ No newline at end of file