diff --git a/src/slice/array.rs b/src/slice/array.rs index 6946b93db..be08b149b 100644 --- a/src/slice/array.rs +++ b/src/slice/array.rs @@ -1,7 +1,7 @@ use crate::iter::plumbing::*; use crate::iter::*; -use super::Iter; +use super::{Iter, IterMut}; /// Parallel iterator over immutable non-overlapping chunks of a slice #[derive(Debug)] @@ -78,3 +78,95 @@ impl<'data, T: Sync + 'data, const N: usize> IndexedParallelIterator for ArrayCh self.iter.with_producer(callback) } } + +/// Parallel iterator over immutable non-overlapping chunks of a slice +#[derive(Debug)] +pub struct ArrayChunksMut<'data, T: Send, const N: usize> { + iter: IterMut<'data, [T; N]>, + rem: &'data mut [T], +} + +impl<'data, T: Send, const N: usize> ArrayChunksMut<'data, T, N> { + pub(super) fn new(slice: &'data mut [T]) -> Self { + assert_ne!(N, 0); + let len = slice.len() / N; + let (fst, snd) = slice.split_at_mut(len * N); + // SAFETY: We cast a slice of `len * N` elements into + // a slice of `len` many `N` elements chunks. + let array_slice: &'data mut [[T; N]] = unsafe { + let ptr = fst.as_mut_ptr() as *mut [T; N]; + ::std::slice::from_raw_parts_mut(ptr, len) + }; + Self { + iter: array_slice.par_iter_mut(), + rem: snd, + } + } + + /// Return the remainder of the original slice that is not going to be + /// returned by the iterator. The returned slice has at most `N-1` + /// elements. + /// + /// Note that this has to consume `self` to return the original lifetime of + /// the data, which prevents this from actually being used as a parallel + /// iterator since that also consumes. This method is provided for parity + /// with `std::iter::ArrayChunksMut`, but consider calling `remainder()` or + /// `take_remainder()` as alternatives. + pub fn into_remainder(self) -> &'data mut [T] { + self.rem + } + + /// Return the remainder of the original slice that is not going to be + /// returned by the iterator. The returned slice has at most `N-1` + /// elements. + /// + /// Consider `take_remainder()` if you need access to the data with its + /// original lifetime, rather than borrowing through `&mut self` here. + pub fn remainder(&mut self) -> &mut [T] { + self.rem + } + + /// Return the remainder of the original slice that is not going to be + /// returned by the iterator. The returned slice has at most `N-1` + /// elements. Subsequent calls will return an empty slice. + pub fn take_remainder(&mut self) -> &'data mut [T] { + std::mem::replace(&mut self.rem, &mut []) + } +} + +impl<'data, T: Send + 'data, const N: usize> ParallelIterator for ArrayChunksMut<'data, T, N> { + type Item = &'data mut [T; N]; + + fn drive_unindexed(self, consumer: C) -> C::Result + where + C: UnindexedConsumer, + { + bridge(self, consumer) + } + + fn opt_len(&self) -> Option { + Some(self.len()) + } +} + +impl<'data, T: Send + 'data, const N: usize> IndexedParallelIterator + for ArrayChunksMut<'data, T, N> +{ + fn drive(self, consumer: C) -> C::Result + where + C: Consumer, + { + bridge(self, consumer) + } + + fn len(&self) -> usize { + self.iter.len() + } + + fn with_producer(self, callback: CB) -> CB::Output + where + CB: ProducerCallback, + { + self.iter.with_producer(callback) + } +} diff --git a/src/slice/mod.rs b/src/slice/mod.rs index 1b52b7291..b64b81f15 100644 --- a/src/slice/mod.rs +++ b/src/slice/mod.rs @@ -14,7 +14,7 @@ mod rchunks; mod test; #[cfg(min_const_generics)] -pub use self::array::ArrayChunks; +pub use self::array::{ArrayChunks, ArrayChunksMut}; use self::mergesort::par_mergesort; use self::quicksort::par_quicksort; @@ -297,6 +297,26 @@ pub trait ParallelSliceMut { RChunksExactMut::new(chunk_size, self.as_parallel_slice_mut()) } + /// Returns a parallel iterator over `N`-element chunks of + /// `self` at a time. The chunks are mutable and do not overlap. + /// + /// If `N` does not divide the length of the slice, then the + /// last up to `N-1` elements will be omitted and can be + /// retrieved from the remainder function of the iterator. + /// + /// # Examples + /// + /// ``` + /// use rayon::prelude::*; + /// let mut array = [1, 2, 3, 4, 5]; + /// array.par_array_chunks_mut() + /// .for_each(|[a, _, b]| std::mem::swap(a, b)); + /// assert_eq!(array, [3, 2, 1, 4, 5]); + /// ``` + fn par_array_chunks_mut(&mut self) -> ArrayChunksMut<'_, T, N> { + ArrayChunksMut::new(self.as_parallel_slice_mut()) + } + /// Sorts the slice in parallel. /// /// This sort is stable (i.e., does not reorder equal elements) and *O*(*n* \* log(*n*)) worst-case. diff --git a/src/slice/test.rs b/src/slice/test.rs index 997e148e8..b27b74842 100644 --- a/src/slice/test.rs +++ b/src/slice/test.rs @@ -176,3 +176,17 @@ fn test_par_array_chunks_remainder() { assert_eq!(c.remainder(), &[4]); assert_eq!(c.len(), 2); } + +#[test] +fn test_par_array_chunks_mut_remainder() { + let v: &mut [i32] = &mut [0, 1, 2, 3, 4]; + let mut c = v.par_array_chunks_mut::<2>(); + assert_eq!(c.remainder(), &[4]); + assert_eq!(c.len(), 2); + assert_eq!(c.into_remainder(), &[4]); + + let mut c = v.par_array_chunks_mut::<2>(); + assert_eq!(c.take_remainder(), &[4]); + assert_eq!(c.take_remainder(), &[]); + assert_eq!(c.len(), 2); +} diff --git a/tests/debug.rs b/tests/debug.rs index e87f1e658..022f25e90 100644 --- a/tests/debug.rs +++ b/tests/debug.rs @@ -124,6 +124,7 @@ fn debug_vec() { check(v.par_array_chunks::<42>()); check(v.par_chunks_mut(42)); check(v.par_chunks_exact_mut(42)); + check(v.par_array_chunks_mut::<42>()); check(v.par_rchunks(42)); check(v.par_rchunks_exact(42)); check(v.par_rchunks_mut(42)); diff --git a/tests/producer_split_at.rs b/tests/producer_split_at.rs index 41ef92947..592720d1f 100644 --- a/tests/producer_split_at.rs +++ b/tests/producer_split_at.rs @@ -243,6 +243,40 @@ fn slice_chunks_exact_mut() { } } +#[test] +fn slice_array_chunks_mut() { + use std::convert::{TryFrom, TryInto}; + fn check_len(s: &mut [i32], v: &mut [i32]) + where + for<'a> &'a mut [i32; N]: PartialEq + TryFrom<&'a mut [i32]> + std::fmt::Debug, + { + // TODO: use https://github.com/rust-lang/rust/pull/74373 instead. + let expected: Vec<_> = v + .chunks_exact_mut(N) + .map(|s| s.try_into().ok().unwrap()) + .collect(); + map_triples(expected.len() + 1, |i, j, k| { + Split::forward(s.par_array_chunks_mut::(), i, j, k, &expected); + Split::reverse(s.par_array_chunks_mut::(), i, j, k, &expected); + }); + } + + let mut s: Vec<_> = (0..10).collect(); + let mut v: Vec<_> = s.clone(); + check_len::<1>(&mut s, &mut v); + check_len::<2>(&mut s, &mut v); + check_len::<3>(&mut s, &mut v); + check_len::<4>(&mut s, &mut v); + check_len::<5>(&mut s, &mut v); + check_len::<6>(&mut s, &mut v); + check_len::<7>(&mut s, &mut v); + check_len::<8>(&mut s, &mut v); + check_len::<9>(&mut s, &mut v); + check_len::<10>(&mut s, &mut v); + check_len::<11>(&mut s, &mut v); + check_len::<12>(&mut s, &mut v); +} + #[test] fn slice_rchunks() { let s: Vec<_> = (0..10).collect();