Skip to content

Commit

Permalink
Make sure all calculator support full removal of samples/properties w…
Browse files Browse the repository at this point in the history
…ith selection
  • Loading branch information
Luthaf committed Sep 19, 2023
1 parent 6bbcf0f commit 6970341
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 32 deletions.
89 changes: 57 additions & 32 deletions rascaline/src/calculators/neighbor_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ impl CalculatorBase for NeighborList {

fn properties(&self, keys: &Labels) -> Vec<Labels> {
let mut properties = LabelsBuilder::new(self.properties_names());
properties.add(&[LabelValue::new(0)]);
properties.add(&[LabelValue::new(1)]);
let properties = properties.finish();

return vec![properties; keys.count()];
Expand Down Expand Up @@ -321,9 +321,13 @@ impl HalfNeighborList {

if let Some(sample_i) = sample_i {
let array = block_data.values.to_array_mut();
array[[sample_i, 0, 0]] = pair_vector[0];
array[[sample_i, 1, 0]] = pair_vector[1];
array[[sample_i, 2, 0]] = pair_vector[2];
for (property_i, &[distance]) in block_data.properties.iter_fixed_size().enumerate() {
if distance == 1 {
array[[sample_i, 0, property_i]] = pair_vector[0];
array[[sample_i, 1, property_i]] = pair_vector[1];
array[[sample_i, 2, property_i]] = pair_vector[2];
}
}

if let Some(mut gradient) = block.gradient_mut("positions") {
let gradient = gradient.data_mut();
Expand All @@ -337,13 +341,17 @@ impl HalfNeighborList {

let array = gradient.values.to_array_mut();

array[[first_grad_sample_i, 0, 0, 0]] = -1.0;
array[[first_grad_sample_i, 1, 1, 0]] = -1.0;
array[[first_grad_sample_i, 2, 2, 0]] = -1.0;
for (property_i, &[distance]) in gradient.properties.iter_fixed_size().enumerate() {
if distance == 1 {
array[[first_grad_sample_i, 0, 0, property_i]] = -1.0;
array[[first_grad_sample_i, 1, 1, property_i]] = -1.0;
array[[first_grad_sample_i, 2, 2, property_i]] = -1.0;

array[[second_grad_sample_i, 0, 0, 0]] = 1.0;
array[[second_grad_sample_i, 1, 1, 0]] = 1.0;
array[[second_grad_sample_i, 2, 2, 0]] = 1.0;
array[[second_grad_sample_i, 0, 0, property_i]] = 1.0;
array[[second_grad_sample_i, 1, 1, property_i]] = 1.0;
array[[second_grad_sample_i, 2, 2, property_i]] = 1.0;
}
}
}
}
}
Expand Down Expand Up @@ -521,9 +529,14 @@ impl FullNeighborList {

if let Some(sample_i) = sample_i {
let array = block_data.values.to_array_mut();
array[[sample_i, 0, 0]] = pair.vector[0];
array[[sample_i, 1, 0]] = pair.vector[1];
array[[sample_i, 2, 0]] = pair.vector[2];

for (property_i, &[distance]) in block_data.properties.iter_fixed_size().enumerate() {
if distance == 1 {
array[[sample_i, 0, property_i]] = pair.vector[0];
array[[sample_i, 1, property_i]] = pair.vector[1];
array[[sample_i, 2, property_i]] = pair.vector[2];
}
}

if let Some(mut gradient) = block.gradient_mut("positions") {
let gradient = gradient.data_mut();
Expand All @@ -537,13 +550,17 @@ impl FullNeighborList {

let array = gradient.values.to_array_mut();

array[[first_grad_sample_i, 0, 0, 0]] = -1.0;
array[[first_grad_sample_i, 1, 1, 0]] = -1.0;
array[[first_grad_sample_i, 2, 2, 0]] = -1.0;
for (property_i, &[distance]) in gradient.properties.iter_fixed_size().enumerate() {
if distance == 1 {
array[[first_grad_sample_i, 0, 0, property_i]] = -1.0;
array[[first_grad_sample_i, 1, 1, property_i]] = -1.0;
array[[first_grad_sample_i, 2, 2, property_i]] = -1.0;

array[[second_grad_sample_i, 0, 0, 0]] = 1.0;
array[[second_grad_sample_i, 1, 1, 0]] = 1.0;
array[[second_grad_sample_i, 2, 2, 0]] = 1.0;
array[[second_grad_sample_i, 0, 0, property_i]] = 1.0;
array[[second_grad_sample_i, 1, 1, property_i]] = 1.0;
array[[second_grad_sample_i, 2, 2, property_i]] = 1.0;
}
}
}
}
}
Expand All @@ -569,9 +586,13 @@ impl FullNeighborList {

if let Some(sample_i) = sample_i {
let array = block_data.values.to_array_mut();
array[[sample_i, 0, 0]] = -pair.vector[0];
array[[sample_i, 1, 0]] = -pair.vector[1];
array[[sample_i, 2, 0]] = -pair.vector[2];
for (property_i, &[distance]) in block_data.properties.iter_fixed_size().enumerate() {
if distance == 1 {
array[[sample_i, 0, property_i]] = -pair.vector[0];
array[[sample_i, 1, property_i]] = -pair.vector[1];
array[[sample_i, 2, property_i]] = -pair.vector[2];
}
}

if let Some(mut gradient) = block.gradient_mut("positions") {
let gradient = gradient.data_mut();
Expand All @@ -585,13 +606,17 @@ impl FullNeighborList {

let array = gradient.values.to_array_mut();

array[[first_grad_sample_i, 0, 0, 0]] = -1.0;
array[[first_grad_sample_i, 1, 1, 0]] = -1.0;
array[[first_grad_sample_i, 2, 2, 0]] = -1.0;
for (property_i, &[distance]) in gradient.properties.iter_fixed_size().enumerate() {
if distance == 1 {
array[[first_grad_sample_i, 0, 0, property_i]] = -1.0;
array[[first_grad_sample_i, 1, 1, property_i]] = -1.0;
array[[first_grad_sample_i, 2, 2, property_i]] = -1.0;

array[[second_grad_sample_i, 0, 0, 0]] = 1.0;
array[[second_grad_sample_i, 1, 1, 0]] = 1.0;
array[[second_grad_sample_i, 2, 2, 0]] = 1.0;
array[[second_grad_sample_i, 0, 0, property_i]] = 1.0;
array[[second_grad_sample_i, 1, 1, property_i]] = 1.0;
array[[second_grad_sample_i, 2, 2, property_i]] = 1.0;
}
}
}
}
}
Expand Down Expand Up @@ -633,7 +658,7 @@ mod tests {

// O-H block
let block = descriptor.block_by_id(0);
assert_eq!(block.properties(), Labels::new(["distance"], &[[0]]));
assert_eq!(block.properties(), Labels::new(["distance"], &[[1]]));

assert_eq!(block.components().len(), 1);
assert_eq!(block.components()[0], Labels::new(["pair_direction"], &[[0], [1], [2]]));
Expand Down Expand Up @@ -685,7 +710,7 @@ mod tests {

// O-H block
let block = descriptor.block_by_id(0);
assert_eq!(block.properties(), Labels::new(["distance"], &[[0]]));
assert_eq!(block.properties(), Labels::new(["distance"], &[[1]]));

assert_eq!(block.components().len(), 1);
assert_eq!(block.components()[0], Labels::new(["pair_direction"], &[[0], [1], [2]]));
Expand All @@ -705,7 +730,7 @@ mod tests {

// H-O block
let block = descriptor.block_by_id(1);
assert_eq!(block.properties(), Labels::new(["distance"], &[[0]]));
assert_eq!(block.properties(), Labels::new(["distance"], &[[1]]));

assert_eq!(block.components().len(), 1);
assert_eq!(block.components()[0], Labels::new(["pair_direction"], &[[0], [1], [2]]));
Expand Down Expand Up @@ -782,7 +807,7 @@ mod tests {

let properties = Labels::new(
["distance"],
&[[0]],
&[[1]],
);

let keys = Labels::new(
Expand Down
8 changes: 8 additions & 0 deletions rascaline/src/calculators/tests_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,15 @@ pub fn compute_partial(
check_compute_partial_keys(&mut calculator, &mut *systems, &full, &subset_keys.finish());

check_compute_partial_properties(&mut calculator, &mut *systems, &full, properties);
// check we can remove all properties
let empty_properties = Labels::empty(properties.names());
check_compute_partial_properties(&mut calculator, &mut *systems, &full, &empty_properties);

check_compute_partial_samples(&mut calculator, &mut *systems, &full, samples);
// check we can remove all samples
let empty_samples = Labels::empty(samples.names());
check_compute_partial_samples(&mut calculator, &mut *systems, &full, &empty_samples);

check_compute_partial_both(&mut calculator, &mut *systems, &full, samples, properties);
}

Expand Down

0 comments on commit 6970341

Please sign in to comment.