Skip to content

Commit

Permalink
Update all calculators to remove neighbors list workaround
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Feb 8, 2024
1 parent 8cd5bb6 commit 2845e65
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 34 deletions.
104 changes: 85 additions & 19 deletions rascaline/src/calculators/neighbor_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub struct NeighborList {
pub full_neighbor_list: bool,
/// Should individual atoms be considered their own neighbor? Setting this
/// to `true` will add "self pairs", i.e. pairs between an atom and itself,
/// with the distance 0. The `pair_id` of such pairs is set to -1.
/// with the distance 0.
pub self_pairs: bool,
}

Expand Down Expand Up @@ -423,7 +423,8 @@ impl FullNeighborList {
let cell_c = pair.cell_shift_indices[2];

if species_first == species_second {
// same species for both atoms in the pair
// same species for both atoms in the pair, add the pair
// twice in both directions.
if species[pair.first] == species_first.i32() && species[pair.second] == species_second.i32() {
builder.add(&[
LabelValue::from(system_i),
Expand All @@ -434,18 +435,14 @@ impl FullNeighborList {
LabelValue::from(cell_c),
]);

if pair.first != pair.second {
// if the pair is between two different atoms,
// also add the reversed (second -> first) pair.
builder.add(&[
LabelValue::from(system_i),
LabelValue::from(pair.second),
LabelValue::from(pair.first),
LabelValue::from(-cell_a),
LabelValue::from(-cell_b),
LabelValue::from(-cell_c),
]);
}
builder.add(&[
LabelValue::from(system_i),
LabelValue::from(pair.second),
LabelValue::from(pair.first),
LabelValue::from(-cell_a),
LabelValue::from(-cell_b),
LabelValue::from(-cell_c),
]);
}
} else {
// different species, find the right order for the pair
Expand Down Expand Up @@ -501,6 +498,11 @@ impl FullNeighborList {
let species = system.species()?;

for pair in system.pairs()? {
if pair.first == pair.second {
// self pairs should not be part of the neighbors list
assert_ne!(pair.cell_shift_indices, [0, 0, 0]);
}

let first_block_i = descriptor.keys().position(&[
species[pair.first].into(), species[pair.second].into()
]);
Expand Down Expand Up @@ -565,11 +567,6 @@ impl FullNeighborList {
}
}

if pair.first == pair.second {
// do not duplicate self pairs
continue;
}

// then the pair second -> first
if let Some(second_block_i) = second_block_i {
let mut block = descriptor.block_mut_by_id(second_block_i);
Expand Down Expand Up @@ -764,6 +761,75 @@ mod tests {
assert_relative_eq!(array, expected, max_relative=1e-6);
}

#[test]
fn periodic_neighbor_list() {
let mut calculator = Calculator::from(Box::new(NeighborList{
cutoff: 12.0,
full_neighbor_list: false,
self_pairs: false,
}) as Box<dyn CalculatorBase>);

let mut systems = test_systems(&["CH"]);

let descriptor = calculator.compute(&mut systems, Default::default()).unwrap();
assert_eq!(*descriptor.keys(), Labels::new(
["species_first_atom", "species_second_atom"],
&[[1, 1], [1, 6], [6, 6]]
));

// H-H block
let block = descriptor.block_by_id(0);
assert_eq!(block.samples(), Labels::new(
["structure", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"],
// the pairs only differ in cell shifts
&[[0, 1, 1, 0, 0, 1], [0, 1, 1, 0, 1, 0], [0, 1, 1, 1, 0, 0]]
));

let array = block.values().to_array();
let expected = &ndarray::arr3(&[
[[0.0], [0.0], [10.0]],
[[0.0], [10.0], [0.0]],
[[10.0], [0.0], [0.0]],
]).into_dyn();
assert_relative_eq!(array, expected, max_relative=1e-6);

// now a full NL
let mut calculator = Calculator::from(Box::new(NeighborList{
cutoff: 12.0,
full_neighbor_list: true,
self_pairs: false,
}) as Box<dyn CalculatorBase>);

let descriptor = calculator.compute(&mut systems, Default::default()).unwrap();
assert_eq!(*descriptor.keys(), Labels::new(
["species_first_atom", "species_second_atom"],
&[[1, 1], [1, 6], [6, 1], [6, 6]]
));

// H-H block
let block = descriptor.block_by_id(0);
assert_eq!(block.samples(), Labels::new(
["structure", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"],
// twice as many pairs
&[
[0, 1, 1, 0, 0, 1], [0, 1, 1, 0, 0, -1],
[0, 1, 1, 0, 1, 0], [0, 1, 1, 0, -1, 0],
[0, 1, 1, 1, 0, 0], [0, 1, 1, -1, 0, 0],
]
));

let array = block.values().to_array();
let expected = &ndarray::arr3(&[
[[0.0], [0.0], [10.0]],
[[0.0], [0.0], [-10.0]],
[[0.0], [10.0], [0.0]],
[[0.0], [-10.0], [0.0]],
[[10.0], [0.0], [0.0]],
[[-10.0], [0.0], [0.0]],
]).into_dyn();
assert_relative_eq!(array, expected, max_relative=1e-6);
}

#[test]
fn finite_differences_positions() {
// half neighbor list
Expand Down
8 changes: 1 addition & 7 deletions rascaline/src/calculators/soap/spherical_expansion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,6 @@ impl SphericalExpansion {
}
}

if pair.first == pair.second {
// do not compute for the reversed pair if the pair is
// between an atom and its image
continue;
}

if let Some(mapped_center) = result.centers_mapping[pair.second] {
// add the pair contribution to the atomic environnement
// corresponding to the **second** atom in the pair
Expand Down Expand Up @@ -778,7 +772,7 @@ mod tests {

fn parameters() -> SphericalExpansionParameters {
SphericalExpansionParameters {
cutoff: 3.5,
cutoff: 7.8,
max_radial: 6,
max_angular: 6,
atomic_gaussian_width: 0.3,
Expand Down
10 changes: 2 additions & 8 deletions rascaline/src/calculators/soap/spherical_expansion_pair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -755,13 +755,7 @@ impl CalculatorBase for SphericalExpansionByPair {
}
}

// also check for the block with a reversed pair, except if
// we are handling a pair between an atom and it's own
// periodic image
if pair.first == pair.second {
continue;
}

// also check for the block with a reversed pair
contribution.inverse_pair(&self.m_1_pow_l);

for spherical_harmonics_l in 0..=self.parameters.max_angular {
Expand Down Expand Up @@ -817,7 +811,7 @@ mod tests {

fn parameters() -> SphericalExpansionParameters {
SphericalExpansionParameters {
cutoff: 3.5,
cutoff: 7.3,
max_radial: 6,
max_angular: 6,
atomic_gaussian_width: 0.3,
Expand Down

0 comments on commit 2845e65

Please sign in to comment.