Skip to content

Commit

Permalink
Add readouts tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vapavlo committed Dec 21, 2024
1 parent aba9d16 commit e1e8c8c
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 0 deletions.
Empty file added test/nn/readouts/__init__.py
Empty file.
107 changes: 107 additions & 0 deletions test/nn/readouts/test_identical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import pytest
import torch
import torch_geometric.data as tg_data
from topobenchmark.nn.readouts.base import AbstractZeroCellReadOut
from topobenchmark.nn.readouts.identical import NoReadOut


class TestNoReadOut:
@pytest.fixture
def base_kwargs(self):
"""Fixture providing the required base parameters."""
return {
'hidden_dim': 64,
'out_channels': 32,
'task_level': 'graph'
}

@pytest.fixture
def readout_layer(self, base_kwargs):
"""Fixture to create a NoReadOut instance for testing."""
return NoReadOut(**base_kwargs)

@pytest.fixture
def sample_model_output(self):
"""Fixture to create a sample model output dictionary."""
return {
'x_0': torch.randn(10, 64), # Required key for model output
'edge_indices': torch.randint(0, 10, (2, 15)),
'other_data': torch.randn(10, 32)
}

@pytest.fixture
def sample_batch(self):
"""Fixture to create a sample batch of graph data."""
return tg_data.Data(
x=torch.randn(10, 32),
edge_index=torch.randint(0, 10, (2, 15)),
batch_0=torch.zeros(10, dtype=torch.long) # Required key for batch data
)

def test_initialization(self, base_kwargs):
"""Test that NoReadOut initializes correctly with required parameters."""
readout = NoReadOut(**base_kwargs)
assert isinstance(readout, NoReadOut)
assert isinstance(readout, AbstractZeroCellReadOut)

def test_forward_pass_returns_unchanged_output(self, readout_layer, sample_model_output, sample_batch):
"""Test that forward pass returns the model output without modifications."""
original_output = sample_model_output.copy()
output = readout_layer(sample_model_output, sample_batch)

# The output should contain the original data plus the computed logits
for key in original_output:
assert key in output
assert torch.equal(output[key], original_output[key])
assert 'logits' in output

def test_invalid_task_level(self, base_kwargs):
"""Test that initialization fails with invalid task_level."""
invalid_kwargs = base_kwargs.copy()
invalid_kwargs['task_level'] = 'invalid_level'
with pytest.raises(AssertionError, match="Invalid task_level"):
NoReadOut(**invalid_kwargs)

def test_repr(self, readout_layer):
"""Test the string representation of the NoReadOut layer."""
assert str(readout_layer) == "NoReadOut()"
assert repr(readout_layer) == "NoReadOut()"

def test_forward_pass_with_different_batch_sizes(self, readout_layer):
"""Test that forward pass works with different batch sizes."""
# Test with single graph
single_batch = tg_data.Data(
x=torch.randn(5, 32),
edge_index=torch.randint(0, 5, (2, 8)),
batch_0=torch.zeros(5, dtype=torch.long)
)
single_output = {
'x_0': torch.randn(5, 64),
'embeddings': torch.randn(5, 64)
}
result = readout_layer(single_output, single_batch)
assert 'logits' in result

# Test with multiple graphs
multi_batch = tg_data.Data(
x=torch.randn(15, 32),
edge_index=torch.randint(0, 15, (2, 25)),
batch_0=torch.cat([torch.zeros(5), torch.ones(5), torch.ones(5) * 2]).long()
)
multi_output = {
'x_0': torch.randn(15, 64),
'embeddings': torch.randn(15, 64)
}
result = readout_layer(multi_output, multi_batch)
assert 'logits' in result

def test_kwargs_handling(self, base_kwargs):
"""Test that the layer correctly handles both required and additional keyword arguments."""
additional_kwargs = {
'random_param': 42,
'another_param': 'test',
'pooling_type': 'mean' # Valid additional parameter
}
kwargs = {**base_kwargs, **additional_kwargs}
readout = NoReadOut(**kwargs)
assert isinstance(readout, NoReadOut)
131 changes: 131 additions & 0 deletions test/nn/readouts/test_propagate_signal_down.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import pytest
import torch
import torch_geometric.data as tg_data
import topomodelx
from topobenchmark.nn.readouts.propagate_signal_down import PropagateSignalDown


class TestPropagateSignalDown:
@pytest.fixture
def base_kwargs(self):
"""Fixture providing the required base parameters."""
return {
'hidden_dim': 64,
'out_channels': 32,
'task_level': 'graph',
'num_cell_dimensions': 2, # Need at least 2 dimensions for signal propagation
'readout_name': 'test_readout'
}

