Skip to content

Commit

Permalink
unit tests for mace <-> e3nn reshapes
Browse files Browse the repository at this point in the history
  • Loading branch information
sblackburn-mila committed Jun 25, 2024
1 parent 1983ea0 commit c04353c
Showing 1 changed file with 54 additions and 1 deletion.
55 changes: 54 additions & 1 deletion tests/models/test_mace_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from crystal_diffusion.models.mace_utils import (
get_normalized_irreps_permutation_indices, get_pretrained_mace,
input_to_mace)
input_to_mace, reshape_from_e3nn_to_mace, reshape_from_mace_to_e3nn)
from crystal_diffusion.namespace import NOISY_CARTESIAN_POSITIONS, UNIT_CELL
from crystal_diffusion.utils.basis_transformations import \
get_positions_from_coordinates
Expand Down Expand Up @@ -246,3 +246,56 @@ def test_download_pretrained_mace_invalid_model_name(self, mock_model_savedir):
with pytest.raises(AssertionError) as e:
get_pretrained_mace("invalid_name", mock_model_savedir)
assert "Model name should be small, medium or large. Got invalid_name" in str(e.value)


class TestReshapes:
@pytest.fixture
def num_nodes(self):
return 5

@pytest.fixture
def num_channels(self):
return 3

@pytest.fixture
def ell_max(self):
return 2

@pytest.fixture
def irrep(self, num_channels, ell_max):
irrep_str = f"{num_channels}x0e"
for i in range(1, ell_max + 1):
parity = "e" if i % 2 == 0 else "o"
irrep_str += f"+ {num_channels}x{i}{parity}"
return o3.Irreps(irrep_str)

@pytest.fixture
def mace_format_tensor(self, num_nodes, num_channels, ell_max):
return torch.rand(num_nodes, num_channels, (ell_max + 1) ** 2)

@pytest.fixture
def e3nn_format_tensor(self, num_nodes, num_channels, ell_max):
return torch.rand(num_nodes, num_channels * (ell_max + 1) ** 2)

def test_reshape_from_mace_to_e3nn(self, mace_format_tensor, irrep, ell_max, num_channels):
converted_tensor = reshape_from_mace_to_e3nn(mace_format_tensor, irrep)
# mace_format_tensor: (node, channel, (lmax + 1) ** 2)
# converted: (node, channel * (lmax +1)**2)
# check for each ell that the values match
for ell in range(ell_max + 1):
start_idx = ell ** 2
end_idx = (ell + 1) ** 2
expected_values = mace_format_tensor[:, :, start_idx:end_idx].reshape(-1, num_channels * (2 * ell + 1))
assert torch.allclose(expected_values,
converted_tensor[:, num_channels * start_idx: num_channels * end_idx])

def test_reshape_from_e3nn_to_mace(self, e3nn_format_tensor, irrep, ell_max, num_channels):
converted_tensor = reshape_from_e3nn_to_mace(e3nn_format_tensor, irrep)
# e3nn_format_tensor: (node, channel * (lmax +1)**2)
# converted: (node, channel, (lmax + 1) ** 2)
for ell in range(ell_max + 1):
start_idx = num_channels * (ell ** 2)
end_idx = num_channels * ((ell + 1) ** 2)
expected_values = e3nn_format_tensor[:, start_idx:end_idx].reshape(-1, num_channels, 2 * ell + 1)
assert torch.allclose(expected_values,
converted_tensor[:, :, ell ** 2:(ell + 1) ** 2])

0 comments on commit c04353c

Please sign in to comment.