Skip to content

Commit

Permalink
Ensure that the EGNN score network works in 1D, 2D and 3D.
Browse files Browse the repository at this point in the history
  • Loading branch information
rousseab committed Dec 26, 2024
1 parent 4071d3a commit 98afd87
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def get_edges_with_radial_cutoff(
unit_cell: torch.Tensor,
radial_cutoff: float = 4.0,
drop_duplicate_edges: bool = True,
spatial_dimension: int = 3
) -> torch.Tensor:
"""Get edges for a batch with a cutoff based on distance.
Expand All @@ -127,7 +128,7 @@ def get_edges_with_radial_cutoff(
relative_coordinates, unit_cell
)
adj_matrix, _, _, _ = get_adj_matrix(
cartesian_coordinates, unit_cell, radial_cutoff
cartesian_coordinates, unit_cell, radial_cutoff, spatial_dimension
)
# adj_matrix is a n_edges x 2 tensor with duplicates with different shifts.
# the uplifting in 2 x spatial_dimension manages the shifts in a natural way. This means we can ignore the shifts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


def get_adj_matrix(
positions: torch.Tensor, basis_vectors: torch.Tensor, radial_cutoff: float = 4.0
positions: torch.Tensor, basis_vectors: torch.Tensor, radial_cutoff: float = 4.0, spatial_dimension: int = 3
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Create the adjacency and shift matrices.
Expand All @@ -32,7 +32,7 @@ def get_adj_matrix(
batch_size, number_of_atoms, spatial_dimensions = positions.shape

adjacency_info = get_periodic_adjacency_information(
positions, basis_vectors, radial_cutoff
positions, basis_vectors, radial_cutoff, spatial_dimension
)

# The indices in the adjacency matrix must be shifted to account for the batching
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def _forward_unchecked(
batch[UNIT_CELL],
self.radial_cutoff,
drop_duplicate_edges=self.drop_duplicate_edges,
spatial_dimension=self.spatial_dimension
)

edges = edges.to(relative_coordinates.device)
Expand Down
77 changes: 63 additions & 14 deletions src/diffusion_for_multi_scale_molecular_dynamics/utils/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import itertools
from collections import namedtuple

import einops
import numpy as np
import torch
from pykeops.torch import LazyTensor
Expand All @@ -32,7 +33,7 @@


def get_periodic_adjacency_information(
cartesian_positions: torch.Tensor, basis_vectors: torch.Tensor, radial_cutoff: float
cartesian_positions: torch.Tensor, basis_vectors: torch.Tensor, radial_cutoff: float, spatial_dimension: int = 3
) -> AdjacencyInfo:
"""Get periodic adjacency information.
Expand Down Expand Up @@ -61,14 +62,15 @@ def get_periodic_adjacency_information(
Args:
cartesian_positions : atomic positions, assumed to be within the unit cell, in Euclidean space, in Angstrom.
Dimension [batch_size, max_number_of_atoms, 3]
Dimension [batch_size, max_number_of_atoms, spatial_dimension]
basis_vectors : vectors that define the unit cell, (a1, a2, a3). The basis vectors are assumed
to be vertically stacked, namely
[-- a1 --]
[-- a2 --]
[-- a3 --]
Dimension [batch_size, 3, 3].
radial_cutoff : largest distance between neighbors, in Angstrom.
spatial_dimension: the dimension of space.
Returns:
adjacency_info: an AdjacencyInfo object that contains
Expand All @@ -83,7 +85,6 @@ def get_periodic_adjacency_information(
), "Wrong number of dimensions for relative_coordinates"
assert len(basis_vectors.shape) == 3, "Wrong number of dimensions for basis_vectors"

spatial_dimension = 3 # We define this to avoid "magic numbers" in the code below.
batch_size, max_natom, spatial_dimension_ = cartesian_positions.shape
assert (
spatial_dimension_ == spatial_dimension
Expand All @@ -103,7 +104,7 @@ def get_periodic_adjacency_information(

# Check that the radial cutoff does not lead to possible neighbors beyond the first shell.
shortest_cell_crossing_distances = _get_shortest_distance_that_crosses_unit_cell(
basis_vectors
basis_vectors, spatial_dimension=spatial_dimension
)
assert torch.all(shortest_cell_crossing_distances > radial_cutoff), (
"The radial cutoff is so large that neighbors could be located "
Expand All @@ -112,7 +113,7 @@ def get_periodic_adjacency_information(

# The relative coordinates lattice vectors have dimensions [number of lattice vectors, spatial_dimension]
relative_lattice_vectors = _get_relative_coordinates_lattice_vectors(
number_of_shells=1
number_of_shells=1, spatial_dimension=spatial_dimension
).to(device)
number_of_relative_lattice_vectors = len(relative_lattice_vectors)

Expand Down Expand Up @@ -266,7 +267,7 @@ def _get_shifted_positions(


def _get_shortest_distance_that_crosses_unit_cell(
basis_vectors: torch.Tensor,
basis_vectors: torch.Tensor, spatial_dimension: int = 3
) -> torch.Tensor:
"""Get the shortest distance that crosses unit cell.
Expand All @@ -280,28 +281,76 @@ def _get_shortest_distance_that_crosses_unit_cell(
/ v /
---------------------------
Args:
basis_vectors : basis vectors that define the unit cell.
Dimension [batch_size, spatial_dimension = 3]
Dimension [batch_size, spatial_dimension]
Returns:
shortest_distances: shortest distance that can cross the unit cell, from one side to its other parallel side.
Dimension [batch_size].
"""
# It is straightforward to show that the distance between two parallel planes,
# (say the plane spanned by (a1, a2) crossing the origin and the plane spanned by (a1, a2) crossing the point a3)
# is given by unit_normal DOT a3. The unit normal to the plane is proportional to the cross product of a1 and a2.
#
# This idea must be repeated for the three pairs of planes bounding the unit cell.
spatial_dimension = 3
assert spatial_dimension in {1, 2, 3}, "The spatial dimension must be 1, 2 or 3."
assert len(basis_vectors.shape) == 3, "basis_vectors has wrong shape."
assert (
basis_vectors.shape[1] == spatial_dimension
), "Basis vectors in wrong spatial dimension."
assert (
basis_vectors.shape[2] == spatial_dimension
), "Basis vectors in wrong spatial dimension."

match spatial_dimension:
case 1:
return _get_shortest_distance_that_crosses_unit_cell_1d(basis_vectors)
case 2:
return _get_shortest_distance_that_crosses_unit_cell_2d(basis_vectors)
case 3:
return _get_shortest_distance_that_crosses_unit_cell_3d(basis_vectors)
case _:
raise RuntimeError("Spatial dimension must be 1, 2 or 3.")


def _get_shortest_distance_that_crosses_unit_cell_1d(
basis_vectors: torch.Tensor,
) -> torch.Tensor:
"""Get the shortest distance that crosses unit cell in 1D."""
distances = basis_vectors.norm(dim=[-1, -2])
return distances


def _get_shortest_distance_that_crosses_unit_cell_2d(
basis_vectors: torch.Tensor,
) -> torch.Tensor:
"""Get the shortest distance that crosses unit cell in 2D."""
a1 = basis_vectors[:, 0, :]
a2 = basis_vectors[:, 1, :]

dot_product = einops.einsum(a1, a2, "b i, b i -> b")

norm_a1 = torch.norm(a1, dim=-1)
norm_a2 = torch.norm(a2, dim=-1)

orthogonal_a2 = a2 - (dot_product / norm_a1**2).unsqueeze(1) * a1
distances_1 = orthogonal_a2.norm(dim=-1)

orthogonal_a1 = a1 - (dot_product / norm_a2**2).unsqueeze(1) * a2
distances_2 = orthogonal_a1.norm(dim=-1)

distances = (
torch.stack([distances_1, distances_2], dim=1).min(dim=1).values
)

return distances


def _get_shortest_distance_that_crosses_unit_cell_3d(
basis_vectors: torch.Tensor,
) -> torch.Tensor:
"""Get the shortest distance that crosses unit cell in 3D."""
# It is straightforward to show that the distance between two parallel planes,
# (say the plane spanned by (a1, a2) crossing the origin and the plane spanned by (a1, a2) crossing the point a3)
# is given by unit_normal DOT a3. The unit normal to the plane is proportional to the cross product of a1 and a2.
#
# This idea must be repeated for the three pairs of planes bounding the unit cell.
a1 = basis_vectors[:, 0, :]
a2 = basis_vectors[:, 1, :]
a3 = basis_vectors[:, 2, :]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,15 @@ def score_network(self, score_network_parameters):

class TestEGNNScoreNetwork(BaseScoreNetworkGeneralTests):

@pytest.fixture(params=[1, 2, 3])
def spatial_dimension(self, request):
return request.param

@pytest.fixture(params=[("fully_connected", None), ("radial_cutoff", 3.0)])
def score_network_parameters(self, request, num_atom_types):
def score_network_parameters(self, request, spatial_dimension, num_atom_types):
edges, radial_cutoff = request.param
return EGNNScoreNetworkParameters(
spatial_dimension=spatial_dimension,
edges=edges, radial_cutoff=radial_cutoff, num_atom_types=num_atom_types
)

Expand Down
82 changes: 69 additions & 13 deletions tests/utils/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,25 @@ def number_of_atoms():
return 32


@pytest.fixture()
def spatial_dimension(request):
return 3


@pytest.fixture
def basis_vectors(batch_size, spatial_dimension):
# orthogonal boxes with dimensions between 5 and 10.
orthogonal_boxes = torch.stack(
[torch.diag(5.0 + 5.0 * torch.rand(spatial_dimension)) for _ in range(batch_size)]
)
# add a bit of noise to make the vectors not quite orthogonal
basis_vectors = orthogonal_boxes + 0.1 * torch.randn(batch_size, spatial_dimension, spatial_dimension)
return basis_vectors


@pytest.fixture
def relative_coordinates(batch_size, number_of_atoms):
return torch.rand(batch_size, number_of_atoms, 3)
def relative_coordinates(batch_size, number_of_atoms, spatial_dimension):
return torch.rand(batch_size, number_of_atoms, spatial_dimension)


@pytest.fixture
Expand All @@ -47,9 +63,9 @@ def positions(relative_coordinates, basis_vectors):


@pytest.fixture
def lattice_vectors(batch_size, basis_vectors, number_of_shells):
def lattice_vectors(batch_size, basis_vectors, number_of_shells, spatial_dimension):
relative_lattice_vectors = _get_relative_coordinates_lattice_vectors(
number_of_shells
number_of_shells, spatial_dimension
)
batched_relative_lattice_vectors = relative_lattice_vectors.repeat(batch_size, 1, 1)
lattice_vectors = get_positions_from_coordinates(
Expand Down Expand Up @@ -218,31 +234,71 @@ def test_get_periodic_adjacency_information(
)


@pytest.mark.parametrize("spatial_dimension", [1, 2, 3])
def test_get_periodic_neighbour_indices_and_displacements_large_cutoff(
basis_vectors, relative_coordinates
basis_vectors, relative_coordinates, spatial_dimension
):
# Check that the code crashes if the radial cutoff is too big!
shortest_cell_crossing_distances = _get_shortest_distance_that_crosses_unit_cell(
basis_vectors
basis_vectors, spatial_dimension=spatial_dimension
).min()

large_radial_cutoff = shortest_cell_crossing_distances + 0.1
small_radial_cutoff = shortest_cell_crossing_distances - 0.1
large_radial_cutoff = (shortest_cell_crossing_distances + 0.1).item()
small_radial_cutoff = (shortest_cell_crossing_distances - 0.1).item()

# Should run
get_periodic_adjacency_information(
relative_coordinates, basis_vectors, small_radial_cutoff
relative_coordinates, basis_vectors, small_radial_cutoff, spatial_dimension=spatial_dimension
)

with pytest.raises(AssertionError):
# Should crash
get_periodic_adjacency_information(
relative_coordinates, basis_vectors, large_radial_cutoff
relative_coordinates, basis_vectors, large_radial_cutoff, spatial_dimension=spatial_dimension
)


@pytest.mark.parametrize("number_of_shells", [1, 2, 3])
def test_get_relative_coordinates_lattice_vectors(number_of_shells):
def test_get_relative_coordinates_lattice_vectors_1d(number_of_shells):

expected_lattice_vectors = []

for nx in torch.arange(-number_of_shells, number_of_shells + 1):
lattice_vector = torch.tensor([nx])
expected_lattice_vectors.append(lattice_vector)

expected_lattice_vectors = torch.stack(expected_lattice_vectors).to(
dtype=torch.float32
)
computed_lattice_vectors = _get_relative_coordinates_lattice_vectors(
number_of_shells, spatial_dimension=1
)

torch.testing.assert_close(expected_lattice_vectors, computed_lattice_vectors)


@pytest.mark.parametrize("number_of_shells", [1, 2, 3])
def test_get_relative_coordinates_lattice_vectors_2d(number_of_shells):

expected_lattice_vectors = []

for nx in torch.arange(-number_of_shells, number_of_shells + 1):
for ny in torch.arange(-number_of_shells, number_of_shells + 1):
lattice_vector = torch.tensor([nx, ny])
expected_lattice_vectors.append(lattice_vector)

expected_lattice_vectors = torch.stack(expected_lattice_vectors).to(
dtype=torch.float32
)
computed_lattice_vectors = _get_relative_coordinates_lattice_vectors(
number_of_shells, spatial_dimension=2
)

torch.testing.assert_close(expected_lattice_vectors, computed_lattice_vectors)


@pytest.mark.parametrize("number_of_shells", [1, 2, 3])
def test_get_relative_coordinates_lattice_vectors_3d(number_of_shells):

expected_lattice_vectors = []

Expand All @@ -256,7 +312,7 @@ def test_get_relative_coordinates_lattice_vectors(number_of_shells):
dtype=torch.float32
)
computed_lattice_vectors = _get_relative_coordinates_lattice_vectors(
number_of_shells
number_of_shells, spatial_dimension=3
)

torch.testing.assert_close(expected_lattice_vectors, computed_lattice_vectors)
Expand Down Expand Up @@ -289,7 +345,7 @@ def test_get_shifted_positions(positions, lattice_vectors):
)


def test_get_shortest_distance_that_crosses_unit_cell(basis_vectors):
def test_get_shortest_distance_that_crosses_unit_cell_3d(basis_vectors):
expected_shortest_distances = []
for matrix in basis_vectors.numpy():
a1, a2, a3 = matrix
Expand Down

0 comments on commit 98afd87

Please sign in to comment.