diff --git a/docs/extensions/rascaline_json_schema.py b/docs/extensions/rascaline_json_schema.py index 3ec969533..bb5f87159 100644 --- a/docs/extensions/rascaline_json_schema.py +++ b/docs/extensions/rascaline_json_schema.py @@ -3,9 +3,8 @@ from docutils import nodes from docutils.parsers.rst import Directive -from markdown_it import MarkdownIt - from html_hidden import html_hidden +from markdown_it import MarkdownIt from myst_parser.config.main import MdParserConfig from myst_parser.mdit_to_docutils.base import DocutilsRenderer @@ -121,7 +120,7 @@ def _json_schema_to_nodes(self, schema, inline=False): if isinstance(subfields, nodes.literal): subfields = [subfields] - for (i, sf) in enumerate(subfields): + for i, sf in enumerate(subfields): field += sf if isinstance(sf, nodes.inline): @@ -152,8 +151,11 @@ def _json_schema_to_nodes(self, schema, inline=False): raise Exception(f"unknown integer format: {schema['format']}") elif schema["type"] == "string": - # TODO enums? - return nodes.literal(text="string") + if "enum" in schema: + values = [f'"{v}"' for v in schema["enum"]] + return nodes.literal(text=" | ".join(values)) + else: + return nodes.literal(text="string") elif schema["type"] == "boolean": return nodes.literal(text="boolean") diff --git a/python/rascaline/rascaline/_c_api.py b/python/rascaline/rascaline/_c_api.py index 3149df312..6d7b43666 100644 --- a/python/rascaline/rascaline/_c_api.py +++ b/python/rascaline/rascaline/_c_api.py @@ -46,6 +46,7 @@ class rascal_pair_t(ctypes.Structure): ("second", c_uintptr_t), ("distance", ctypes.c_double), ("vector", ctypes.c_double * 3), + ("cell_shift_indices", ctypes.c_int32 * 3), ] diff --git a/python/rascaline/rascaline/calculators.py b/python/rascaline/rascaline/calculators.py index 88ddb589d..a3122df18 100644 --- a/python/rascaline/rascaline/calculators.py +++ b/python/rascaline/rascaline/calculators.py @@ -37,29 +37,48 @@ def __init__(self, cutoff, delta, name): class NeighborList(CalculatorBase): """ - This calculator computes the neighbor list for a given spherical cutoff, and - returns the list of distance vectors between all pairs of atoms strictly - inside the cutoff. - - Users can request either a "full" neighbor list (including an entry for both - ``i - j`` pairs and ``j - i`` pairs) or save memory/computational by only - working with "half" neighbor list (only including one entry for each ``i/j`` - pair) - - Self pairs (pairs between an atom and periodic copy itself) can appear when - the cutoff is larger than the cell under periodic boundary conditions. Self - pairs with a distance of 0 are only included when the user passes - ``self_pairs=True``, which is not the default behavior. - - This sample produces a single property (``"distance"``) with three - components (``"pair_direction"``) containing the x, y, and z component of - the vector from the first atom in the pair to the second. In addition to the - atom indexes, the samples also contain a pair index, to be able to - distinguish between multiple pairs between the same atom (if the cutoff is - larger than the cell). + This calculator computes the neighbor list for a given spherical cutoff, and returns + the list of distance vectors between all pairs of atoms strictly inside the cutoff. + + Users can request either a "full" neighbor list (including an entry for both ``i - + j`` pairs and ``j - i`` pairs) or save memory/computational by only working with + "half" neighbor list (only including one entry for each ``i/j`` pair) + + Pairs between an atom and it's own periodic copy can appear when the cutoff is + larger than the cell under periodic boundary conditions. Self pairs with a distance + of 0 (i.e. self pairs inside the original unit cell) are only included when using + ``self_pairs=True``. + + The ``quantity`` parameter determine what will be included in the output. It can + take one of three values: + + - ``"Distance"``, to get the distance between the atoms, accounting for periodic + boundary conditions. This is the default. + - ``"CellShiftVector"``, to get the cell shift vector, which can then be used to + apply periodic boundary conditions and compute the distance vector. + + If ``S`` is the cell shift vector, ``rij`` the pair distance vector, ``ri`` and + ``rj`` the position of the atoms, ``rij = rj - ri + S``. + - ``"CellShiftIndices"``, to get three integers indicating the number of cell + vectors (``A``, ``B``, and ``C``) entering the cell shift. + + If the cell vectors are ``A``, ``B``, and ``C``, this returns three integers + ``sa``, ``sb``, and ``sc`` such that the cell shift ``S = sa * A + sb * B + sc * + C``. + + This calculator produces a single property (``"distance"``, ``"cell_shift_vector"``, + or ``"cell_shift_indices"``) with three components (``"pair_direction"``) containing + the x, y, and z component of the requested vector. In addition to the atom indexes, + the samples also contain a pair index, to be able to distinguish between multiple + pairs between the same atom (if the cutoff is larger than the cell). """ - def __init__(self, cutoff, full_neighbor_list, self_pairs=False): + def __init__( + self, + cutoff: float, + full_neighbor_list: bool, + self_pairs: bool = False, + ): parameters = { "cutoff": cutoff, "full_neighbor_list": full_neighbor_list, diff --git a/python/rascaline/rascaline/systems/ase.py b/python/rascaline/rascaline/systems/ase.py index c51e4aa2d..ef8ca1c70 100644 --- a/python/rascaline/rascaline/systems/ase.py +++ b/python/rascaline/rascaline/systems/ase.py @@ -89,21 +89,21 @@ def compute_neighbors(self, cutoff): self._pairs = [] - nl_result = neighborlist.neighbor_list("ijdD", self._atoms, cutoff) - for i, j, d, D in zip(*nl_result): + nl_result = neighborlist.neighbor_list("ijdDS", self._atoms, cutoff) + for i, j, d, D, S in zip(*nl_result): if j < i: # we want a half neighbor list, so drop all duplicated # neighbors continue - self._pairs.append((i, j, d, D)) + self._pairs.append((i, j, d, D, S)) self._pairs_by_center = [] for _ in range(self.size()): self._pairs_by_center.append([]) - for i, j, d, D in self._pairs: - self._pairs_by_center[i].append((i, j, d, D)) - self._pairs_by_center[j].append((i, j, d, D)) + for i, j, d, D, S in self._pairs: + self._pairs_by_center[i].append((i, j, d, D, S)) + self._pairs_by_center[j].append((i, j, d, D, S)) def pairs(self): return self._pairs diff --git a/python/rascaline/rascaline/systems/base.py b/python/rascaline/rascaline/systems/base.py index 16162a27a..71bc3ce98 100644 --- a/python/rascaline/rascaline/systems/base.py +++ b/python/rascaline/rascaline/systems/base.py @@ -314,17 +314,18 @@ def pairs(self): The pairs are those which were computed by the last call :py:func:`SystemBase.compute_neighbors` - Get all neighbor pairs in this system as a list of tuples ``(int, int, - float, (float, float, float))`` containing the indexes of the first and - second atom in the pair, the distance between the atoms, and the wrapped - between them. Alternatively, this function can return a 1D numpy array - with ``dtype=rascal_pair_t``. - - The list of pair should only contain each pair once (and not twice as - ``i-j`` and ``j-i``), should not contain self pairs (``i-i``); and - should only contains pairs where the distance between atoms is actually - bellow the cutoff passed in the last call to - :py:func:`rascaline.SystemBase.compute_neighbors`. + Get all neighbor pairs in this system as a list of tuples ``(int, int, float, + (float, float, float), (int, int, int))`` containing the indexes of the first + and second atom in the pair, the distance between the atoms, the vector between + them, and the cell shift. The vector should be ``position[first] - + position[second] * + H * cell_shift`` where ``H`` is the cell matrix. + Alternatively, this function can return a 1D numpy array with + ``dtype=rascal_pair_t``. + + The list of pair should only contain each pair once (and not twice as ``i-j`` + and ``j-i``), should not contain self pairs (``i-i``); and should only contains + pairs where the distance between atoms is actually bellow the cutoff passed in + the last call to :py:func:`rascaline.SystemBase.compute_neighbors`. This function is only valid to call after a call to :py:func:`rascaline.SystemBase.compute_neighbors` to set the cutoff. diff --git a/python/rascaline/tests/calculators/dummy_calculator.py b/python/rascaline/tests/calculators/dummy_calculator.py index 827cfb352..826786f8b 100644 --- a/python/rascaline/tests/calculators/dummy_calculator.py +++ b/python/rascaline/tests/calculators/dummy_calculator.py @@ -51,8 +51,8 @@ def test_compute(): H_block = descriptor.block(species_center=1) assert H_block.values.shape == (2, 2) - assert np.all(H_block.values[0] == (2, 1)) - assert np.all(H_block.values[1] == (3, 3)) + assert np.all(H_block.values[0] == (2, 11)) + assert np.all(H_block.values[1] == (3, 13)) assert len(H_block.samples) == 2 assert H_block.samples.names == ["structure", "center"] diff --git a/python/rascaline/tests/calculators/sample_selection.py b/python/rascaline/tests/calculators/sample_selection.py index 3c7cf9b78..97abc5ccc 100644 --- a/python/rascaline/tests/calculators/sample_selection.py +++ b/python/rascaline/tests/calculators/sample_selection.py @@ -49,8 +49,8 @@ def test_selection(): H_block = descriptor.block(species_center=1) assert H_block.values.shape == (2, 2) - assert np.all(H_block.values[0] == (2, 1)) - assert np.all(H_block.values[1] == (3, 3)) + assert np.all(H_block.values[0] == (2, 11)) + assert np.all(H_block.values[1] == (3, 13)) O_block = descriptor.block(species_center=8) assert O_block.values.shape == (1, 2) @@ -72,8 +72,8 @@ def test_subset_variables(): H_block = descriptor.block(species_center=1) assert H_block.values.shape == (2, 2) - assert np.all(H_block.values[0] == (2, 1)) - assert np.all(H_block.values[1] == (3, 3)) + assert np.all(H_block.values[0] == (2, 11)) + assert np.all(H_block.values[1] == (3, 13)) O_block = descriptor.block(species_center=8) assert O_block.values.shape == (1, 2) @@ -128,7 +128,7 @@ def test_predefined_selection(): H_block = descriptor.block(species_center=1) assert H_block.values.shape == (1, 2) - assert np.all(H_block.values[0] == (3, 3)) + assert np.all(H_block.values[0] == (3, 13)) O_block = descriptor.block(species_center=8) assert O_block.values.shape == (1, 2) diff --git a/python/rascaline/tests/test_systems.py b/python/rascaline/tests/test_systems.py index b57c784aa..699a70808 100644 --- a/python/rascaline/tests/test_systems.py +++ b/python/rascaline/tests/test_systems.py @@ -9,7 +9,7 @@ def species(self): return [1, 1, 8, 8] def positions(self): - return [[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3]] + return [[0, 0, 10], [0, 0, 1], [0, 0, 2], [0, 0, 3]] def cell(self): return [[10, 0, 0], [0, 10, 0], [0, 0, 10]] @@ -19,29 +19,29 @@ def compute_neighbors(self, cutoff): def pairs(self): return [ - (0, 1, 1.0, (0.0, 0.0, 1.0)), - (1, 2, 1.0, (0.0, 0.0, 1.0)), - (2, 3, 1.0, (0.0, 0.0, 1.0)), + (0, 1, 1.0, (0.0, 0.0, 1.0), (0, 0, 1)), + (1, 2, 1.0, (0.0, 0.0, 1.0), (0, 0, 0)), + (2, 3, 1.0, (0.0, 0.0, 1.0), (0, 0, 0)), ] def pairs_containing(self, center): if center == 0: return [ - (0, 1, 1.0, (0.0, 0.0, 1.0)), + (0, 1, 1.0, (0.0, 0.0, 1.0), (0, 0, 1)), ] elif center == 1: return [ - (0, 1, 1.0, (0.0, 0.0, 1.0)), - (1, 2, 1.0, (0.0, 0.0, 1.0)), + (0, 1, 1.0, (0.0, 0.0, 1.0), (0, 0, 1)), + (1, 2, 1.0, (0.0, 0.0, 1.0), (0, 0, 0)), ] elif center == 2: return [ - (1, 2, 1.0, (0.0, 0.0, 1.0)), - (2, 3, 1.0, (0.0, 0.0, 1.0)), + (1, 2, 1.0, (0.0, 0.0, 1.0), (0, 0, 0)), + (2, 3, 1.0, (0.0, 0.0, 1.0), (0, 0, 0)), ] elif center == 3: return [ - (2, 3, 1.0, (0.0, 0.0, 1.0)), + (2, 3, 1.0, (0.0, 0.0, 1.0), (0, 0, 0)), ] else: raise Exception("got invalid center") diff --git a/rascaline-c-api/include/rascaline.h b/rascaline-c-api/include/rascaline.h index 782aed64d..efcbfbf5a 100644 --- a/rascaline-c-api/include/rascaline.h +++ b/rascaline-c-api/include/rascaline.h @@ -131,10 +131,17 @@ typedef struct rascal_pair_t { */ double distance; /** - * vector from the first atom to the second atom, wrapped inside the unit - * cell as required by periodic boundary conditions. + * vector from the first atom to the second atom, accounting for periodic + * boundary conditions. This should be + * `position[second] - position[first] + H * cell_shift` + * where `H` is the cell matrix. */ double vector[3]; + /** + * How many cell shift where applied to the `second` atom to create this + * pair. + */ + int32_t cell_shift_indices[3]; } rascal_pair_t; /** diff --git a/rascaline-c-api/src/system.rs b/rascaline-c-api/src/system.rs index 41ecce8fc..66a991344 100644 --- a/rascaline-c-api/src/system.rs +++ b/rascaline-c-api/src/system.rs @@ -19,9 +19,14 @@ pub struct rascal_pair_t { pub second: usize, /// distance between the two atoms pub distance: f64, - /// vector from the first atom to the second atom, wrapped inside the unit - /// cell as required by periodic boundary conditions. + /// vector from the first atom to the second atom, accounting for periodic + /// boundary conditions. This should be + /// `position[second] - position[first] + H * cell_shift` + /// where `H` is the cell matrix. pub vector: [f64; 3], + /// How many cell shift where applied to the `second` atom to create this + /// pair. + pub cell_shift_indices: [i32; 3], } /// A `rascal_system_t` deals with the storage of atoms and related information, diff --git a/rascaline-c-api/tests/calculator.cpp b/rascaline-c-api/tests/calculator.cpp index 41e077e84..a1709896a 100644 --- a/rascaline-c-api/tests/calculator.cpp +++ b/rascaline-c-api/tests/calculator.cpp @@ -169,7 +169,7 @@ TEST_CASE("Compute descriptor") { 1, 0, /**/ 0, 1, }; auto values = std::vector{ - 5, 9, /**/ 6, 18, /**/ 7, 15, + 5, 39, /**/ 6, 18, /**/ 7, 15, }; auto gradient_samples = std::vector{ 0, 0, 0, /**/ 0, 0, 1, /**/ 0, 0, 2, @@ -194,7 +194,7 @@ TEST_CASE("Compute descriptor") { 0, 0, }; values = std::vector{ - 4, 3, + 4, 33, }; gradient_samples = std::vector{ 0, 0, 0, /**/ 0, 0, 1, @@ -260,7 +260,7 @@ TEST_CASE("Compute descriptor") { 1, 0, /**/ 0, 1, }; auto values = std::vector{ - 5, 9, /**/ 7, 15, + 5, 39, /**/ 7, 15, }; auto gradient_samples = std::vector{ 0, 0, 0, /**/ 0, 0, 1, /**/ 0, 0, 2, @@ -337,7 +337,7 @@ TEST_CASE("Compute descriptor") { 0, 1, }; auto values = std::vector{ - 9, /**/ 18, /**/ 15, + 39, /**/ 18, /**/ 15, }; auto gradient_samples = std::vector{ 0, 0, 0, /**/ 0, 0, 1, /**/ 0, 0, 2, @@ -362,7 +362,7 @@ TEST_CASE("Compute descriptor") { 0, 0, }; values = std::vector{ - 3, + 33, }; gradient_samples = std::vector{ 0, 0, 0, /**/ 0, 0, 1, @@ -575,7 +575,7 @@ TEST_CASE("Compute descriptor") { 1, 0, 0, 1 }; values = std::vector{ - 4, 3 + 4, 33 }; gradient_samples = std::vector{ 0, 0, 0, /**/ 0, 0, 1, diff --git a/rascaline-c-api/tests/cxx/calculator.cpp b/rascaline-c-api/tests/cxx/calculator.cpp index ce11ddbca..1ddca8b48 100644 --- a/rascaline-c-api/tests/cxx/calculator.cpp +++ b/rascaline-c-api/tests/cxx/calculator.cpp @@ -112,7 +112,7 @@ TEST_CASE("Compute descriptor") { {{1, 0}, {0, 1}} )); CHECK(block.values() == metatensor::NDArray( - {5.0, 9.0, 6.0, 18.0, 7.0, 15.0}, + {5.0, 39.0, 6.0, 18.0, 7.0, 15.0}, {3, 2} )); @@ -150,7 +150,7 @@ TEST_CASE("Compute descriptor") { {{1, 0}, {0, 1}} )); CHECK(block.values() == metatensor::NDArray( - {4.0, 3.0}, + {4.0, 33.0}, {1, 2} )); @@ -192,7 +192,7 @@ TEST_CASE("Compute descriptor") { {{1, 0}, {0, 1}} )); CHECK(block.values() == metatensor::NDArray( - {5.0, 9.0, 7.0, 15.0}, + {5.0, 39.0, 7.0, 15.0}, {2, 2} )); @@ -265,7 +265,7 @@ TEST_CASE("Compute descriptor") { {{0, 1}} )); CHECK(block.values() == metatensor::NDArray( - {9.0, 18.0, 15.0}, + {39.0, 18.0, 15.0}, {3, 1} )); @@ -303,7 +303,7 @@ TEST_CASE("Compute descriptor") { {{0, 1}} )); CHECK(block.values() == metatensor::NDArray( - {3.0}, + {33.0}, {1, 1} )); @@ -460,7 +460,7 @@ TEST_CASE("Compute descriptor") { {{1, 0}, {0, 1}} )); CHECK(block.values() == metatensor::NDArray( - {4.0, 3.0}, + {4.0, 33.0}, {1, 2} )); } diff --git a/rascaline-c-api/tests/cxx/test_system.hpp b/rascaline-c-api/tests/cxx/test_system.hpp index bcff03b9b..5adc13ed0 100644 --- a/rascaline-c-api/tests/cxx/test_system.hpp +++ b/rascaline-c-api/tests/cxx/test_system.hpp @@ -17,7 +17,7 @@ class TestSystem: public rascaline::System { const double* positions() const override { static double POSITIONS[4][3] = { - {0, 0, 0}, + {10, 10, 10}, {1, 1, 1}, {2, 2, 2}, {3, 3, 3}, @@ -40,30 +40,30 @@ class TestSystem: public rascaline::System { const std::vector& pairs() const override { static std::vector PAIRS = { - {0, 1, SQRT_3, {1, 1, 1}}, - {1, 2, SQRT_3, {1, 1, 1}}, - {2, 3, SQRT_3, {1, 1, 1}}, + {0, 1, SQRT_3, {1.0, 1.0, 1.0}, {1, 1, 1}}, + {1, 2, SQRT_3, {1.0, 1.0, 1.0}, {0, 0, 0}}, + {2, 3, SQRT_3, {1.0, 1.0, 1.0}, {0, 0, 0}}, }; return PAIRS; } const std::vector& pairs_containing(uintptr_t center) const override { static std::vector PAIRS_0 = { - {0, 1, SQRT_3, {1, 1, 1}}, + {0, 1, SQRT_3, {1.0, 1.0, 1.0}, {1, 1, 1}}, }; static std::vector PAIRS_1 = { - {0, 1, SQRT_3, {1, 1, 1}}, - {1, 2, SQRT_3, {1, 1, 1}}, + {0, 1, SQRT_3, {1.0, 1.0, 1.0}, {1, 1, 1}}, + {1, 2, SQRT_3, {1.0, 1.0, 1.0}, {0, 0, 0}}, }; static std::vector PAIRS_2 = { - {1, 2, SQRT_3, {1, 1, 1}}, - {2, 3, SQRT_3, {1, 1, 1}}, + {1, 2, SQRT_3, {1.0, 1.0, 1.0}, {0, 0, 0}}, + {2, 3, SQRT_3, {1.0, 1.0, 1.0}, {0, 0, 0}}, }; static std::vector PAIRS_3 = { - {2, 3, SQRT_3, {1, 1, 1}}, + {2, 3, SQRT_3, {1.0, 1.0, 1.0}, {0, 0, 0}}, }; if (center == 0) { diff --git a/rascaline-c-api/tests/helpers.cpp b/rascaline-c-api/tests/helpers.cpp index 63c640394..69fa6710f 100644 --- a/rascaline-c-api/tests/helpers.cpp +++ b/rascaline-c-api/tests/helpers.cpp @@ -17,7 +17,7 @@ rascal_system_t simple_system() { system.positions = [](const void* _, const double** positions) { static double POSITIONS[4][3] = { - {0, 0, 0}, + {10, 10, 10}, {1, 1, 1}, {2, 2, 2}, {3, 3, 3}, @@ -53,9 +53,9 @@ rascal_system_t simple_system() { system.pairs = [](const void* _, const rascal_pair_t** pairs, uintptr_t* count) { static rascal_pair_t PAIRS[] = { - {0, 1, SQRT_3, {1, 1, 1}}, - {1, 2, SQRT_3, {1, 1, 1}}, - {2, 3, SQRT_3, {1, 1, 1}}, + {0, 1, SQRT_3, {1.0, 1.0, 1.0}, {1, 1, 1}}, + {1, 2, SQRT_3, {1.0, 1.0, 1.0}, {0, 0, 0}}, + {2, 3, SQRT_3, {1.0, 1.0, 1.0}, {0, 0, 0}}, }; *pairs = PAIRS; @@ -65,21 +65,21 @@ rascal_system_t simple_system() { system.pairs_containing = [](const void* _, uintptr_t center, const rascal_pair_t** pairs, uintptr_t* count){ static rascal_pair_t PAIRS_0[] = { - {0, 1, SQRT_3, {1, 1, 1}}, + {0, 1, SQRT_3, {1.0, 1.0, 1.0}, {1, 1, 1}}, }; static rascal_pair_t PAIRS_1[] = { - {0, 1, SQRT_3, {1, 1, 1}}, - {1, 2, SQRT_3, {1, 1, 1}}, + {0, 1, SQRT_3, {1.0, 1.0, 1.0}, {1, 1, 1}}, + {1, 2, SQRT_3, {1.0, 1.0, 1.0}, {0, 0, 0}}, }; static rascal_pair_t PAIRS_2[] = { - {1, 2, SQRT_3, {1, 1, 1}}, - {2, 3, SQRT_3, {1, 1, 1}}, + {1, 2, SQRT_3, {1.0, 1.0, 1.0}, {0, 0, 0}}, + {2, 3, SQRT_3, {1.0, 1.0, 1.0}, {0, 0, 0}}, }; static rascal_pair_t PAIRS_3[] = { - {2, 3, SQRT_3, {1, 1, 1}}, + {2, 3, SQRT_3, {1.0, 1.0, 1.0}, {0, 0, 0}}, }; if (center == 0) { diff --git a/rascaline-torch/tests/system.cpp b/rascaline-torch/tests/system.cpp index ee594b4f6..07f7ce58a 100644 --- a/rascaline-torch/tests/system.cpp +++ b/rascaline-torch/tests/system.cpp @@ -17,8 +17,8 @@ TEST_CASE("Systems") { CHECK(system.use_native_system() == true); - system.set_precomputed_pairs(3.2, {{0, 1, 0.0, {0.0, 0.0, 0.0}}}); - system.set_precomputed_pairs(4.5, {{3, 2, 0.0, {0.0, 0.0, 0.0}}}); + system.set_precomputed_pairs(3.2, {{0, 1, 0.0, {0.0, 0.0, 0.0}, {0, 1, 0}}}); + system.set_precomputed_pairs(4.5, {{3, 2, 0.0, {0.0, 0.0, 0.0}, {0, 1, 0}}}); CHECK(system.use_native_system() == false); CHECK_THROWS_WITH( diff --git a/rascaline/src/calculators/dummy_calculator.rs b/rascaline/src/calculators/dummy_calculator.rs index 85f1bc6f2..0674db16b 100644 --- a/rascaline/src/calculators/dummy_calculator.rs +++ b/rascaline/src/calculators/dummy_calculator.rs @@ -8,7 +8,7 @@ use crate::labels::{SpeciesFilter, SamplesBuilder}; use crate::labels::AtomCenteredSamples; use crate::labels::{CenterSpeciesKeys, KeysBuilder}; -use crate::{Error, System}; +use crate::{Error, System, Vector3D}; /// A stupid calculator implementation used to test the API, and API binding to /// C/Python/etc. @@ -137,8 +137,51 @@ impl CalculatorBase for DummyCalculator { system.compute_neighbors(self.cutoff)?; let positions = system.positions()?; + let cell = system.cell()?.matrix(); let mut sum = positions[center_i][0] + positions[center_i][1] + positions[center_i][2]; for pair in system.pairs()? { + // this code just check for consistency in the + // neighbor list + let shift = pair.cell_shift_indices[0] as f64 * Vector3D::from(cell[0]) + + pair.cell_shift_indices[1] as f64 * Vector3D::from(cell[1]) + + pair.cell_shift_indices[2] as f64 * Vector3D::from(cell[2]); + let from_shift = positions[pair.second] - positions[pair.first] + shift; + if !approx::relative_eq!(from_shift, pair.vector, max_relative=1e-6) { + return Err(Error::InvalidParameter(format!( + "system implementation returned inconsistent neighbors list:\ + pair.vector is {:?}, but the cell shift give {:?} for atoms {}-{}", + pair.vector, from_shift, pair.first, pair.second + ))); + } + + if !approx::relative_eq!(pair.vector.norm(), pair.distance, max_relative=1e-6) { + return Err(Error::InvalidParameter(format!( + "system implementation returned inconsistent neighbors list:\ + pair.vector norm is {}, but pair.distance is {} for atoms {}-{}", + pair.vector.norm(), pair.distance, pair.first, pair.second + ))); + } + + let pairs_by_center = system.pairs_containing(pair.first)?; + if !pairs_by_center.iter().any(|p| p == pair) { + return Err(Error::InvalidParameter(format!( + "system implementation returned inconsistent neighbors list:\ + pairs_containing({}) does not contains a pair for atoms {}-{}", + pair.first, pair.first, pair.second + ))); + } + + let pairs_by_center = system.pairs_containing(pair.second)?; + if !pairs_by_center.iter().any(|p| p == pair) { + return Err(Error::InvalidParameter(format!( + "system implementation returned inconsistent neighbors list:\ + pairs_containing({}) does not contains a pair for atoms {}-{}", + pair.second, pair.first, pair.second + ))); + } + // end of neighbors list consistency check + + // actual values calculation if pair.first == center_i { sum += positions[pair.second][0] + positions[pair.second][1] + positions[pair.second][2]; } diff --git a/rascaline/src/calculators/neighbor_list.rs b/rascaline/src/calculators/neighbor_list.rs index cd2b7a0c6..c092b4753 100644 --- a/rascaline/src/calculators/neighbor_list.rs +++ b/rascaline/src/calculators/neighbor_list.rs @@ -17,17 +17,17 @@ use crate::{Error, System}; /// working with "half" neighbor list (only including one entry for each `i/j` /// pair) /// -/// Self pairs (pairs between an atom and periodic copy itself) can appear when -/// the cutoff is larger than the cell under periodic boundary conditions. Self -/// pairs with a distance of 0 are not included in this calculator, even though -/// they are required when computing SOAP. +/// Pairs between an atom and it's own periodic copy can appear when the cutoff +/// is larger than the cell under periodic boundary conditions. Self pairs with +/// a distance of 0 (i.e. self pairs inside the original unit cell) are only +/// included when using `self_pairs = true`. /// -/// This sample produces a single property (`"distance"`) with three components -/// (`"pair_direction"`) containing the x, y, and z component of the vector from -/// the first atom in the pair to the second. In addition to the atom indexes, -/// the samples also contain a pair index, to be able to distinguish between -/// multiple pairs between the same atom (if the cutoff is larger than the -/// cell). +/// This calculator produces a single property (``"distance"``) with three +/// components (``"pair_direction"``) containing the x, y, and z component of +/// the distance vector of the pair. +/// +/// The samples also contain the two atoms indexes, as well as the number of +/// cell boundaries crossed to create this pair. #[derive(Debug, Clone)] #[derive(serde::Deserialize, serde::Serialize, schemars::JsonSchema)] pub struct NeighborList { @@ -69,23 +69,35 @@ impl CalculatorBase for NeighborList { assert!(self.cutoff > 0.0 && self.cutoff.is_finite()); if self.full_neighbor_list { - FullNeighborList { cutoff: self.cutoff, self_pairs: self.self_pairs }.keys(systems) + FullNeighborList { + cutoff: self.cutoff, + self_pairs: self.self_pairs, + }.keys(systems) } else { - HalfNeighborList { cutoff: self.cutoff, self_pairs: self.self_pairs }.keys(systems) + HalfNeighborList { + cutoff: self.cutoff, + self_pairs: self.self_pairs, + }.keys(systems) } } fn samples_names(&self) -> Vec<&str> { - return vec!["structure", "pair_id", "first_atom", "second_atom"]; + return vec!["structure", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"]; } fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { assert!(self.cutoff > 0.0 && self.cutoff.is_finite()); if self.full_neighbor_list { - FullNeighborList { cutoff: self.cutoff, self_pairs: self.self_pairs }.samples(keys, systems) + FullNeighborList { + cutoff: self.cutoff, + self_pairs: self.self_pairs, + }.samples(keys, systems) } else { - HalfNeighborList { cutoff: self.cutoff, self_pairs: self.self_pairs }.samples(keys, systems) + HalfNeighborList { + cutoff: self.cutoff, + self_pairs: self.self_pairs, + }.samples(keys, systems) } } @@ -102,9 +114,9 @@ impl CalculatorBase for NeighborList { for block_samples in samples { let mut builder = LabelsBuilder::new(vec!["sample", "structure", "atom"]); - for (sample_i, &[system_i, pair_id, first, second]) in block_samples.iter_fixed_size().enumerate() { + for (sample_i, &[system_i, first, second, cell_a, cell_b, cell_c]) in block_samples.iter_fixed_size().enumerate() { // self pairs do not contribute to gradients - if pair_id == -1 { + if first == second && cell_a == 0 && cell_b == 0 && cell_c == 0 { continue; } builder.add(&[sample_i.into(), system_i, first]); @@ -128,7 +140,7 @@ impl CalculatorBase for NeighborList { fn properties(&self, keys: &Labels) -> Vec { let mut properties = LabelsBuilder::new(self.properties_names()); - properties.add(&[LabelValue::new(0)]); + properties.add(&[LabelValue::new(1)]); let properties = properties.finish(); return vec![properties; keys.count()]; @@ -137,9 +149,15 @@ impl CalculatorBase for NeighborList { #[time_graph::instrument(name = "NeighborList::compute")] fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { if self.full_neighbor_list { - FullNeighborList { cutoff: self.cutoff, self_pairs: self.self_pairs }.compute(systems, descriptor) + FullNeighborList { + cutoff: self.cutoff, + self_pairs: self.self_pairs, + }.compute(systems, descriptor) } else { - HalfNeighborList { cutoff: self.cutoff, self_pairs: self.self_pairs }.compute(systems, descriptor) + HalfNeighborList { + cutoff: self.cutoff, + self_pairs: self.self_pairs, + }.compute(systems, descriptor) } } } @@ -149,7 +167,7 @@ impl CalculatorBase for NeighborList { #[derive(Debug, Clone)] struct HalfNeighborList { cutoff: f64, - self_pairs: bool + self_pairs: bool, } impl HalfNeighborList { @@ -185,15 +203,29 @@ impl HalfNeighborList { let mut results = Vec::new(); for [species_first, species_second] in keys.iter_fixed_size() { - let mut builder = LabelsBuilder::new( - vec!["structure", "pair_id", "first_atom", "second_atom"] - ); + let mut builder = LabelsBuilder::new(vec![ + "structure", + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c" + ]); + for (system_i, system) in systems.iter_mut().enumerate() { system.compute_neighbors(self.cutoff)?; let species = system.species()?; - for (pair_id, pair) in system.pairs()?.iter().enumerate() { + for pair in system.pairs()? { let ((species_i, species_j), invert) = sort_pair((species[pair.first], species[pair.second])); + + let shifts = pair.cell_shift_indices; + let (cell_a, cell_b, cell_c) = if invert { + (-shifts[0], -shifts[1], -shifts[2]) + } else { + (shifts[0], shifts[1], shifts[2]) + }; + let (atom_i, atom_j) = if invert { (pair.second, pair.first) } else { @@ -201,7 +233,14 @@ impl HalfNeighborList { }; if species_i == species_first.i32() && species_j == species_second.i32() { - builder.add(&[system_i, pair_id, atom_i, atom_j]); + builder.add(&[ + LabelValue::from(system_i), + LabelValue::from(atom_i), + LabelValue::from(atom_j), + LabelValue::from(cell_a), + LabelValue::from(cell_b), + LabelValue::from(cell_c), + ]); } } @@ -211,17 +250,17 @@ impl HalfNeighborList { if species[center_i] == species_first.i32() { builder.add(&[ system_i.into(), - // set pair_id as -1 for self pairs - LabelValue::new(-1), center_i.into(), center_i.into(), + LabelValue::from(0), + LabelValue::from(0), + LabelValue::from(0), ]); } } } } - results.push(builder.finish()); } @@ -233,7 +272,7 @@ impl HalfNeighborList { system.compute_neighbors(self.cutoff)?; let species = system.species()?; - for (pair_id, pair) in system.pairs()?.iter().enumerate() { + for pair in system.pairs()? { // Sort the species in the pair to ensure a canonical order of // the atoms in it. This guarantee that multiple call to this // calculator always returns pairs in the same order, even if @@ -250,49 +289,70 @@ impl HalfNeighborList { pair.vector }; + let shifts = pair.cell_shift_indices; + let (cell_a, cell_b, cell_c) = if invert { + (-shifts[0], -shifts[1], -shifts[2]) + } else { + (shifts[0], shifts[1], shifts[2]) + }; + let (atom_i, atom_j) = if invert { (pair.second, pair.first) } else { (pair.first, pair.second) }; - let block_id = descriptor.keys().position(&[ + let block_i = descriptor.keys().position(&[ species_i.into(), species_j.into() - ]).expect("missing block"); - - let mut block = descriptor.block_mut_by_id(block_id); - let block_data = block.data_mut(); - - let sample_i = block_data.samples.position(&[ - system_i.into(), pair_id.into(), atom_i.into(), atom_j.into() ]); - if let Some(sample_i) = sample_i { - let array = block_data.values.to_array_mut(); - - array[[sample_i, 0, 0]] = pair_vector[0]; - array[[sample_i, 1, 0]] = pair_vector[1]; - array[[sample_i, 2, 0]] = pair_vector[2]; + if let Some(block_i) = block_i { + let mut block = descriptor.block_mut_by_id(block_i); + let block_data = block.data_mut(); + + let sample_i = block_data.samples.position(&[ + LabelValue::from(system_i), + LabelValue::from(atom_i), + LabelValue::from(atom_j), + LabelValue::from(cell_a), + LabelValue::from(cell_b), + LabelValue::from(cell_c), + ]); + + if let Some(sample_i) = sample_i { + let array = block_data.values.to_array_mut(); + for (property_i, &[distance]) in block_data.properties.iter_fixed_size().enumerate() { + if distance == 1 { + array[[sample_i, 0, property_i]] = pair_vector[0]; + array[[sample_i, 1, property_i]] = pair_vector[1]; + array[[sample_i, 2, property_i]] = pair_vector[2]; + } + } - if let Some(mut gradient) = block.gradient_mut("positions") { - let gradient = gradient.data_mut(); + if let Some(mut gradient) = block.gradient_mut("positions") { + let gradient = gradient.data_mut(); - let first_grad_sample_i = gradient.samples.position(&[ - sample_i.into(), system_i.into(), atom_i.into() - ]).expect("missing gradient sample"); - let second_grad_sample_i = gradient.samples.position(&[ - sample_i.into(), system_i.into(), atom_j.into() - ]).expect("missing gradient sample"); + let first_grad_sample_i = gradient.samples.position(&[ + sample_i.into(), system_i.into(), atom_i.into() + ]).expect("missing gradient sample"); + let second_grad_sample_i = gradient.samples.position(&[ + sample_i.into(), system_i.into(), atom_j.into() + ]).expect("missing gradient sample"); - let array = gradient.values.to_array_mut(); + let array = gradient.values.to_array_mut(); - array[[first_grad_sample_i, 0, 0, 0]] = -1.0; - array[[first_grad_sample_i, 1, 1, 0]] = -1.0; - array[[first_grad_sample_i, 2, 2, 0]] = -1.0; + for (property_i, &[distance]) in gradient.properties.iter_fixed_size().enumerate() { + if distance == 1 { + array[[first_grad_sample_i, 0, 0, property_i]] = -1.0; + array[[first_grad_sample_i, 1, 1, property_i]] = -1.0; + array[[first_grad_sample_i, 2, 2, property_i]] = -1.0; - array[[second_grad_sample_i, 0, 0, 0]] = 1.0; - array[[second_grad_sample_i, 1, 1, 0]] = 1.0; - array[[second_grad_sample_i, 2, 2, 0]] = 1.0; + array[[second_grad_sample_i, 0, 0, property_i]] = 1.0; + array[[second_grad_sample_i, 1, 1, property_i]] = 1.0; + array[[second_grad_sample_i, 2, 2, property_i]] = 1.0; + } + } + } } } } @@ -307,7 +367,7 @@ impl HalfNeighborList { #[derive(Debug, Clone)] pub struct FullNeighborList { pub cutoff: f64, - pub self_pairs: bool + pub self_pairs: bool, } impl FullNeighborList { @@ -340,33 +400,73 @@ impl FullNeighborList { return Ok(keys.finish()); } - fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { + pub(crate) fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { let mut results = Vec::new(); for &[species_first, species_second] in keys.iter_fixed_size() { - let mut builder = LabelsBuilder::new( - vec!["structure", "pair_id", "first_atom", "second_atom"] - ); + let mut builder = LabelsBuilder::new(vec![ + "structure", + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c" + ]); + for (system_i, system) in systems.iter_mut().enumerate() { system.compute_neighbors(self.cutoff)?; let species = system.species()?; - for (pair_id, pair) in system.pairs()?.iter().enumerate() { + for pair in system.pairs()? { + let cell_a = pair.cell_shift_indices[0]; + let cell_b = pair.cell_shift_indices[1]; + let cell_c = pair.cell_shift_indices[2]; + if species_first == species_second { // same species for both atoms in the pair if species[pair.first] == species_first.i32() && species[pair.second] == species_second.i32() { - builder.add(&[system_i, pair_id, pair.first, pair.second]); + builder.add(&[ + LabelValue::from(system_i), + LabelValue::from(pair.first), + LabelValue::from(pair.second), + LabelValue::from(cell_a), + LabelValue::from(cell_b), + LabelValue::from(cell_c), + ]); + if pair.first != pair.second { - // do not duplicate self pairs - builder.add(&[system_i, pair_id, pair.second, pair.first]); + // if the pair is between two different atoms, + // also add the reversed (second -> first) pair. + builder.add(&[ + LabelValue::from(system_i), + LabelValue::from(pair.second), + LabelValue::from(pair.first), + LabelValue::from(-cell_a), + LabelValue::from(-cell_b), + LabelValue::from(-cell_c), + ]); } } } else { - // different species + // different species, find the right order for the pair if species[pair.first] == species_first.i32() && species[pair.second] == species_second.i32() { - builder.add(&[system_i, pair_id, pair.first, pair.second]); + builder.add(&[ + LabelValue::from(system_i), + LabelValue::from(pair.first), + LabelValue::from(pair.second), + LabelValue::from(cell_a), + LabelValue::from(cell_b), + LabelValue::from(cell_c), + ]); } else if species[pair.second] == species_first.i32() && species[pair.first] == species_second.i32() { - builder.add(&[system_i, pair_id, pair.second, pair.first]); + builder.add(&[ + LabelValue::from(system_i), + LabelValue::from(pair.second), + LabelValue::from(pair.first), + LabelValue::from(-cell_a), + LabelValue::from(-cell_b), + LabelValue::from(-cell_c), + ]); } } } @@ -377,10 +477,11 @@ impl FullNeighborList { if species[center_i] == species_first.i32() { builder.add(&[ system_i.into(), - // set pair_id as -1 for self pairs - LabelValue::new(-1), center_i.into(), center_i.into(), + LabelValue::from(0), + LabelValue::from(0), + LabelValue::from(0), ]); } } @@ -393,104 +494,130 @@ impl FullNeighborList { return Ok(results); } + #[allow(clippy::too_many_lines)] fn compute(&mut self, systems: &mut [Box], descriptor: &mut TensorMap) -> Result<(), Error> { for (system_i, system) in systems.iter_mut().enumerate() { system.compute_neighbors(self.cutoff)?; let species = system.species()?; - for (pair_id, pair) in system.pairs()?.iter().enumerate() { - let first_block_id = descriptor.keys().position(&[ + for pair in system.pairs()?.iter() { + let first_block_i = descriptor.keys().position(&[ species[pair.first].into(), species[pair.second].into() - ]).expect("missing block"); - - let second_block_id = if species[pair.first] == species[pair.second] { - None - } else { - Some(descriptor.keys().position(&[ - species[pair.second].into(), species[pair.first].into() - ]).expect("missing block")) - }; - - // first, the pair first -> second - let mut block = descriptor.block_mut_by_id(first_block_id); - let block_data = block.data_mut(); + ]); - let sample_i = block_data.samples.position(&[ - system_i.into(), pair_id.into(), pair.first.into(), pair.second.into() + let second_block_i = descriptor.keys().position(&[ + species[pair.second].into(), species[pair.first].into() ]); - if let Some(sample_i) = sample_i { - let array = block_data.values.to_array_mut(); + let cell_a = pair.cell_shift_indices[0]; + let cell_b = pair.cell_shift_indices[1]; + let cell_c = pair.cell_shift_indices[2]; - array[[sample_i, 0, 0]] = pair.vector[0]; - array[[sample_i, 1, 0]] = pair.vector[1]; - array[[sample_i, 2, 0]] = pair.vector[2]; + // first, the pair first -> second + if let Some(first_block_i) = first_block_i { + let mut block = descriptor.block_mut_by_id(first_block_i); + let block_data = block.data_mut(); + + let sample_i = block_data.samples.position(&[ + LabelValue::from(system_i), + LabelValue::from(pair.first), + LabelValue::from(pair.second), + LabelValue::from(cell_a), + LabelValue::from(cell_b), + LabelValue::from(cell_c), + ]); + + if let Some(sample_i) = sample_i { + let array = block_data.values.to_array_mut(); + + for (property_i, &[distance]) in block_data.properties.iter_fixed_size().enumerate() { + if distance == 1 { + array[[sample_i, 0, property_i]] = pair.vector[0]; + array[[sample_i, 1, property_i]] = pair.vector[1]; + array[[sample_i, 2, property_i]] = pair.vector[2]; + } + } - if let Some(mut gradient) = block.gradient_mut("positions") { - let gradient = gradient.data_mut(); + if let Some(mut gradient) = block.gradient_mut("positions") { + let gradient = gradient.data_mut(); - let first_grad_sample_i = gradient.samples.position(&[ - sample_i.into(), system_i.into(), pair.first.into() - ]).expect("missing gradient sample"); - let second_grad_sample_i = gradient.samples.position(&[ - sample_i.into(), system_i.into(), pair.second.into() - ]).expect("missing gradient sample"); + let first_grad_sample_i = gradient.samples.position(&[ + sample_i.into(), system_i.into(), pair.first.into() + ]).expect("missing gradient sample"); + let second_grad_sample_i = gradient.samples.position(&[ + sample_i.into(), system_i.into(), pair.second.into() + ]).expect("missing gradient sample"); - let array = gradient.values.to_array_mut(); + let array = gradient.values.to_array_mut(); - array[[first_grad_sample_i, 0, 0, 0]] = -1.0; - array[[first_grad_sample_i, 1, 1, 0]] = -1.0; - array[[first_grad_sample_i, 2, 2, 0]] = -1.0; + for (property_i, &[distance]) in gradient.properties.iter_fixed_size().enumerate() { + if distance == 1 { + array[[first_grad_sample_i, 0, 0, property_i]] = -1.0; + array[[first_grad_sample_i, 1, 1, property_i]] = -1.0; + array[[first_grad_sample_i, 2, 2, property_i]] = -1.0; - array[[second_grad_sample_i, 0, 0, 0]] = 1.0; - array[[second_grad_sample_i, 1, 1, 0]] = 1.0; - array[[second_grad_sample_i, 2, 2, 0]] = 1.0; + array[[second_grad_sample_i, 0, 0, property_i]] = 1.0; + array[[second_grad_sample_i, 1, 1, property_i]] = 1.0; + array[[second_grad_sample_i, 2, 2, property_i]] = 1.0; + } + } + } } } - // then the pair second -> first - let mut block = if let Some(second_block_id) = second_block_id { - descriptor.block_mut_by_id(second_block_id) - } else { - if pair.first == pair.second { - // do not duplicate self pairs - continue - } - // same species for both atoms in the pair, keep the same block - block - }; - - let block_data = block.data_mut(); - let sample_i = block_data.samples.position(&[ - system_i.into(), pair_id.into(), pair.second.into(), pair.first.into() - ]); - - if let Some(sample_i) = sample_i { - let array = block_data.values.to_array_mut(); + if pair.first == pair.second { + // do not duplicate self pairs + continue; + } - array[[sample_i, 0, 0]] = -pair.vector[0]; - array[[sample_i, 1, 0]] = -pair.vector[1]; - array[[sample_i, 2, 0]] = -pair.vector[2]; + // then the pair second -> first + if let Some(second_block_i) = second_block_i { + let mut block = descriptor.block_mut_by_id(second_block_i); + + let block_data = block.data_mut(); + let sample_i = block_data.samples.position(&[ + LabelValue::from(system_i), + LabelValue::from(pair.second), + LabelValue::from(pair.first), + LabelValue::from(-cell_a), + LabelValue::from(-cell_b), + LabelValue::from(-cell_c), + ]); + + if let Some(sample_i) = sample_i { + let array = block_data.values.to_array_mut(); + for (property_i, &[distance]) in block_data.properties.iter_fixed_size().enumerate() { + if distance == 1 { + array[[sample_i, 0, property_i]] = -pair.vector[0]; + array[[sample_i, 1, property_i]] = -pair.vector[1]; + array[[sample_i, 2, property_i]] = -pair.vector[2]; + } + } - if let Some(mut gradient) = block.gradient_mut("positions") { - let gradient = gradient.data_mut(); + if let Some(mut gradient) = block.gradient_mut("positions") { + let gradient = gradient.data_mut(); - let first_grad_sample_i = gradient.samples.position(&[ - sample_i.into(), system_i.into(), pair.second.into() - ]).expect("missing gradient sample"); - let second_grad_sample_i = gradient.samples.position(&[ - sample_i.into(), system_i.into(), pair.first.into() - ]).expect("missing gradient sample"); + let first_grad_sample_i = gradient.samples.position(&[ + sample_i.into(), system_i.into(), pair.second.into() + ]).expect("missing gradient sample"); + let second_grad_sample_i = gradient.samples.position(&[ + sample_i.into(), system_i.into(), pair.first.into() + ]).expect("missing gradient sample"); - let array = gradient.values.to_array_mut(); + let array = gradient.values.to_array_mut(); - array[[first_grad_sample_i, 0, 0, 0]] = -1.0; - array[[first_grad_sample_i, 1, 1, 0]] = -1.0; - array[[first_grad_sample_i, 2, 2, 0]] = -1.0; + for (property_i, &[distance]) in gradient.properties.iter_fixed_size().enumerate() { + if distance == 1 { + array[[first_grad_sample_i, 0, 0, property_i]] = -1.0; + array[[first_grad_sample_i, 1, 1, property_i]] = -1.0; + array[[first_grad_sample_i, 2, 2, property_i]] = -1.0; - array[[second_grad_sample_i, 0, 0, 0]] = 1.0; - array[[second_grad_sample_i, 1, 1, 0]] = 1.0; - array[[second_grad_sample_i, 2, 2, 0]] = 1.0; + array[[second_grad_sample_i, 0, 0, property_i]] = 1.0; + array[[second_grad_sample_i, 1, 1, property_i]] = 1.0; + array[[second_grad_sample_i, 2, 2, property_i]] = 1.0; + } + } + } } } } @@ -531,15 +658,15 @@ mod tests { // O-H block let block = descriptor.block_by_id(0); - assert_eq!(block.properties(), Labels::new(["distance"], &[[0]])); + assert_eq!(block.properties(), Labels::new(["distance"], &[[1]])); assert_eq!(block.components().len(), 1); assert_eq!(block.components()[0], Labels::new(["pair_direction"], &[[0], [1], [2]])); assert_eq!(block.samples(), Labels::new( - ["structure", "pair_id", "first_atom", "second_atom"], + ["structure", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"], // we have two O-H pairs - &[[0, 0, 0, 1], [0, 1, 0, 2]] + &[[0, 0, 1, 0, 0, 0], [0, 0, 2, 0, 0, 0]] )); let array = block.values().to_array(); @@ -552,9 +679,9 @@ mod tests { // H-H block let block = descriptor.block_by_id(1); assert_eq!(block.samples(), Labels::new( - ["structure", "pair_id", "first_atom", "second_atom"], + ["structure", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"], // we have one H-H pair - &[[0, 2, 1, 2]] + &[[0, 1, 2, 0, 0, 0]] )); let array = block.values().to_array(); @@ -583,15 +710,15 @@ mod tests { // O-H block let block = descriptor.block_by_id(0); - assert_eq!(block.properties(), Labels::new(["distance"], &[[0]])); + assert_eq!(block.properties(), Labels::new(["distance"], &[[1]])); assert_eq!(block.components().len(), 1); assert_eq!(block.components()[0], Labels::new(["pair_direction"], &[[0], [1], [2]])); assert_eq!(block.samples(), Labels::new( - ["structure", "pair_id", "first_atom", "second_atom"], + ["structure", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"], // we have two O-H pairs - &[[0, 0, 0, 1], [0, 1, 0, 2]] + &[[0, 0, 1, 0, 0, 0], [0, 0, 2, 0, 0, 0]] )); let array = block.values().to_array(); @@ -603,15 +730,15 @@ mod tests { // H-O block let block = descriptor.block_by_id(1); - assert_eq!(block.properties(), Labels::new(["distance"], &[[0]])); + assert_eq!(block.properties(), Labels::new(["distance"], &[[1]])); assert_eq!(block.components().len(), 1); assert_eq!(block.components()[0], Labels::new(["pair_direction"], &[[0], [1], [2]])); assert_eq!(block.samples(), Labels::new( - ["structure", "pair_id", "first_atom", "second_atom"], + ["structure", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"], // we have two H-O pairs - &[[0, 0, 1, 0], [0, 1, 2, 0]] + &[[0, 1, 0, 0, 0, 0], [0, 2, 0, 0, 0, 0]] )); let array = block.values().to_array(); @@ -624,9 +751,9 @@ mod tests { // H-H block let block = descriptor.block_by_id(2); assert_eq!(block.samples(), Labels::new( - ["structure", "pair_id", "first_atom", "second_atom"], + ["structure", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"], // we have one H-H pair, twice - &[[0, 2, 1, 2], [0, 2, 2, 1]] + &[[0, 1, 2, 0, 0, 0], [0, 2, 1, 0, 0, 0]] )); let array = block.values().to_array(); @@ -640,7 +767,7 @@ mod tests { #[test] fn finite_differences_positions() { // half neighbor list - let calculator = Calculator::from(Box::new(NeighborList{ + let calculator = Calculator::from(Box::new(NeighborList { cutoff: 1.0, full_neighbor_list: false, self_pairs: false, @@ -655,7 +782,7 @@ mod tests { crate::calculators::tests_utils::finite_differences_positions(calculator, &system, options); // full neighbor list - let calculator = Calculator::from(Box::new(NeighborList{ + let calculator = Calculator::from(Box::new(NeighborList { cutoff: 1.0, full_neighbor_list: true, self_pairs: false, @@ -666,8 +793,8 @@ mod tests { #[test] fn compute_partial() { // half neighbor list - let calculator = Calculator::from(Box::new(NeighborList{ - cutoff: 1.0, + let calculator = Calculator::from(Box::new(NeighborList { + cutoff: 3.0, full_neighbor_list: false, self_pairs: false, }) as Box); @@ -680,12 +807,12 @@ mod tests { let properties = Labels::new( ["distance"], - &[[0]], + &[[1]], ); let keys = Labels::new( ["species_first_atom", "species_second_atom"], - &[[-42, 1], [1, -42], [1, 1], [6, 6]] + &[[-42, 1], [1, -42], [1, 1], [1, 6], [6, 1], [6, 6]] ); crate::calculators::tests_utils::compute_partial( @@ -693,8 +820,8 @@ mod tests { ); // full neighbor list - let calculator = Calculator::from(Box::new(NeighborList{ - cutoff: 1.0, + let calculator = Calculator::from(Box::new(NeighborList { + cutoff: 3.0, full_neighbor_list: true, self_pairs: false, }) as Box); @@ -705,7 +832,7 @@ mod tests { #[test] fn check_self_pairs() { - let mut calculator = Calculator::from(Box::new(NeighborList{ + let mut calculator = Calculator::from(Box::new(NeighborList { cutoff: 2.0, full_neighbor_list: true, self_pairs: true, @@ -724,9 +851,14 @@ mod tests { let block = descriptor.block_by_id(3); let block = block.data(); assert_eq!(*block.samples, Labels::new( - ["structure", "pair_id", "first_atom", "second_atom"], - // we have one H-H pair and two self-pairs - &[[0, 2, 1, 2], [0, 2, 2, 1], [0, -1, 1, 1], [0, -1, 2, 2]] + ["structure", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"], + &[ + // we have one H-H pair and two self-pairs + [0, 1, 2, 0, 0, 0], + [0, 2, 1, 0, 0, 0], + [0, 1, 1, 0, 0, 0], + [0, 2, 2, 0, 0, 0], + ] )); } diff --git a/rascaline/src/calculators/soap/spherical_expansion_pair.rs b/rascaline/src/calculators/soap/spherical_expansion_pair.rs index 33420a93f..897ecd7cf 100644 --- a/rascaline/src/calculators/soap/spherical_expansion_pair.rs +++ b/rascaline/src/calculators/soap/spherical_expansion_pair.rs @@ -281,12 +281,13 @@ impl SphericalExpansionByPair { // loop over all samples in this block, find self pairs // (`pair_id` is -1), and fill the data using `self_contribution` - for (sample_i, &[structure, pair_id, atom_1, atom_2]) in data.samples.iter_fixed_size().enumerate() { + for (sample_i, &[structure, atom_1, atom_2, cell_a, cell_b, cell_c]) in data.samples.iter_fixed_size().enumerate() { // it is possible that the samples from values.samples are not // part of the systems (the user requested extra samples). In // that case, we need to skip anything that does not exist, or // with a different species center - if structure.usize() >= systems.len() || pair_id != -1 { + let is_self_pair = atom_1 == atom_2 && cell_a == 0 && cell_b == 0 && cell_c == 0; + if structure.usize() >= systems.len() || !is_self_pair { continue; } @@ -452,7 +453,7 @@ impl SphericalExpansionByPair { // gradient of the pair contribution w.r.t. the position of // the first atom let first_grad_sample_i = gradient.samples.position(&[ - sample_i.into(), /* structure */ sample[0], /* pair.first */ sample[2] + sample_i.into(), /* structure */ sample[0], /* pair.first */ sample[1] ]).expect("missing first gradient sample"); for spatial in 0..3 { @@ -469,7 +470,7 @@ impl SphericalExpansionByPair { // gradient of the pair contribution w.r.t. the position of // the second atom let second_grad_sample_i = gradient.samples.position(&[ - sample_i.into(), /* structure */ sample[0], /* pair.second */ sample[3] + sample_i.into(), /* structure */ sample[0], /* pair.second */ sample[2] ]).expect("missing second gradient sample"); for spatial in 0..3 { @@ -527,22 +528,12 @@ impl CalculatorBase for SphericalExpansionByPair { } fn keys(&self, systems: &mut [Box]) -> Result { - // the species part of the keys is the same for all l - let species_keys = FullNeighborList { cutoff: self.parameters.cutoff, self_pairs: false }.keys(systems)?; - let mut all_species_pairs = species_keys.iter().map(|p| (p[0], p[1])).collect::>(); - - // also include self-pairs in case they are missing from species_keys - let mut all_species = BTreeSet::new(); - for system in systems { - let species = system.species()?; - for &species in species { - all_species.insert(species); - } - } - - for species in all_species { - all_species_pairs.insert((species.into(), species.into())); - } + // the species part of the keys is the same for all l, and the same as + // what a FullNeighborList with `self_pairs=True` produces. + let full_neighbors_list_keys = FullNeighborList { + cutoff: self.parameters.cutoff, + self_pairs: true, + }.keys(systems)?; let mut keys = LabelsBuilder::new(vec![ "spherical_harmonics_l", @@ -550,7 +541,7 @@ impl CalculatorBase for SphericalExpansionByPair { "species_atom_2" ]); - for (s1, s2) in all_species_pairs { + for &[s1, s2] in full_neighbors_list_keys.iter_fixed_size() { for l in 0..=self.parameters.max_angular { keys.add(&[l.into(), s1, s2]); } @@ -561,70 +552,58 @@ impl CalculatorBase for SphericalExpansionByPair { } fn samples_names(&self) -> Vec<&str> { - return vec!["structure", "pair_id", "first_atom", "second_atom"]; + return vec!["structure", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"]; } fn samples(&self, keys: &Labels, systems: &mut [Box]) -> Result, Error> { - let mut result = Vec::new(); - - // we only need to compute samples once for each l, the cache stores the - // ones we have already computed - let mut cache: BTreeMap<_, Labels> = BTreeMap::new(); + // get all species pairs in keys as a new set of Labels + let mut species_keys = BTreeSet::new(); for &[_, s1, s2] in keys.iter_fixed_size() { - let samples = match cache.entry((s1, s2)) { - Entry::Occupied(entry) => entry.get().clone(), - Entry::Vacant(entry) => { - let mut builder = LabelsBuilder::new(self.samples_names()); - - for (system_i, system) in systems.iter_mut().enumerate() { - system.compute_neighbors(self.parameters.cutoff)?; - let species = system.species()?; - - if s1 == s2 { - // same species for the two atoms, we need to insert - // the self pairs - for center_i in 0..system.size()? { - if species[center_i] == s1.i32() { - builder.add(&[ - system_i.into(), - // set pair_id as -1 for self pairs - LabelValue::new(-1), - center_i.into(), - center_i.into(), - ]); - } - } + species_keys.insert((s1, s2)); + } + let mut builder = LabelsBuilder::new(vec!["species_atom_1", "species_atom_2"]); + for (s1, s2) in species_keys { + builder.add(&[s1, s2]); + } + let species_keys = builder.finish(); + + // for l=0, we want to include self pairs in the samples + let mut samples_by_species_l0: BTreeMap<_, Labels> = BTreeMap::new(); + let full_neighbors_list_samples = FullNeighborList { + cutoff: self.parameters.cutoff, + self_pairs: true, + }.samples(&species_keys, systems)?; + + debug_assert_eq!(species_keys.count(), full_neighbors_list_samples.len()); + for (&[s1, s2], samples) in species_keys.iter_fixed_size().zip(full_neighbors_list_samples) { + samples_by_species_l0.insert((s1, s2), samples); + } - for (pair_id, pair) in system.pairs()?.iter().enumerate() { - if species[pair.first] == s1.i32() && species[pair.second] == s2.i32() { - builder.add(&[system_i, pair_id, pair.first, pair.second]); - if pair.first != pair.second { - // do not duplicate actual pairs between - // an atom and itself (these can exist - // when the cutoff is larger than the - // cell in periodic boundary conditions) - builder.add(&[system_i, pair_id, pair.second, pair.first]); - } - } - } - } else { - // different species for the two atoms - for (pair_id, pair) in system.pairs()?.iter().enumerate() { - if species[pair.first] == s1.i32() && species[pair.second] == s2.i32() { - builder.add(&[system_i, pair_id, pair.first, pair.second]); - } else if species[pair.second] == s1.i32() && species[pair.first] == s2.i32() { - builder.add(&[system_i, pair_id, pair.second, pair.first]); - } - } - } - } + // we only need to compute samples once for each l>0, so we compute them + // using FullNeighborList::samples, store them in a (species, species) + // => Labels map and then re-use them from this map as needed. + let mut samples_by_species: BTreeMap<_, Labels> = BTreeMap::new(); + if self.parameters.max_angular > 0 { + let full_neighbors_list_samples = FullNeighborList { + cutoff: self.parameters.cutoff, + self_pairs: false, + }.samples(&species_keys, systems)?; + + debug_assert_eq!(species_keys.count(), full_neighbors_list_samples.len()); + for (&[s1, s2], samples) in species_keys.iter_fixed_size().zip(full_neighbors_list_samples) { + samples_by_species.insert((s1, s2), samples); + } + } - let samples = builder.finish(); - entry.insert(samples).clone() - } + let mut result = Vec::new(); + for &[l, s1, s2] in keys.iter_fixed_size() { + let samples = if l.i32() == 0 { + samples_by_species_l0.get(&(s1, s2)).expect("missing samples for one species pair") + } else { + samples_by_species.get(&(s1, s2)).expect("missing samples for one species pair") }; - result.push(samples); + result.push(samples.clone()); } return Ok(result); @@ -643,12 +622,11 @@ impl CalculatorBase for SphericalExpansionByPair { for block_samples in samples { let mut builder = LabelsBuilder::new(vec!["sample", "structure", "atom"]); - for (sample_i, &[system_i, pair_id, first, second]) in block_samples.iter_fixed_size().enumerate() { + for (sample_i, &[system_i, first, second, cell_a, cell_b, cell_c]) in block_samples.iter_fixed_size().enumerate() { // self pairs do not contribute to gradients - if pair_id == -1 { + if first == second && cell_a == 0 && cell_b == 0 && cell_c == 0 { continue; } - builder.add(&[sample_i.into(), system_i, first]); builder.add(&[sample_i.into(), system_i, second]); } @@ -733,7 +711,7 @@ impl CalculatorBase for SphericalExpansionByPair { Matrix3::zero() }; - for (pair_id, pair) in system.pairs()?.iter().enumerate() { + for pair in system.pairs()? { let direction = pair.vector / pair.distance; self.compute_for_pair(pair.distance, direction, do_gradients, &mut contribution); @@ -743,6 +721,10 @@ impl CalculatorBase for SphericalExpansionByPair { pair.vector[0] * inverse_cell[0][2] + pair.vector[1] * inverse_cell[1][2] + pair.vector[2] * inverse_cell[2][2], ); + let cell_shift_a = pair.cell_shift_indices[0]; + let cell_shift_b = pair.cell_shift_indices[1]; + let cell_shift_c = pair.cell_shift_indices[2]; + let species_first = species[pair.first]; let species_second = species[pair.second]; for spherical_harmonics_l in 0..=self.parameters.max_angular { @@ -754,10 +736,12 @@ impl CalculatorBase for SphericalExpansionByPair { if let Some(block_i) = block_i { let sample = &[ - system_i.into(), - pair_id.into(), - pair.first.into(), - pair.second.into(), + LabelValue::from(system_i), + LabelValue::from(pair.first), + LabelValue::from(pair.second), + LabelValue::from(cell_shift_a), + LabelValue::from(cell_shift_b), + LabelValue::from(cell_shift_c), ]; SphericalExpansionByPair::accumulate_in_block( @@ -789,10 +773,12 @@ impl CalculatorBase for SphericalExpansionByPair { if let Some(block_i) = block_i { let sample = &[ - system_i.into(), - pair_id.into(), - pair.second.into(), - pair.first.into(), + LabelValue::from(system_i), + LabelValue::from(pair.second), + LabelValue::from(pair.first), + LabelValue::from(-cell_shift_a), + LabelValue::from(-cell_shift_b), + LabelValue::from(-cell_shift_c), ]; SphericalExpansionByPair::accumulate_in_block( @@ -943,10 +929,11 @@ mod tests { let block = block.data(); let values = block.values.as_array(); + assert_eq!(spx.samples.names(), ["structure", "center"]); for (spx_sample, expected) in spx.samples.iter().zip(spx_values.axis_iter(Axis(0))) { let mut sum = ndarray::Array::zeros(expected.raw_dim()); - for (sample_i, &[structure, _, center, _]) in block.samples.iter_fixed_size().enumerate() { + for (sample_i, &[structure, center, _, _, _, _]) in block.samples.iter_fixed_size().enumerate() { if spx_sample[0] == structure && spx_sample[1] == center { sum += &values.slice(s![sample_i, .., ..]); } diff --git a/rascaline/src/calculators/tests_utils.rs b/rascaline/src/calculators/tests_utils.rs index b1379da7f..329892e6e 100644 --- a/rascaline/src/calculators/tests_utils.rs +++ b/rascaline/src/calculators/tests_utils.rs @@ -22,7 +22,11 @@ pub fn compute_partial( ) { let full = calculator.compute(systems, Default::default()).unwrap(); - assert!(full.keys().count() < keys.count(), "selected keys should be a superset of the keys"); + assert_eq!( + full.keys().intersection(keys, None, None).unwrap().count(), + full.keys().count(), + "selected keys should be a superset of the keys, a subset will be created manually" + ); check_compute_partial_keys(&mut calculator, &mut *systems, &full, keys); assert!(keys.count() > 3, "selected keys should have more than 3 keys"); @@ -33,7 +37,15 @@ pub fn compute_partial( check_compute_partial_keys(&mut calculator, &mut *systems, &full, &subset_keys.finish()); check_compute_partial_properties(&mut calculator, &mut *systems, &full, properties); + // check we can remove all properties + let empty_properties = Labels::empty(properties.names()); + check_compute_partial_properties(&mut calculator, &mut *systems, &full, &empty_properties); + check_compute_partial_samples(&mut calculator, &mut *systems, &full, samples); + // check we can remove all samples + let empty_samples = Labels::empty(samples.names()); + check_compute_partial_samples(&mut calculator, &mut *systems, &full, &empty_samples); + check_compute_partial_both(&mut calculator, &mut *systems, &full, samples, properties); } diff --git a/rascaline/src/systems/mod.rs b/rascaline/src/systems/mod.rs index 8a60f50cd..3e173fd6b 100644 --- a/rascaline/src/systems/mod.rs +++ b/rascaline/src/systems/mod.rs @@ -19,7 +19,7 @@ pub(crate) mod test_utils; // WARNING: any change to this definition MUST be reflected in rascal_pair_t as // well #[repr(C)] -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq)] pub struct Pair { /// index of the first atom in the pair pub first: usize, @@ -27,9 +27,13 @@ pub struct Pair { pub second: usize, /// distance between the two atoms pub distance: f64, - /// vector from the first atom to the second atom, wrapped inside the unit - /// cell as required + /// vector from the first atom to the second atom, accounting for periodic + /// boundary conditions. This should be `position\[second\] - + /// position\[first\] + H * cell_shift` where `H` is the cell matrix. pub vector: Vector3D, + /// How many cell shift where applied to the `second` atom to create this + /// pair. + pub cell_shift_indices: [i32; 3], } /// A `System` deals with the storage of atoms and related information, as well diff --git a/rascaline/src/systems/neighbors.rs b/rascaline/src/systems/neighbors.rs index c50e81df1..fca5f24b8 100644 --- a/rascaline/src/systems/neighbors.rs +++ b/rascaline/src/systems/neighbors.rs @@ -14,7 +14,7 @@ const MAX_NUMBER_OF_CELLS: f64 = 1e5; /// The cell shift can be used to reconstruct the vector between two points, /// wrapped inside the unit cell. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] -pub struct CellShift([isize; 3]); +pub struct CellShift([i32; 3]); impl std::ops::Add for CellShift { type Output = CellShift; @@ -39,7 +39,7 @@ impl std::ops::Sub for CellShift { } impl std::ops::Index for CellShift { - type Output = isize; + type Output = i32; fn index(&self, index: usize) -> &Self::Output { &self.0[index] @@ -88,7 +88,7 @@ pub struct AtomData { pub struct CellList { /// How many cells do we need to look at when searching neighbors to include /// all neighbors below cutoff - n_search: [isize; 3], + n_search: [i32; 3], /// the cells themselves cells: ndarray::Array3>, /// Unit cell defining periodic boundary conditions @@ -132,9 +132,9 @@ impl CellList { // number of cells to search in each direction to make sure all possible // pairs below the cutoff are accounted for. let mut n_search = [ - f64::trunc(cutoff * n_cells[0] / distances_between_faces[0]) as isize, - f64::trunc(cutoff * n_cells[1] / distances_between_faces[1]) as isize, - f64::trunc(cutoff * n_cells[2] / distances_between_faces[2]) as isize, + f64::trunc(cutoff * n_cells[0] / distances_between_faces[0]) as i32, + f64::trunc(cutoff * n_cells[1] / distances_between_faces[1]) as i32, + f64::trunc(cutoff * n_cells[2] / distances_between_faces[2]) as i32, ]; let n_cells = [ @@ -176,9 +176,9 @@ impl CellList { // find the cell in which this atom should go let cell_index = [ - f64::floor(fractional[0] * n_cells[0] as f64) as isize, - f64::floor(fractional[1] * n_cells[1] as f64) as isize, - f64::floor(fractional[2] * n_cells[2] as f64) as isize, + f64::floor(fractional[0] * n_cells[0] as f64) as i32, + f64::floor(fractional[1] * n_cells[1] as f64) as i32, + f64::floor(fractional[2] * n_cells[2] as f64) as i32, ]; // deal with pbc by wrapping the atom inside if it was outside of the @@ -231,9 +231,9 @@ impl CellList { for delta_y in search_y.clone() { for delta_z in search_z.clone() { let cell_i = [ - cell_i_x as isize + delta_x, - cell_i_y as isize + delta_y, - cell_i_z as isize + delta_z, + cell_i_x as i32 + delta_x, + cell_i_y as i32 + delta_y, + cell_i_z as i32 + delta_z, ]; // shift vector from one cell to the other and index of @@ -284,8 +284,9 @@ impl CellList { /// Function to compute both quotient and remainder of the division of a by b. /// This function follows Python convention, making sure the remainder have the /// same sign as `b`. -fn divmod(a: isize, b: usize) -> (isize, usize) { - let b = b as isize; +fn divmod(a: i32, b: usize) -> (i32, usize) { + debug_assert!(b < (i32::MAX as usize)); + let b = b as i32; let mut quotient = a / b; let mut remainder = a % b; if remainder < 0 { @@ -296,7 +297,7 @@ fn divmod(a: isize, b: usize) -> (isize, usize) { } /// Apply the [`divmod`] function to three components at the time -fn divmod_vec(a: [isize; 3], b: [usize; 3]) -> ([isize; 3], [usize; 3]) { +fn divmod_vec(a: [i32; 3], b: [usize; 3]) -> ([i32; 3], [usize; 3]) { let (qx, rx) = divmod(a[0], b[0]); let (qy, ry) = divmod(a[1], b[1]); let (qz, rz) = divmod(a[2], b[2]); @@ -349,6 +350,7 @@ impl NeighborsList { second: pair.second, distance: distance2.sqrt(), vector: vector, + cell_shift_indices: pair.shift.0 }; pairs.push(pair); @@ -424,26 +426,27 @@ mod tests { let neighbors = NeighborsList::new(&positions, cell, 3.0); let expected = [ - Vector3D::new(0.0, -1.0, -1.0), - Vector3D::new(1.0, 0.0, -1.0), - Vector3D::new(1.0, -1.0, 0.0), - Vector3D::new(-1.0, 0.0, -1.0), - Vector3D::new(0.0, 1.0, -1.0), - Vector3D::new(-1.0, -1.0, 0.0), - Vector3D::new(1.0, 1.0, 0.0), - Vector3D::new(0.0, -1.0, 1.0), - Vector3D::new(1.0, 0.0, 1.0), - Vector3D::new(-1.0, 1.0, 0.0), - Vector3D::new(-1.0, 0.0, 1.0), - Vector3D::new(0.0, 1.0, 1.0), + (Vector3D::new(0.0, -1.0, -1.0), [-1, 0, 0]), + (Vector3D::new(1.0, 0.0, -1.0), [-1, 0, 1]), + (Vector3D::new(1.0, -1.0, 0.0), [-1, 1, 0]), + (Vector3D::new(-1.0, 0.0, -1.0), [0, -1, 0]), + (Vector3D::new(0.0, 1.0, -1.0), [0, -1, 1]), + (Vector3D::new(-1.0, -1.0, 0.0), [0, 0, -1]), + (Vector3D::new(1.0, 1.0, 0.0), [0, 0, 1]), + (Vector3D::new(0.0, -1.0, 1.0), [0, 1, -1]), + (Vector3D::new(1.0, 0.0, 1.0), [0, 1, 0]), + (Vector3D::new(-1.0, 1.0, 0.0), [1, -1, 0]), + (Vector3D::new(-1.0, 0.0, 1.0), [1, 0, -1]), + (Vector3D::new(0.0, 1.0, 1.0), [1, 0, 0]), ]; assert_eq!(neighbors.pairs.len(), 12); - for (pair, vector) in neighbors.pairs.iter().zip(&expected) { + for (pair, (vector, shifts)) in neighbors.pairs.iter().zip(&expected) { assert_eq!(pair.first, 0); assert_eq!(pair.second, 0); assert_ulps_eq!(pair.distance, 2.1213203435596424); assert_ulps_eq!(pair.vector / 1.5, vector); + assert_eq!(&pair.cell_shift_indices, shifts); } } @@ -473,6 +476,7 @@ mod tests { for (pair, expected) in neighbors.pairs.iter().zip(&expected) { assert_eq!(pair.first, expected.0); assert_eq!(pair.second, expected.1); + assert_eq!(pair.cell_shift_indices, [0, 0, 0]); assert_ulps_eq!(pair.distance, 2.0); } } diff --git a/rascaline/src/types/matrix.rs b/rascaline/src/types/matrix.rs index e5ec14c7b..066a0f747 100644 --- a/rascaline/src/types/matrix.rs +++ b/rascaline/src/types/matrix.rs @@ -497,74 +497,74 @@ impl From<[[f64; 3]; 3]> for Matrix3 { } } -#[cfg(test)] -mod tests { - use super::*; - use crate::Vector3D; - - use approx::{AbsDiffEq, RelativeEq, UlpsEq, assert_ulps_eq}; - - impl AbsDiffEq for Matrix3 { - type Epsilon = ::Epsilon; +impl approx::AbsDiffEq for Matrix3 { + type Epsilon = ::Epsilon; - fn default_epsilon() -> Self::Epsilon { - f64::default_epsilon() - } + fn default_epsilon() -> Self::Epsilon { + f64::default_epsilon() + } - fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool { - f64::abs_diff_eq(&self[0][0], &other[0][0], epsilon) && - f64::abs_diff_eq(&self[0][1], &other[0][1], epsilon) && - f64::abs_diff_eq(&self[0][2], &other[0][2], epsilon) && + fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool { + f64::abs_diff_eq(&self[0][0], &other[0][0], epsilon) && + f64::abs_diff_eq(&self[0][1], &other[0][1], epsilon) && + f64::abs_diff_eq(&self[0][2], &other[0][2], epsilon) && - f64::abs_diff_eq(&self[1][0], &other[1][0], epsilon) && - f64::abs_diff_eq(&self[1][1], &other[1][1], epsilon) && - f64::abs_diff_eq(&self[1][2], &other[1][2], epsilon) && + f64::abs_diff_eq(&self[1][0], &other[1][0], epsilon) && + f64::abs_diff_eq(&self[1][1], &other[1][1], epsilon) && + f64::abs_diff_eq(&self[1][2], &other[1][2], epsilon) && - f64::abs_diff_eq(&self[2][0], &other[2][0], epsilon) && - f64::abs_diff_eq(&self[2][1], &other[2][1], epsilon) && - f64::abs_diff_eq(&self[2][2], &other[2][2], epsilon) - } + f64::abs_diff_eq(&self[2][0], &other[2][0], epsilon) && + f64::abs_diff_eq(&self[2][1], &other[2][1], epsilon) && + f64::abs_diff_eq(&self[2][2], &other[2][2], epsilon) } +} - impl RelativeEq for Matrix3 { - fn default_max_relative() -> Self::Epsilon { - f64::default_max_relative() - } +impl approx::RelativeEq for Matrix3 { + fn default_max_relative() -> Self::Epsilon { + f64::default_max_relative() + } - fn relative_eq(&self, other: &Self, epsilon: Self::Epsilon, max_relative: Self::Epsilon) -> bool { - f64::relative_eq(&self[0][0], &other[0][0], epsilon, max_relative) && - f64::relative_eq(&self[0][1], &other[0][1], epsilon, max_relative) && - f64::relative_eq(&self[0][2], &other[0][2], epsilon, max_relative) && + fn relative_eq(&self, other: &Self, epsilon: Self::Epsilon, max_relative: Self::Epsilon) -> bool { + f64::relative_eq(&self[0][0], &other[0][0], epsilon, max_relative) && + f64::relative_eq(&self[0][1], &other[0][1], epsilon, max_relative) && + f64::relative_eq(&self[0][2], &other[0][2], epsilon, max_relative) && - f64::relative_eq(&self[1][0], &other[1][0], epsilon, max_relative) && - f64::relative_eq(&self[1][1], &other[1][1], epsilon, max_relative) && - f64::relative_eq(&self[1][2], &other[1][2], epsilon, max_relative) && + f64::relative_eq(&self[1][0], &other[1][0], epsilon, max_relative) && + f64::relative_eq(&self[1][1], &other[1][1], epsilon, max_relative) && + f64::relative_eq(&self[1][2], &other[1][2], epsilon, max_relative) && - f64::relative_eq(&self[2][0], &other[2][0], epsilon, max_relative) && - f64::relative_eq(&self[2][1], &other[2][1], epsilon, max_relative) && - f64::relative_eq(&self[2][2], &other[2][2], epsilon, max_relative) - } + f64::relative_eq(&self[2][0], &other[2][0], epsilon, max_relative) && + f64::relative_eq(&self[2][1], &other[2][1], epsilon, max_relative) && + f64::relative_eq(&self[2][2], &other[2][2], epsilon, max_relative) } +} - impl UlpsEq for Matrix3 { - fn default_max_ulps() -> u32 { - f64::default_max_ulps() - } +impl approx::UlpsEq for Matrix3 { + fn default_max_ulps() -> u32 { + f64::default_max_ulps() + } - fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool { - f64::ulps_eq(&self[0][0], &other[0][0], epsilon, max_ulps) && - f64::ulps_eq(&self[0][1], &other[0][1], epsilon, max_ulps) && - f64::ulps_eq(&self[0][2], &other[0][2], epsilon, max_ulps) && + fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool { + f64::ulps_eq(&self[0][0], &other[0][0], epsilon, max_ulps) && + f64::ulps_eq(&self[0][1], &other[0][1], epsilon, max_ulps) && + f64::ulps_eq(&self[0][2], &other[0][2], epsilon, max_ulps) && - f64::ulps_eq(&self[1][0], &other[1][0], epsilon, max_ulps) && - f64::ulps_eq(&self[1][1], &other[1][1], epsilon, max_ulps) && - f64::ulps_eq(&self[1][2], &other[1][2], epsilon, max_ulps) && + f64::ulps_eq(&self[1][0], &other[1][0], epsilon, max_ulps) && + f64::ulps_eq(&self[1][1], &other[1][1], epsilon, max_ulps) && + f64::ulps_eq(&self[1][2], &other[1][2], epsilon, max_ulps) && - f64::ulps_eq(&self[2][0], &other[2][0], epsilon, max_ulps) && - f64::ulps_eq(&self[2][1], &other[2][1], epsilon, max_ulps) && - f64::ulps_eq(&self[2][2], &other[2][2], epsilon, max_ulps) - } + f64::ulps_eq(&self[2][0], &other[2][0], epsilon, max_ulps) && + f64::ulps_eq(&self[2][1], &other[2][1], epsilon, max_ulps) && + f64::ulps_eq(&self[2][2], &other[2][2], epsilon, max_ulps) } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Vector3D; + + use approx::assert_ulps_eq; #[test] fn specials_matrix() { diff --git a/rascaline/src/types/vectors.rs b/rascaline/src/types/vectors.rs index 04eaa5acb..a470d431b 100644 --- a/rascaline/src/types/vectors.rs +++ b/rascaline/src/types/vectors.rs @@ -329,51 +329,49 @@ impl Default for Vector3D { } } -#[cfg(test)] -#[allow(clippy::op_ref, clippy::let_underscore_untyped)] -mod tests { - use crate::{Matrix3, Vector3D}; - use std::f64; +impl approx::AbsDiffEq for Vector3D { + type Epsilon = ::Epsilon; - use approx::{AbsDiffEq, RelativeEq, UlpsEq}; - - impl AbsDiffEq for Vector3D { - type Epsilon = ::Epsilon; - - fn default_epsilon() -> Self::Epsilon { - f64::default_epsilon() - } + fn default_epsilon() -> Self::Epsilon { + f64::default_epsilon() + } - fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool { - f64::abs_diff_eq(&self[0], &other[0], epsilon) && - f64::abs_diff_eq(&self[1], &other[1], epsilon) && - f64::abs_diff_eq(&self[2], &other[2], epsilon) - } + fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool { + f64::abs_diff_eq(&self[0], &other[0], epsilon) && + f64::abs_diff_eq(&self[1], &other[1], epsilon) && + f64::abs_diff_eq(&self[2], &other[2], epsilon) } +} - impl RelativeEq for Vector3D { - fn default_max_relative() -> Self::Epsilon { - f64::default_max_relative() - } +impl approx::RelativeEq for Vector3D { + fn default_max_relative() -> Self::Epsilon { + f64::default_max_relative() + } - fn relative_eq(&self, other: &Self, epsilon: Self::Epsilon, max_relative: Self::Epsilon) -> bool { - f64::relative_eq(&self[0], &other[0], epsilon, max_relative) && - f64::relative_eq(&self[1], &other[1], epsilon, max_relative) && - f64::relative_eq(&self[2], &other[2], epsilon, max_relative) - } + fn relative_eq(&self, other: &Self, epsilon: Self::Epsilon, max_relative: Self::Epsilon) -> bool { + f64::relative_eq(&self[0], &other[0], epsilon, max_relative) && + f64::relative_eq(&self[1], &other[1], epsilon, max_relative) && + f64::relative_eq(&self[2], &other[2], epsilon, max_relative) } +} - impl UlpsEq for Vector3D { - fn default_max_ulps() -> u32 { - f64::default_max_ulps() - } +impl approx::UlpsEq for Vector3D { + fn default_max_ulps() -> u32 { + f64::default_max_ulps() + } - fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool { - f64::ulps_eq(&self[0], &other[0], epsilon, max_ulps) && - f64::ulps_eq(&self[1], &other[1], epsilon, max_ulps) && - f64::ulps_eq(&self[2], &other[2], epsilon, max_ulps) - } + fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool { + f64::ulps_eq(&self[0], &other[0], epsilon, max_ulps) && + f64::ulps_eq(&self[1], &other[1], epsilon, max_ulps) && + f64::ulps_eq(&self[2], &other[2], epsilon, max_ulps) } +} + +#[cfg(test)] +#[allow(clippy::op_ref, clippy::let_underscore_untyped)] +mod tests { + use crate::{Matrix3, Vector3D}; + use std::f64; #[test] fn add() {