diff --git a/tests/models/test_mace_utils.py b/tests/models/test_mace_utils.py index ce628f9b..ecf5613a 100644 --- a/tests/models/test_mace_utils.py +++ b/tests/models/test_mace_utils.py @@ -11,7 +11,7 @@ from crystal_diffusion.models.mace_utils import ( get_normalized_irreps_permutation_indices, get_pretrained_mace, - input_to_mace) + input_to_mace, reshape_from_e3nn_to_mace, reshape_from_mace_to_e3nn) from crystal_diffusion.namespace import NOISY_CARTESIAN_POSITIONS, UNIT_CELL from crystal_diffusion.utils.basis_transformations import \ get_positions_from_coordinates @@ -246,3 +246,56 @@ def test_download_pretrained_mace_invalid_model_name(self, mock_model_savedir): with pytest.raises(AssertionError) as e: get_pretrained_mace("invalid_name", mock_model_savedir) assert "Model name should be small, medium or large. Got invalid_name" in str(e.value) + + +class TestReshapes: + @pytest.fixture + def num_nodes(self): + return 5 + + @pytest.fixture + def num_channels(self): + return 3 + + @pytest.fixture + def ell_max(self): + return 2 + + @pytest.fixture + def irrep(self, num_channels, ell_max): + irrep_str = f"{num_channels}x0e" + for i in range(1, ell_max + 1): + parity = "e" if i % 2 == 0 else "o" + irrep_str += f"+ {num_channels}x{i}{parity}" + return o3.Irreps(irrep_str) + + @pytest.fixture + def mace_format_tensor(self, num_nodes, num_channels, ell_max): + return torch.rand(num_nodes, num_channels, (ell_max + 1) ** 2) + + @pytest.fixture + def e3nn_format_tensor(self, num_nodes, num_channels, ell_max): + return torch.rand(num_nodes, num_channels * (ell_max + 1) ** 2) + + def test_reshape_from_mace_to_e3nn(self, mace_format_tensor, irrep, ell_max, num_channels): + converted_tensor = reshape_from_mace_to_e3nn(mace_format_tensor, irrep) + # mace_format_tensor: (node, channel, (lmax + 1) ** 2) + # converted: (node, channel * (lmax +1)**2) + # check for each ell that the values match + for ell in range(ell_max + 1): + start_idx = ell ** 2 + end_idx = (ell + 1) ** 2 + expected_values = mace_format_tensor[:, :, start_idx:end_idx].reshape(-1, num_channels * (2 * ell + 1)) + assert torch.allclose(expected_values, + converted_tensor[:, num_channels * start_idx: num_channels * end_idx]) + + def test_reshape_from_e3nn_to_mace(self, e3nn_format_tensor, irrep, ell_max, num_channels): + converted_tensor = reshape_from_e3nn_to_mace(e3nn_format_tensor, irrep) + # e3nn_format_tensor: (node, channel * (lmax +1)**2) + # converted: (node, channel, (lmax + 1) ** 2) + for ell in range(ell_max + 1): + start_idx = num_channels * (ell ** 2) + end_idx = num_channels * ((ell + 1) ** 2) + expected_values = e3nn_format_tensor[:, start_idx:end_idx].reshape(-1, num_channels, 2 * ell + 1) + assert torch.allclose(expected_values, + converted_tensor[:, :, ell ** 2:(ell + 1) ** 2])