-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
238 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import pytest | ||
import torch | ||
import torch_geometric.data as tg_data | ||
from topobenchmark.nn.readouts.base import AbstractZeroCellReadOut | ||
from topobenchmark.nn.readouts.identical import NoReadOut | ||
|
||
|
||
class TestNoReadOut: | ||
@pytest.fixture | ||
def base_kwargs(self): | ||
"""Fixture providing the required base parameters.""" | ||
return { | ||
'hidden_dim': 64, | ||
'out_channels': 32, | ||
'task_level': 'graph' | ||
} | ||
|
||
@pytest.fixture | ||
def readout_layer(self, base_kwargs): | ||
"""Fixture to create a NoReadOut instance for testing.""" | ||
return NoReadOut(**base_kwargs) | ||
|
||
@pytest.fixture | ||
def sample_model_output(self): | ||
"""Fixture to create a sample model output dictionary.""" | ||
return { | ||
'x_0': torch.randn(10, 64), # Required key for model output | ||
'edge_indices': torch.randint(0, 10, (2, 15)), | ||
'other_data': torch.randn(10, 32) | ||
} | ||
|
||
@pytest.fixture | ||
def sample_batch(self): | ||
"""Fixture to create a sample batch of graph data.""" | ||
return tg_data.Data( | ||
x=torch.randn(10, 32), | ||
edge_index=torch.randint(0, 10, (2, 15)), | ||
batch_0=torch.zeros(10, dtype=torch.long) # Required key for batch data | ||
) | ||
|
||
def test_initialization(self, base_kwargs): | ||
"""Test that NoReadOut initializes correctly with required parameters.""" | ||
readout = NoReadOut(**base_kwargs) | ||
assert isinstance(readout, NoReadOut) | ||
assert isinstance(readout, AbstractZeroCellReadOut) | ||
|
||
def test_forward_pass_returns_unchanged_output(self, readout_layer, sample_model_output, sample_batch): | ||
"""Test that forward pass returns the model output without modifications.""" | ||
original_output = sample_model_output.copy() | ||
output = readout_layer(sample_model_output, sample_batch) | ||
|
||
# The output should contain the original data plus the computed logits | ||
for key in original_output: | ||
assert key in output | ||
assert torch.equal(output[key], original_output[key]) | ||
assert 'logits' in output | ||
|
||
def test_invalid_task_level(self, base_kwargs): | ||
"""Test that initialization fails with invalid task_level.""" | ||
invalid_kwargs = base_kwargs.copy() | ||
invalid_kwargs['task_level'] = 'invalid_level' | ||
with pytest.raises(AssertionError, match="Invalid task_level"): | ||
NoReadOut(**invalid_kwargs) | ||
|
||
def test_repr(self, readout_layer): | ||
"""Test the string representation of the NoReadOut layer.""" | ||
assert str(readout_layer) == "NoReadOut()" | ||
assert repr(readout_layer) == "NoReadOut()" | ||
|
||
def test_forward_pass_with_different_batch_sizes(self, readout_layer): | ||
"""Test that forward pass works with different batch sizes.""" | ||
# Test with single graph | ||
single_batch = tg_data.Data( | ||
x=torch.randn(5, 32), | ||
edge_index=torch.randint(0, 5, (2, 8)), | ||
batch_0=torch.zeros(5, dtype=torch.long) | ||
) | ||
single_output = { | ||
'x_0': torch.randn(5, 64), | ||
'embeddings': torch.randn(5, 64) | ||
} | ||
result = readout_layer(single_output, single_batch) | ||
assert 'logits' in result | ||
|
||
# Test with multiple graphs | ||
multi_batch = tg_data.Data( | ||
x=torch.randn(15, 32), | ||
edge_index=torch.randint(0, 15, (2, 25)), | ||
batch_0=torch.cat([torch.zeros(5), torch.ones(5), torch.ones(5) * 2]).long() | ||
) | ||
multi_output = { | ||
'x_0': torch.randn(15, 64), | ||
'embeddings': torch.randn(15, 64) | ||
} | ||
result = readout_layer(multi_output, multi_batch) | ||
assert 'logits' in result | ||
|
||
def test_kwargs_handling(self, base_kwargs): | ||
"""Test that the layer correctly handles both required and additional keyword arguments.""" | ||
additional_kwargs = { | ||
'random_param': 42, | ||
'another_param': 'test', | ||
'pooling_type': 'mean' # Valid additional parameter | ||
} | ||
kwargs = {**base_kwargs, **additional_kwargs} | ||
readout = NoReadOut(**kwargs) | ||
assert isinstance(readout, NoReadOut) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import pytest | ||
import torch | ||
import torch_geometric.data as tg_data | ||
import topomodelx | ||
from topobenchmark.nn.readouts.propagate_signal_down import PropagateSignalDown | ||
|
||
|
||
class TestPropagateSignalDown: | ||
@pytest.fixture | ||
def base_kwargs(self): | ||
"""Fixture providing the required base parameters.""" | ||
return { | ||
'hidden_dim': 64, | ||
'out_channels': 32, | ||
'task_level': 'graph', | ||
'num_cell_dimensions': 2, # Need at least 2 dimensions for signal propagation | ||
'readout_name': 'test_readout' | ||
} | ||
|
||
@pytest.fixture | ||
def readout_layer(self, base_kwargs): | ||
"""Fixture to create a PropagateSignalDown instance for testing.""" | ||
layer = PropagateSignalDown(**base_kwargs) | ||
layer.hidden_dim = base_kwargs['hidden_dim'] | ||
return layer | ||
|
||
@pytest.fixture | ||
def create_sparse_incidence_matrix(self): | ||
"""Helper fixture to create sparse incidence matrices.""" | ||
def _create_matrix(num_source, num_target, sparsity=0.3): | ||
num_entries = int(num_source * num_target * sparsity) | ||
indices = torch.zeros((2, num_entries), dtype=torch.long) | ||
values = torch.ones(num_entries) | ||
|
||
for i in range(num_entries): | ||
source = torch.randint(0, num_source, (1,)) | ||
target = torch.randint(0, num_target, (1,)) | ||
indices[0, i] = source | ||
indices[1, i] = target | ||
values[i] = torch.randint(0, 2, (1,)) * 2 - 1 # {-1, 1} values | ||
|
||
sparse_matrix = torch.sparse_coo_tensor( | ||
indices=torch.stack([indices[1], indices[0]]), | ||
values=values, | ||
size=(num_target, num_source) | ||
).coalesce() | ||
|
||
return sparse_matrix | ||
return _create_matrix | ||
|
||
@pytest.fixture | ||
def sample_batch(self, create_sparse_incidence_matrix): | ||
"""Fixture to create a sample batch with required incidence matrices.""" | ||
num_nodes = 10 | ||
num_edges = 15 | ||
|
||
return tg_data.Data( | ||
x=torch.randn(num_nodes, 64), | ||
edge_index=torch.randint(0, num_nodes, (2, num_edges)), | ||
batch_0=torch.zeros(num_nodes, dtype=torch.long), | ||
incidence_1=create_sparse_incidence_matrix(num_edges, num_nodes) | ||
) | ||
|
||
@pytest.fixture | ||
def sample_model_output(self, sample_batch): | ||
"""Fixture to create a sample model output with cell embeddings.""" | ||
hidden_dim = 64 | ||
|
||
num_nodes = sample_batch.x.size(0) | ||
num_edges = sample_batch.edge_index.size(1) | ||
|
||
return { | ||
'logits': torch.randn(num_nodes, hidden_dim), | ||
'x_0': torch.randn(num_nodes, hidden_dim), | ||
'x_1': torch.randn(num_edges, hidden_dim), | ||
} | ||
|
||
def test_forward_propagation(self, readout_layer, sample_model_output, sample_batch): | ||
"""Test the forward pass with detailed assertions.""" | ||
initial_output = {k: v.clone() for k, v in sample_model_output.items()} | ||
sample_model_output['x_0'] = sample_model_output['logits'] | ||
|
||
output = readout_layer(sample_model_output, sample_batch) | ||
|
||
assert 'x_0' in output | ||
assert output['x_0'].shape == initial_output['logits'].shape | ||
assert output['x_0'].dtype == torch.float32 | ||
|
||
assert 'x_1' in output | ||
assert output['x_1'].shape == initial_output['x_1'].shape | ||
assert output['x_1'].dtype == torch.float32 | ||
|
||
@pytest.mark.parametrize('missing_key', ['incidence_1']) | ||
def test_missing_incidence_matrix(self, readout_layer, sample_model_output, sample_batch, missing_key): | ||
"""Test handling of missing incidence matrices.""" | ||
invalid_batch = tg_data.Data(**{k: v for k, v in sample_batch.items() if k != missing_key}) | ||
sample_model_output['x_0'] = sample_model_output['logits'] | ||
|
||
with pytest.raises(KeyError): | ||
readout_layer(sample_model_output, invalid_batch) | ||
|
||
@pytest.mark.parametrize('missing_key', ['x_1']) # Changed to only test x_1 | ||
def test_missing_cell_features(self, readout_layer, sample_model_output, sample_batch, missing_key): | ||
"""Test handling of missing cell features.""" | ||
invalid_output = {k: v for k, v in sample_model_output.items() if k != missing_key} | ||
invalid_output['x_0'] = invalid_output['logits'] # Always map logits to x_0 | ||
|
||
with pytest.raises(KeyError): | ||
readout_layer(invalid_output, sample_batch) | ||
|
||
def test_gradient_flow(self, readout_layer, sample_model_output, sample_batch): | ||
"""Test gradient flow through the network.""" | ||
# Create a copy of logits tensor to track gradients properly | ||
logits = sample_model_output['logits'].clone().detach().requires_grad_(True) | ||
x_1 = sample_model_output['x_1'].clone().detach().requires_grad_(True) | ||
|
||
model_output = { | ||
'logits': logits, | ||
'x_0': logits, # Share the same tensor | ||
'x_1': x_1 | ||
} | ||
|
||
output = readout_layer(model_output, sample_batch) | ||
loss = output['x_0'].sum() | ||
loss.backward() | ||
|
||
# Check gradient flow | ||
assert logits.grad is not None | ||
assert not torch.allclose(logits.grad, torch.zeros_like(logits.grad)) | ||
assert x_1.grad is not None | ||
assert not torch.allclose(x_1.grad, torch.zeros_like(x_1.grad)) |