@pytest.fixture
def readout_layer(self, base_kwargs):
"""Fixture to create a PropagateSignalDown instance for testing."""
layer = PropagateSignalDown(**base_kwargs)
layer.hidden_dim = base_kwargs['hidden_dim']
return layer

@pytest.fixture
def create_sparse_incidence_matrix(self):
"""Helper fixture to create sparse incidence matrices."""
def _create_matrix(num_source, num_target, sparsity=0.3):
num_entries = int(num_source * num_target * sparsity)
indices = torch.zeros((2, num_entries), dtype=torch.long)
values = torch.ones(num_entries)

for i in range(num_entries):
source = torch.randint(0, num_source, (1,))
target = torch.randint(0, num_target, (1,))
indices[0, i] = source
indices[1, i] = target
values[i] = torch.randint(0, 2, (1,)) * 2 - 1 # {-1, 1} values

sparse_matrix = torch.sparse_coo_tensor(
indices=torch.stack([indices[1], indices[0]]),
values=values,
size=(num_target, num_source)
).coalesce()

return sparse_matrix
return _create_matrix

@pytest.fixture
def sample_batch(self, create_sparse_incidence_matrix):
"""Fixture to create a sample batch with required incidence matrices."""
num_nodes = 10
num_edges = 15

return tg_data.Data(
x=torch.randn(num_nodes, 64),
edge_index=torch.randint(0, num_nodes, (2, num_edges)),
batch_0=torch.zeros(num_nodes, dtype=torch.long),
incidence_1=create_sparse_incidence_matrix(num_edges, num_nodes)
)

@pytest.fixture
def sample_model_output(self, sample_batch):
"""Fixture to create a sample model output with cell embeddings."""
hidden_dim = 64

num_nodes = sample_batch.x.size(0)
num_edges = sample_batch.edge_index.size(1)

return {
'logits': torch.randn(num_nodes, hidden_dim),
'x_0': torch.randn(num_nodes, hidden_dim),
'x_1': torch.randn(num_edges, hidden_dim),
}

def test_forward_propagation(self, readout_layer, sample_model_output, sample_batch):
"""Test the forward pass with detailed assertions."""
initial_output = {k: v.clone() for k, v in sample_model_output.items()}
sample_model_output['x_0'] = sample_model_output['logits']

output = readout_layer(sample_model_output, sample_batch)

assert 'x_0' in output
assert output['x_0'].shape == initial_output['logits'].shape
assert output['x_0'].dtype == torch.float32

assert 'x_1' in output
assert output['x_1'].shape == initial_output['x_1'].shape
assert output['x_1'].dtype == torch.float32

@pytest.mark.parametrize('missing_key', ['incidence_1'])
def test_missing_incidence_matrix(self, readout_layer, sample_model_output, sample_batch, missing_key):
"""Test handling of missing incidence matrices."""
invalid_batch = tg_data.Data(**{k: v for k, v in sample_batch.items() if k != missing_key})
sample_model_output['x_0'] = sample_model_output['logits']

with pytest.raises(KeyError):
readout_layer(sample_model_output, invalid_batch)

@pytest.mark.parametrize('missing_key', ['x_1']) # Changed to only test x_1
def test_missing_cell_features(self, readout_layer, sample_model_output, sample_batch, missing_key):
"""Test handling of missing cell features."""
invalid_output = {k: v for k, v in sample_model_output.items() if k != missing_key}
invalid_output['x_0'] = invalid_output['logits'] # Always map logits to x_0

with pytest.raises(KeyError):
readout_layer(invalid_output, sample_batch)

def test_gradient_flow(self, readout_layer, sample_model_output, sample_batch):
"""Test gradient flow through the network."""
# Create a copy of logits tensor to track gradients properly
logits = sample_model_output['logits'].clone().detach().requires_grad_(True)
x_1 = sample_model_output['x_1'].clone().detach().requires_grad_(True)

model_output = {
'logits': logits,
'x_0': logits, # Share the same tensor
'x_1': x_1
}

output = readout_layer(model_output, sample_batch)
loss = output['x_0'].sum()
loss.backward()

# Check gradient flow
assert logits.grad is not None
assert not torch.allclose(logits.grad, torch.zeros_like(logits.grad))
assert x_1.grad is not None
assert not torch.allclose(x_1.grad, torch.zeros_like(x_1.grad))

0 comments on commit e1e8c8c

Please sign in to comment.