Skip to content

Commit

Permalink
Add gcnn tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vapavlo committed Dec 21, 2024
1 parent e1e8c8c commit 38fbf32
Show file tree
Hide file tree
Showing 2 changed files with 493 additions and 1 deletion.
148 changes: 147 additions & 1 deletion test/nn/backbones/combinatorial/test_gccn.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,150 @@ def test_get_activation():
assert issubclass(relu_module, torch.nn.Module)

with pytest.raises(NotImplementedError):
get_activation("invalid_activation")
get_activation("invalid_activation")


@pytest.mark.parametrize("activation", ["relu", "elu", "tanh", "id"])
def test_topotune_different_activations(activation):
"""
Test TopoTune with multiple activations to improve coverage of get_activation.
Parameters
----------
activation : str
Activation function.
"""
batch = create_mock_complex_batch()
gnn = MockGNN(16, 32, 16)

neighborhoods = OmegaConf.create(["up_adjacency-0", "down_incidence-1"])
model = TopoTune(
GNN=gnn,
neighborhoods=neighborhoods,
layers=1, # single layer to keep test simpler
use_edge_attr=False,
activation=activation,
)

output = model(batch)
# We expect a dict of updated features for each rank in the batch
assert isinstance(output, dict)
for rank, feat in output.items():
assert isinstance(feat, torch.Tensor)
# The shape should match the original x_rank shape
original_feat = getattr(batch, f"x_{rank}")
assert feat.shape == original_feat.shape


def test_topotune_use_edge_attr_true():
"""
Test TopoTune with use_edge_attr=True to ensure that edge attributes flow through properly.
"""
batch = create_mock_complex_batch()
gnn = MockGNN(16, 32, 16)

# Add more complex neighborhoods to ensure both interrank and intrarank expansions
neighborhoods = OmegaConf.create([
"up_adjacency-0", # intrarank route rank=0->0
"up_adjacency-1", # intrarank route rank=1->1
"down_incidence-1", # interrank route rank=1->0
"down_incidence-2", # interrank route rank=2->1
])
model = TopoTune(
GNN=gnn,
neighborhoods=neighborhoods,
layers=2,
use_edge_attr=True,
activation="relu",
)

output = model(batch)
assert isinstance(output, dict)
# Check that each rank in [0,1,2] got updated
for rank in range(3):
assert rank in output
assert isinstance(output[rank], torch.Tensor)
# The shape should match the original x_rank shape
original_feat = getattr(batch, f"x_{rank}")
assert output[rank].shape == original_feat.shape


def test_topotune_single_node_per_rank():
"""
Test corner case: each rank has only 1 cell, ensuring the path that returns early in intrarank_gnn_forward (x.shape[0] < 2).
"""
# Create a batch with just 1 node, 1 edge, 1 face
batch = create_mock_complex_batch()
gnn = MockGNN(16, 32, 16)

neighborhoods = OmegaConf.create(["up_adjacency-0", "down_incidence-1"])
model = TopoTune(
GNN=gnn,
neighborhoods=neighborhoods,
layers=1,
use_edge_attr=False,
activation="relu",
)
output = model(batch)
# Since we have exactly 1 cell in each rank, intrarank_gnn_forward
# should skip the GNN pass and return the original features
assert isinstance(output, dict)
for rank, feat in output.items():
# Should remain the same as the input
assert torch.allclose(feat, getattr(batch, f"x_{rank}"), atol=1e-6)


def test_topotune_multiple_layers():
"""
Test TopoTune with multiple layers > 2 to ensure repeated forward passes.
"""
batch = create_mock_complex_batch()
gnn = MockGNN(16, 32, 16)

neighborhoods = OmegaConf.create(["up_adjacency-0", "down_incidence-1"])
model = TopoTune(
GNN=gnn,
neighborhoods=neighborhoods,
layers=3, # more than 2
use_edge_attr=False,
activation="relu",
)

output = model(batch)
assert isinstance(output, dict)
# By default, the final shape should still be (N, 16) per rank
for rank, feat in output.items():
original_feat = getattr(batch, f"x_{rank}")
assert feat.shape == original_feat.shape


def test_topotune_src_rank_larger_than_dst_rank():
"""
Test a scenario where src_rank > dst_rank for an interrank route.
"""
batch = create_mock_complex_batch()
gnn = MockGNN(16, 32, 16)
# Force a route from rank=2 -> rank=0, for instance
neighborhoods = OmegaConf.create(["down_incidence-1", "down_incidence-2"])
# topotune will interpret these strings as routes:
# (1->0) from down_incidence-1
# (2->1) from down_incidence-2
# Let's force an additional route from 2->0 by customizing the route logic if you want
# but as is, 2->0 won't happen automatically unless your `get_routes_from_neighborhoods`
# is coded that way. We'll just rely on existing logic for (2->1).

model = TopoTune(
GNN=gnn,
neighborhoods=neighborhoods,
layers=1,
use_edge_attr=False,
activation="relu",
)

output = model(batch)
assert isinstance(output, dict)
# Ranks 0, 1, 2 should exist in the final output dictionary
for rank in [0, 1, 2]:
assert rank in output
assert output[rank].shape == getattr(batch, f"x_{rank}").shape

Loading

0 comments on commit 38fbf32

Please sign in to comment.