Skip to content

Commit

Permalink
Remove more static specifications of the dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
cschwan committed Sep 17, 2024
1 parent c437e0c commit 315e285
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 10 deletions.
15 changes: 6 additions & 9 deletions pineappl/src/interpolation.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
//! Interpolation module.

use super::convert;
use super::packed_array::PackedArray;
use arrayvec::ArrayVec;
use serde::{Deserialize, Serialize};
use std::mem;
use std::ops::IndexMut;

const MAX_INTERP_ORDER_PLUS_ONE: usize = 8;

Expand Down Expand Up @@ -252,9 +252,8 @@ pub fn interpolate<const D: usize>(
interps: &[Interp],
ntuple: &[f64],
weight: f64,
array: &mut impl IndexMut<[usize; D], Output = f64>,
array: &mut PackedArray<f64>,
) -> bool {
use super::packed_array;
use itertools::Itertools;

if weight == 0.0 {
Expand Down Expand Up @@ -291,17 +290,15 @@ pub fn interpolate<const D: usize>(

let shape: ArrayVec<_, D> = interps.iter().map(|interp| interp.order() + 1).collect();

for (i, node_weights) in node_weights
for (i, node_weight) in node_weights
.into_iter()
// TODO: replace this with something else to avoid allocating memory
.multi_cartesian_product()
.map(|weights| weights.iter().product::<f64>())
.enumerate()
{
let mut index = packed_array::unravel_index::<D>(i, &shape);
for (entry, start_index) in index.iter_mut().zip(&indices) {
*entry += start_index;
}
array[index] += weight * node_weights.iter().product::<f64>();
let idx = array.sub_block_idx(&indices, i, &shape);
array[idx] += weight * node_weight;
}

true
Expand Down
146 changes: 145 additions & 1 deletion pineappl/src/packed_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,33 @@ impl<T: Copy + Default + PartialEq> PackedArray<T> {
.filter(|&(_, entry)| *entry != Default::default())
.map(|(indices, entry)| (indices, *entry))
}

/// TODO
// TODO: rewrite this method into `sub_block_iter_mut() -> impl Iterator<Item = &mut f64>`
pub fn sub_block_idx(
&self,
start_index: &[usize],
mut i: usize,
fill_shape: &[usize],
) -> usize {
use super::packed_array;

assert_eq!(start_index.len(), fill_shape.len());

let mut index = {
assert!(i < fill_shape.iter().product());
let mut indices = vec![0; start_index.len()];
for (j, d) in indices.iter_mut().zip(fill_shape).rev() {
*j = i % d;
i /= d;
}
indices
};
for (entry, start_index) in index.iter_mut().zip(start_index) {
*entry += start_index;
}
packed_array::ravel_multi_index(&index, &self.shape)
}
}

impl<T: Copy + MulAssign<T>> MulAssign<T> for PackedArray<T> {
Expand Down Expand Up @@ -132,7 +159,7 @@ impl<T: Copy + Default + PartialEq> PackedArray<T> {
}

/// Converts a `multi_index` into a flat index.
fn ravel_multi_index(multi_index: &[usize], shape: &[usize]) -> usize {
pub fn ravel_multi_index(multi_index: &[usize], shape: &[usize]) -> usize {
assert_eq!(multi_index.len(), shape.len());

multi_index
Expand Down Expand Up @@ -236,6 +263,123 @@ impl<T: Copy + Default + PartialEq> Index<usize> for PackedArray<T> {
}
}

impl<T: Clone + Copy + Default + PartialEq> IndexMut<usize> for PackedArray<T> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
// assert_eq!(index.len(), self.shape.len());

// // Panic if the index value for any dimension is greater or equal than the length of this
// // dimension.
// assert!(
// index.iter().zip(self.shape.iter()).all(|(&i, &d)| i < d),
// "index {:?} is out of bounds for array of shape {:?}",
// index,
// self.shape
// );

// // The insertion cases are:
// // 1. this array already stores an element at `index`:
// // -> we just have to update this element
// // 2. this array does not store an element at `index`:
// // a. the distance of the (raveled) `index` is `threshold_distance` away from the next
// // or previous element that is already stored:
// // -> we can merge the new element into already stored groups, potentially padding
// // with `T::default()` elements
// // b. the distance of the (raveled) `index` from the existing elements is greater than
// // `threshold_distance`:
// // -> we insert the element as a new group

// let raveled_index = ravel_multi_index(&index, &self.shape);
let raveled_index = index;

// To determine which groups the new element is close to, `point` is the index of the
// start_index of the first group after the new element. `point` is 0 if no elements before
// the new element are stored, and point is `self.start_indices.len()` if no elements after
// the new element are stored.
let point = self.start_indices.partition_point(|&i| i <= raveled_index);

// `point_entries` is the index of the first element of the next group, given in
// `self.entries`, i.e. the element at index `self.start_indices[point]`.
let point_entries = self.lengths.iter().take(point).sum::<usize>();

// Maximum distance for merging groups. If the new element is within `threshold_distance`
// of an existing group (i.e. there are `threshold_distance - 1` implicit elements
// between them), we merge the new element into the existing group. We choose 2 as the
// `threshold_distance` based on memory: in the case of `T` = `f64`, it is more economical
// to store one zero explicitly than to store the start_index and length of a new group.
let threshold_distance = 2;

// If `point > 0`, there is at least one group preceding the new element. Thus, in the
// following we determine if we can insert the new element into this group.
if point > 0 {
// start_index and length of the group before the new element, i.e. the group
// (potentially) getting the new element
let start_index = self.start_indices[point - 1];
let length = self.lengths[point - 1];

// Case 1: an element is already stored at this `index`
if raveled_index < start_index + length {
return &mut self.entries[point_entries - length + raveled_index - start_index];
// Case 2a: the new element can be merged into the preceding group
} else if raveled_index < start_index + length + threshold_distance {
let distance = raveled_index - (start_index + length) + 1;
// Merging happens by increasing the length of the group
self.lengths[point - 1] += distance;
// and inserting the necessary number of default elements.
self.entries.splice(
point_entries..point_entries,
iter::repeat(Default::default()).take(distance),
);

// If the new element is within `threshold_distance` of the *next* group, we merge
// the next group into this group.
if let Some(start_index_next) = self.start_indices.get(point) {
if raveled_index + threshold_distance >= *start_index_next {
let distance_next = start_index_next - raveled_index;

// Increase the length of this group
self.lengths[point - 1] += distance_next - 1 + self.lengths[point];
// and remove the next group. we don't have to manipulate `self.entries`,
// since the grouping of the elements is handled only by
// `self.start_indices` and `self.lengths`
self.lengths.remove(point);
self.start_indices.remove(point);
// Insert the default elements between the groups.
self.entries.splice(
point_entries..point_entries,
iter::repeat(Default::default()).take(distance_next - 1),
);
}
}

return &mut self.entries[point_entries - 1 + distance];
}
}

// Case 2a: the new element can be merged into the next group. No `self.lengths.remove` and
// `self.start_indices.remove` here, since we are not merging two groups.
if let Some(start_index_next) = self.start_indices.get(point) {
if raveled_index + threshold_distance >= *start_index_next {
let distance = start_index_next - raveled_index;

self.start_indices[point] = raveled_index;
self.lengths[point] += distance;
self.entries.splice(
point_entries..point_entries,
iter::repeat(Default::default()).take(distance),
);
return &mut self.entries[point_entries];
}
}

// Case 2b: we insert a new group of length 1
self.start_indices.insert(point, raveled_index);
self.lengths.insert(point, 1);
self.entries.insert(point_entries, Default::default());

&mut self.entries[point_entries]
}
}

impl<T: Clone + Copy + Default + PartialEq> IndexMut<&[usize]> for PackedArray<T> {
fn index_mut(&mut self, index: &[usize]) -> &mut Self::Output {
assert_eq!(index.len(), self.shape.len());
Expand Down

0 comments on commit 315e285

Please sign in to comment.